• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library to modify a quantized model's interface from float to integer."""
16
17from tensorflow.lite.python import schema_py_generated as schema_fb
18from tensorflow.lite.tools.optimize.python import _pywrap_modify_model_interface
19from tensorflow.lite.tools.optimize.python import modify_model_interface_constants as mmi_constants
20
21
22def _parse_type_to_int(dtype, flag):
23  """Converts a tflite type to it's integer representation.
24
25  Args:
26    dtype: tf.DType representing the inference type.
27    flag: str representing the flag name.
28
29  Returns:
30     integer, a tflite TensorType enum value.
31
32  Raises:
33    ValueError: Unsupported tflite type.
34  """
35  # Validate if dtype is supported in tflite and is a valid interface type.
36  if dtype not in mmi_constants.TFLITE_TYPES:
37    raise ValueError(
38        "Unsupported value '{0}' for {1}. Only {2} are supported.".format(
39            dtype, flag, mmi_constants.TFLITE_TYPES))
40
41  dtype_str = mmi_constants.TFLITE_TO_STR_TYPES[dtype]
42  dtype_int = schema_fb.TensorType.__dict__[dtype_str]
43
44  return dtype_int
45
46
47def modify_model_interface(input_file, output_file, input_type, output_type):
48  """Modify a quantized model's interface (input/output) from float to integer.
49
50  Args:
51    input_file: Full path name to the input tflite file.
52    output_file: Full path name to the output tflite file.
53    input_type: Final input interface type.
54    output_type: Final output interface type.
55
56  Raises:
57    RuntimeError: If the modification of the model interface was unsuccessful.
58    ValueError: If the input_type or output_type is unsupported.
59
60  """
61  # Map the interface types to integer values
62  input_type_int = _parse_type_to_int(input_type, 'input_type')
63  output_type_int = _parse_type_to_int(output_type, 'output_type')
64
65  # Invoke the function to modify the model interface
66  status = _pywrap_modify_model_interface.modify_model_interface(
67      input_file, output_file, input_type_int, output_type_int)
68
69  # Throw an exception if the return status is an error.
70  if status != 0:
71    raise RuntimeError(
72        'Error occurred when trying to modify the model input type from float '
73        'to {input_type} and output type from float to {output_type}.'.format(
74            input_type=input_type, output_type=output_type))
75