Objective
We start from the android sample project for image classification. The project classifies Android’s camera images into ImageNet classes. We are going to modify the TFLite models and the Android java code to return the extracted features along with the classification probabilities.
Customize the TFLite Model
The sample Android project will automatically download and use a predefined TFLite model. If you want to try a different TFLite model, you can download one from here. For classification purposes, Quantized (smaller, faster, less accurate) and Floating point(bulky, slower, better performance) models are available. Download a model based on your requirements. In the compressed file, you will find 7 files. I downloaded the MobileNetV2 floating-point model. Note that you should change the commands based on the model you download throughout this post.
(Checkpoints)
mobilenet_v2_1.0_224.ckpt.data-00000-of-00001
mobilenet_v2_1.0_224.ckpt.meta
mobilenet_v2_1.0_224.ckpt.index
(protobuf files - contains the frozen model)
mobilenet_v2_1.0_224_frozen.pb
mobilenet_v2_1.0_224_eval.pbtxt
(TFLite model - can be attained by converting the frozen model)
mobilenet_v2_1.0_224.tflite
mobilenet_v2_1.0_224_info.txt
The compressed file contains checkpoint files so one can load the model and modify the layers. To be able to load the model, a file named checkpoint
with the following content is needed.
model_checkpoint_path: "mobilenet_v2_1.0_224.ckpt"
all_model_checkpoint_paths: "mobilenet_v2_1.0_224.ckpt"
Check out this post to learn about loading the checkpoint files and freeze them for later use. For the purpose of this post, no modification is needed to the model’s architecture. To convert a model to a TFLite model, first we need to freeze it. The frozen model can be then converted to TFLite using the tflite_convert
script. Starting from TF 1.9 the tflite_convert
is installed as a part of the TF python package. Check out this official post for more information regarding tflite_convert
. We are going to add an intermediate layer as an additional output. Adding an additional output to the TFLite model is possible even when the checkpoint files are not provided. We only need the *.pb
files to modify the outputs.
To add an intermediate layer to the model’s outputs, we need to choose a layer first. The file mobilenet_v2_1.0_224_eval.pbtxt
contains the layer’s information but it is very big and hard to read. Alternatively, you can use Tensorboard to load the pbtxt
file and get a better sense of the network’s architecture. I chose the layer MobilenetV2/Logits/AvgPool
with the shape (1, 1, 1280)
.
Here is the tflite_convert
command to export the new TFLite model with an additional output.
tflite_convert
--output_file=customized.tflite \
--graph_def_file=mobilenet_v2_1.0_224_frozen.pb \
--input_arrays=input \
--input_shapes=1,224,224,3 \
--output_arrays=MobilenetV2/Logits/AvgPool,MobilenetV2/Predictions/Reshape_1
Customizing the Android Project
First, let’s modify the model’s type spinner values in the file app/src/main/res/values/strings.xml
and remove the Quantized
option since we only converted our model with the floating-point kind. We also need to change the model’s path in the file ClassifierFloatMobileNet.java
.
@Override
protected String getModelPath() {
// you can download this file from
// see build.gradle for where to obtain this file. It should be auto
// downloaded into assets.
return "customized.tflite";
}
All other changes will happen in the file Classifier.java
. The function recognize_image
in Classifier.java
will return a list of Recognition
s. Recognition
is a class defined in the same file. To keep the modifications as little as possible, let us modify the Recognition class and add a new field to it to store the features from the AvgPool layer. We will save the intermediate features inside the first Recognition in the list of Recognitions returned by the recognize_image
function.
/** An immutable result returned by a Classifier describing what was recognized. */
public static class Recognition {
private final String id;
private final String title;
private final Float confidence;
private RectF location;
private float[] features;
public Recognition(String id, String title, Float confidence, RectF location, float[] features) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
this.features = features;
}
public Recognition(final String id, final String title, final Float confidence, final RectF location) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
public RectF getLocation() {
return new RectF(location);
}
public float[] getFeatures() {
return features;
}
public void setFeatures(float[] features) {
this.features = features;
}
public void setLocation(RectF location) {
this.location = location;
}
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
if (location != null) {
resultString += location + " ";
}
return resultString.trim();
}
}
To extract the nework’s output the sample project uses run(Object input, Object output)
of the class org.tensorflow.lite.Interpreter
to run the model inference since the model takes only one input, and provides only one output. For our network, we need to use the function runForMultipleInputsOutputs(@NonNull Object[] inputs, @NonNull Map<Integer,Object> outputs)
, so we will add a Map<Integer, Object>
to store the output buffers.
Map<Integer, Object> outputBuffers = new HashMap<>();
private final TensorBuffer outputProbabilityBuffer;
private final TensorBuffer outputFeatureBuffer;
And initialize them in the constructor.
List<int[]> outputShapes = new ArrayList<>();
List<DataType> outputTypes = new ArrayList<>();
for(int i = 0; i < tflite.getOutputTensorCount(); i++){
outputShapes.add(tflite.getOutputTensor(i).shape());
outputTypes.add(tflite.getOutputTensor(i).dataType());
}
outputFeatureBuffer = TensorBuffer.createFixedSize(outputShapes.get(0), outputTypes.get(0));
outputProbabilityBuffer = TensorBuffer.createFixedSize(outputShapes.get(1), outputTypes.get(1));
outputBuffers.put(0, outputFeatureBuffer.getBuffer().rewind());
outputBuffers.put(1, outputProbabilityBuffer.getBuffer().rewind());
In the recognize_image
function, we simply replace tflite.run
with tflite.runForMultipleInputsOutputs
.
Object[] inputs = {inputImageBuffer.getBuffer()};
tflite.runForMultipleInputsOutputs(inputs, outputBuffers);
We can extract the AvgPool
features like the following:
float[] features = outputFeatureBuffer.getFloatArray();
The recognize_image
returns the top-k classes with the highest probability identified by the network by calling the function getTopKProbability
. We will pass the AvgPool
features to this function as an input and add them to the Recognition
with the highest certainty.
/** Gets the top-k results. */
private static List<Recognition> getTopKProbability(Map<String, Float> labelProb, float[] features) {
// Find the best classifications.
PriorityQueue<Recognition> pq =
new PriorityQueue<>(
MAX_RESULTS,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (Map.Entry<String, Float> entry : labelProb.entrySet()) {
pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue(), null));
}
final ArrayList<Recognition> recognitions = new ArrayList<>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
Recognition toAdd = pq.poll();
if(i == 0)
toAdd.setFeatures(features);
recognitions.add(toAdd);
}
return recognitions;
}
Please find the complete project forked and modified from the original repo here.