1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for authoring package.""" 16# pylint: disable=g-direct-tensorflow-import 17 18import tensorflow as tf 19 20from tensorflow.lite.python.authoring import authoring 21 22 23class TFLiteAuthoringTest(tf.test.TestCase): 24 25 def test_simple_cosh(self): 26 @authoring.compatible 27 @tf.function(input_signature=[ 28 tf.TensorSpec(shape=[None], dtype=tf.float32) 29 ]) 30 def f(x): 31 return tf.cosh(x) 32 33 result = f(tf.constant([0.0])) 34 log_messages = f.get_compatibility_log() 35 self.assertEqual(result, tf.constant([1.0])) 36 self.assertIn( 37 "COMPATIBILITY WARNING: op 'tf.Cosh' require(s) \"Select TF Ops\" for " 38 "model conversion for TensorFlow Lite. " 39 "https://www.tensorflow.org/lite/guide/ops_select", log_messages) 40 41 # Check the op location ends with filename of the this test. 42 self.assertIn("authoring_test.py", log_messages[-1]) 43 44 def test_simple_cosh_raises_CompatibilityError(self): 45 @authoring.compatible(raise_exception=True) 46 @tf.function(input_signature=[ 47 tf.TensorSpec(shape=[None], dtype=tf.float32) 48 ]) 49 def f(x): 50 return tf.cosh(x) 51 52 # Check if the CompatibilityError exception is raised. 53 with self.assertRaises(authoring.CompatibilityError): 54 result = f(tf.constant([0.0])) 55 del result 56 log_messages = f.get_compatibility_log() 57 self.assertIn( 58 "COMPATIBILITY WARNING: op 'tf.Cosh' require(s) \"Select TF Ops\" for " 59 "model conversion for TensorFlow Lite. " 60 "https://www.tensorflow.org/lite/guide/ops_select", log_messages) 61 62 def test_flex_compatibility(self): 63 @authoring.compatible 64 @tf.function(input_signature=[ 65 tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32) 66 ]) 67 def f(inp): 68 tanh = tf.math.tanh(inp) 69 conv3d = tf.nn.conv3d( 70 tanh, 71 tf.ones([3, 3, 3, 3, 3]), 72 strides=[1, 1, 1, 1, 1], 73 padding="SAME") 74 erf = tf.math.erf(conv3d) 75 output = tf.math.tanh(erf) 76 return output 77 78 f(tf.ones(shape=(3, 3, 3, 3, 3), dtype=tf.float32)) 79 log_messages = f.get_compatibility_log() 80 self.assertIn( 81 "COMPATIBILITY WARNING: op 'tf.Erf' require(s) \"Select TF Ops\" for " 82 "model conversion for TensorFlow Lite. " 83 "https://www.tensorflow.org/lite/guide/ops_select", log_messages) 84 85 def test_compatibility_error(self): 86 @authoring.compatible 87 @tf.function 88 def f(): 89 dataset = tf.data.Dataset.range(3) 90 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 91 return dataset 92 93 f() 94 log_messages = f.get_compatibility_log() 95 self.assertIn( 96 "COMPATIBILITY ERROR: op 'tf.DummySeedGenerator, tf.RangeDataset, " 97 "tf.ShuffleDatasetV3' is(are) not natively supported by " 98 "TensorFlow Lite. You need to provide a custom operator. " 99 "https://www.tensorflow.org/lite/guide/ops_custom", log_messages) 100 101 def test_simple_variable(self): 102 external_var = tf.Variable(1.0) 103 @authoring.compatible 104 @tf.function(input_signature=[ 105 tf.TensorSpec(shape=[None], dtype=tf.float32) 106 ]) 107 def f(x): 108 return x * external_var 109 110 result = f(tf.constant(2.0, shape=(1))) 111 log_messages = f.get_compatibility_log() 112 113 self.assertEqual(result, tf.constant([2.0])) 114 self.assertEmpty(log_messages) 115 116 def test_class_method(self): 117 class Model(tf.Module): 118 119 @authoring.compatible 120 @tf.function(input_signature=[ 121 tf.TensorSpec(shape=[None], dtype=tf.float32) 122 ]) 123 def eval(self, x): 124 return tf.cosh(x) 125 126 m = Model() 127 result = m.eval(tf.constant([0.0])) 128 log_messages = m.eval.get_compatibility_log() 129 130 self.assertEqual(result, tf.constant([1.0])) 131 self.assertIn( 132 "COMPATIBILITY WARNING: op 'tf.Cosh' require(s) \"Select TF Ops\" for " 133 "model conversion for TensorFlow Lite. " 134 "https://www.tensorflow.org/lite/guide/ops_select", log_messages) 135 136 def test_decorated_function_type(self): 137 @authoring.compatible 138 @tf.function(input_signature=[ 139 tf.TensorSpec(shape=[None], dtype=tf.float32) 140 ]) 141 def func(x): 142 return tf.cos(x) 143 144 result = func(tf.constant([0.0])) 145 self.assertEqual(result, tf.constant([1.0])) 146 147 # Check if the decorator keeps __name__ attribute. 148 self.assertEqual(func.__name__, "func") 149 150 # Check if the decorator works with get_concrete_function method. 151 converter = tf.lite.TFLiteConverter.from_concrete_functions( 152 [func.get_concrete_function()], func) 153 converter.convert() 154 155 def test_decorated_class_method_type(self): 156 class Model(tf.Module): 157 158 @authoring.compatible 159 @tf.function(input_signature=[ 160 tf.TensorSpec(shape=[None], dtype=tf.float32) 161 ]) 162 def eval(self, x): 163 return tf.cos(x) 164 165 m = Model() 166 result = m.eval(tf.constant([0.0])) 167 self.assertEqual(result, tf.constant([1.0])) 168 169 # Check if the decorator keeps __name__ attribute. 170 self.assertEqual(m.eval.__name__, "eval") 171 172 # Check if the decorator works with get_concrete_function method. 173 converter = tf.lite.TFLiteConverter.from_concrete_functions( 174 [m.eval.get_concrete_function()], m) 175 converter.convert() 176 177 def test_simple_cosh_multiple(self): 178 @authoring.compatible 179 @tf.function(input_signature=[ 180 tf.TensorSpec(shape=[None], dtype=tf.float32) 181 ]) 182 def f(x): 183 return tf.cosh(x) 184 185 f(tf.constant([1.0])) 186 f(tf.constant([2.0])) 187 f(tf.constant([3.0])) 188 warning_messages = f.get_compatibility_log() 189 190 # Test if compatiblility checks happens only once. 191 # The number of warning_messages will be 2 by op location detail. 192 self.assertEqual(2, len(warning_messages)) 193 194 def test_user_tf_ops_all_filtered(self): 195 target_spec = tf.lite.TargetSpec() 196 target_spec.supported_ops = [ 197 tf.lite.OpsSet.TFLITE_BUILTINS, 198 tf.lite.OpsSet.SELECT_TF_OPS, 199 ] 200 target_spec.experimental_select_user_tf_ops = [ 201 "RangeDataset", "DummySeedGenerator", "ShuffleDatasetV3" 202 ] 203 @authoring.compatible(converter_target_spec=target_spec) 204 @tf.function 205 def f(): 206 dataset = tf.data.Dataset.range(3) 207 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 208 return dataset 209 210 f() 211 log_messages = f.get_compatibility_log() 212 self.assertEmpty(log_messages) 213 214 def test_user_tf_ops_partial_filtered(self): 215 target_spec = tf.lite.TargetSpec() 216 target_spec.supported_ops = [ 217 tf.lite.OpsSet.TFLITE_BUILTINS, 218 tf.lite.OpsSet.SELECT_TF_OPS, 219 ] 220 target_spec.experimental_select_user_tf_ops = [ 221 "DummySeedGenerator" 222 ] 223 @authoring.compatible(converter_target_spec=target_spec) 224 225 @authoring.compatible(converter_target_spec=target_spec) 226 @tf.function 227 def f(): 228 dataset = tf.data.Dataset.range(3) 229 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 230 return dataset 231 232 f() 233 log_messages = f.get_compatibility_log() 234 self.assertIn( 235 "COMPATIBILITY ERROR: op 'tf.RangeDataset, tf.ShuffleDatasetV3' is(are) " 236 "not natively supported by TensorFlow Lite. You need to provide a " 237 "custom operator. https://www.tensorflow.org/lite/guide/ops_custom", 238 log_messages) 239 240 def test_allow_custom_ops(self): 241 @authoring.compatible(converter_allow_custom_ops=True) 242 @tf.function 243 def f(): 244 dataset = tf.data.Dataset.range(3) 245 dataset = dataset.shuffle(3, reshuffle_each_iteration=True) 246 return dataset 247 248 f() 249 log_messages = f.get_compatibility_log() 250 self.assertEmpty(log_messages) 251 252if __name__ == "__main__": 253 tf.test.main() 254