1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow_lite_support/codegen/code_generator.h" 24 #include "tensorflow_lite_support/codegen/utils.h" 25 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" 26 #include "tensorflow/lite/schema/schema_generated.h" 27 28 namespace tflite { 29 namespace support { 30 namespace codegen { 31 32 namespace details_android_java { 33 34 /// The intermediate data structure for generating code from TensorMetadata. 35 /// Should only be used as const reference when created. 36 struct TensorInfo { 37 std::string name; 38 std::string upper_camel_name; 39 std::string content_type; 40 std::string wrapper_type; 41 std::string processor_type; 42 bool is_input; 43 /// Optional. Set to -1 if not applicable. 44 int normalization_unit; 45 /// Optional. Set to -1 if associated_axis_label is empty. 46 int associated_axis_label_index; 47 /// Optional. Set to -1 if associated_value_label is empty. 48 int associated_value_label_index; 49 }; 50 51 /// The intermediate data structure for generating code from ModelMetadata. 52 /// Should only be used as const reference when created. 53 struct ModelInfo { 54 std::string package_name; 55 std::string model_asset_path; 56 std::string model_class_name; 57 std::string model_versioned_name; 58 std::vector<TensorInfo> inputs; 59 std::vector<TensorInfo> outputs; 60 // Extra helper fields. For models with inputs "a", "b" and outputs "x", "y": 61 std::string input_type_param_list; 62 // e.g. "TensorImage a, TensorBuffer b" 63 std::string inputs_list; 64 // e.g. "a, b" 65 std::string postprocessor_type_param_list; 66 // e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor" 67 std::string postprocessors_list; 68 // e.g. "xPostprocessor, yPostprocessor" 69 }; 70 71 } // namespace details_android_java 72 73 constexpr char JAVA_EXT[] = ".java"; 74 75 /// Generates Android supporting codes and modules (in Java) based on TFLite 76 /// metadata. 77 class AndroidJavaGenerator : public CodeGenerator { 78 public: 79 /// Creates an AndroidJavaGenerator. 80 /// Args: 81 /// - module_root: The root of destination Java module. 82 explicit AndroidJavaGenerator(const std::string& module_root); 83 84 /// Generates files. Returns the file paths and contents. 85 /// Args: 86 /// - model: The TFLite model with Metadata filled. 87 /// - package_name: The name of the Java package which generated classes 88 /// belong to. 89 /// - model_class_name: A readable name of the generated wrapper class, such 90 /// as "ImageClassifier", "MobileNetV2" or "MyModel". 91 /// - model_asset_path: The relevant path to the model file in the asset. 92 // TODO(b/141225157): Automatically generate model_class_name. 93 GenerationResult Generate(const Model* model, const std::string& package_name, 94 const std::string& model_class_name, 95 const std::string& model_asset_path); 96 97 /// Generates files and returns the file paths and contents. 98 /// It's mostly identical with the previous one, but the model here is 99 /// provided as binary flatbuffer content without parsing. 100 GenerationResult Generate(const char* model_storage, 101 const std::string& package_name, 102 const std::string& model_class_name, 103 const std::string& model_asset_path); 104 105 std::string GetErrorMessage(); 106 107 private: 108 const std::string module_root_; 109 ErrorReporter err_; 110 }; 111 112 } // namespace codegen 113 } // namespace support 114 } // namespace tflite 115 116 #endif // TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ 117