1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for post training quantization with calibration.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.util.lazy_loader import LazyLoader 24 25# Lazy load since some of the performance benchmark skylark rules 26# break dependencies. Must use double quotes to match code internal rewrite 27# rule. 28_calibration_wrapper = LazyLoader( 29 "_calibration_wrapper", globals(), 30 "tensorflow.lite.python.optimize." 31 "_pywrap_tensorflow_lite_calibration_wrapper") 32 33 34def add_intermediate_tensors(model_content): 35 """Adds intermediate tensors to fused op if needed.""" 36 return _calibration_wrapper.AddIntermediateTensors(model_content) 37 38 39class Calibrator(object): 40 """Calibrates a floating point model and then quantizes it. 41 42 This is an internal class, not a public interface. 43 """ 44 45 def __init__(self, model_content): 46 """Constructor. 47 48 Args: 49 model_content: Content of a TF-Lite Flatbuffer file. 50 51 Raises: 52 ValueError: If the calibrator was unable to open the model. 53 """ 54 if not model_content: 55 raise ValueError("`model_content` must be specified.") 56 try: 57 self._calibrator = ( 58 _calibration_wrapper.CalibrationWrapper(model_content)) 59 except Exception as e: 60 raise ValueError("Failed to parse the model: %s." % e) 61 if not self._calibrator: 62 raise ValueError("Failed to parse the model.") 63 64 def calibrate_and_quantize(self, 65 dataset_gen, 66 input_type, 67 output_type, 68 allow_float, 69 activations_type=dtypes.int8, 70 resize_input=True): 71 """Calibrates the model with specified generator and then quantizes it. 72 73 The input shapes of the calibrator are resized with the calibration data if 74 `resize_input` is set. 75 76 Returns: 77 A quantized model. 78 79 Args: 80 dataset_gen: A generator that generates calibration samples. 81 input_type: A tf.dtype representing the desired real-value input type. 82 output_type: A tf.dtype representing the desired real-value output type. 83 allow_float: A boolean. False if the resulting model cannot perform float 84 computation, useful when targeting an integer-only backend. 85 If False, an error will be thrown if an operation cannot be 86 quantized, otherwise the model will fallback to float ops. 87 activations_type: A tf.dtype representing the desired type for 88 activations. 89 resize_input: A boolean. True if the shape of the sample data is different 90 from the input. 91 """ 92 initialized = False 93 for sample in dataset_gen(): 94 if not initialized: 95 initialized = True 96 if resize_input: 97 self._calibrator.Prepare([list(s.shape) for s in sample]) 98 else: 99 self._calibrator.Prepare() 100 self._calibrator.FeedTensor(sample) 101 return self._calibrator.QuantizeModel( 102 np.dtype(input_type.as_numpy_dtype()).num, 103 np.dtype(output_type.as_numpy_dtype()).num, allow_float, 104 np.dtype(activations_type.as_numpy_dtype()).num) 105 106 def calibrate_and_quantize_single(self, 107 dataset_gen, 108 input_type, 109 output_type, 110 allow_float, 111 op_output_name, 112 resize_input=True): 113 """Calibrates the model with specified generator and then quantizes it. 114 115 Only the single op with output op_output_name will be quantized. 116 The input shapes of the calibrator are resized with the calibration data. 117 118 Returns: 119 A quantized model. 120 121 Args: 122 dataset_gen: A generator that generates calibration samples. 123 input_type: A tf.dtype representing the desired real-value input type. 124 output_type: A tf.dtype representing the desired real-value output type. 125 allow_float: A boolean. False if the resulting model cannot perform float 126 computation, useful when targeting an integer-only backend. If False, an 127 error will be thrown if an operation cannot be quantized, otherwise the 128 model will fallback to float ops. 129 op_output_name: A string, only this op will be quantized. 130 resize_input: A boolean. True if the shape of the sample data is different 131 from the input. 132 """ 133 initialized = False 134 for sample in dataset_gen(): 135 if not initialized: 136 initialized = True 137 if resize_input: 138 self._calibrator.Prepare([list(s.shape) for s in sample]) 139 else: 140 self._calibrator.Prepare() 141 self._calibrator.FeedTensor(sample) 142 return self._calibrator.QuantizeModel( 143 np.dtype(input_type.as_numpy_dtype()).num, 144 np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name) 145 146 def calibrate(self, dataset_gen): 147 """Calibrates the model with specified generator. 148 149 Returns: 150 A model with min and max calibration stats. 151 152 Args: 153 dataset_gen: A generator that generates calibration samples. 154 """ 155 initialized = False 156 for sample in dataset_gen(): 157 if not initialized: 158 initialized = True 159 self._calibrator.Prepare([list(s.shape) for s in sample]) 160 self._calibrator.FeedTensor(sample) 161 return self._calibrator.Calibrate() 162