1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for post training quantization with calibration.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.lite.python.convert_phase import Component 23from tensorflow.lite.python.convert_phase import convert_phase 24from tensorflow.lite.python.convert_phase import SubComponent 25from tensorflow.python.framework import dtypes 26from tensorflow.python.util.lazy_loader import LazyLoader 27 28# Lazy load since some of the performance benchmark skylark rules 29# break dependencies. Must use double quotes to match code internal rewrite 30# rule. 31_calibration_wrapper = LazyLoader( 32 "_calibration_wrapper", globals(), 33 "tensorflow.lite.python.optimize." 34 "_pywrap_tensorflow_lite_calibration_wrapper") 35 36 37def add_intermediate_tensors(model_content): 38 """Adds intermediate tensors to fused op if needed.""" 39 return _calibration_wrapper.AddIntermediateTensors(model_content) 40 41 42class Calibrator(object): 43 """Calibrates a floating point model and then quantizes it. 44 45 This is an internal class, not a public interface. 46 """ 47 48 def __init__(self, 49 model_content, 50 custom_op_registerers_by_name=None, 51 custom_op_registerers_by_func=None): 52 """Constructor. 53 54 Args: 55 model_content: Content of a TF-Lite Flatbuffer file. 56 custom_op_registerers_by_name: List of str (symbol names) that take a 57 pointer to a MutableOpResolver and register custom ops. 58 custom_op_registerers_by_func: List of functions that take a pointer to a 59 MutableOpResolver and register custom ops. 60 61 Raises: 62 ValueError: If the calibrator was unable to open the model. 63 """ 64 if not model_content: 65 raise ValueError("`model_content` must be specified.") 66 if custom_op_registerers_by_name is None: 67 custom_op_registerers_by_name = [] 68 if custom_op_registerers_by_func is None: 69 custom_op_registerers_by_func = [] 70 try: 71 self._calibrator = ( 72 _calibration_wrapper.CalibrationWrapper( 73 model_content, custom_op_registerers_by_name, 74 custom_op_registerers_by_func)) 75 except Exception as e: 76 raise ValueError("Failed to parse the model: %s." % e) 77 if not self._calibrator: 78 raise ValueError("Failed to parse the model.") 79 80 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, 81 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER) 82 def calibrate_and_quantize(self, 83 dataset_gen, 84 input_type, 85 output_type, 86 allow_float, 87 activations_type=dtypes.int8, 88 resize_input=True, 89 disable_per_channel=False): 90 """Calibrates the model with specified generator and then quantizes it. 91 92 The input shapes of the calibrator are resized with the calibration data if 93 `resize_input` is set. 94 95 Returns: 96 A quantized model. 97 98 Args: 99 dataset_gen: A generator that generates calibration samples. 100 input_type: A tf.dtype representing the desired real-value input type. 101 output_type: A tf.dtype representing the desired real-value output type. 102 allow_float: A boolean. False if the resulting model cannot perform float 103 computation, useful when targeting an integer-only backend. 104 If False, an error will be thrown if an operation cannot be 105 quantized, otherwise the model will fallback to float ops. 106 activations_type: A tf.dtype representing the desired type for 107 activations. 108 resize_input: A boolean. True if the shape of the sample data is different 109 from the input. 110 disable_per_channel: A boolean. True if disabling per-channel 111 quantization. 112 """ 113 initialized = False 114 for sample in dataset_gen(): 115 if not initialized: 116 initialized = True 117 if resize_input: 118 self._calibrator.Prepare([list(s.shape) for s in sample]) 119 else: 120 self._calibrator.Prepare() 121 self._calibrator.FeedTensor(sample) 122 return self._calibrator.QuantizeModel( 123 np.dtype(input_type.as_numpy_dtype()).num, 124 np.dtype(output_type.as_numpy_dtype()).num, allow_float, 125 np.dtype(activations_type.as_numpy_dtype()).num, 126 disable_per_channel) 127 128 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, 129 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER) 130 def calibrate_and_quantize_single(self, 131 dataset_gen, 132 input_type, 133 output_type, 134 allow_float, 135 op_output_name, 136 resize_input=True): 137 """Calibrates the model with specified generator and then quantizes it. 138 139 Only the single op with output op_output_name will be quantized. 140 The input shapes of the calibrator are resized with the calibration data. 141 142 Returns: 143 A quantized model. 144 145 Args: 146 dataset_gen: A generator that generates calibration samples. 147 input_type: A tf.dtype representing the desired real-value input type. 148 output_type: A tf.dtype representing the desired real-value output type. 149 allow_float: A boolean. False if the resulting model cannot perform float 150 computation, useful when targeting an integer-only backend. If False, an 151 error will be thrown if an operation cannot be quantized, otherwise the 152 model will fallback to float ops. 153 op_output_name: A string, only this op will be quantized. 154 resize_input: A boolean. True if the shape of the sample data is different 155 from the input. 156 """ 157 initialized = False 158 for sample in dataset_gen(): 159 if not initialized: 160 initialized = True 161 if resize_input: 162 self._calibrator.Prepare([list(s.shape) for s in sample]) 163 else: 164 self._calibrator.Prepare() 165 self._calibrator.FeedTensor(sample) 166 return self._calibrator.QuantizeModel( 167 np.dtype(input_type.as_numpy_dtype()).num, 168 np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name) 169 170 @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.CALIBRATE) 171 def calibrate(self, dataset_gen): 172 """Calibrates the model with specified generator. 173 174 Returns: 175 A model with min and max calibration stats. 176 177 Args: 178 dataset_gen: A generator that generates calibration samples. 179 """ 180 initialized = False 181 for sample in dataset_gen(): 182 if not initialized: 183 initialized = True 184 self._calibrator.Prepare([list(s.shape) for s in sample]) 185 self._calibrator.FeedTensor(sample) 186 return self._calibrator.Calibrate() 187