1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for post training quantization with calibration.""" 16import numpy as np 17 18from tensorflow.lite.python.convert_phase import Component 19from tensorflow.lite.python.convert_phase import convert_phase 20from tensorflow.lite.python.convert_phase import SubComponent 21from tensorflow.lite.python.interpreter import Interpreter 22from tensorflow.python.framework import dtypes 23from tensorflow.python.util.lazy_loader import LazyLoader 24 25# Lazy load since some of the performance benchmark skylark rules 26# break dependencies. Must use double quotes to match code internal rewrite 27# rule. 28_calibration_wrapper = LazyLoader( 29 "_calibration_wrapper", globals(), 30 "tensorflow.lite.python.optimize." 31 "_pywrap_tensorflow_lite_calibration_wrapper") 32 33 34def add_intermediate_tensors(model_content): 35 """Adds intermediate tensors to fused op if needed.""" 36 return _calibration_wrapper.AddIntermediateTensors(model_content) 37 38 39class Calibrator: 40 """Calibrates a floating point model and then quantizes it. 41 42 This is an internal class, not a public interface. 43 """ 44 45 def __init__(self, 46 model_content, 47 custom_op_registerers_by_name=None, 48 custom_op_registerers_by_func=None): 49 """Constructor. 50 51 Args: 52 model_content: Content of a TF-Lite Flatbuffer file. 53 custom_op_registerers_by_name: List of str (symbol names) that take a 54 pointer to a MutableOpResolver and register custom ops. 55 custom_op_registerers_by_func: List of functions that take a pointer to a 56 MutableOpResolver and register custom ops. 57 58 Raises: 59 ValueError: If the calibrator was unable to open the model. 60 """ 61 if not model_content: 62 raise ValueError("`model_content` must be specified.") 63 if custom_op_registerers_by_name is None: 64 custom_op_registerers_by_name = [] 65 if custom_op_registerers_by_func is None: 66 custom_op_registerers_by_func = [] 67 try: 68 self._calibrator = ( 69 _calibration_wrapper.CalibrationWrapper( 70 model_content, custom_op_registerers_by_name, 71 custom_op_registerers_by_func)) 72 self._model_content = model_content 73 except Exception as e: 74 raise ValueError("Failed to parse the model: %s." % e) 75 if not self._calibrator: 76 raise ValueError("Failed to parse the model.") 77 self._interpreter = None 78 79 def _create_input_array_from_dict(self, signature_key, inputs): 80 input_array = [] 81 signature_runner = self._interpreter.get_signature_runner(signature_key) 82 input_details = sorted( 83 signature_runner.get_input_details().items(), 84 key=lambda item: item[1]["index"]) 85 for input_name, _ in input_details: 86 input_array.append(inputs[input_name]) 87 return input_array 88 89 def _feed_tensors(self, dataset_gen, resize_input): 90 """Feed tensors to the calibrator.""" 91 initialized = {} 92 93 for sample in dataset_gen(): 94 if isinstance(sample, tuple): 95 if not isinstance(sample[1], dict): 96 raise ValueError("You need to provide either a dictionary with input " 97 "names and values in the second arugment in the " 98 "tuple") 99 # Convert signature based inputs to the tensor index based data. 100 if self._interpreter is None: 101 self._interpreter = Interpreter(model_content=self._model_content) 102 signature_key = sample[0] 103 input_array = self._create_input_array_from_dict( 104 signature_key, sample[1]) 105 elif isinstance(sample, dict): 106 # Convert signature based inputs to the tensor index based data. 107 if self._interpreter is None: 108 self._interpreter = Interpreter(model_content=self._model_content) 109 signature_key = None 110 input_array = self._create_input_array_from_dict(None, sample) 111 elif isinstance(sample, list): 112 signature_key = None 113 input_array = sample 114 else: 115 raise ValueError("You need to provide either a dictionary with input " 116 "names and values, a tuple with signature key and a " 117 "dictionary with input names and values, or an array " 118 "with input values in the order of input tensors of " 119 "the graph in the representative_dataset function. " 120 "Unsupported value from dataset: {}.".format(sample)) 121 122 if signature_key not in initialized: 123 initialized[signature_key] = True 124 if resize_input: 125 if signature_key is not None: 126 self._calibrator.Prepare([list(s.shape) for s in input_array], 127 signature_key) 128 else: 129 self._calibrator.Prepare([list(s.shape) for s in input_array]) 130 else: 131 if signature_key is not None: 132 self._calibrator.Prepare(signature_key) 133 else: 134 self._calibrator.Prepare() 135 if signature_key is not None: 136 self._calibrator.FeedTensor(input_array, signature_key) 137 else: 138 self._calibrator.FeedTensor(input_array) 139 140 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, 141 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER) 142 def calibrate_and_quantize(self, 143 dataset_gen, 144 input_type, 145 output_type, 146 allow_float, 147 activations_type=dtypes.int8, 148 bias_type=dtypes.int32, 149 resize_input=True, 150 disable_per_channel=False): 151 """Calibrates the model with specified generator and then quantizes it. 152 153 The input shapes of the calibrator are resized with the calibration data if 154 `resize_input` is set. 155 156 Returns: 157 A quantized model. 158 159 Args: 160 dataset_gen: A generator that generates calibration samples. 161 input_type: A tf.dtype representing the desired real-value input type. 162 output_type: A tf.dtype representing the desired real-value output type. 163 allow_float: A boolean. False if the resulting model cannot perform float 164 computation, useful when targeting an integer-only backend. 165 If False, an error will be thrown if an operation cannot be 166 quantized, otherwise the model will fallback to float ops. 167 activations_type: A tf.dtype representing the desired type for 168 activations. 169 bias_type: A tf.dtype representing the desired type for bias. 170 resize_input: A boolean. True if the shape of the sample data is different 171 from the input. 172 disable_per_channel: A boolean. True if disabling per-channel 173 quantization. 174 """ 175 self._feed_tensors(dataset_gen, resize_input) 176 return self._calibrator.QuantizeModel( 177 np.dtype(input_type.as_numpy_dtype()).num, 178 np.dtype(output_type.as_numpy_dtype()).num, allow_float, 179 np.dtype(activations_type.as_numpy_dtype()).num, 180 np.dtype(bias_type.as_numpy_dtype()).num, disable_per_channel) 181 182 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, 183 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER) 184 def calibrate_and_quantize_single(self, 185 dataset_gen, 186 input_type, 187 output_type, 188 allow_float, 189 op_output_name, 190 resize_input=True): 191 """Calibrates the model with specified generator and then quantizes it. 192 193 Only the single op with output op_output_name will be quantized. 194 The input shapes of the calibrator are resized with the calibration data. 195 196 Returns: 197 A quantized model. 198 199 Args: 200 dataset_gen: A generator that generates calibration samples. 201 input_type: A tf.dtype representing the desired real-value input type. 202 output_type: A tf.dtype representing the desired real-value output type. 203 allow_float: A boolean. False if the resulting model cannot perform float 204 computation, useful when targeting an integer-only backend. If False, an 205 error will be thrown if an operation cannot be quantized, otherwise the 206 model will fallback to float ops. 207 op_output_name: A string, only this op will be quantized. 208 resize_input: A boolean. True if the shape of the sample data is different 209 from the input. 210 """ 211 self._feed_tensors(dataset_gen, resize_input) 212 return self._calibrator.QuantizeModel( 213 np.dtype(input_type.as_numpy_dtype()).num, 214 np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name) 215 216 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.CALIBRATE) 217 def calibrate(self, dataset_gen): 218 """Calibrates the model with specified generator. 219 220 Returns: 221 A model with min and max calibration stats. 222 223 Args: 224 dataset_gen: A generator that generates calibration samples. 225 """ 226 self._feed_tensors(dataset_gen, resize_input=True) 227 return self._calibrator.Calibrate() 228