1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py functionality related to TensorFlow 2.0.""" 16 17import os 18 19from absl.testing import parameterized 20import numpy as np 21import tensorflow as tf 22 23from tensorflow.lite.python.interpreter import Interpreter 24from tensorflow.python.eager import def_function 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.trackable import autotrackable 33 34 35class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase): 36 """Base test class for TensorFlow Lite 2.x model tests.""" 37 38 def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None): 39 """Evaluates the model on the `input_data`. 40 41 Args: 42 tflite_model: TensorFlow Lite model. 43 input_data: List of EagerTensor const ops containing the input data for 44 each input tensor. 45 input_shapes: List of tuples representing the `shape_signature` and the 46 new shape of each input tensor that has unknown dimensions. 47 48 Returns: 49 [np.ndarray] 50 """ 51 interpreter = Interpreter(model_content=tflite_model) 52 input_details = interpreter.get_input_details() 53 if input_shapes: 54 for idx, (shape_signature, final_shape) in enumerate(input_shapes): 55 self.assertTrue( 56 (input_details[idx]['shape_signature'] == shape_signature).all()) 57 index = input_details[idx]['index'] 58 interpreter.resize_tensor_input(index, final_shape, strict=True) 59 interpreter.allocate_tensors() 60 61 output_details = interpreter.get_output_details() 62 input_details = interpreter.get_input_details() 63 64 for input_tensor, tensor_data in zip(input_details, input_data): 65 interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) 66 interpreter.invoke() 67 return [ 68 interpreter.get_tensor(details['index']) for details in output_details 69 ] 70 71 def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, signature_key, 72 inputs): 73 """Evaluates the model on the `inputs`. 74 75 Args: 76 tflite_model: TensorFlow Lite model. 77 signature_key: Signature key. 78 inputs: Map from input tensor names in the SignatureDef to tensor value. 79 80 Returns: 81 Dictionary of outputs. 82 Key is the output name in the SignatureDef 'signature_key' 83 Value is the output value 84 """ 85 interpreter = Interpreter(model_content=tflite_model) 86 signature_runner = interpreter.get_signature_runner(signature_key) 87 return signature_runner(**inputs) 88 89 def _getSimpleVariableModel(self): 90 root = autotrackable.AutoTrackable() 91 root.v1 = variables.Variable(3.) 92 root.v2 = variables.Variable(2.) 93 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 94 return root 95 96 def _getSimpleModelWithVariables(self): 97 98 class SimpleModelWithOneVariable(autotrackable.AutoTrackable): 99 """Basic model with 1 variable.""" 100 101 def __init__(self): 102 super(SimpleModelWithOneVariable, self).__init__() 103 self.var = variables.Variable(array_ops.zeros((1, 10), name='var')) 104 105 @def_function.function 106 def assign_add(self, x): 107 self.var.assign_add(x) 108 return self.var 109 110 return SimpleModelWithOneVariable() 111 112 def _getMultiFunctionModel(self): 113 114 class BasicModel(autotrackable.AutoTrackable): 115 """Basic model with multiple functions.""" 116 117 def __init__(self): 118 self.y = None 119 self.z = None 120 121 @def_function.function 122 def add(self, x): 123 if self.y is None: 124 self.y = variables.Variable(2.) 125 return x + self.y 126 127 @def_function.function 128 def sub(self, x): 129 if self.z is None: 130 self.z = variables.Variable(3.) 131 return x - self.z 132 133 @def_function.function 134 def mul_add(self, x, y): 135 if self.z is None: 136 self.z = variables.Variable(3.) 137 return x * self.z + y 138 139 return BasicModel() 140 141 def _getMultiFunctionModelWithSharedWeight(self): 142 143 class BasicModelWithSharedWeight(autotrackable.AutoTrackable): 144 """Model with multiple functions and a shared weight.""" 145 146 def __init__(self): 147 self.weight = constant_op.constant([1.0], 148 shape=(1, 512, 512, 1), 149 dtype=dtypes.float32) 150 151 @def_function.function 152 def add(self, x): 153 return x + self.weight 154 155 @def_function.function 156 def sub(self, x): 157 return x - self.weight 158 159 @def_function.function 160 def mul(self, x): 161 return x * self.weight 162 163 return BasicModelWithSharedWeight() 164 165 def _getMatMulModelWithSmallWeights(self): 166 167 class MatMulModelWithSmallWeights(autotrackable.AutoTrackable): 168 """MatMul model with small weights and relatively large biases.""" 169 170 def __init__(self): 171 self.weight = constant_op.constant([[1e-3, -1e-3], [-2e-4, 2e-4]], 172 shape=(2, 2), 173 dtype=dtypes.float32) 174 self.bias = constant_op.constant([1.28, 2.55], 175 shape=(2,), 176 dtype=dtypes.float32) 177 178 @def_function.function 179 def matmul(self, x): 180 return x @ self.weight + self.bias 181 182 return MatMulModelWithSmallWeights() 183 184 def _getSqrtModel(self): 185 """Returns a model with only one sqrt op, to test non-quantizable op.""" 186 187 @def_function.function(input_signature=[ 188 tensor_spec.TensorSpec(shape=(1, 10), dtype=dtypes.float32) 189 ]) 190 def sqrt(x): 191 return math_ops.sqrt(x) 192 193 def calibration_gen(): 194 for _ in range(5): 195 yield [np.random.uniform(0, 16, size=(1, 10)).astype(np.float32)] 196 197 return sqrt, calibration_gen 198 199 def _assertValidDebugInfo(self, debug_info): 200 """Verify the DebugInfo is valid.""" 201 file_names = set() 202 for file_path in debug_info.files: 203 file_names.add(os.path.basename(file_path)) 204 # To make the test independent on how the nodes are created, we only assert 205 # the name of this test file. 206 self.assertIn('lite_v2_test.py', file_names) 207 self.assertNotIn('lite_test.py', file_names) 208 209 def _createV2QATLowBitKerasModel(self, shape, weight_only, num_bits, bit_min, 210 bit_max): 211 """Creates a simple QAT num_bits-Weight Keras Model.""" 212 input_name = 'input' 213 output_name = 'scores' 214 215 class ConvWrapper(tf.keras.layers.Wrapper): 216 """A Wrapper for simulating QAT on Conv2D layers.""" 217 218 def build(self, input_shape): 219 if not self.layer.built: 220 self.layer.build(input_shape) 221 self.quantized_weights = self.layer.kernel 222 223 def call(self, inputs): 224 self.layer.kernel = ( 225 tf.quantization.fake_quant_with_min_max_vars_per_channel( 226 self.quantized_weights, min=[bit_min], max=[bit_max], 227 num_bits=num_bits, narrow_range=True)) 228 if not weight_only: 229 quant_inputs = tf.quantization.fake_quant_with_min_max_vars( 230 inputs, min=0, max=6, num_bits=8) 231 outputs = self.layer.call(quant_inputs) 232 return tf.quantization.fake_quant_with_min_max_vars( 233 outputs, min=0, max=6, num_bits=8) 234 return self.layer.call(inputs) 235 236 input_tensor = tf.keras.layers.Input(shape, name=input_name) 237 kernel_shape = (shape[-1], 3, 3, 1) 238 # Ensure constant weights contains the min and max. 239 initial_weights = np.linspace( 240 bit_min, bit_max, np.prod(kernel_shape)).reshape(kernel_shape) 241 test_initializer = tf.constant_initializer(initial_weights) 242 x = ConvWrapper(tf.keras.layers.Conv2D( 243 1, (3, 3), kernel_initializer=test_initializer, 244 activation='relu6'))(input_tensor) 245 scores = tf.keras.layers.Flatten(name=output_name)(x) 246 model = tf.keras.Model(input_tensor, scores) 247 return model, input_name, output_name 248