1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Modify a quantized model's interface from float to integer.""" 16 17from absl import app 18from absl import flags 19 20from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants 21from tensorflow.lite.tools.optimize.python import modify_model_interface_lib as mmi_lib 22 23FLAGS = flags.FLAGS 24 25flags.DEFINE_string('input_tflite_file', None, 26 'Full path name to the input TFLite file.') 27flags.DEFINE_string('output_tflite_file', None, 28 'Full path name to the output TFLite file.') 29flags.DEFINE_enum('input_type', mmi_constants.DEFAULT_STR_TYPE, 30 mmi_constants.STR_TYPES, 31 'Modified input integer interface type.') 32flags.DEFINE_enum('output_type', mmi_constants.DEFAULT_STR_TYPE, 33 mmi_constants.STR_TYPES, 34 'Modified output integer interface type.') 35 36flags.mark_flag_as_required('input_tflite_file') 37flags.mark_flag_as_required('output_tflite_file') 38 39 40def main(_): 41 input_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.input_type] 42 output_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.output_type] 43 44 mmi_lib.modify_model_interface(FLAGS.input_file, FLAGS.output_file, 45 input_type, output_type) 46 47 print('Successfully modified the model input type from FLOAT to ' 48 '{input_type} and output type from FLOAT to {output_type}.'.format( 49 input_type=FLAGS.input_type, output_type=FLAGS.output_type)) 50 51 52if __name__ == '__main__': 53 app.run(main) 54