• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for post training quantization with calibration."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.util.lazy_loader import LazyLoader
24
25# Lazy load since some of the performance benchmark skylark rules
26# break dependencies. Must use double quotes to match code internal rewrite
27# rule.
28_calibration_wrapper = LazyLoader(
29    "_calibration_wrapper", globals(),
30    "tensorflow.lite.python.optimize."
31    "_pywrap_tensorflow_lite_calibration_wrapper")
32
33
34def add_intermediate_tensors(model_content):
35  """Adds intermediate tensors to fused op if needed."""
36  return _calibration_wrapper.AddIntermediateTensors(model_content)
37
38
39class Calibrator(object):
40  """Calibrates a floating point model and then quantizes it.
41
42  This is an internal class, not a public interface.
43  """
44
45  def __init__(self, model_content):
46    """Constructor.
47
48    Args:
49      model_content: Content of a TF-Lite Flatbuffer file.
50
51    Raises:
52      ValueError: If the calibrator was unable to open the model.
53    """
54    if not model_content:
55      raise ValueError("`model_content` must be specified.")
56    try:
57      self._calibrator = (
58          _calibration_wrapper.CalibrationWrapper(model_content))
59    except Exception as e:
60      raise ValueError("Failed to parse the model: %s." % e)
61    if not self._calibrator:
62      raise ValueError("Failed to parse the model.")
63
64  def calibrate_and_quantize(self,
65                             dataset_gen,
66                             input_type,
67                             output_type,
68                             allow_float,
69                             activations_type=dtypes.int8,
70                             resize_input=True):
71    """Calibrates the model with specified generator and then quantizes it.
72
73    The input shapes of the calibrator are resized with the calibration data if
74    `resize_input` is set.
75
76    Returns:
77      A quantized model.
78
79    Args:
80      dataset_gen: A generator that generates calibration samples.
81      input_type: A tf.dtype representing the desired real-value input type.
82      output_type: A tf.dtype representing the desired real-value output type.
83      allow_float: A boolean. False if the resulting model cannot perform float
84                   computation, useful when targeting an integer-only backend.
85                   If False, an error will be thrown if an operation cannot be
86                   quantized, otherwise the model will fallback to float ops.
87      activations_type: A tf.dtype representing the desired type for
88                   activations.
89      resize_input: A boolean. True if the shape of the sample data is different
90        from the input.
91    """
92    initialized = False
93    for sample in dataset_gen():
94      if not initialized:
95        initialized = True
96        if resize_input:
97          self._calibrator.Prepare([list(s.shape) for s in sample])
98        else:
99          self._calibrator.Prepare()
100      self._calibrator.FeedTensor(sample)
101    return self._calibrator.QuantizeModel(
102        np.dtype(input_type.as_numpy_dtype()).num,
103        np.dtype(output_type.as_numpy_dtype()).num, allow_float,
104        np.dtype(activations_type.as_numpy_dtype()).num)
105
106  def calibrate_and_quantize_single(self,
107                                    dataset_gen,
108                                    input_type,
109                                    output_type,
110                                    allow_float,
111                                    op_output_name,
112                                    resize_input=True):
113    """Calibrates the model with specified generator and then quantizes it.
114
115    Only the single op with output op_output_name will be quantized.
116    The input shapes of the calibrator are resized with the calibration data.
117
118    Returns:
119      A quantized model.
120
121    Args:
122      dataset_gen: A generator that generates calibration samples.
123      input_type: A tf.dtype representing the desired real-value input type.
124      output_type: A tf.dtype representing the desired real-value output type.
125      allow_float: A boolean. False if the resulting model cannot perform float
126        computation, useful when targeting an integer-only backend. If False, an
127        error will be thrown if an operation cannot be quantized, otherwise the
128        model will fallback to float ops.
129      op_output_name: A string, only this op will be quantized.
130      resize_input: A boolean. True if the shape of the sample data is different
131        from the input.
132    """
133    initialized = False
134    for sample in dataset_gen():
135      if not initialized:
136        initialized = True
137        if resize_input:
138          self._calibrator.Prepare([list(s.shape) for s in sample])
139        else:
140          self._calibrator.Prepare()
141      self._calibrator.FeedTensor(sample)
142    return self._calibrator.QuantizeModel(
143        np.dtype(input_type.as_numpy_dtype()).num,
144        np.dtype(output_type.as_numpy_dtype()).num, allow_float, op_output_name)
145
146  def calibrate(self, dataset_gen):
147    """Calibrates the model with specified generator.
148
149    Returns:
150      A model with min and max calibration stats.
151
152    Args:
153      dataset_gen: A generator that generates calibration samples.
154    """
155    initialized = False
156    for sample in dataset_gen():
157      if not initialized:
158        initialized = True
159        self._calibrator.Prepare([list(s.shape) for s in sample])
160      self._calibrator.FeedTensor(sample)
161    return self._calibrator.Calibrate()
162