• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for post training quantization with calibration."""
16import numpy as np
17
18from tensorflow.lite.python.convert_phase import Component
19from tensorflow.lite.python.convert_phase import convert_phase
20from tensorflow.lite.python.convert_phase import SubComponent
21from tensorflow.lite.python.interpreter import Interpreter
22from tensorflow.python.framework import dtypes
23from tensorflow.python.util.lazy_loader import LazyLoader
24
25# Lazy load since some of the performance benchmark skylark rules
26# break dependencies. Must use double quotes to match code internal rewrite
27# rule.
28_calibration_wrapper = LazyLoader(
29    "_calibration_wrapper", globals(),
30    "tensorflow.lite.python.optimize."
31    "_pywrap_tensorflow_lite_calibration_wrapper")
32
33
34def add_intermediate_tensors(model_content):
35  """Adds intermediate tensors to fused op if needed."""
36  return _calibration_wrapper.AddIntermediateTensors(model_content)
37
38
39class Calibrator:
40  """Calibrates a floating point model and then quantizes it.
41
42  This is an internal class, not a public interface.
43  """
44
45  def __init__(self,
46               model_content,
47               custom_op_registerers_by_name=None,
48               custom_op_registerers_by_func=None):
49    """Constructor.
50
51    Args:
52      model_content: Content of a TF-Lite Flatbuffer file.
53      custom_op_registerers_by_name: List of str (symbol names) that take a
54        pointer to a MutableOpResolver and register custom ops.
55      custom_op_registerers_by_func: List of functions that take a pointer to a
56        MutableOpResolver and register custom ops.
57
58    Raises:
59      ValueError: If the calibrator was unable to open the model.
60    """
61    if not model_content:
62      raise ValueError("`model_content` must be specified.")
63    if custom_op_registerers_by_name is None:
64      custom_op_registerers_by_name = []
65    if custom_op_registerers_by_func is None:
66      custom_op_registerers_by_func = []
67    try:
68      self._calibrator = (
69          _calibration_wrapper.CalibrationWrapper(
70              model_content, custom_op_registerers_by_name,
71              custom_op_registerers_by_func))
72      self._model_content = model_content
73    except Exception as e:
74      raise ValueError("Failed to parse the model: %s." % e)
75    if not self._calibrator:
76      raise ValueError("Failed to parse the model.")
77    self._interpreter = None
78
79  def _create_input_array_from_dict(self, signature_key, inputs):
80    input_array = []
81    signature_runner = self._interpreter.get_signature_runner(signature_key)
82    input_details = sorted(
83        signature_runner.get_input_details().items(),
84        key=lambda item: item[1]["index"])
85    for input_name, _ in input_details:
86      input_array.append(inputs[input_name])
87    return input_array
88
89  def _feed_tensors(self, dataset_gen, resize_input):
90    """Feed tensors to the calibrator."""
91    initialized = {}
92
93    for sample in dataset_gen():
94      if isinstance(sample, tuple):
95        if not isinstance(sample[1], dict):
96          raise ValueError("You need to provide either a dictionary with input "
97                           "names and values in the second arugment in the "
98                           "tuple")
99        # Convert signature based inputs to the tensor index based data.
100        if self._interpreter is None:
101          self._interpreter = Interpreter(model_content=self._model_content)
102        signature_key = sample[0]
103        input_array = self._create_input_array_from_dict(
104            signature_key, sample[1])
105      elif isinstance(sample, dict):
106        # Convert signature based inputs to the tensor index based data.
107        if self._interpreter is None:
108          self._interpreter = Interpreter(model_content=self._model_content)
109        signature_key = None
110        input_array = self._create_input_array_from_dict(None, sample)
111      elif isinstance(sample, list):
112        signature_key = None
113        input_array = sample
114      else:
115        raise ValueError("You need to provide either a dictionary with input "
116                         "names and values, a tuple with signature key and a "
117                         "dictionary with input names and values, or an array "
118                         "with input values in the order of input tensors of "
119                         "the graph in the representative_dataset function. "
120                         "Unsupported value from dataset: {}.".format(sample))
121
122      if signature_key not in initialized:
123        initialized[signature_key] = True
124        if resize_input:
125          if signature_key is not None:
126            self._calibrator.Prepare([list(s.shape) for s in input_array],
127                                     signature_key)
128          else:
129            self._calibrator.Prepare([list(s.shape) for s in input_array])
130        else:
131          if signature_key is not None:
132            self._calibrator.Prepare(signature_key)
133          else:
134            self._calibrator.Prepare()
135      if signature_key is not None:
136        self._calibrator.FeedTensor(input_array, signature_key)
137      else:
138        self._calibrator.FeedTensor(input_array)
139
140  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL,
141                 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER)
142  def calibrate_and_quantize(self,
143                             dataset_gen,
144                             input_type,
145                             output_type,
146                             allow_float,
147                             activations_type=dtypes.int8,
148                             bias_type=dtypes.int32,
149                             resize_input=True,
150                             disable_per_channel=False):
151    """Calibrates the model with specified generator and then quantizes it.
152
153    The input shapes of the calibrator are resized with the calibration data if
154    `resize_input` is set.
155
156    Returns:
157      A quantized model.
158
159    Args:
160      dataset_gen: A generator that generates calibration samples.
161      input_type: A tf.dtype representing the desired real-value input type.
162      output_type: A tf.dtype representing the desired real-value output type.
163      allow_float: A boolean. False if the resulting model cannot perform float
164                   computation, useful when targeting an integer-only backend.
165                   If False, an error will be thrown if an operation cannot be
166                   quantized, otherwise the model will fallback to float ops.
167      activations_type: A tf.dtype representing the desired type for
168                   activations.
169      bias_type: A tf.dtype representing the desired type for bias.
170      resize_input: A boolean. True if the shape of the sample data is different
171        from the input.
172      disable_per_channel: A boolean. True if disabling per-channel
173                   quantization.
174    """
175    self._feed_tensors(dataset_gen, resize_input)
176    return self._calibrator.QuantizeModel(
177        np.dtype(input_type.as_numpy_dtype()).num,
178        np.dtype(output_type.as_numpy_dtype()).num, allow_float,
179        np.dtype(activations_type.as_numpy_dtype()).num,
180        np.dtype(bias_type.as_numpy_dtype()).num, disable_per_channel)
181
182  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL,
183                 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER)
184  def calibrate_and_quantize_single(self,
185                                    dataset_gen,
186                                    input_type,
187                                    output_type,
188                                    allow_float,
189                                    op_output_name,
190                                    resize_input=True):
191    """Calibrates the model with specified generator and then quantizes it.
192
193    Only the single op with output op_output_name will be quantized.
194    The input shapes of the calibrator are resized with the calibration data.
195
196    Returns:
197      A quantized model.
198
199    Args:
200      dataset_gen: A generator that generates calibration samples.
201      input_type: A tf.dtype representing the desired real-value input type.
202      output_type: A tf.dtype representing the desired real-value output type.
203      allow_float: A boolean. False if the resulting model cannot perform float
204        computation, useful when targeting an integer-only backend. If False, an
205        error will be thrown if an operation cannot be quantized, otherwise the
206        model will fallback to float ops.
207      op_output_name: A string, only this op will be quantized.
208      resize_input: A boolean. True if the shape of the sample data is different
209        from the input.
210    """
211    self._feed_tensors(dataset_gen, resize_input)
212    return self._calibrator.QuantizeModel(
213        np.dtype(input_type.as_numpy_dtype()).num,
214        np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name)
215
216  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.CALIBRATE)
217  def calibrate(self, dataset_gen):
218    """Calibrates the model with specified generator.
219
220    Returns:
221      A model with min and max calibration stats.
222
223    Args:
224      dataset_gen: A generator that generates calibration samples.
225    """
226    self._feed_tensors(dataset_gen, resize_input=True)
227    return self._calibrator.Calibrate()
228