1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts a TFLite model to a TFLite Micro model (C++ Source).""" 16 17from absl import app 18from absl import flags 19 20from tensorflow.lite.python import util 21 22FLAGS = flags.FLAGS 23 24flags.DEFINE_string("input_tflite_file", None, 25 "Full path name to the input TFLite model file.") 26flags.DEFINE_string( 27 "output_source_file", None, 28 "Full path name to the output TFLite Micro model (C++ Source) file).") 29flags.DEFINE_string("output_header_file", None, 30 "Full filepath of the output C header file.") 31flags.DEFINE_string("array_variable_name", None, 32 "Name to use for the C data array variable.") 33flags.DEFINE_integer("line_width", 80, "Width to use for formatting.") 34flags.DEFINE_string("include_guard", None, 35 "Name to use for the C header include guard.") 36flags.DEFINE_string("include_path", None, 37 "Optional path to include in generated source file.") 38flags.DEFINE_boolean( 39 "use_tensorflow_license", False, 40 "Whether to prefix the generated files with the TF Apache2 license.") 41 42flags.mark_flag_as_required("input_tflite_file") 43flags.mark_flag_as_required("output_source_file") 44flags.mark_flag_as_required("output_header_file") 45flags.mark_flag_as_required("array_variable_name") 46 47 48def main(_): 49 with open(FLAGS.input_tflite_file, "rb") as input_handle: 50 input_data = input_handle.read() 51 52 source, header = util.convert_bytes_to_c_source( 53 data=input_data, 54 array_name=FLAGS.array_variable_name, 55 max_line_width=FLAGS.line_width, 56 include_guard=FLAGS.include_guard, 57 include_path=FLAGS.include_path, 58 use_tensorflow_license=FLAGS.use_tensorflow_license) 59 60 with open(FLAGS.output_source_file, "w") as source_handle: 61 source_handle.write(source) 62 63 with open(FLAGS.output_header_file, "w") as header_handle: 64 header_handle.write(header) 65 66 67if __name__ == "__main__": 68 app.run(main) 69