1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for elementwise ops.""" 16import tensorflow.compat.v1 as tf 17from tensorflow.lite.testing.zip_test_utils import create_tensor_data 18from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 19from tensorflow.lite.testing.zip_test_utils import register_make_test_function 20 21 22def _make_elementwise_tests(op, allow_fully_quantize=False, min_value=-100, 23 max_value=100): 24 """Make a set of tests to do element-wise operations.""" 25 26 def f(options): 27 """Actual function that generates examples.""" 28 test_parameters = [ 29 { 30 "input_dtype": [tf.float32], 31 "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], 32 "fully_quantize": [False], 33 "input_range": [[min_value, max_value]], 34 }, 35 { 36 "input_dtype": [tf.float32], 37 "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], 38 "fully_quantize": [True], 39 "input_range": [[min_value, max_value]], 40 }, 41 ] 42 43 if not allow_fully_quantize: 44 test_parameters = [ 45 test_parameter for test_parameter in test_parameters 46 if True not in test_parameter["fully_quantize"] 47 ] 48 49 def build_graph(parameters): 50 """Build the unary op testing graph.""" 51 input_value = tf.compat.v1.placeholder( 52 dtype=parameters["input_dtype"], 53 name="input1", 54 shape=parameters["input_shape"]) 55 out = op(input_value) 56 return [input_value], [out] 57 58 def build_inputs(parameters, sess, inputs, outputs): 59 input_value = create_tensor_data(parameters["input_dtype"], 60 parameters["input_shape"], 61 min_value=min_value, 62 max_value=max_value) 63 return [input_value], sess.run( 64 outputs, feed_dict={inputs[0]: input_value}) 65 66 make_zip_of_tests(options, test_parameters, build_graph, build_inputs) 67 68 return f 69 70 71@register_make_test_function() 72def make_sin_tests(options): 73 """Make a set of tests to do sin.""" 74 return _make_elementwise_tests(tf.sin)(options) 75 76 77@register_make_test_function() 78def make_log_tests(options): 79 """Make a set of tests to do log.""" 80 return _make_elementwise_tests(tf.math.log)(options) 81 82 83@register_make_test_function() 84def make_sqrt_tests(options): 85 """Make a set of tests to do sqrt.""" 86 return _make_elementwise_tests(tf.sqrt)(options) 87 88 89@register_make_test_function() 90def make_rsqrt_tests(options): 91 """Make a set of tests to do 1/sqrt.""" 92 return _make_elementwise_tests(tf.math.rsqrt, allow_fully_quantize=True, 93 min_value=.1, max_value=1)(options) 94 95 96@register_make_test_function() 97def make_square_tests(options): 98 """Make a set of tests to do square.""" 99 return _make_elementwise_tests(tf.square)(options) 100