1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file contains the logic of android model wrapper generation.
17 //
18 // At the beginning is the helper functions handling metadata and code writer.
19 //
20 // Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest
21 // files are simple. The wrapper file generation is a bit complex so we divided
22 // it into several sub-functions.
23 //
24 // The structure of the wrapper file looks like:
25 //
26 // [ imports ]
27 // [ class ]
28 // [ inner "Outputs" class ]
29 // [ innner "Metadata" class ]
30 // [ APIs ] ( including ctors, public APIs and private APIs )
31 //
32 // We tried to mostly write it in a "template-generation" way. `CodeWriter` does
33 // the job as a template renderer. To avoid repeatedly setting the token values,
34 // helper functions `SetCodeWriterWith{Foo}Info` set the token values with info
35 // structures (`TensorInfo` and `ModelInfo`) - the Info structures are
36 // intermediate datastructures between Metadata (represented in Flatbuffers) and
37 // generated code.
38
39 #include "tensorflow_lite_support/codegen/android_java_generator.h"
40
41 #include <ctype.h>
42
43 #include <algorithm>
44 #include <memory>
45 #include <string>
46 #include <vector>
47
48 #include "tensorflow_lite_support/codegen/code_generator.h"
49 #include "tensorflow_lite_support/codegen/metadata_helper.h"
50 #include "tensorflow_lite_support/codegen/utils.h"
51 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
52
53 namespace tflite {
54 namespace support {
55 namespace codegen {
56
57 namespace {
58
59 using details_android_java::ModelInfo;
60 using details_android_java::TensorInfo;
61
62 // Helper class to organize the C++ code block as a generated code block.
63 // Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
64 class AsBlock {
65 public:
AsBlock(CodeWriter * code_writer,const std::string & before,bool trailing_blank_line=false)66 AsBlock(CodeWriter* code_writer, const std::string& before,
67 bool trailing_blank_line = false)
68 : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
69 code_writer_->AppendNoNewLine(before);
70 code_writer_->Append(" {");
71 code_writer_->Indent();
72 }
~AsBlock()73 ~AsBlock() {
74 code_writer_->Outdent();
75 code_writer_->Append("}");
76 if (trailing_blank_line_) {
77 code_writer_->NewLine();
78 }
79 }
80
81 private:
82 CodeWriter* code_writer_;
83 bool trailing_blank_line_;
84 };
85
86 // Declare the functions first, so that the functions can follow a logical
87 // order.
88 bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
89 bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
90 bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
91 bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
92 bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
93 bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
94
GetModelVersionedName(const ModelMetadata * metadata)95 std::string GetModelVersionedName(const ModelMetadata* metadata) {
96 std::string model_name = "MyModel";
97 if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
98 model_name = metadata->name()->str();
99 }
100 std::string model_version = "unknown";
101 if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
102 model_version = metadata->version()->str();
103 }
104 return model_name + " (Version: " + model_version + ")";
105 }
106
CreateTensorInfo(const TensorMetadata * metadata,const std::string & name,bool is_input,int index,ErrorReporter * err)107 TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
108 const std::string& name, bool is_input, int index,
109 ErrorReporter* err) {
110 TensorInfo tensor_info;
111 std::string tensor_identifier = is_input ? "input" : "output";
112 tensor_identifier += " " + std::to_string(index);
113 tensor_info.associated_axis_label_index = FindAssociatedFile(
114 metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
115 tensor_info.associated_value_label_index = FindAssociatedFile(
116 metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
117 if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
118 tensor_info.associated_value_label_index >= 0)) {
119 err->Warning(
120 "Found label file on input tensor (%s). Label file for input "
121 "tensor is not supported yet. The "
122 "file will be ignored.",
123 tensor_identifier.c_str());
124 }
125 if (tensor_info.associated_axis_label_index >= 0 &&
126 tensor_info.associated_value_label_index >= 0) {
127 err->Warning(
128 "Found both axis label file and value label file for tensor (%s), "
129 "which is not supported. Only the axis label file will be used.",
130 tensor_identifier.c_str());
131 }
132 tensor_info.is_input = is_input;
133 tensor_info.name = SnakeCaseToCamelCase(name);
134 tensor_info.upper_camel_name = tensor_info.name;
135 tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
136 tensor_info.normalization_unit =
137 FindNormalizationUnit(metadata, tensor_identifier, err);
138 if (metadata->content() != nullptr &&
139 metadata->content()->content_properties() != nullptr) {
140 // Enter tensor wrapper type inferring
141 if (metadata->content()->content_properties_type() ==
142 ContentProperties_ImageProperties) {
143 if (metadata->content()
144 ->content_properties_as_ImageProperties()
145 ->color_space() == ColorSpaceType_RGB) {
146 tensor_info.content_type = "image";
147 tensor_info.wrapper_type = "TensorImage";
148 tensor_info.processor_type = "ImageProcessor";
149 return tensor_info;
150 } else {
151 err->Warning(
152 "Found Non-RGB image on tensor (%s). Codegen currently does not "
153 "support it, and regard it as a plain numeric tensor.",
154 tensor_identifier.c_str());
155 }
156 }
157 }
158 tensor_info.content_type = "tensor";
159 tensor_info.wrapper_type = "TensorBuffer";
160 tensor_info.processor_type = "TensorProcessor";
161 return tensor_info;
162 }
163
CreateModelInfo(const ModelMetadata * metadata,const std::string & package_name,const std::string & model_class_name,const std::string & model_asset_path,ErrorReporter * err)164 ModelInfo CreateModelInfo(const ModelMetadata* metadata,
165 const std::string& package_name,
166 const std::string& model_class_name,
167 const std::string& model_asset_path,
168 ErrorReporter* err) {
169 ModelInfo model_info;
170 if (!CodeGenerator::VerifyMetadata(metadata, err)) {
171 // TODO(b/150116380): Create dummy model info.
172 err->Error("Validating metadata failed.");
173 return model_info;
174 }
175 model_info.package_name = package_name;
176 model_info.model_class_name = model_class_name;
177 model_info.model_asset_path = model_asset_path;
178 model_info.model_versioned_name = GetModelVersionedName(metadata);
179 const auto* graph = metadata->subgraph_metadata()->Get(0);
180 auto names = CodeGenerator::NameInputsAndOutputs(
181 graph->input_tensor_metadata(), graph->output_tensor_metadata());
182 std::vector<std::string> input_tensor_names = std::move(names.first);
183 std::vector<std::string> output_tensor_names = std::move(names.second);
184
185 for (int i = 0; i < input_tensor_names.size(); i++) {
186 model_info.inputs.push_back(
187 CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
188 input_tensor_names[i], true, i, err));
189 if (i < input_tensor_names.size() - 1) {
190 model_info.inputs_list += ", ";
191 model_info.input_type_param_list += ", ";
192 }
193 model_info.inputs_list += model_info.inputs[i].name;
194 model_info.input_type_param_list +=
195 model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name;
196 }
197 for (int i = 0; i < output_tensor_names.size(); i++) {
198 model_info.outputs.push_back(
199 CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
200 output_tensor_names[i], false, i, err));
201 if (i < output_tensor_names.size() - 1) {
202 model_info.postprocessor_type_param_list += ", ";
203 model_info.postprocessors_list += ", ";
204 }
205 model_info.postprocessors_list +=
206 model_info.outputs[i].name + "Postprocessor";
207 model_info.postprocessor_type_param_list +=
208 model_info.outputs[i].processor_type + " " +
209 model_info.outputs[i].name + "Postprocessor";
210 }
211 return model_info;
212 }
213
SetCodeWriterWithTensorInfo(CodeWriter * code_writer,const TensorInfo & tensor_info)214 void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
215 const TensorInfo& tensor_info) {
216 code_writer->SetTokenValue("NAME", tensor_info.name);
217 code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
218 code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
219 code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
220 std::string wrapper_name = tensor_info.wrapper_type;
221 wrapper_name[0] = tolower(wrapper_name[0]);
222 code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
223 code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
224 code_writer->SetTokenValue("NORMALIZATION_UNIT",
225 std::to_string(tensor_info.normalization_unit));
226 code_writer->SetTokenValue(
227 "ASSOCIATED_AXIS_LABEL_INDEX",
228 std::to_string(tensor_info.associated_axis_label_index));
229 code_writer->SetTokenValue(
230 "ASSOCIATED_VALUE_LABEL_INDEX",
231 std::to_string(tensor_info.associated_value_label_index));
232 }
233
SetCodeWriterWithModelInfo(CodeWriter * code_writer,const ModelInfo & model_info)234 void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
235 const ModelInfo& model_info) {
236 code_writer->SetTokenValue("PACKAGE", model_info.package_name);
237 code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
238 code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
239 // Extra info, half generated.
240 code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST",
241 model_info.input_type_param_list);
242 code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list);
243 code_writer->SetTokenValue("POSTPROCESSORS_LIST",
244 model_info.postprocessors_list);
245 code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST",
246 model_info.postprocessor_type_param_list);
247 }
248
249 constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
250
ConvertPackageToPath(const std::string & package)251 std::string ConvertPackageToPath(const std::string& package) {
252 if (package == JAVA_DEFAULT_PACKAGE) {
253 return "";
254 }
255 std::string path = package;
256 std::replace(path.begin(), path.end(), '.', '/');
257 return path;
258 }
259
IsImageUsed(const ModelInfo & model)260 bool IsImageUsed(const ModelInfo& model) {
261 for (const auto& input : model.inputs) {
262 if (input.content_type == "image") {
263 return true;
264 }
265 }
266 for (const auto& output : model.outputs) {
267 if (output.content_type == "image") {
268 return true;
269 }
270 }
271 return false;
272 }
273
274 // The following functions generates the wrapper Java code for a model.
275
GenerateWrapperFileContent(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)276 bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
277 ErrorReporter* err) {
278 code_writer->Append("// Generated by TFLite Support.");
279 code_writer->Append("package {{PACKAGE}};");
280 code_writer->NewLine();
281
282 if (!GenerateWrapperImports(code_writer, model, err)) {
283 err->Error("Fail to generate imports for wrapper class.");
284 return false;
285 }
286 if (!GenerateWrapperClass(code_writer, model, err)) {
287 err->Error("Fail to generate wrapper class.");
288 return false;
289 }
290 code_writer->NewLine();
291 return true;
292 }
293
GenerateWrapperImports(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)294 bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
295 ErrorReporter* err) {
296 const std::string support_pkg = "org.tensorflow.lite.support.";
297 std::vector<std::string> imports{
298 "android.content.Context",
299 "java.io.IOException",
300 "java.nio.ByteBuffer",
301 "java.nio.FloatBuffer",
302 "java.util.Arrays",
303 "java.util.HashMap",
304 "java.util.List",
305 "java.util.Map",
306 "org.tensorflow.lite.DataType",
307 "org.tensorflow.lite.Tensor",
308 "org.tensorflow.lite.Tensor.QuantizationParams",
309 support_pkg + "common.FileUtil",
310 support_pkg + "common.TensorProcessor",
311 support_pkg + "common.ops.CastOp",
312 support_pkg + "common.ops.DequantizeOp",
313 support_pkg + "common.ops.NormalizeOp",
314 support_pkg + "common.ops.QuantizeOp",
315 support_pkg + "label.Category",
316 support_pkg + "label.TensorLabel",
317 support_pkg + "metadata.MetadataExtractor",
318 support_pkg + "metadata.schema.NormalizationOptions",
319 support_pkg + "model.Model",
320 support_pkg + "tensorbuffer.TensorBuffer",
321 };
322 if (IsImageUsed(model)) {
323 for (const auto& target :
324 {"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
325 "image.ops.ResizeOp.ResizeMethod"}) {
326 imports.push_back(support_pkg + target);
327 }
328 }
329
330 std::sort(imports.begin(), imports.end());
331 for (const auto& target : imports) {
332 code_writer->SetTokenValue("TARGET", target);
333 code_writer->Append("import {{TARGET}};");
334 }
335 code_writer->NewLine();
336 return true;
337 }
338
GenerateWrapperClass(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)339 bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
340 ErrorReporter* err) {
341 code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
342 model.model_versioned_name);
343 code_writer->Append(
344 R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
345 const auto code_block =
346 AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
347 code_writer->Append(R"(private final Metadata metadata;
348 private final Model model;
349 private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
350 for (const auto& tensor : model.inputs) {
351 SetCodeWriterWithTensorInfo(code_writer, tensor);
352 code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
353 }
354 for (const auto& tensor : model.outputs) {
355 SetCodeWriterWithTensorInfo(code_writer, tensor);
356 code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
357 }
358 code_writer->NewLine();
359 if (!GenerateWrapperOutputs(code_writer, model, err)) {
360 err->Error("Failed to generate output classes");
361 return false;
362 }
363 code_writer->NewLine();
364 if (!GenerateWrapperMetadata(code_writer, model, err)) {
365 err->Error("Failed to generate the metadata class");
366 return false;
367 }
368 code_writer->NewLine();
369 if (!GenerateWrapperAPI(code_writer, model, err)) {
370 err->Error("Failed to generate the common APIs");
371 return false;
372 }
373 return true;
374 }
375
GenerateWrapperOutputs(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)376 bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
377 ErrorReporter* err) {
378 code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
379 auto class_block = AsBlock(code_writer, "public static class Outputs");
380 for (const auto& tensor : model.outputs) {
381 SetCodeWriterWithTensorInfo(code_writer, tensor);
382 code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
383 if (tensor.associated_axis_label_index >= 0) {
384 code_writer->Append("private final List<String> {{NAME}}Labels;");
385 }
386 code_writer->Append(
387 "private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
388 }
389 // Getters
390 for (const auto& tensor : model.outputs) {
391 SetCodeWriterWithTensorInfo(code_writer, tensor);
392 code_writer->NewLine();
393 if (tensor.associated_axis_label_index >= 0) {
394 if (tensor.content_type == "tensor") {
395 code_writer->Append(
396 R"(public List<Category> get{{NAME_U}}AsCategoryList() {
397 return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
398 })");
399 } else { // image
400 err->Warning(
401 "Axis label for images is not supported. The labels will "
402 "be ignored.");
403 }
404 } else { // no label
405 code_writer->Append(
406 R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
407 return postprocess{{NAME_U}}({{NAME}});
408 })");
409 }
410 }
411 code_writer->NewLine();
412 {
413 const auto ctor_block = AsBlock(
414 code_writer,
415 "Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
416 for (const auto& tensor : model.outputs) {
417 SetCodeWriterWithTensorInfo(code_writer, tensor);
418 if (tensor.content_type == "image") {
419 code_writer->Append(
420 R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
421 {{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
422 } else { // FEATURE, UNKNOWN
423 code_writer->Append(
424 "{{NAME}} = "
425 "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
426 "metadata.get{{NAME_U}}Type());");
427 }
428 if (tensor.associated_axis_label_index >= 0) {
429 code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
430 }
431 code_writer->Append(
432 "this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
433 }
434 }
435 code_writer->NewLine();
436 {
437 const auto get_buffer_block =
438 AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
439 code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
440 for (int i = 0; i < model.outputs.size(); i++) {
441 SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
442 code_writer->SetTokenValue("ID", std::to_string(i));
443 code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
444 }
445 code_writer->Append("return outputs;");
446 }
447 for (const auto& tensor : model.outputs) {
448 SetCodeWriterWithTensorInfo(code_writer, tensor);
449 code_writer->NewLine();
450 {
451 auto processor_block =
452 AsBlock(code_writer,
453 "private {{WRAPPER_TYPE}} "
454 "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
455 code_writer->Append(
456 "return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
457 }
458 }
459 return true;
460 }
461
GenerateWrapperMetadata(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)462 bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
463 ErrorReporter* err) {
464 code_writer->Append(
465 "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
466 const auto class_block = AsBlock(code_writer, "public static class Metadata");
467 for (const auto& tensor : model.inputs) {
468 SetCodeWriterWithTensorInfo(code_writer, tensor);
469 code_writer->Append(R"(private final int[] {{NAME}}Shape;
470 private final DataType {{NAME}}DataType;
471 private final QuantizationParams {{NAME}}QuantizationParams;)");
472 if (tensor.normalization_unit >= 0) {
473 code_writer->Append(R"(private final float[] {{NAME}}Mean;
474 private final float[] {{NAME}}Stddev;)");
475 }
476 }
477 for (const auto& tensor : model.outputs) {
478 SetCodeWriterWithTensorInfo(code_writer, tensor);
479 code_writer->Append(R"(private final int[] {{NAME}}Shape;
480 private final DataType {{NAME}}DataType;
481 private final QuantizationParams {{NAME}}QuantizationParams;)");
482 if (tensor.normalization_unit >= 0) {
483 code_writer->Append(R"(private final float[] {{NAME}}Mean;
484 private final float[] {{NAME}}Stddev;)");
485 }
486 if (tensor.associated_axis_label_index >= 0 ||
487 tensor.associated_value_label_index >= 0) {
488 code_writer->Append("private final List<String> {{NAME}}Labels;");
489 }
490 }
491 code_writer->NewLine();
492 {
493 const auto ctor_block = AsBlock(
494 code_writer,
495 "public Metadata(ByteBuffer buffer, Model model) throws IOException");
496 code_writer->Append(
497 "MetadataExtractor extractor = new MetadataExtractor(buffer);");
498 for (int i = 0; i < model.inputs.size(); i++) {
499 SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
500 code_writer->SetTokenValue("ID", std::to_string(i));
501 code_writer->Append(
502 R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
503 {{NAME}}Shape = {{NAME}}Tensor.shape();
504 {{NAME}}DataType = {{NAME}}Tensor.dataType();
505 {{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
506 if (model.inputs[i].normalization_unit >= 0) {
507 code_writer->Append(
508 R"(NormalizationOptions {{NAME}}NormalizationOptions =
509 (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
510 FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
511 {{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
512 {{NAME}}MeanBuffer.get({{NAME}}Mean);
513 FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
514 {{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
515 {{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
516 }
517 }
518 for (int i = 0; i < model.outputs.size(); i++) {
519 SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
520 code_writer->SetTokenValue("ID", std::to_string(i));
521 code_writer->Append(
522 R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
523 {{NAME}}Shape = {{NAME}}Tensor.shape();
524 {{NAME}}DataType = {{NAME}}Tensor.dataType();
525 {{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
526 if (model.outputs[i].normalization_unit >= 0) {
527 code_writer->Append(
528 R"(NormalizationOptions {{NAME}}NormalizationOptions =
529 (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
530 FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
531 {{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
532 {{NAME}}MeanBuffer.get({{NAME}}Mean);
533 FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
534 {{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
535 {{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
536 }
537 if (model.outputs[i].associated_axis_label_index >= 0) {
538 code_writer->Append(R"(String {{NAME}}LabelsFileName =
539 extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
540 {{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
541 } else if (model.outputs[i].associated_value_label_index >= 0) {
542 code_writer->Append(R"(String {{NAME}}LabelsFileName =
543 extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
544 {{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
545 }
546 }
547 }
548 for (const auto& tensor : model.inputs) {
549 SetCodeWriterWithTensorInfo(code_writer, tensor);
550 code_writer->Append(R"(
551 public int[] get{{NAME_U}}Shape() {
552 return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
553 }
554
555 public DataType get{{NAME_U}}Type() {
556 return {{NAME}}DataType;
557 }
558
559 public QuantizationParams get{{NAME_U}}QuantizationParams() {
560 return {{NAME}}QuantizationParams;
561 })");
562 if (tensor.normalization_unit >= 0) {
563 code_writer->Append(R"(
564 public float[] get{{NAME_U}}Mean() {
565 return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
566 }
567
568 public float[] get{{NAME_U}}Stddev() {
569 return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
570 })");
571 }
572 }
573 for (const auto& tensor : model.outputs) {
574 SetCodeWriterWithTensorInfo(code_writer, tensor);
575 code_writer->Append(R"(
576 public int[] get{{NAME_U}}Shape() {
577 return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
578 }
579
580 public DataType get{{NAME_U}}Type() {
581 return {{NAME}}DataType;
582 }
583
584 public QuantizationParams get{{NAME_U}}QuantizationParams() {
585 return {{NAME}}QuantizationParams;
586 })");
587 if (tensor.normalization_unit >= 0) {
588 code_writer->Append(R"(
589 public float[] get{{NAME_U}}Mean() {
590 return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
591 }
592
593 public float[] get{{NAME_U}}Stddev() {
594 return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
595 })");
596 }
597 if (tensor.associated_axis_label_index >= 0 ||
598 tensor.associated_value_label_index >= 0) {
599 code_writer->Append(R"(
600 public List<String> get{{NAME_U}}Labels() {
601 return {{NAME}}Labels;
602 })");
603 }
604 }
605 return true;
606 }
607
GenerateWrapperAPI(CodeWriter * code_writer,const ModelInfo & model,ErrorReporter * err)608 bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
609 ErrorReporter* err) {
610 code_writer->Append(R"(public Metadata getMetadata() {
611 return metadata;
612 }
613 )");
614 code_writer->Append(R"(/**
615 * Creates interpreter and loads associated files if needed.
616 *
617 * @throws IOException if an I/O error occurs when loading the tflite model.
618 */
619 public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException {
620 return newInstance(context, MODEL_NAME, new Model.Options.Builder().build());
621 }
622
623 /**
624 * Creates interpreter and loads associated files if needed, but loading another model in the same
625 * input / output structure with the original one.
626 *
627 * @throws IOException if an I/O error occurs when loading the tflite model.
628 */
629 public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException {
630 return newInstance(context, modelPath, new Model.Options.Builder().build());
631 }
632
633 /**
634 * Creates interpreter and loads associated files if needed, with running options configured.
635 *
636 * @throws IOException if an I/O error occurs when loading the tflite model.
637 */
638 public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException {
639 return newInstance(context, MODEL_NAME, runningOptions);
640 }
641
642 /**
643 * Creates interpreter for a user-specified model.
644 *
645 * @throws IOException if an I/O error occurs when loading the tflite model.
646 */
647 public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException {
648 Model model = Model.createModel(context, modelPath, runningOptions);
649 Metadata metadata = new Metadata(model.getData(), model);
650 MyImageClassifier instance = new MyImageClassifier(model, metadata);)");
651 for (const auto& tensor : model.inputs) {
652 SetCodeWriterWithTensorInfo(code_writer, tensor);
653 code_writer->Append(
654 R"( instance.reset{{NAME_U}}Preprocessor(
655 instance.buildDefault{{NAME_U}}Preprocessor());)");
656 }
657 for (const auto& tensor : model.outputs) {
658 SetCodeWriterWithTensorInfo(code_writer, tensor);
659 code_writer->Append(
660 R"( instance.reset{{NAME_U}}Postprocessor(
661 instance.buildDefault{{NAME_U}}Postprocessor());)");
662 }
663 code_writer->Append(R"( return instance;
664 }
665 )");
666
667 // Pre, post processor setters
668 for (const auto& tensor : model.inputs) {
669 SetCodeWriterWithTensorInfo(code_writer, tensor);
670 code_writer->Append(R"(
671 public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) {
672 {{NAME}}Preprocessor = processor;
673 })");
674 }
675 for (const auto& tensor : model.outputs) {
676 SetCodeWriterWithTensorInfo(code_writer, tensor);
677 code_writer->Append(R"(
678 public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) {
679 {{NAME}}Postprocessor = processor;
680 })");
681 }
682 // Process method
683 code_writer->Append(R"(
684 /** Triggers the model. */
685 public Outputs process({{INPUT_TYPE_PARAM_LIST}}) {
686 Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}});
687 Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}});
688 model.run(inputBuffers, outputs.getBuffer());
689 return outputs;
690 }
691
692 /** Closes the model. */
693 public void close() {
694 model.close();
695 }
696 )");
697 {
698 auto block =
699 AsBlock(code_writer,
700 "private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)");
701 code_writer->Append(R"(this.model = model;
702 this.metadata = metadata;)");
703 }
704 for (const auto& tensor : model.inputs) {
705 code_writer->NewLine();
706 SetCodeWriterWithTensorInfo(code_writer, tensor);
707 auto block = AsBlock(
708 code_writer,
709 "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()");
710 code_writer->Append(
711 "{{PROCESSOR_TYPE}}.Builder builder = new "
712 "{{PROCESSOR_TYPE}}.Builder()");
713 if (tensor.content_type == "image") {
714 code_writer->Append(R"( .add(new ResizeOp(
715 metadata.get{{NAME_U}}Shape()[1],
716 metadata.get{{NAME_U}}Shape()[2],
717 ResizeMethod.NEAREST_NEIGHBOR)))");
718 }
719 if (tensor.normalization_unit >= 0) {
720 code_writer->Append(
721 R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
722 }
723 code_writer->Append(
724 R"( .add(new QuantizeOp(
725 metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
726 metadata.get{{NAME_U}}QuantizationParams().getScale()))
727 .add(new CastOp(metadata.get{{NAME_U}}Type()));
728 return builder.build();)");
729 }
730 for (const auto& tensor : model.outputs) {
731 code_writer->NewLine();
732 SetCodeWriterWithTensorInfo(code_writer, tensor);
733 auto block = AsBlock(
734 code_writer,
735 "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()");
736 code_writer->AppendNoNewLine(
737 R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder()
738 .add(new DequantizeOp(
739 metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
740 metadata.get{{NAME_U}}QuantizationParams().getScale())))");
741 if (tensor.normalization_unit >= 0) {
742 code_writer->AppendNoNewLine(R"(
743 .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
744 }
745 code_writer->Append(R"(;
746 return builder.build();)");
747 }
748 code_writer->NewLine();
749 {
750 const auto block =
751 AsBlock(code_writer,
752 "private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})");
753 CodeWriter param_list_gen(err);
754 for (const auto& tensor : model.inputs) {
755 SetCodeWriterWithTensorInfo(code_writer, tensor);
756 code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});");
757 SetCodeWriterWithTensorInfo(¶m_list_gen, tensor);
758 param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), ");
759 }
760 param_list_gen.Backspace(2);
761 code_writer->AppendNoNewLine("return new Object[] {");
762 code_writer->AppendNoNewLine(param_list_gen.ToString());
763 code_writer->Append("};");
764 }
765 return true;
766 }
767
GenerateBuildGradleContent(CodeWriter * code_writer,const ModelInfo & model_info)768 bool GenerateBuildGradleContent(CodeWriter* code_writer,
769 const ModelInfo& model_info) {
770 code_writer->Append(R"(buildscript {
771 repositories {
772 google()
773 jcenter()
774 }
775 dependencies {
776 classpath 'com.android.tools.build:gradle:3.2.1'
777 }
778 }
779
780 allprojects {
781 repositories {
782 google()
783 jcenter()
784 flatDir {
785 dirs 'libs'
786 }
787 }
788 }
789
790 apply plugin: 'com.android.library'
791
792 android {
793 compileSdkVersion 29
794 defaultConfig {
795 targetSdkVersion 29
796 versionCode 1
797 versionName "1.0"
798 }
799 aaptOptions {
800 noCompress "tflite"
801 }
802 compileOptions {
803 sourceCompatibility = '1.8'
804 targetCompatibility = '1.8'
805 }
806 lintOptions {
807 abortOnError false
808 }
809 }
810
811 configurations {
812 libMetadata
813 }
814
815 dependencies {
816 libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
817 }
818
819 task downloadLibs(type: Sync) {
820 from configurations.libMetadata
821 into "$buildDir/libs"
822 rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
823 }
824
825 preBuild.dependsOn downloadLibs
826
827 dependencies {
828 compileOnly 'org.checkerframework:checker-qual:2.5.8'
829 api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
830 api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
831 api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
832 implementation 'org.apache.commons:commons-compress:1.19'
833 })");
834 return true;
835 }
836
GenerateAndroidManifestContent(CodeWriter * code_writer,const ModelInfo & model_info)837 bool GenerateAndroidManifestContent(CodeWriter* code_writer,
838 const ModelInfo& model_info) {
839 code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
840 <manifest xmlns:android="http://schemas.android.com/apk/res/android"
841 package="{{PACKAGE}}">
842 </manifest>)");
843 return true;
844 }
845
846 bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
847 code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
848 // TODO(b/158651848) Generate imports for TFLS util types like TensorImage.
849 code_writer->AppendNoNewLine(R"(
850 ```
851 import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
852
853 // 1. Initialize the Model
854 {{MODEL_CLASS_NAME}} model = null;
855
856 try {
857 model = {{MODEL_CLASS_NAME}}.newInstance(context); // android.content.Context
858 } catch (IOException e) {
859 e.printStackTrace();
860 }
861
862 if (model != null) {
863
864 // 2. Set the inputs)");
865 for (const auto& t : model_info.inputs) {
866 SetCodeWriterWithTensorInfo(code_writer, t);
867 if (t.content_type == "image") {
868 code_writer->Append(R"(
869 // Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
870 Bitmap bitmap = ...;
871 TensorImage {{NAME}} = TensorImage.fromBitmap(bitmap);
872 // Alternatively, load the input tensor "{{NAME}}" from pixel values.
873 // Check out TensorImage documentation to load other image data structures.
874 // int[] pixelValues = ...;
875 // int[] shape = ...;
876 // TensorImage {{NAME}} = new TensorImage();
877 // {{NAME}}.load(pixelValues, shape);)");
878 } else {
879 code_writer->Append(R"(
880 // Prepare input tensor "{{NAME}}" from an array.
881 // Check out TensorBuffer documentation to load other data structures.
882 TensorBuffer {{NAME}} = ...;
883 int[] values = ...;
884 int[] shape = ...;
885 {{NAME}}.load(values, shape);)");
886 }
887 }
888 code_writer->Append(R"(
889 // 3. Run the model
890 {{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)");
891 code_writer->Append(R"(
892 // 4. Retrieve the results)");
893 for (const auto& t : model_info.outputs) {
894 SetCodeWriterWithTensorInfo(code_writer, t);
895 if (t.associated_axis_label_index >= 0) {
896 code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>");
897 code_writer->Append(
898 " List<Category> {{NAME}} = "
899 "outputs.get{{NAME_U}}AsCategoryList();");
900 } else {
901 code_writer->Append(
902 " {{WRAPPER_TYPE}} {{NAME}} = "
903 "outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();");
904 }
905 }
906 code_writer->Append(R"(}
907 ```)");
908 return true;
909 }
910
911 GenerationResult::File GenerateWrapperFile(const std::string& module_root,
912 const ModelInfo& model_info,
913 ErrorReporter* err) {
914 const auto java_path = JoinPath(module_root, "src/main/java");
915 const auto package_path =
916 JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
917 const auto file_path =
918 JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
919
920 CodeWriter code_writer(err);
921 code_writer.SetIndentString(" ");
922 SetCodeWriterWithModelInfo(&code_writer, model_info);
923
924 if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
925 err->Error("Generating Java wrapper content failed.");
926 }
927
928 const auto java_file = code_writer.ToString();
929 return GenerationResult::File{file_path, java_file};
930 }
931
932 GenerationResult::File GenerateBuildGradle(const std::string& module_root,
933 const ModelInfo& model_info,
934 ErrorReporter* err) {
935 const auto file_path = JoinPath(module_root, "build.gradle");
936 CodeWriter code_writer(err);
937 SetCodeWriterWithModelInfo(&code_writer, model_info);
938 if (!GenerateBuildGradleContent(&code_writer, model_info)) {
939 err->Error("Generating build.gradle failed.");
940 }
941 const auto content = code_writer.ToString();
942 return GenerationResult::File{file_path, content};
943 }
944
945 GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
946 const ModelInfo& model_info,
947 ErrorReporter* err) {
948 const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
949 CodeWriter code_writer(err);
950 SetCodeWriterWithModelInfo(&code_writer, model_info);
951 if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
952 err->Error("Generating AndroidManifest.xml failed.");
953 }
954 return GenerationResult::File{file_path, code_writer.ToString()};
955 }
956
957 GenerationResult::File GenerateDoc(const std::string& module_root,
958 const ModelInfo& model_info,
959 ErrorReporter* err) {
960 std::string lower = model_info.model_class_name;
961 for (int i = 0; i < lower.length(); i++) {
962 lower[i] = std::tolower(lower[i]);
963 }
964 const auto file_path = JoinPath(module_root, lower + ".md");
965 CodeWriter code_writer(err);
966 SetCodeWriterWithModelInfo(&code_writer, model_info);
967 if (!GenerateDocContent(&code_writer, model_info)) {
968 err->Error("Generating doc failed.");
969 }
970 return GenerationResult::File{file_path, code_writer.ToString()};
971 }
972
973 } // namespace
974
AndroidJavaGenerator(const std::string & module_root)975 AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
976 : CodeGenerator(), module_root_(module_root) {}
977
Generate(const Model * model,const std::string & package_name,const std::string & model_class_name,const std::string & model_asset_path)978 GenerationResult AndroidJavaGenerator::Generate(
979 const Model* model, const std::string& package_name,
980 const std::string& model_class_name, const std::string& model_asset_path) {
981 GenerationResult result;
982 if (model == nullptr) {
983 err_.Error(
984 "Cannot read model from the buffer. Codegen will generate nothing.");
985 return result;
986 }
987 const ModelMetadata* metadata = GetMetadataFromModel(model);
988 if (metadata == nullptr) {
989 err_.Error(
990 "Cannot find TFLite Metadata in the model. Codegen will generate "
991 "nothing.");
992 return result;
993 }
994 details_android_java::ModelInfo model_info = CreateModelInfo(
995 metadata, package_name, model_class_name, model_asset_path, &err_);
996 result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
997 result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
998 result.files.push_back(
999 GenerateAndroidManifest(module_root_, model_info, &err_));
1000 result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
1001 return result;
1002 }
1003
Generate(const char * model_storage,const std::string & package_name,const std::string & model_class_name,const std::string & model_asset_path)1004 GenerationResult AndroidJavaGenerator::Generate(
1005 const char* model_storage, const std::string& package_name,
1006 const std::string& model_class_name, const std::string& model_asset_path) {
1007 const Model* model = GetModel(model_storage);
1008 return Generate(model, package_name, model_class_name, model_asset_path);
1009 }
1010
GetErrorMessage()1011 std::string AndroidJavaGenerator::GetErrorMessage() {
1012 return err_.GetMessage();
1013 }
1014
1015 } // namespace codegen
1016 } // namespace support
1017 } // namespace tflite
1018