• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for lite.py functionality related to TensorFlow 2.0."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from absl.testing import parameterized
25from six.moves import zip
26
27from tensorflow.lite.python.interpreter import Interpreter
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.training.tracking import tracking
35
36
37class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
38  """Base test class for TensorFlow Lite 2.x model tests."""
39
40  def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
41    """Evaluates the model on the `input_data`.
42
43    Args:
44      tflite_model: TensorFlow Lite model.
45      input_data: List of EagerTensor const ops containing the input data for
46        each input tensor.
47      input_shapes: List of tuples representing the `shape_signature` and the
48        new shape of each input tensor that has unknown dimensions.
49
50    Returns:
51      [np.ndarray]
52    """
53    interpreter = Interpreter(model_content=tflite_model)
54    input_details = interpreter.get_input_details()
55    if input_shapes:
56      for idx, (shape_signature, final_shape) in enumerate(input_shapes):
57        self.assertTrue(
58            (input_details[idx]['shape_signature'] == shape_signature).all())
59        index = input_details[idx]['index']
60        interpreter.resize_tensor_input(index, final_shape, strict=True)
61    interpreter.allocate_tensors()
62
63    output_details = interpreter.get_output_details()
64    input_details = interpreter.get_input_details()
65
66    for input_tensor, tensor_data in zip(input_details, input_data):
67      interpreter.set_tensor(input_tensor['index'], tensor_data.numpy())
68    interpreter.invoke()
69    return [
70        interpreter.get_tensor(details['index']) for details in output_details
71    ]
72
73  def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, signature_key,
74                                            inputs):
75    """Evaluates the model on the `inputs`.
76
77    Args:
78      tflite_model: TensorFlow Lite model.
79      signature_key: Signature key.
80      inputs: Map from input tensor names in the SignatureDef to tensor value.
81
82    Returns:
83      Dictionary of outputs.
84      Key is the output name in the SignatureDef 'signature_key'
85      Value is the output value
86    """
87    interpreter = Interpreter(model_content=tflite_model)
88    signature_runner = interpreter.get_signature_runner(signature_key)
89    return signature_runner(**inputs)
90
91  def _getSimpleVariableModel(self):
92    root = tracking.AutoTrackable()
93    root.v1 = variables.Variable(3.)
94    root.v2 = variables.Variable(2.)
95    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
96    return root
97
98  def _getSimpleModelWithVariables(self):
99
100    class SimpleModelWithOneVariable(tracking.AutoTrackable):
101      """Basic model with 1 variable."""
102
103      def __init__(self):
104        super(SimpleModelWithOneVariable, self).__init__()
105        self.var = variables.Variable(array_ops.zeros((1, 10), name='var'))
106
107      @def_function.function
108      def assign_add(self, x):
109        self.var.assign_add(x)
110        return self.var
111
112    return SimpleModelWithOneVariable()
113
114  def _getMultiFunctionModel(self):
115
116    class BasicModel(tracking.AutoTrackable):
117      """Basic model with multiple functions."""
118
119      def __init__(self):
120        self.y = None
121        self.z = None
122
123      @def_function.function
124      def add(self, x):
125        if self.y is None:
126          self.y = variables.Variable(2.)
127        return x + self.y
128
129      @def_function.function
130      def sub(self, x):
131        if self.z is None:
132          self.z = variables.Variable(3.)
133        return x - self.z
134
135      @def_function.function
136      def mul_add(self, x, y):
137        if self.z is None:
138          self.z = variables.Variable(3.)
139        return x * self.z + y
140
141    return BasicModel()
142
143  def _getMultiFunctionModelWithSharedWeight(self):
144
145    class BasicModelWithSharedWeight(tracking.AutoTrackable):
146      """Model with multiple functions and a shared weight."""
147
148      def __init__(self):
149        self.weight = constant_op.constant([1.0],
150                                           shape=(1, 512, 512, 1),
151                                           dtype=dtypes.float32)
152
153      @def_function.function
154      def add(self, x):
155        return x + self.weight
156
157      @def_function.function
158      def sub(self, x):
159        return x - self.weight
160
161      @def_function.function
162      def mul(self, x):
163        return x * self.weight
164
165    return BasicModelWithSharedWeight()
166
167  def _assertValidDebugInfo(self, debug_info):
168    """Verify the DebugInfo is valid."""
169    file_names = set()
170    for file_path in debug_info.files:
171      file_names.add(os.path.basename(file_path))
172    # To make the test independent on how the nodes are created, we only assert
173    # the name of this test file.
174    self.assertIn('lite_v2_test.py', file_names)
175    self.assertNotIn('lite_test.py', file_names)
176