1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for where_v2.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import tensorflow.compat.v1 as tf 21from tensorflow.lite.testing.zip_test_utils import create_tensor_data 22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 23from tensorflow.lite.testing.zip_test_utils import register_make_test_function 24 25 26@register_make_test_function() 27def make_where_v2_tests(options): 28 """Make a set of tests to do where_v2.""" 29 30 test_parameters = [ 31 { 32 "input_condition_shape": [[1, 2, 3, 4]], 33 "input_dtype": [tf.float32, tf.int32], 34 "input_shape_set": [([1, 2, 3, 4], [1, 1, 1, 1]),], 35 }, 36 { 37 "input_condition_shape": [[2], [1]], 38 "input_dtype": [tf.float32, tf.int32], 39 "input_shape_set": [([2, 1, 2, 1], [2, 1, 2, 1]),], 40 }, 41 { 42 "input_condition_shape": [[1, 4, 2]], 43 "input_dtype": [tf.float32, tf.int32], 44 "input_shape_set": [([1, 3, 4, 2], [1, 3, 4, 2]),], 45 }, 46 { 47 "input_condition_shape": [[1, 2]], 48 "input_dtype": [tf.float32, tf.int32], 49 "input_shape_set": [([1, 2, 2], [1, 2, 2]),], 50 }, 51 { 52 "input_condition_shape": [[1, 1]], 53 "input_dtype": [tf.float32, tf.int32], 54 "input_shape_set": [([1, 1, 2, 2], [1, 1, 2, 2]),], 55 }, 56 { 57 "input_condition_shape": [[4]], 58 "input_dtype": [tf.float32, tf.int32], 59 "input_shape_set": [([4, 4], [4, 4]),], 60 }, 61 { 62 "input_condition_shape": [[2]], 63 "input_dtype": [tf.float32, tf.int32], 64 "input_shape_set": [([2, 3], [2, 3]),], 65 }, 66 { 67 "input_condition_shape": [[1, 2]], 68 "input_dtype": [tf.float32, tf.int32], 69 "input_shape_set": [([1, 2, 2], [1, 2]),], 70 }, 71 ] 72 73 def build_graph(parameters): 74 """Build the where op testing graph.""" 75 input_condition = tf.compat.v1.placeholder( 76 dtype=tf.bool, 77 name="input_condition", 78 shape=parameters["input_condition_shape"]) 79 input_value1 = tf.compat.v1.placeholder( 80 dtype=parameters["input_dtype"], 81 name="input_x", 82 shape=parameters["input_shape_set"][0]) 83 input_value2 = tf.compat.v1.placeholder( 84 dtype=parameters["input_dtype"], 85 name="input_y", 86 shape=parameters["input_shape_set"][1]) 87 out = tf.where_v2(input_condition, input_value1, input_value2) 88 return [input_condition, input_value1, input_value2], [out] 89 90 def build_inputs(parameters, sess, inputs, outputs): 91 input_condition = create_tensor_data(tf.bool, 92 parameters["input_condition_shape"]) 93 input_value1 = create_tensor_data(parameters["input_dtype"], 94 parameters["input_shape_set"][0]) 95 input_value2 = create_tensor_data(parameters["input_dtype"], 96 parameters["input_shape_set"][1]) 97 return [input_condition, input_value1, input_value2], sess.run( 98 outputs, 99 feed_dict=dict( 100 zip(inputs, [input_condition, input_value1, input_value2]))) 101 102 options.use_experimental_converter = True 103 make_zip_of_tests( 104 options, 105 test_parameters, 106 build_graph, 107 build_inputs, 108 expected_tf_failures=2) 109