To implement deep learning in smart phone, one can refer to the examples of TensorFlow Android app
here. These examples include classifier, detector, voice recognition etc. Based on these frameworks, one can develop his own app with customized deep learning model. But during the process of development, on-target debugging is often needed to ensure that Android implementation generates identical results as offline processing. In this blog, we will provide a few tips for this debugging process.
One important step of debugging is that we need to know what happens in the device. For this purpose, Android provides logging function. Taking the file of TensorFlowImageClassifier.java as an example, it read in image data as:
for (int i = 0; i < intValues.length; ++i) {
final int val = intValues[i];
floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
To know their values, we can add log function as:
Log.i(TAG, "floatValues " + floatValues[i * 3 + 0] + " , " + floatValues[i * 3 + 1]+ " , " + floatValues[i * 3 + 2]);
Then by the device with PC by USB cable, each time we open the app, logcat will show these log information. This logging process can be extended to whatever variables we are interested in. Android Studio provides a logcat window. Another way for collecting log is to use adb instruction from command window. In my PC, it is
C:\Users\Nan\AppData\Local\Android\Sdk\platform-tools>adb logcat > logcat.txt
To implement deep learning neural network, TensorFlowImageClassifier.java loads a Tensorflow model used for image classification. To test our own model, we can exchange the default TensorFlow model to our own model. But sometimes the model may not work as expected. To debug this, not only the final output of the model is needed, sometimes we also want to dump out the intermediate results. The way to dump out intermediate results is to use fetch() function in TensorFlowInferenceInterface class. In the instruction below,
batch_normalization_1/FusedBatchNorm_1 is a intermediate point in our customized TensorFlow model. fetch() function sends output of this intermediate point to the variable of
outputs_intermediate. This variable has been initialized in earlier part of the code. To identify the labels of intermediate points of TensorFlow model such as
batch_normalization_1/FusedBatchNorm_1, the tool of TensorBoard can be used.
inferenceInterface.fetch("batch_normalization_1/FusedBatchNorm_1", outputs_intermediate);
To support this Android debug process, there is often an offline processing flow which does things in parallel. This offline processing flow serves as reference for Android debugging. Below is snippet of Python code used for Tensorflow-based offline processing. Note that g_in is a variable for kernel weights and it is converted to kernel_conv2d_1, which is a numpy array. As a good first step for debugging, constant inputs can be injected for comparing outputs between Android and offline processing. Then we can inject the same image.
from tensorflow.python.platform import gfile
with tf.Session() as sess:
# load model from pb file
with gfile.FastGFile(wkdir+'/'+pb_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
g_in = tf.import_graph_def(graph_def, return_elements=['conv2d_1/kernel/read:0'])
# write to tensorboard (check tensorboard for each op names)
writer = tf.summary.FileWriter(wkdir+'/log/')
writer.add_graph(sess.graph)
writer.flush()
writer.close()
# inference by the model (op name must comes with :0 to specify the index of its output)
tensor_output = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0')
tensor_input = sess.graph.get_tensor_by_name('import/input_1:0')
tensor_intermediate = sess.graph.get_tensor_by_name('import/batch_normalization_1/FusedBatchNorm_1:0')
kernel_conv2d_1=g_in[0].eval(session=sess)
predictions = sess.run(tensor_output, {tensor_input: image_test})
predictions_0 = sess.run(tensor_intermediate, {tensor_input: image_test})