• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py functionality related to TensorFlow 2.0."""
16
17import os
18
19from absl.testing import parameterized
20import numpy as np
21import tensorflow as tf
22
23from tensorflow.lite.python.interpreter import Interpreter
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.trackable import autotrackable
33
34
35class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
36  """Base test class for TensorFlow Lite 2.x model tests."""
37
38  def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
39    """Evaluates the model on the `input_data`.
40
41    Args:
42      tflite_model: TensorFlow Lite model.
43      input_data: List of EagerTensor const ops containing the input data for
44        each input tensor.
45      input_shapes: List of tuples representing the `shape_signature` and the
46        new shape of each input tensor that has unknown dimensions.
47
48    Returns:
49      [np.ndarray]
50    """
51    interpreter = Interpreter(model_content=tflite_model)
52    input_details = interpreter.get_input_details()
53    if input_shapes:
54      for idx, (shape_signature, final_shape) in enumerate(input_shapes):
55        self.assertTrue(
56            (input_details[idx]['shape_signature'] == shape_signature).all())
57        index = input_details[idx]['index']
58        interpreter.resize_tensor_input(index, final_shape, strict=True)
59    interpreter.allocate_tensors()
60
61    output_details = interpreter.get_output_details()
62    input_details = interpreter.get_input_details()
63
64    for input_tensor, tensor_data in zip(input_details, input_data):
65      interpreter.set_tensor(input_tensor['index'], tensor_data.numpy())
66    interpreter.invoke()
67    return [
68        interpreter.get_tensor(details['index']) for details in output_details
69    ]
70
71  def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, signature_key,
72                                            inputs):
73    """Evaluates the model on the `inputs`.
74
75    Args:
76      tflite_model: TensorFlow Lite model.
77      signature_key: Signature key.
78      inputs: Map from input tensor names in the SignatureDef to tensor value.
79
80    Returns:
81      Dictionary of outputs.
82      Key is the output name in the SignatureDef 'signature_key'
83      Value is the output value
84    """
85    interpreter = Interpreter(model_content=tflite_model)
86    signature_runner = interpreter.get_signature_runner(signature_key)
87    return signature_runner(**inputs)
88
89  def _getSimpleVariableModel(self):
90    root = autotrackable.AutoTrackable()
91    root.v1 = variables.Variable(3.)
92    root.v2 = variables.Variable(2.)
93    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
94    return root
95
96  def _getSimpleModelWithVariables(self):
97
98    class SimpleModelWithOneVariable(autotrackable.AutoTrackable):
99      """Basic model with 1 variable."""
100
101      def __init__(self):
102        super(SimpleModelWithOneVariable, self).__init__()
103        self.var = variables.Variable(array_ops.zeros((1, 10), name='var'))
104
105      @def_function.function
106      def assign_add(self, x):
107        self.var.assign_add(x)
108        return self.var
109
110    return SimpleModelWithOneVariable()
111
112  def _getMultiFunctionModel(self):
113
114    class BasicModel(autotrackable.AutoTrackable):
115      """Basic model with multiple functions."""
116
117      def __init__(self):
118        self.y = None
119        self.z = None
120
121      @def_function.function
122      def add(self, x):
123        if self.y is None:
124          self.y = variables.Variable(2.)
125        return x + self.y
126
127      @def_function.function
128      def sub(self, x):
129        if self.z is None:
130          self.z = variables.Variable(3.)
131        return x - self.z
132
133      @def_function.function
134      def mul_add(self, x, y):
135        if self.z is None:
136          self.z = variables.Variable(3.)
137        return x * self.z + y
138
139    return BasicModel()
140
141  def _getMultiFunctionModelWithSharedWeight(self):
142
143    class BasicModelWithSharedWeight(autotrackable.AutoTrackable):
144      """Model with multiple functions and a shared weight."""
145
146      def __init__(self):
147        self.weight = constant_op.constant([1.0],
148                                           shape=(1, 512, 512, 1),
149                                           dtype=dtypes.float32)
150
151      @def_function.function
152      def add(self, x):
153        return x + self.weight
154
155      @def_function.function
156      def sub(self, x):
157        return x - self.weight
158
159      @def_function.function
160      def mul(self, x):
161        return x * self.weight
162
163    return BasicModelWithSharedWeight()
164
165  def _getMatMulModelWithSmallWeights(self):
166
167    class MatMulModelWithSmallWeights(autotrackable.AutoTrackable):
168      """MatMul model with small weights and relatively large biases."""
169
170      def __init__(self):
171        self.weight = constant_op.constant([[1e-3, -1e-3], [-2e-4, 2e-4]],
172                                           shape=(2, 2),
173                                           dtype=dtypes.float32)
174        self.bias = constant_op.constant([1.28, 2.55],
175                                         shape=(2,),
176                                         dtype=dtypes.float32)
177
178      @def_function.function
179      def matmul(self, x):
180        return x @ self.weight + self.bias
181
182    return MatMulModelWithSmallWeights()
183
184  def _getSqrtModel(self):
185    """Returns a model with only one sqrt op, to test non-quantizable op."""
186
187    @def_function.function(input_signature=[
188        tensor_spec.TensorSpec(shape=(1, 10), dtype=dtypes.float32)
189    ])
190    def sqrt(x):
191      return math_ops.sqrt(x)
192
193    def calibration_gen():
194      for _ in range(5):
195        yield [np.random.uniform(0, 16, size=(1, 10)).astype(np.float32)]
196
197    return sqrt, calibration_gen
198
199  def _assertValidDebugInfo(self, debug_info):
200    """Verify the DebugInfo is valid."""
201    file_names = set()
202    for file_path in debug_info.files:
203      file_names.add(os.path.basename(file_path))
204    # To make the test independent on how the nodes are created, we only assert
205    # the name of this test file.
206    self.assertIn('lite_v2_test.py', file_names)
207    self.assertNotIn('lite_test.py', file_names)
208
209  def _createV2QATLowBitKerasModel(self, shape, weight_only, num_bits, bit_min,
210                                   bit_max):
211    """Creates a simple QAT num_bits-Weight Keras Model."""
212    input_name = 'input'
213    output_name = 'scores'
214
215    class ConvWrapper(tf.keras.layers.Wrapper):
216      """A Wrapper for simulating QAT on Conv2D layers."""
217
218      def build(self, input_shape):
219        if not self.layer.built:
220          self.layer.build(input_shape)
221        self.quantized_weights = self.layer.kernel
222
223      def call(self, inputs):
224        self.layer.kernel = (
225            tf.quantization.fake_quant_with_min_max_vars_per_channel(
226                self.quantized_weights, min=[bit_min], max=[bit_max],
227                num_bits=num_bits, narrow_range=True))
228        if not weight_only:
229          quant_inputs = tf.quantization.fake_quant_with_min_max_vars(
230              inputs, min=0, max=6, num_bits=8)
231          outputs = self.layer.call(quant_inputs)
232          return tf.quantization.fake_quant_with_min_max_vars(
233              outputs, min=0, max=6, num_bits=8)
234        return self.layer.call(inputs)
235
236    input_tensor = tf.keras.layers.Input(shape, name=input_name)
237    kernel_shape = (shape[-1], 3, 3, 1)
238    # Ensure constant weights contains the min and max.
239    initial_weights = np.linspace(
240        bit_min, bit_max, np.prod(kernel_shape)).reshape(kernel_shape)
241    test_initializer = tf.constant_initializer(initial_weights)
242    x = ConvWrapper(tf.keras.layers.Conv2D(
243        1, (3, 3), kernel_initializer=test_initializer,
244        activation='relu6'))(input_tensor)
245    scores = tf.keras.layers.Flatten(name=output_name)(x)
246    model = tf.keras.Model(input_tensor, scores)
247    return model, input_name, output_name
248