• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Modify a quantized model's interface from float to integer."""
16
17from absl import app
18from absl import flags
19
20from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
21from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib
22
23FLAGS = flags.FLAGS
24
25flags.DEFINE_string('input_tflite_file', None,
26                    'Full path name to the input TFLite file.')
27flags.DEFINE_string('output_tflite_file', None,
28                    'Full path name to the output TFLite file.')
29flags.DEFINE_enum('input_type', mmi_constants.DEFAULT_STR_TYPE,
30                  mmi_constants.STR_TYPES,
31                  'Modified input integer interface type.')
32flags.DEFINE_enum('output_type', mmi_constants.DEFAULT_STR_TYPE,
33                  mmi_constants.STR_TYPES,
34                  'Modified output integer interface type.')
35
36flags.mark_flag_as_required('input_tflite_file')
37flags.mark_flag_as_required('output_tflite_file')
38
39
40def main(_):
41  input_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.input_type]
42  output_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.output_type]
43
44  mmi_lib.modify_model_interface(FLAGS.input_file, FLAGS.output_file,
45                                 input_type, output_type)
46
47  print('Successfully modified the model input type from FLOAT to '
48        '{input_type} and output type from FLOAT to {output_type}.'.format(
49            input_type=FLAGS.input_type, output_type=FLAGS.output_type))
50
51
52if __name__ == '__main__':
53  app.run(main)
54