• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for authoring package."""
16# pylint: disable=g-direct-tensorflow-import
17
18import tensorflow as tf
19
20from tensorflow.lite.python.authoring import authoring
21
22
23class TFLiteAuthoringTest(tf.test.TestCase):
24
25  def test_simple_cosh(self):
26    @authoring.compatible
27    @tf.function(input_signature=[
28        tf.TensorSpec(shape=[None], dtype=tf.float32)
29    ])
30    def f(x):
31      return tf.cosh(x)
32
33    result = f(tf.constant([0.0]))
34    log_messages = f.get_compatibility_log()
35    self.assertEqual(result, tf.constant([1.0]))
36    self.assertIn(
37        "COMPATIBILITY WARNING: op 'tf.Cosh' require(s) \"Select TF Ops\" for "
38        "model conversion for TensorFlow Lite. "
39        "https://www.tensorflow.org/lite/guide/ops_select", log_messages)
40
41    # Check the op location ends with filename of the this test.
42    self.assertIn("authoring_test.py", log_messages[-1])
43
44  def test_simple_cosh_raises_CompatibilityError(self):
45    @authoring.compatible(raise_exception=True)
46    @tf.function(input_signature=[
47        tf.TensorSpec(shape=[None], dtype=tf.float32)
48    ])
49    def f(x):
50      return tf.cosh(x)
51
52    # Check if the CompatibilityError exception is raised.
53    with self.assertRaises(authoring.CompatibilityError):
54      result = f(tf.constant([0.0]))
55      del result
56    log_messages = f.get_compatibility_log()
57    self.assertIn(
58        "COMPATIBILITY WARNING: op 'tf.Cosh' require(s) \"Select TF Ops\" for "
59        "model conversion for TensorFlow Lite. "
60        "https://www.tensorflow.org/lite/guide/ops_select", log_messages)
61
62  def test_flex_compatibility(self):
63    @authoring.compatible
64    @tf.function(input_signature=[
65        tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
66    ])
67    def f(inp):
68      tanh = tf.math.tanh(inp)
69      conv3d = tf.nn.conv3d(
70          tanh,
71          tf.ones([3, 3, 3, 3, 3]),
72          strides=[1, 1, 1, 1, 1],
73          padding="SAME")
74      erf = tf.math.erf(conv3d)
75      output = tf.math.tanh(erf)
76      return output
77
78    f(tf.ones(shape=(3, 3, 3, 3, 3), dtype=tf.float32))
79    log_messages = f.get_compatibility_log()
80    self.assertIn(
81        "COMPATIBILITY WARNING: op 'tf.Erf' require(s) \"Select TF Ops\" for "
82        "model conversion for TensorFlow Lite. "
83        "https://www.tensorflow.org/lite/guide/ops_select", log_messages)
84
85  def test_compatibility_error(self):
86    @authoring.compatible
87    @tf.function
88    def f():
89      dataset = tf.data.Dataset.range(3)
90      dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
91      return dataset
92
93    f()
94    log_messages = f.get_compatibility_log()
95    self.assertIn(
96        "COMPATIBILITY ERROR: op 'tf.DummySeedGenerator, tf.RangeDataset, "
97        "tf.ShuffleDatasetV3' is(are) not natively supported by "
98        "TensorFlow Lite. You need to provide a custom operator. "
99        "https://www.tensorflow.org/lite/guide/ops_custom", log_messages)
100
101  def test_simple_variable(self):
102    external_var = tf.Variable(1.0)
103    @authoring.compatible
104    @tf.function(input_signature=[
105        tf.TensorSpec(shape=[None], dtype=tf.float32)
106    ])
107    def f(x):
108      return x * external_var
109
110    result = f(tf.constant(2.0, shape=(1)))
111    log_messages = f.get_compatibility_log()
112
113    self.assertEqual(result, tf.constant([2.0]))
114    self.assertEmpty(log_messages)
115
116  def test_class_method(self):
117    class Model(tf.Module):
118
119      @authoring.compatible
120      @tf.function(input_signature=[
121          tf.TensorSpec(shape=[None], dtype=tf.float32)
122      ])
123      def eval(self, x):
124        return tf.cosh(x)
125
126    m = Model()
127    result = m.eval(tf.constant([0.0]))
128    log_messages = m.eval.get_compatibility_log()
129
130    self.assertEqual(result, tf.constant([1.0]))
131    self.assertIn(
132        "COMPATIBILITY WARNING: op 'tf.Cosh' require(s) \"Select TF Ops\" for "
133        "model conversion for TensorFlow Lite. "
134        "https://www.tensorflow.org/lite/guide/ops_select", log_messages)
135
136  def test_decorated_function_type(self):
137    @authoring.compatible
138    @tf.function(input_signature=[
139        tf.TensorSpec(shape=[None], dtype=tf.float32)
140    ])
141    def func(x):
142      return tf.cos(x)
143
144    result = func(tf.constant([0.0]))
145    self.assertEqual(result, tf.constant([1.0]))
146
147    # Check if the decorator keeps __name__ attribute.
148    self.assertEqual(func.__name__, "func")
149
150    # Check if the decorator works with get_concrete_function method.
151    converter = tf.lite.TFLiteConverter.from_concrete_functions(
152        [func.get_concrete_function()], func)
153    converter.convert()
154
155  def test_decorated_class_method_type(self):
156    class Model(tf.Module):
157
158      @authoring.compatible
159      @tf.function(input_signature=[
160          tf.TensorSpec(shape=[None], dtype=tf.float32)
161      ])
162      def eval(self, x):
163        return tf.cos(x)
164
165    m = Model()
166    result = m.eval(tf.constant([0.0]))
167    self.assertEqual(result, tf.constant([1.0]))
168
169    # Check if the decorator keeps __name__ attribute.
170    self.assertEqual(m.eval.__name__, "eval")
171
172    # Check if the decorator works with get_concrete_function method.
173    converter = tf.lite.TFLiteConverter.from_concrete_functions(
174        [m.eval.get_concrete_function()], m)
175    converter.convert()
176
177  def test_simple_cosh_multiple(self):
178    @authoring.compatible
179    @tf.function(input_signature=[
180        tf.TensorSpec(shape=[None], dtype=tf.float32)
181    ])
182    def f(x):
183      return tf.cosh(x)
184
185    f(tf.constant([1.0]))
186    f(tf.constant([2.0]))
187    f(tf.constant([3.0]))
188    warning_messages = f.get_compatibility_log()
189
190    # Test if compatiblility checks happens only once.
191    # The number of warning_messages will be 2 by op location detail.
192    self.assertEqual(2, len(warning_messages))
193
194  def test_user_tf_ops_all_filtered(self):
195    target_spec = tf.lite.TargetSpec()
196    target_spec.supported_ops = [
197        tf.lite.OpsSet.TFLITE_BUILTINS,
198        tf.lite.OpsSet.SELECT_TF_OPS,
199    ]
200    target_spec.experimental_select_user_tf_ops = [
201        "RangeDataset", "DummySeedGenerator", "ShuffleDatasetV3"
202    ]
203    @authoring.compatible(converter_target_spec=target_spec)
204    @tf.function
205    def f():
206      dataset = tf.data.Dataset.range(3)
207      dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
208      return dataset
209
210    f()
211    log_messages = f.get_compatibility_log()
212    self.assertEmpty(log_messages)
213
214  def test_user_tf_ops_partial_filtered(self):
215    target_spec = tf.lite.TargetSpec()
216    target_spec.supported_ops = [
217        tf.lite.OpsSet.TFLITE_BUILTINS,
218        tf.lite.OpsSet.SELECT_TF_OPS,
219    ]
220    target_spec.experimental_select_user_tf_ops = [
221        "DummySeedGenerator"
222    ]
223    @authoring.compatible(converter_target_spec=target_spec)
224
225    @authoring.compatible(converter_target_spec=target_spec)
226    @tf.function
227    def f():
228      dataset = tf.data.Dataset.range(3)
229      dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
230      return dataset
231
232    f()
233    log_messages = f.get_compatibility_log()
234    self.assertIn(
235        "COMPATIBILITY ERROR: op 'tf.RangeDataset, tf.ShuffleDatasetV3' is(are) "
236        "not natively supported by TensorFlow Lite. You need to provide a "
237        "custom operator. https://www.tensorflow.org/lite/guide/ops_custom",
238        log_messages)
239
240  def test_allow_custom_ops(self):
241    @authoring.compatible(converter_allow_custom_ops=True)
242    @tf.function
243    def f():
244      dataset = tf.data.Dataset.range(3)
245      dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
246      return dataset
247
248    f()
249    log_messages = f.get_compatibility_log()
250    self.assertEmpty(log_messages)
251
252if __name__ == "__main__":
253  tf.test.main()
254