1# Lint as: python2, python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for lite.py functionality related to TensorFlow 2.0.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24from absl.testing import parameterized 25from six.moves import zip 26 27from tensorflow.lite.python.interpreter import Interpreter 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.training.tracking import tracking 35 36 37class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase): 38 """Base test class for TensorFlow Lite 2.x model tests.""" 39 40 def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None): 41 """Evaluates the model on the `input_data`. 42 43 Args: 44 tflite_model: TensorFlow Lite model. 45 input_data: List of EagerTensor const ops containing the input data for 46 each input tensor. 47 input_shapes: List of tuples representing the `shape_signature` and the 48 new shape of each input tensor that has unknown dimensions. 49 50 Returns: 51 [np.ndarray] 52 """ 53 interpreter = Interpreter(model_content=tflite_model) 54 input_details = interpreter.get_input_details() 55 if input_shapes: 56 for idx, (shape_signature, final_shape) in enumerate(input_shapes): 57 self.assertTrue( 58 (input_details[idx]['shape_signature'] == shape_signature).all()) 59 index = input_details[idx]['index'] 60 interpreter.resize_tensor_input(index, final_shape, strict=True) 61 interpreter.allocate_tensors() 62 63 output_details = interpreter.get_output_details() 64 input_details = interpreter.get_input_details() 65 66 for input_tensor, tensor_data in zip(input_details, input_data): 67 interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) 68 interpreter.invoke() 69 return [ 70 interpreter.get_tensor(details['index']) for details in output_details 71 ] 72 73 def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, signature_key, 74 inputs): 75 """Evaluates the model on the `inputs`. 76 77 Args: 78 tflite_model: TensorFlow Lite model. 79 signature_key: Signature key. 80 inputs: Map from input tensor names in the SignatureDef to tensor value. 81 82 Returns: 83 Dictionary of outputs. 84 Key is the output name in the SignatureDef 'signature_key' 85 Value is the output value 86 """ 87 interpreter = Interpreter(model_content=tflite_model) 88 signature_runner = interpreter.get_signature_runner(signature_key) 89 return signature_runner(**inputs) 90 91 def _getSimpleVariableModel(self): 92 root = tracking.AutoTrackable() 93 root.v1 = variables.Variable(3.) 94 root.v2 = variables.Variable(2.) 95 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 96 return root 97 98 def _getSimpleModelWithVariables(self): 99 100 class SimpleModelWithOneVariable(tracking.AutoTrackable): 101 """Basic model with 1 variable.""" 102 103 def __init__(self): 104 super(SimpleModelWithOneVariable, self).__init__() 105 self.var = variables.Variable(array_ops.zeros((1, 10), name='var')) 106 107 @def_function.function 108 def assign_add(self, x): 109 self.var.assign_add(x) 110 return self.var 111 112 return SimpleModelWithOneVariable() 113 114 def _getMultiFunctionModel(self): 115 116 class BasicModel(tracking.AutoTrackable): 117 """Basic model with multiple functions.""" 118 119 def __init__(self): 120 self.y = None 121 self.z = None 122 123 @def_function.function 124 def add(self, x): 125 if self.y is None: 126 self.y = variables.Variable(2.) 127 return x + self.y 128 129 @def_function.function 130 def sub(self, x): 131 if self.z is None: 132 self.z = variables.Variable(3.) 133 return x - self.z 134 135 @def_function.function 136 def mul_add(self, x, y): 137 if self.z is None: 138 self.z = variables.Variable(3.) 139 return x * self.z + y 140 141 return BasicModel() 142 143 def _getMultiFunctionModelWithSharedWeight(self): 144 145 class BasicModelWithSharedWeight(tracking.AutoTrackable): 146 """Model with multiple functions and a shared weight.""" 147 148 def __init__(self): 149 self.weight = constant_op.constant([1.0], 150 shape=(1, 512, 512, 1), 151 dtype=dtypes.float32) 152 153 @def_function.function 154 def add(self, x): 155 return x + self.weight 156 157 @def_function.function 158 def sub(self, x): 159 return x - self.weight 160 161 @def_function.function 162 def mul(self, x): 163 return x * self.weight 164 165 return BasicModelWithSharedWeight() 166 167 def _assertValidDebugInfo(self, debug_info): 168 """Verify the DebugInfo is valid.""" 169 file_names = set() 170 for file_path in debug_info.files: 171 file_names.add(os.path.basename(file_path)) 172 # To make the test independent on how the nodes are created, we only assert 173 # the name of this test file. 174 self.assertIn('lite_v2_test.py', file_names) 175 self.assertNotIn('lite_test.py', file_names) 176