• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for post training quantization with calibration."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.lite.python.convert_phase import Component
23from tensorflow.lite.python.convert_phase import convert_phase
24from tensorflow.lite.python.convert_phase import SubComponent
25from tensorflow.python.framework import dtypes
26from tensorflow.python.util.lazy_loader import LazyLoader
27
28# Lazy load since some of the performance benchmark skylark rules
29# break dependencies. Must use double quotes to match code internal rewrite
30# rule.
31_calibration_wrapper = LazyLoader(
32    "_calibration_wrapper", globals(),
33    "tensorflow.lite.python.optimize."
34    "_pywrap_tensorflow_lite_calibration_wrapper")
35
36
37def add_intermediate_tensors(model_content):
38  """Adds intermediate tensors to fused op if needed."""
39  return _calibration_wrapper.AddIntermediateTensors(model_content)
40
41
42class Calibrator(object):
43  """Calibrates a floating point model and then quantizes it.
44
45  This is an internal class, not a public interface.
46  """
47
48  def __init__(self,
49               model_content,
50               custom_op_registerers_by_name=None,
51               custom_op_registerers_by_func=None):
52    """Constructor.
53
54    Args:
55      model_content: Content of a TF-Lite Flatbuffer file.
56      custom_op_registerers_by_name: List of str (symbol names) that take a
57        pointer to a MutableOpResolver and register custom ops.
58      custom_op_registerers_by_func: List of functions that take a pointer to a
59        MutableOpResolver and register custom ops.
60
61    Raises:
62      ValueError: If the calibrator was unable to open the model.
63    """
64    if not model_content:
65      raise ValueError("`model_content` must be specified.")
66    if custom_op_registerers_by_name is None:
67      custom_op_registerers_by_name = []
68    if custom_op_registerers_by_func is None:
69      custom_op_registerers_by_func = []
70    try:
71      self._calibrator = (
72          _calibration_wrapper.CalibrationWrapper(
73              model_content, custom_op_registerers_by_name,
74              custom_op_registerers_by_func))
75    except Exception as e:
76      raise ValueError("Failed to parse the model: %s." % e)
77    if not self._calibrator:
78      raise ValueError("Failed to parse the model.")
79
80  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL,
81                 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER)
82  def calibrate_and_quantize(self,
83                             dataset_gen,
84                             input_type,
85                             output_type,
86                             allow_float,
87                             activations_type=dtypes.int8,
88                             resize_input=True,
89                             disable_per_channel=False):
90    """Calibrates the model with specified generator and then quantizes it.
91
92    The input shapes of the calibrator are resized with the calibration data if
93    `resize_input` is set.
94
95    Returns:
96      A quantized model.
97
98    Args:
99      dataset_gen: A generator that generates calibration samples.
100      input_type: A tf.dtype representing the desired real-value input type.
101      output_type: A tf.dtype representing the desired real-value output type.
102      allow_float: A boolean. False if the resulting model cannot perform float
103                   computation, useful when targeting an integer-only backend.
104                   If False, an error will be thrown if an operation cannot be
105                   quantized, otherwise the model will fallback to float ops.
106      activations_type: A tf.dtype representing the desired type for
107                   activations.
108      resize_input: A boolean. True if the shape of the sample data is different
109        from the input.
110      disable_per_channel: A boolean. True if disabling per-channel
111                   quantization.
112    """
113    initialized = False
114    for sample in dataset_gen():
115      if not initialized:
116        initialized = True
117        if resize_input:
118          self._calibrator.Prepare([list(s.shape) for s in sample])
119        else:
120          self._calibrator.Prepare()
121      self._calibrator.FeedTensor(sample)
122    return self._calibrator.QuantizeModel(
123        np.dtype(input_type.as_numpy_dtype()).num,
124        np.dtype(output_type.as_numpy_dtype()).num, allow_float,
125        np.dtype(activations_type.as_numpy_dtype()).num,
126        disable_per_channel)
127
128  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL,
129                 SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER)
130  def calibrate_and_quantize_single(self,
131                                    dataset_gen,
132                                    input_type,
133                                    output_type,
134                                    allow_float,
135                                    op_output_name,
136                                    resize_input=True):
137    """Calibrates the model with specified generator and then quantizes it.
138
139    Only the single op with output op_output_name will be quantized.
140    The input shapes of the calibrator are resized with the calibration data.
141
142    Returns:
143      A quantized model.
144
145    Args:
146      dataset_gen: A generator that generates calibration samples.
147      input_type: A tf.dtype representing the desired real-value input type.
148      output_type: A tf.dtype representing the desired real-value output type.
149      allow_float: A boolean. False if the resulting model cannot perform float
150        computation, useful when targeting an integer-only backend. If False, an
151        error will be thrown if an operation cannot be quantized, otherwise the
152        model will fallback to float ops.
153      op_output_name: A string, only this op will be quantized.
154      resize_input: A boolean. True if the shape of the sample data is different
155        from the input.
156    """
157    initialized = False
158    for sample in dataset_gen():
159      if not initialized:
160        initialized = True
161        if resize_input:
162          self._calibrator.Prepare([list(s.shape) for s in sample])
163        else:
164          self._calibrator.Prepare()
165      self._calibrator.FeedTensor(sample)
166    return self._calibrator.QuantizeModel(
167        np.dtype(input_type.as_numpy_dtype()).num,
168        np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name)
169
170  @convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.CALIBRATE)
171  def calibrate(self, dataset_gen):
172    """Calibrates the model with specified generator.
173
174    Returns:
175      A model with min and max calibration stats.
176
177    Args:
178      dataset_gen: A generator that generates calibration samples.
179    """
180    initialized = False
181    for sample in dataset_gen():
182      if not initialized:
183        initialized = True
184        self._calibrator.Prepare([list(s.shape) for s in sample])
185      self._calibrator.FeedTensor(sample)
186    return self._calibrator.Calibrate()
187