• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Calibrator."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.lite.python.optimize import calibrator as _calibrator
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import test_util
23from tensorflow.python.platform import resource_loader
24from tensorflow.python.platform import test
25
26
27class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
28
29  @parameterized.named_parameters(
30      # Activation type Int8
31      ('UseActivationTypeInt8', dtypes.int8),
32      # Activation type Int16
33      ('UseActivationTypeInt16', dtypes.int16))
34  def test_calibration_with_quantization(self, activations_type):
35    model_path = resource_loader.get_path_to_datafile(
36        'test_data/mobilenet_like_model.bin')
37    float_model = open(model_path, 'rb').read()
38    quantizer = _calibrator.Calibrator(float_model)
39
40    # Input generator for the model.
41    def input_gen():
42      for _ in range(10):
43        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
44
45    quantized_model = quantizer.calibrate_and_quantize(input_gen,
46                                                       dtypes.float32,
47                                                       dtypes.float32,
48                                                       False,
49                                                       activations_type)
50    self.assertIsNotNone(quantized_model)
51
52  @parameterized.named_parameters(
53      # Activation type Int8
54      ('UseActivationTypeInt8', dtypes.int8),
55      # Activation type Int16
56      ('UseActivationTypeInt16', dtypes.int16))
57  def test_calibration_with_quantization_allow_float(self, activations_type):
58    model_path = resource_loader.get_path_to_datafile(
59        'test_data/mobilenet_like_model.bin')
60    float_model = open(model_path, 'rb').read()
61    quantizer = _calibrator.Calibrator(float_model)
62
63    # Input generator for the model.
64    def input_gen():
65      for _ in range(10):
66        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
67
68    quantized_model = quantizer.calibrate_and_quantize(input_gen,
69                                                       dtypes.float32,
70                                                       dtypes.float32,
71                                                       True,
72                                                       activations_type)
73    self.assertIsNotNone(quantized_model)
74
75  def test_calibration_with_quantization_single_op(self):
76    model_path = resource_loader.get_path_to_datafile(
77        'test_data/mobilenet_like_model.bin')
78    float_model = open(model_path, 'rb').read()
79    quantizer = _calibrator.Calibrator(float_model)
80
81    # Input generator for the model.
82    def input_gen():
83      for _ in range(10):
84        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
85
86    quantized_model = quantizer.calibrate_and_quantize_single(
87        input_gen, dtypes.float32, dtypes.float32, True, 'conv2d_8/BiasAdd')
88    self.assertIsNotNone(quantized_model)
89
90  def test_calibration_with_string_input(self):
91    model_path = resource_loader.get_path_to_datafile(
92        'test_data/string_input_flex_model.bin')
93    with open(model_path, 'rb') as fp:
94      model_with_string_input = fp.read()
95    quantizer = _calibrator.Calibrator(model_with_string_input)
96    # Input generator for the model.
97    def input_gen():
98      for i in range(10):
99        yield [np.array(u'Test' + str(i))]
100
101    quantized_model = quantizer.calibrate_and_quantize_single(
102        input_gen, dtypes.float32, dtypes.float32, True, 'Identity')
103    self.assertIsNotNone(quantized_model)
104
105  @parameterized.named_parameters(
106      # Activation type Int8
107      ('UseActivationTypeInt8 - EnableMlirQuantizer', dtypes.int8),
108      # Activation type Int16
109      ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', dtypes.int16))
110  def test_calibration_with_quantization_multiple_inputs(
111      self, activations_type):
112    # Load multi add model from test data.
113    # This model has 4 inputs of size (1, 8, 8, 3).
114    model_path = resource_loader.get_path_to_datafile(
115        '../../testdata/multi_add.bin')
116    float_model = open(model_path, 'rb').read()
117    quantizer = _calibrator.Calibrator(float_model)
118
119    # Input generator for the model.
120    def input_gen():
121      for _ in range(10):
122        yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
123
124    quantized_model = quantizer.calibrate_and_quantize(input_gen,
125                                                       dtypes.float32,
126                                                       dtypes.float32,
127                                                       False,
128                                                       activations_type)
129    self.assertIsNotNone(quantized_model)
130
131  def test_invalid_model_buffer(self):
132    float_model = b'\0' * 100
133    with self.assertRaisesRegex(ValueError, 'Failed to parse the model'):
134      _calibrator.Calibrator(float_model)
135
136  # TODO(fengliuai): enable mlir quantizer
137  def test_empty_calibrator_gen(self):
138    model_path = resource_loader.get_path_to_datafile(
139        'test_data/mobilenet_like_model.bin')
140    float_model = open(model_path, 'rb').read()
141    quantizer = _calibrator.Calibrator(float_model)
142
143    def empty_input_gen():
144      for i in ():
145        yield i
146
147    with self.assertRaises(RuntimeError):
148      quantizer.calibrate_and_quantize(empty_input_gen, dtypes.float32,
149                                       dtypes.float32, False)
150
151  def test_invalid_shape_calibrator_gen(self):
152    model_path = resource_loader.get_path_to_datafile(
153        'test_data/mobilenet_like_model.bin')
154    float_model = open(model_path, 'rb').read()
155    quantizer = _calibrator.Calibrator(float_model)
156
157    # Input generator with incorrect shape.
158    def input_gen():
159      for _ in range(10):
160        yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]
161
162    with self.assertRaisesRegex(ValueError, 'Size mismatch'):
163      quantizer.calibrate_and_quantize(
164          input_gen,
165          dtypes.float32,
166          dtypes.float32,
167          False,
168          activations_type=dtypes.int8,
169          bias_type=dtypes.int32,
170          resize_input=False)
171
172  def test_invalid_type_calibrator_gen(self):
173    model_path = resource_loader.get_path_to_datafile(
174        'test_data/mobilenet_like_model.bin')
175    float_model = open(model_path, 'rb').read()
176    quantizer = _calibrator.Calibrator(float_model)
177
178    # Input generator with incorrect type.
179    def input_gen():
180      for _ in range(10):
181        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)]
182
183    with self.assertRaises(ValueError):
184      quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
185                                       dtypes.float32, False, dtypes.int8)
186
187  def test_calibration(self):
188    model_path = resource_loader.get_path_to_datafile(
189        'test_data/mobilenet_like_model.bin')
190    float_model = open(model_path, 'rb').read()
191    quantizer = _calibrator.Calibrator(float_model)
192
193    # Input generator for the model.
194    def input_gen():
195      for _ in range(10):
196        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
197
198    quantized_model = quantizer.calibrate(input_gen)
199    self.assertIsNotNone(quantized_model)
200
201  def test_add_intermediate_tensors(self):
202    model_path = resource_loader.get_path_to_datafile(
203        'test_data/mobilenet_like_model.bin')
204    model = open(model_path, 'rb').read()
205    added_model = _calibrator.add_intermediate_tensors(model)
206    self.assertIsNotNone(added_model)
207
208
209if __name__ == '__main__':
210  test.main()
211