1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library to modify a quantized model's interface from float to integer.""" 16 17from tensorflow.lite.python import schema_py_generated as schema_fb 18from tensorflow.lite.tools.optimize.python import _pywrap_modify_model_interface 19from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants 20 21 22def _parse_type_to_int(dtype, flag): 23 """Converts a tflite type to it's integer representation. 24 25 Args: 26 dtype: tf.DType representing the inference type. 27 flag: str representing the flag name. 28 29 Returns: 30 integer, a tflite TensorType enum value. 31 32 Raises: 33 ValueError: Unsupported tflite type. 34 """ 35 # Validate if dtype is supported in tflite and is a valid interface type. 36 if dtype not in mmi_constants.TFLITE_TYPES: 37 raise ValueError( 38 "Unsupported value '{0}' for {1}. Only {2} are supported.".format( 39 dtype, flag, mmi_constants.TFLITE_TYPES)) 40 41 dtype_str = mmi_constants.TFLITE_TO_STR_TYPES[dtype] 42 dtype_int = schema_fb.TensorType.__dict__[dtype_str] 43 44 return dtype_int 45 46 47def modify_model_interface(input_file, output_file, input_type, output_type): 48 """Modify a quantized model's interface (input/output) from float to integer. 49 50 Args: 51 input_file: Full path name to the input tflite file. 52 output_file: Full path name to the output tflite file. 53 input_type: Final input interface type. 54 output_type: Final output interface type. 55 56 Raises: 57 RuntimeError: If the modification of the model interface was unsuccessful. 58 ValueError: If the input_type or output_type is unsupported. 59 60 """ 61 # Map the interface types to integer values 62 input_type_int = _parse_type_to_int(input_type, 'input_type') 63 output_type_int = _parse_type_to_int(output_type, 'output_type') 64 65 # Invoke the function to modify the model interface 66 status = _pywrap_modify_model_interface.modify_model_interface( 67 input_file, output_file, input_type_int, output_type_int) 68 69 # Throw an exception if the return status is an error. 70 if status != 0: 71 raise RuntimeError( 72 'Error occurred when trying to modify the model input type from float ' 73 'to {input_type} and output type from float to {output_type}.'.format( 74 input_type=input_type, output_type=output_type)) 75