• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the logic of android model wrapper generation.
17 //
18 // At the beginning is the helper functions handling metadata and code writer.
19 //
20 // Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest
21 // files are simple. The wrapper file generation is a bit complex so we divided
22 // it into several sub-functions.
23 //
24 // The structure of the wrapper file looks like:
25 //
26 // [ imports ]
27 // [ class ]
28 //   [ inner "Outputs" class ]
29 //   [ innner "Metadata" class ]
30 //   [ APIs ] ( including ctors, public APIs and private APIs )
31 //
32 // We tried to mostly write it in a "template-generation" way. `CodeWriter` does
33 // the job as a template renderer. To avoid repeatedly setting the token values,
34 // helper functions `SetCodeWriterWith{Foo}Info` set the token values with info
35 // structures (`TensorInfo` and `ModelInfo`) - the Info structures are
36 // intermediate datastructures between Metadata (represented in Flatbuffers) and
37 // generated code.
38 
39 #include "tensorflow_lite_support/codegen/android_java_generator.h"
40 
41 #include <ctype.h>
42 
43 #include <algorithm>
44 #include <memory>
45 #include <string>
46 #include <vector>
47 
48 #include "tensorflow_lite_support/codegen/code_generator.h"
49 #include "tensorflow_lite_support/codegen/metadata_helper.h"
50 #include "tensorflow_lite_support/codegen/utils.h"
51 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
52 
53 namespace tflite {
54 namespace support {
55 namespace codegen {
56 
57 namespace {
58 
59 using details_android_java::ModelInfo;
60 using details_android_java::TensorInfo;
61 
62 // Helper class to organize the C++ code block as a generated code block.
63 // Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
64 class AsBlock {
65  public:
AsBlock(CodeWriter * code_writer,const std::string & before,bool trailing_blank_line=false)66   AsBlock(CodeWriter* code_writer, const std::string& before,
67           bool trailing_blank_line = false)
68       : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
69     code_writer_->AppendNoNewLine(before);
70     code_writer_->Append(" {");
71     code_writer_->Indent();
72   }
~AsBlock()73   ~AsBlock() {
74     code_writer_->Outdent();
75     code_writer_->Append("}");
76     if (trailing_blank_line_) {
77       code_writer_->NewLine();
78     }
79   }
80 
81  private:
82   CodeWriter* code_writer_;
83   bool trailing_blank_line_;
84 };
85 
86 // Declare the functions first, so that the functions can follow a logical
87 // order.
88 bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
89 bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
90 bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
91 bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
92 bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
93 bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
94 
GetModelVersionedName(const ModelMetadata * metadata)95 std::string GetModelVersionedName(const ModelMetadata* metadata) {
96   std::string model_name = "MyModel";
97   if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
98     model_name = metadata->name()->str();
99   }
100   std::string model_version = "unknown";
101   if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
102     model_version = metadata->version()->str();
103   }
104   return model_name + " (Version: " + model_version + ")";
105 }
106 
CreateTensorInfo(const TensorMetadata * metadata,const std::string & name,bool is_input,int index,ErrorReporter * err)107 TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
108                             const std::string& name, bool is_input, int index,
109                             ErrorReporter* err) {
110   TensorInfo tensor_info;
111   std::string tensor_identifier = is_input ? "input" : "output";
112   tensor_identifier += " " + std::to_string(index);
113   tensor_info.associated_axis_label_index = FindAssociatedFile(
114       metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
115   tensor_info.associated_value_label_index = FindAssociatedFile(
116       metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
117   if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
118                    tensor_info.associated_value_label_index >= 0)) {
119     err->Warning(
120         "Found label file on input tensor (%s). Label file for input "
121         "tensor is not supported yet. The "
122         "file will be ignored.",
123         tensor_identifier.c_str());
124   }
125   if (tensor_info.associated_axis_label_index >= 0 &&
126       tensor_info.associated_value_label_index >= 0) {
127     err->Warning(
128         "Found both axis label file and value label file for tensor (%s), "
129         "which is not supported. Only the axis label file will be used.",
130         tensor_identifier.c_str());
131   }
132   tensor_info.is_input = is_input;
133   tensor_info.name = SnakeCaseToCamelCase(name);
134   tensor_info.upper_camel_name = tensor_info.name;
135   tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
136   tensor_info.normalization_unit =
137       FindNormalizationUnit(metadata, tensor_identifier, err);
138   if (metadata->content() != nullptr &&
139       metadata->content()->content_properties() != nullptr) {
140     // Enter tensor wrapper type inferring
141     if (metadata->content()->content_properties_type() ==
142         ContentProperties_ImageProperties) {
143       if (metadata->content()
144               ->content_properties_as_ImageProperties()
145               ->color_space() == ColorSpaceType_RGB) {
146         tensor_info.content_type = "image";
147         tensor_info.wrapper_type = "TensorImage";
148         tensor_info.processor_type = "ImageProcessor";
149         return tensor_info;
150       } else {
151         err->Warning(
152             "Found Non-RGB image on tensor (%s). Codegen currently does not "
153             "support it, and regard it as a plain numeric tensor.",
154             tensor_identifier.c_str());
155       }
156     }
157   }
158   tensor_info.content_type = "tensor";
159   tensor_info.wrapper_type = "TensorBuffer";
160   tensor_info.processor_type = "TensorProcessor";
161   return tensor_info;
162 }
163 
CreateModelInfo(const ModelMetadata * metadata,const std::string & package_name,const std::string & model_class_name,const std::string & model_asset_path,ErrorReporter * err)164 ModelInfo CreateModelInfo(const ModelMetadata* metadata,
165                           const std::string& package_name,
166                           const std::string& model_class_name,
167                           const std::string& model_asset_path,
168                           ErrorReporter* err) {
169   ModelInfo model_info;
170   if (!CodeGenerator::VerifyMetadata(metadata, err)) {
171     // TODO(b/150116380): Create dummy model info.
172     err->Error("Validating metadata failed.");
173     return model_info;
174   }
175   model_info.package_name = package_name;
176   model_info.model_class_name = model_class_name;
177   model_info.model_asset_path = model_asset_path;
178   model_info.model_versioned_name = GetModelVersionedName(metadata);
179   const auto* graph = metadata->subgraph_metadata()->Get(0);
180   auto names = CodeGenerator::NameInputsAndOutputs(
181       graph->input_tensor_metadata(), graph->output_tensor_metadata());
182   std::vector<std::string> input_tensor_names = std::move(names.first);
183   std::vector<std::string> output_tensor_names = std::move(names.second);
184 
185   for (int i = 0; i < input_tensor_names.size(); i++) {
186     model_info.inputs.push_back(
187         CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
188                          input_tensor_names[i], true, i, err));
189     if (i < input_tensor_names.size() - 1) {
190       model_info.inputs_list += ", ";
191       model_info.input_type_param_list += ", ";
192     }
193     model_info.inputs_list += model_info.inputs[i].name;
194     model_info.input_type_param_list +=
195         model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name;
196   }
197   for (int i = 0; i < output_tensor_names.size(); i++) {
198     model_info.outputs.push_back(
199         CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
200                          output_tensor_names[i], false, i, err));
201     if (i < output_tensor_names.size() - 1) {
202       model_info.postprocessor_type_param_list += ", ";
203       model_info.postprocessors_list += ", ";
204     }
205     model_info.postprocessors_list +=
206         model_info.outputs[i].name + "Postprocessor";
207     model_info.postprocessor_type_param_list +=
208         model_info.outputs[i].processor_type + " " +
209         model_info.outputs[i].name + "Postprocessor";
210   }
211   return model_info;
212 }
213 
SetCodeWriterWithTensorInfo(CodeWriter * code_writer,const TensorInfo & tensor_info)214 void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
215                                  const TensorInfo& tensor_info) {
216   code_writer->SetTokenValue("NAME", tensor_info.name);
217   code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
218   code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
219   code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
220   std::string wrapper_name = tensor_info.wrapper_type;
221   wrapper_name[0] = tolower(wrapper_name[0]);
222   code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
223   code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
224   code_writer->SetTokenValue("NORMALIZATION_UNIT",
225                              std::to_string(tensor_info.normalization_unit));
226   code_writer->SetTokenValue(
227       "ASSOCIATED_AXIS_LABEL_INDEX",
228       std::to_string(tensor_info.associated_axis_label_index));
229   code_writer->SetTokenValue(
230       "ASSOCIATED_VALUE_LABEL_INDEX",
231       std::to_string(tensor_info.associated_value_label_index));
232 }
233 
SetCodeWriterWithModelInfo(CodeWriter * code_writer,const ModelInfo & model_info)234 void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
235                                 const ModelInfo& model_info) {
236   code_writer->SetTokenValue("PACKAGE", model_info.package_name);
237   code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
238   code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
239   // Extra info, half generated.
240   code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST",
241                              model_info.input_type_param_list);
242   code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list);
243   code_writer->SetTokenValue("POSTPROCESSORS_LIST",
244                              model_info.postprocessors_list);
245   code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST",
246                              model_info.postprocessor_type_param_list);
247 }
248 
249 constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
250 
ConvertPackageToPath(const std::string & package)251 std::string ConvertPackageToPath(const std::string& package) {
252   if (package == JAVA_DEFAULT_PACKAGE) {
253     return "";
254   }
255   std::string path = package;
256   std::replace(path.begin(), path.end(), '.', '/');
257   return path;
258 }
259 
IsImageUsed(const ModelInfo & model)260 bool IsImageUsed(const ModelInfo& model) {
261   for (const auto& input : model.inputs) {
262     if (input.content_type == "image") {
263       return true;
264     }
265   }
266   for (const auto& output : model.outputs) {
267     if (output.content_type == "image") {
268       return true;
269     }
270   }
271   return false;
272 }
273 
274 // The following functions generates the wrapper Java code for a model.
275 
GenerateWrapperFileContent(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)276 bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
277                                 ErrorReporter* err) {
278   code_writer->Append("// Generated by TFLite Support.");
279   code_writer->Append("package {{PACKAGE}};");
280   code_writer->NewLine();
281 
282   if (!GenerateWrapperImports(code_writer, model, err)) {
283     err->Error("Fail to generate imports for wrapper class.");
284     return false;
285   }
286   if (!GenerateWrapperClass(code_writer, model, err)) {
287     err->Error("Fail to generate wrapper class.");
288     return false;
289   }
290   code_writer->NewLine();
291   return true;
292 }
293 
GenerateWrapperImports(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)294 bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
295                             ErrorReporter* err) {
296   const std::string support_pkg = "org.tensorflow.lite.support.";
297   std::vector<std::string> imports{
298       "android.content.Context",
299       "java.io.IOException",
300       "java.nio.ByteBuffer",
301       "java.nio.FloatBuffer",
302       "java.util.Arrays",
303       "java.util.HashMap",
304       "java.util.List",
305       "java.util.Map",
306       "org.tensorflow.lite.DataType",
307       "org.tensorflow.lite.Tensor",
308       "org.tensorflow.lite.Tensor.QuantizationParams",
309       support_pkg + "common.FileUtil",
310       support_pkg + "common.TensorProcessor",
311       support_pkg + "common.ops.CastOp",
312       support_pkg + "common.ops.DequantizeOp",
313       support_pkg + "common.ops.NormalizeOp",
314       support_pkg + "common.ops.QuantizeOp",
315       support_pkg + "label.Category",
316       support_pkg + "label.TensorLabel",
317       support_pkg + "metadata.MetadataExtractor",
318       support_pkg + "metadata.schema.NormalizationOptions",
319       support_pkg + "model.Model",
320       support_pkg + "tensorbuffer.TensorBuffer",
321   };
322   if (IsImageUsed(model)) {
323     for (const auto& target :
324          {"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
325           "image.ops.ResizeOp.ResizeMethod"}) {
326       imports.push_back(support_pkg + target);
327     }
328   }
329 
330   std::sort(imports.begin(), imports.end());
331   for (const auto& target : imports) {
332     code_writer->SetTokenValue("TARGET", target);
333     code_writer->Append("import {{TARGET}};");
334   }
335   code_writer->NewLine();
336   return true;
337 }
338 
GenerateWrapperClass(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)339 bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
340                           ErrorReporter* err) {
341   code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
342                              model.model_versioned_name);
343   code_writer->Append(
344       R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
345   const auto code_block =
346       AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
347   code_writer->Append(R"(private final Metadata metadata;
348 private final Model model;
349 private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
350   for (const auto& tensor : model.inputs) {
351     SetCodeWriterWithTensorInfo(code_writer, tensor);
352     code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
353   }
354   for (const auto& tensor : model.outputs) {
355     SetCodeWriterWithTensorInfo(code_writer, tensor);
356     code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
357   }
358   code_writer->NewLine();
359   if (!GenerateWrapperOutputs(code_writer, model, err)) {
360     err->Error("Failed to generate output classes");
361     return false;
362   }
363   code_writer->NewLine();
364   if (!GenerateWrapperMetadata(code_writer, model, err)) {
365     err->Error("Failed to generate the metadata class");
366     return false;
367   }
368   code_writer->NewLine();
369   if (!GenerateWrapperAPI(code_writer, model, err)) {
370     err->Error("Failed to generate the common APIs");
371     return false;
372   }
373   return true;
374 }
375 
GenerateWrapperOutputs(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)376 bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
377                             ErrorReporter* err) {
378   code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
379   auto class_block = AsBlock(code_writer, "public static class Outputs");
380   for (const auto& tensor : model.outputs) {
381     SetCodeWriterWithTensorInfo(code_writer, tensor);
382     code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
383     if (tensor.associated_axis_label_index >= 0) {
384       code_writer->Append("private final List<String> {{NAME}}Labels;");
385     }
386     code_writer->Append(
387         "private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
388   }
389   // Getters
390   for (const auto& tensor : model.outputs) {
391     SetCodeWriterWithTensorInfo(code_writer, tensor);
392     code_writer->NewLine();
393     if (tensor.associated_axis_label_index >= 0) {
394       if (tensor.content_type == "tensor") {
395         code_writer->Append(
396             R"(public List<Category> get{{NAME_U}}AsCategoryList() {
397   return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
398 })");
399       } else {  // image
400         err->Warning(
401             "Axis label for images is not supported. The labels will "
402             "be ignored.");
403       }
404     } else {  // no label
405       code_writer->Append(
406           R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
407   return postprocess{{NAME_U}}({{NAME}});
408 })");
409     }
410   }
411   code_writer->NewLine();
412   {
413     const auto ctor_block = AsBlock(
414         code_writer,
415         "Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
416     for (const auto& tensor : model.outputs) {
417       SetCodeWriterWithTensorInfo(code_writer, tensor);
418       if (tensor.content_type == "image") {
419         code_writer->Append(
420             R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
421 {{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
422       } else {  // FEATURE, UNKNOWN
423         code_writer->Append(
424             "{{NAME}} = "
425             "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
426             "metadata.get{{NAME_U}}Type());");
427       }
428       if (tensor.associated_axis_label_index >= 0) {
429         code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
430       }
431       code_writer->Append(
432           "this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
433     }
434   }
435   code_writer->NewLine();
436   {
437     const auto get_buffer_block =
438         AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
439     code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
440     for (int i = 0; i < model.outputs.size(); i++) {
441       SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
442       code_writer->SetTokenValue("ID", std::to_string(i));
443       code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
444     }
445     code_writer->Append("return outputs;");
446   }
447   for (const auto& tensor : model.outputs) {
448     SetCodeWriterWithTensorInfo(code_writer, tensor);
449     code_writer->NewLine();
450     {
451       auto processor_block =
452           AsBlock(code_writer,
453                   "private {{WRAPPER_TYPE}} "
454                   "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
455       code_writer->Append(
456           "return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
457     }
458   }
459   return true;
460 }
461 
GenerateWrapperMetadata(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)462 bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
463                              ErrorReporter* err) {
464   code_writer->Append(
465       "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
466   const auto class_block = AsBlock(code_writer, "public static class Metadata");
467   for (const auto& tensor : model.inputs) {
468     SetCodeWriterWithTensorInfo(code_writer, tensor);
469     code_writer->Append(R"(private final int[] {{NAME}}Shape;
470 private final DataType {{NAME}}DataType;
471 private final QuantizationParams {{NAME}}QuantizationParams;)");
472     if (tensor.normalization_unit >= 0) {
473       code_writer->Append(R"(private final float[] {{NAME}}Mean;
474 private final float[] {{NAME}}Stddev;)");
475     }
476   }
477   for (const auto& tensor : model.outputs) {
478     SetCodeWriterWithTensorInfo(code_writer, tensor);
479     code_writer->Append(R"(private final int[] {{NAME}}Shape;
480 private final DataType {{NAME}}DataType;
481 private final QuantizationParams {{NAME}}QuantizationParams;)");
482     if (tensor.normalization_unit >= 0) {
483       code_writer->Append(R"(private final float[] {{NAME}}Mean;
484 private final float[] {{NAME}}Stddev;)");
485     }
486     if (tensor.associated_axis_label_index >= 0 ||
487         tensor.associated_value_label_index >= 0) {
488       code_writer->Append("private final List<String> {{NAME}}Labels;");
489     }
490   }
491   code_writer->NewLine();
492   {
493     const auto ctor_block = AsBlock(
494         code_writer,
495         "public Metadata(ByteBuffer buffer, Model model) throws IOException");
496     code_writer->Append(
497         "MetadataExtractor extractor = new MetadataExtractor(buffer);");
498     for (int i = 0; i < model.inputs.size(); i++) {
499       SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
500       code_writer->SetTokenValue("ID", std::to_string(i));
501       code_writer->Append(
502           R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
503 {{NAME}}Shape = {{NAME}}Tensor.shape();
504 {{NAME}}DataType = {{NAME}}Tensor.dataType();
505 {{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
506       if (model.inputs[i].normalization_unit >= 0) {
507         code_writer->Append(
508             R"(NormalizationOptions {{NAME}}NormalizationOptions =
509     (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
510 FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
511 {{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
512 {{NAME}}MeanBuffer.get({{NAME}}Mean);
513 FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
514 {{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
515 {{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
516       }
517     }
518     for (int i = 0; i < model.outputs.size(); i++) {
519       SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
520       code_writer->SetTokenValue("ID", std::to_string(i));
521       code_writer->Append(
522           R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
523 {{NAME}}Shape = {{NAME}}Tensor.shape();
524 {{NAME}}DataType = {{NAME}}Tensor.dataType();
525 {{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
526       if (model.outputs[i].normalization_unit >= 0) {
527         code_writer->Append(
528             R"(NormalizationOptions {{NAME}}NormalizationOptions =
529     (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
530 FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
531 {{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
532 {{NAME}}MeanBuffer.get({{NAME}}Mean);
533 FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
534 {{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
535 {{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
536       }
537       if (model.outputs[i].associated_axis_label_index >= 0) {
538         code_writer->Append(R"(String {{NAME}}LabelsFileName =
539     extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
540 {{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
541       } else if (model.outputs[i].associated_value_label_index >= 0) {
542         code_writer->Append(R"(String {{NAME}}LabelsFileName =
543     extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
544 {{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
545       }
546     }
547   }
548   for (const auto& tensor : model.inputs) {
549     SetCodeWriterWithTensorInfo(code_writer, tensor);
550     code_writer->Append(R"(
551 public int[] get{{NAME_U}}Shape() {
552   return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
553 }
554 
555 public DataType get{{NAME_U}}Type() {
556   return {{NAME}}DataType;
557 }
558 
559 public QuantizationParams get{{NAME_U}}QuantizationParams() {
560   return {{NAME}}QuantizationParams;
561 })");
562     if (tensor.normalization_unit >= 0) {
563       code_writer->Append(R"(
564 public float[] get{{NAME_U}}Mean() {
565   return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
566 }
567 
568 public float[] get{{NAME_U}}Stddev() {
569   return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
570 })");
571     }
572   }
573   for (const auto& tensor : model.outputs) {
574     SetCodeWriterWithTensorInfo(code_writer, tensor);
575     code_writer->Append(R"(
576 public int[] get{{NAME_U}}Shape() {
577   return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
578 }
579 
580 public DataType get{{NAME_U}}Type() {
581   return {{NAME}}DataType;
582 }
583 
584 public QuantizationParams get{{NAME_U}}QuantizationParams() {
585   return {{NAME}}QuantizationParams;
586 })");
587     if (tensor.normalization_unit >= 0) {
588       code_writer->Append(R"(
589 public float[] get{{NAME_U}}Mean() {
590   return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
591 }
592 
593 public float[] get{{NAME_U}}Stddev() {
594   return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
595 })");
596     }
597     if (tensor.associated_axis_label_index >= 0 ||
598         tensor.associated_value_label_index >= 0) {
599       code_writer->Append(R"(
600 public List<String> get{{NAME_U}}Labels() {
601   return {{NAME}}Labels;
602 })");
603     }
604   }
605   return true;
606 }
607 
GenerateWrapperAPI(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)608 bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
609                         ErrorReporter* err) {
610   code_writer->Append(R"(public Metadata getMetadata() {
611   return metadata;
612 }
613 )");
614   code_writer->Append(R"(/**
615  * Creates interpreter and loads associated files if needed.
616  *
617  * @throws IOException if an I/O error occurs when loading the tflite model.
618  */
619 public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException {
620   return newInstance(context, MODEL_NAME, new Model.Options.Builder().build());
621 }
622 
623 /**
624  * Creates interpreter and loads associated files if needed, but loading another model in the same
625  * input / output structure with the original one.
626  *
627  * @throws IOException if an I/O error occurs when loading the tflite model.
628  */
629 public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException {
630   return newInstance(context, modelPath, new Model.Options.Builder().build());
631 }
632 
633 /**
634  * Creates interpreter and loads associated files if needed, with running options configured.
635  *
636  * @throws IOException if an I/O error occurs when loading the tflite model.
637  */
638 public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException {
639   return newInstance(context, MODEL_NAME, runningOptions);
640 }
641 
642 /**
643  * Creates interpreter for a user-specified model.
644  *
645  * @throws IOException if an I/O error occurs when loading the tflite model.
646  */
647 public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException {
648   Model model = Model.createModel(context, modelPath, runningOptions);
649   Metadata metadata = new Metadata(model.getData(), model);
650   MyImageClassifier instance = new MyImageClassifier(model, metadata);)");
651   for (const auto& tensor : model.inputs) {
652     SetCodeWriterWithTensorInfo(code_writer, tensor);
653     code_writer->Append(
654         R"(  instance.reset{{NAME_U}}Preprocessor(
655       instance.buildDefault{{NAME_U}}Preprocessor());)");
656   }
657   for (const auto& tensor : model.outputs) {
658     SetCodeWriterWithTensorInfo(code_writer, tensor);
659     code_writer->Append(
660         R"(  instance.reset{{NAME_U}}Postprocessor(
661       instance.buildDefault{{NAME_U}}Postprocessor());)");
662   }
663   code_writer->Append(R"(  return instance;
664 }
665 )");
666 
667   // Pre, post processor setters
668   for (const auto& tensor : model.inputs) {
669     SetCodeWriterWithTensorInfo(code_writer, tensor);
670     code_writer->Append(R"(
671 public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) {
672   {{NAME}}Preprocessor = processor;
673 })");
674   }
675   for (const auto& tensor : model.outputs) {
676     SetCodeWriterWithTensorInfo(code_writer, tensor);
677     code_writer->Append(R"(
678 public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) {
679   {{NAME}}Postprocessor = processor;
680 })");
681   }
682   // Process method
683   code_writer->Append(R"(
684 /** Triggers the model. */
685 public Outputs process({{INPUT_TYPE_PARAM_LIST}}) {
686   Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}});
687   Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}});
688   model.run(inputBuffers, outputs.getBuffer());
689   return outputs;
690 }
691 
692 /** Closes the model. */
693 public void close() {
694   model.close();
695 }
696 )");
697   {
698     auto block =
699         AsBlock(code_writer,
700                 "private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)");
701     code_writer->Append(R"(this.model = model;
702 this.metadata = metadata;)");
703   }
704   for (const auto& tensor : model.inputs) {
705     code_writer->NewLine();
706     SetCodeWriterWithTensorInfo(code_writer, tensor);
707     auto block = AsBlock(
708         code_writer,
709         "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()");
710     code_writer->Append(
711         "{{PROCESSOR_TYPE}}.Builder builder = new "
712         "{{PROCESSOR_TYPE}}.Builder()");
713     if (tensor.content_type == "image") {
714       code_writer->Append(R"(    .add(new ResizeOp(
715         metadata.get{{NAME_U}}Shape()[1],
716         metadata.get{{NAME_U}}Shape()[2],
717         ResizeMethod.NEAREST_NEIGHBOR)))");
718     }
719     if (tensor.normalization_unit >= 0) {
720       code_writer->Append(
721           R"(    .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
722     }
723     code_writer->Append(
724         R"(    .add(new QuantizeOp(
725         metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
726         metadata.get{{NAME_U}}QuantizationParams().getScale()))
727     .add(new CastOp(metadata.get{{NAME_U}}Type()));
728 return builder.build();)");
729   }
730   for (const auto& tensor : model.outputs) {
731     code_writer->NewLine();
732     SetCodeWriterWithTensorInfo(code_writer, tensor);
733     auto block = AsBlock(
734         code_writer,
735         "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()");
736     code_writer->AppendNoNewLine(
737         R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder()
738     .add(new DequantizeOp(
739         metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
740         metadata.get{{NAME_U}}QuantizationParams().getScale())))");
741     if (tensor.normalization_unit >= 0) {
742       code_writer->AppendNoNewLine(R"(
743     .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
744     }
745     code_writer->Append(R"(;
746 return builder.build();)");
747   }
748   code_writer->NewLine();
749   {
750     const auto block =
751         AsBlock(code_writer,
752                 "private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})");
753     CodeWriter param_list_gen(err);
754     for (const auto& tensor : model.inputs) {
755       SetCodeWriterWithTensorInfo(code_writer, tensor);
756       code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});");
757       SetCodeWriterWithTensorInfo(&param_list_gen, tensor);
758       param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), ");
759     }
760     param_list_gen.Backspace(2);
761     code_writer->AppendNoNewLine("return new Object[] {");
762     code_writer->AppendNoNewLine(param_list_gen.ToString());
763     code_writer->Append("};");
764   }
765   return true;
766 }
767 
GenerateBuildGradleContent(CodeWriter * code_writer,const ModelInfo & model_info)768 bool GenerateBuildGradleContent(CodeWriter* code_writer,
769                                 const ModelInfo& model_info) {
770   code_writer->Append(R"(buildscript {
771     repositories {
772         google()
773         jcenter()
774     }
775     dependencies {
776         classpath 'com.android.tools.build:gradle:3.2.1'
777     }
778 }
779 
780 allprojects {
781     repositories {
782         google()
783         jcenter()
784         flatDir {
785             dirs 'libs'
786         }
787     }
788 }
789 
790 apply plugin: 'com.android.library'
791 
792 android {
793     compileSdkVersion 29
794     defaultConfig {
795         targetSdkVersion 29
796         versionCode 1
797         versionName "1.0"
798     }
799     aaptOptions {
800         noCompress "tflite"
801     }
802     compileOptions {
803         sourceCompatibility = '1.8'
804         targetCompatibility = '1.8'
805     }
806     lintOptions {
807         abortOnError false
808     }
809 }
810 
811 configurations {
812     libMetadata
813 }
814 
815 dependencies {
816     libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
817 }
818 
819 task downloadLibs(type: Sync) {
820     from configurations.libMetadata
821     into "$buildDir/libs"
822     rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
823 }
824 
825 preBuild.dependsOn downloadLibs
826 
827 dependencies {
828     compileOnly 'org.checkerframework:checker-qual:2.5.8'
829     api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
830     api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
831     api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
832     implementation 'org.apache.commons:commons-compress:1.19'
833 })");
834   return true;
835 }
836 
GenerateAndroidManifestContent(CodeWriter * code_writer,const ModelInfo & model_info)837 bool GenerateAndroidManifestContent(CodeWriter* code_writer,
838                                     const ModelInfo& model_info) {
839   code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
840 <manifest xmlns:android="http://schemas.android.com/apk/res/android"
841     package="{{PACKAGE}}">
842 </manifest>)");
843   return true;
844 }
845 
846 bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
847   code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
848   // TODO(b/158651848) Generate imports for TFLS util types like TensorImage.
849   code_writer->AppendNoNewLine(R"(
850 ```
851 import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
852 
853 // 1. Initialize the Model
854 {{MODEL_CLASS_NAME}} model = null;
855 
856 try {
857     model = {{MODEL_CLASS_NAME}}.newInstance(context);  // android.content.Context
858 } catch (IOException e) {
859     e.printStackTrace();
860 }
861 
862 if (model != null) {
863 
864     // 2. Set the inputs)");
865   for (const auto& t : model_info.inputs) {
866     SetCodeWriterWithTensorInfo(code_writer, t);
867     if (t.content_type == "image") {
868       code_writer->Append(R"(
869     // Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
870     Bitmap bitmap = ...;
871     TensorImage {{NAME}} = TensorImage.fromBitmap(bitmap);
872     // Alternatively, load the input tensor "{{NAME}}" from pixel values.
873     // Check out TensorImage documentation to load other image data structures.
874     // int[] pixelValues = ...;
875     // int[] shape = ...;
876     // TensorImage {{NAME}} = new TensorImage();
877     // {{NAME}}.load(pixelValues, shape);)");
878     } else {
879       code_writer->Append(R"(
880     // Prepare input tensor "{{NAME}}" from an array.
881     // Check out TensorBuffer documentation to load other data structures.
882     TensorBuffer {{NAME}} = ...;
883     int[] values = ...;
884     int[] shape = ...;
885     {{NAME}}.load(values, shape);)");
886     }
887   }
888   code_writer->Append(R"(
889     // 3. Run the model
890     {{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)");
891   code_writer->Append(R"(
892     // 4. Retrieve the results)");
893   for (const auto& t : model_info.outputs) {
894     SetCodeWriterWithTensorInfo(code_writer, t);
895     if (t.associated_axis_label_index >= 0) {
896       code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>");
897       code_writer->Append(
898           "    List<Category> {{NAME}} = "
899           "outputs.get{{NAME_U}}AsCategoryList();");
900     } else {
901       code_writer->Append(
902           "    {{WRAPPER_TYPE}} {{NAME}} = "
903           "outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();");
904     }
905   }
906   code_writer->Append(R"(}
907 ```)");
908   return true;
909 }
910 
911 GenerationResult::File GenerateWrapperFile(const std::string& module_root,
912                                            const ModelInfo& model_info,
913                                            ErrorReporter* err) {
914   const auto java_path = JoinPath(module_root, "src/main/java");
915   const auto package_path =
916       JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
917   const auto file_path =
918       JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
919 
920   CodeWriter code_writer(err);
921   code_writer.SetIndentString("  ");
922   SetCodeWriterWithModelInfo(&code_writer, model_info);
923 
924   if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
925     err->Error("Generating Java wrapper content failed.");
926   }
927 
928   const auto java_file = code_writer.ToString();
929   return GenerationResult::File{file_path, java_file};
930 }
931 
932 GenerationResult::File GenerateBuildGradle(const std::string& module_root,
933                                            const ModelInfo& model_info,
934                                            ErrorReporter* err) {
935   const auto file_path = JoinPath(module_root, "build.gradle");
936   CodeWriter code_writer(err);
937   SetCodeWriterWithModelInfo(&code_writer, model_info);
938   if (!GenerateBuildGradleContent(&code_writer, model_info)) {
939     err->Error("Generating build.gradle failed.");
940   }
941   const auto content = code_writer.ToString();
942   return GenerationResult::File{file_path, content};
943 }
944 
945 GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
946                                                const ModelInfo& model_info,
947                                                ErrorReporter* err) {
948   const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
949   CodeWriter code_writer(err);
950   SetCodeWriterWithModelInfo(&code_writer, model_info);
951   if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
952     err->Error("Generating AndroidManifest.xml failed.");
953   }
954   return GenerationResult::File{file_path, code_writer.ToString()};
955 }
956 
957 GenerationResult::File GenerateDoc(const std::string& module_root,
958                                    const ModelInfo& model_info,
959                                    ErrorReporter* err) {
960   std::string lower = model_info.model_class_name;
961   for (int i = 0; i < lower.length(); i++) {
962     lower[i] = std::tolower(lower[i]);
963   }
964   const auto file_path = JoinPath(module_root, lower + ".md");
965   CodeWriter code_writer(err);
966   SetCodeWriterWithModelInfo(&code_writer, model_info);
967   if (!GenerateDocContent(&code_writer, model_info)) {
968     err->Error("Generating doc failed.");
969   }
970   return GenerationResult::File{file_path, code_writer.ToString()};
971 }
972 
973 }  // namespace
974 
AndroidJavaGenerator(const std::string & module_root)975 AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
976     : CodeGenerator(), module_root_(module_root) {}
977 
Generate(const Model * model,const std::string & package_name,const std::string & model_class_name,const std::string & model_asset_path)978 GenerationResult AndroidJavaGenerator::Generate(
979     const Model* model, const std::string& package_name,
980     const std::string& model_class_name, const std::string& model_asset_path) {
981   GenerationResult result;
982   if (model == nullptr) {
983     err_.Error(
984         "Cannot read model from the buffer. Codegen will generate nothing.");
985     return result;
986   }
987   const ModelMetadata* metadata = GetMetadataFromModel(model);
988   if (metadata == nullptr) {
989     err_.Error(
990         "Cannot find TFLite Metadata in the model. Codegen will generate "
991         "nothing.");
992     return result;
993   }
994   details_android_java::ModelInfo model_info = CreateModelInfo(
995       metadata, package_name, model_class_name, model_asset_path, &err_);
996   result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
997   result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
998   result.files.push_back(
999       GenerateAndroidManifest(module_root_, model_info, &err_));
1000   result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
1001   return result;
1002 }
1003 
Generate(const char * model_storage,const std::string & package_name,const std::string & model_class_name,const std::string & model_asset_path)1004 GenerationResult AndroidJavaGenerator::Generate(
1005     const char* model_storage, const std::string& package_name,
1006     const std::string& model_class_name, const std::string& model_asset_path) {
1007   const Model* model = GetModel(model_storage);
1008   return Generate(model, package_name, model_class_name, model_asset_path);
1009 }
1010 
GetErrorMessage()1011 std::string AndroidJavaGenerator::GetErrorMessage() {
1012   return err_.GetMessage();
1013 }
1014 
1015 }  // namespace codegen
1016 }  // namespace support
1017 }  // namespace tflite
1018