• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts a TFLite model to a TFLite Micro model (C++ Source)."""
16
17from absl import app
18from absl import flags
19
20from tensorflow.lite.python import util
21
22FLAGS = flags.FLAGS
23
24flags.DEFINE_string("input_tflite_file", None,
25                    "Full path name to the input TFLite model file.")
26flags.DEFINE_string(
27    "output_source_file", None,
28    "Full path name to the output TFLite Micro model (C++ Source) file).")
29flags.DEFINE_string("output_header_file", None,
30                    "Full filepath of the output C header file.")
31flags.DEFINE_string("array_variable_name", None,
32                    "Name to use for the C data array variable.")
33flags.DEFINE_integer("line_width", 80, "Width to use for formatting.")
34flags.DEFINE_string("include_guard", None,
35                    "Name to use for the C header include guard.")
36flags.DEFINE_string("include_path", None,
37                    "Optional path to include in generated source file.")
38flags.DEFINE_boolean(
39    "use_tensorflow_license", False,
40    "Whether to prefix the generated files with the TF Apache2 license.")
41
42flags.mark_flag_as_required("input_tflite_file")
43flags.mark_flag_as_required("output_source_file")
44flags.mark_flag_as_required("output_header_file")
45flags.mark_flag_as_required("array_variable_name")
46
47
48def main(_):
49  with open(FLAGS.input_tflite_file, "rb") as input_handle:
50    input_data = input_handle.read()
51
52  source, header = util.convert_bytes_to_c_source(
53      data=input_data,
54      array_name=FLAGS.array_variable_name,
55      max_line_width=FLAGS.line_width,
56      include_guard=FLAGS.include_guard,
57      include_path=FLAGS.include_path,
58      use_tensorflow_license=FLAGS.use_tensorflow_license)
59
60  with open(FLAGS.output_source_file, "w") as source_handle:
61    source_handle.write(source)
62
63  with open(FLAGS.output_header_file, "w") as header_handle:
64    header_handle.write(header)
65
66
67if __name__ == "__main__":
68  app.run(main)
69