1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for unsorted_segment ops.""" 16 17import tensorflow.compat.v1 as tf 18from tensorflow.lite.testing.zip_test_utils import create_tensor_data 19from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 20from tensorflow.lite.testing.zip_test_utils import register_make_test_function 21 22 23def make_unsorted_segment_tests(options, unsorted_segment_op): 24 """Make a set of tests for given unsorted_segment op.""" 25 test_parameters = [{ 26 "data_shape": [[5]], 27 "segment_id": [[0, 1, 1, 0, 1]], 28 "num_segments": [2], 29 "dtype": [tf.int32, tf.float32], 30 "multi_node": [0] 31 }, { 32 "data_shape": [[2, 3, 4], [2, 5, 2]], 33 "segment_id": [[0, 1]], 34 "num_segments": [2], 35 "dtype": [tf.int32, tf.float32], 36 "multi_node": [0] 37 }, { 38 "data_shape": [[4]], 39 "segment_id": [[0, 0, 1, 8]], 40 "num_segments": [9], 41 "dtype": [tf.int32, tf.float32], 42 "multi_node": [0] 43 }, { 44 "data_shape": [[3]], 45 "segment_id": [[-1, -2, -1]], 46 "num_segments": [1], 47 "dtype": [tf.int32, tf.float32], 48 "multi_node": [0] 49 }, { 50 "data_shape": [[3]], 51 "segment_id": [[-1, 0, 1]], 52 "num_segments": [2], 53 "dtype": [tf.int32, tf.float32], 54 "multi_node": [0] 55 }, { 56 "data_shape": [[3, 2]], 57 "segment_id": [[-1, 0, 0]], 58 "num_segments": [1], 59 "dtype": [tf.int32, tf.float32], 60 "multi_node": [0] 61 }, { 62 "data_shape": [[3, 2]], 63 "segment_id": [[-1, -2, -1]], 64 "num_segments": [1], 65 "dtype": [tf.int32, tf.float32], 66 "multi_node": [0] 67 }, { 68 "data_shape": [[4]], 69 "segment_id_shape": [[4]], 70 "segment_id_min": [0], 71 "segment_id_max": [1], 72 "num_segments": [2], 73 "dtype": [tf.int32, tf.float32], 74 "segment_id_2": [[0, 0]], 75 "num_segments_2": [1], 76 "multi_node": [1] 77 }, { 78 "data_shape": [[2, 2, 3]], 79 "segment_id": [[[1, 2], [3, 4]], [4, 5], 80 [[[1, 2, 3], [3, 4, 5]], [[1, 2, 4], [0, 0, -1]]]], 81 "num_segments": [10], 82 "dtype": [tf.int32, tf.float32], 83 "multi_node": [0] 84 }, { 85 "data_shape": [[2, 0, 3]], 86 "segment_id": [[1, 1]], 87 "num_segments": [2], 88 "dtype": [tf.int32, tf.float32], 89 "multi_node": [0] 90 }] 91 92 def build_graph_one_node(parameters): 93 data_tensor = tf.compat.v1.placeholder( 94 dtype=parameters["dtype"], name="data", shape=parameters["data_shape"]) 95 segment_ids_tensor = tf.constant( 96 parameters["segment_id"], dtype=tf.int32, name="segment_ids") 97 num_segments_tensor = tf.constant( 98 parameters["num_segments"], 99 dtype=tf.int32, 100 shape=[], 101 name="num_segments") 102 output = unsorted_segment_op(data_tensor, segment_ids_tensor, 103 num_segments_tensor) 104 return [data_tensor], [output] 105 106 107# test cases for handling dynamically shaped input tensor 108 def build_graph_multi_node(parameters): 109 data_tensor = tf.compat.v1.placeholder( 110 dtype=parameters["dtype"], name="data", shape=parameters["data_shape"]) 111 segment_ids_tensor = tf.compat.v1.placeholder( 112 dtype=tf.int32, 113 name="segment_ids", 114 shape=parameters["segment_id_shape"]) 115 num_segments_tensor = tf.constant( 116 parameters["num_segments"], 117 dtype=tf.int32, 118 shape=[], 119 name="num_segments") 120 intermediate_tensor = unsorted_segment_op(data_tensor, segment_ids_tensor, 121 num_segments_tensor) 122 segment_ids_tensor_2 = tf.constant( 123 parameters["segment_id_2"], dtype=tf.int32, name="segment_ids_2") 124 num_segments_tensor_2 = tf.constant( 125 parameters["num_segments_2"], 126 dtype=tf.int32, 127 shape=[], 128 name="num_segments_2") 129 output = unsorted_segment_op(intermediate_tensor, segment_ids_tensor_2, 130 num_segments_tensor_2) 131 return [data_tensor, segment_ids_tensor], [output] 132 133 def build_graph(parameters): 134 multi_node = parameters["multi_node"] 135 if multi_node: 136 return build_graph_multi_node(parameters) 137 138 return build_graph_one_node(parameters) 139 140 def build_inputs_one_node(parameters, sess, inputs, outputs): 141 data_value = create_tensor_data( 142 parameters["dtype"], shape=parameters["data_shape"]) 143 return [data_value], sess.run( 144 outputs, feed_dict=dict(zip(inputs, [data_value]))) 145 146 def build_inputs_multi_node(parameters, sess, inputs, outputs): 147 data_value = create_tensor_data( 148 dtype=parameters["dtype"], shape=parameters["data_shape"]) 149 segment_id_value = create_tensor_data( 150 dtype=tf.int32, 151 shape=parameters["segment_id_shape"], 152 min_value=parameters["segment_id_min"], 153 max_value=parameters["segment_id_max"]) 154 return [data_value, segment_id_value], sess.run( 155 outputs, feed_dict=dict(zip(inputs, [data_value, segment_id_value]))) 156 157 def build_inputs(parameters, sess, inputs, outputs): 158 multi_node = parameters["multi_node"] 159 if multi_node: 160 return build_inputs_multi_node(parameters, sess, inputs, outputs) 161 162 return build_inputs_one_node(parameters, sess, inputs, outputs) 163 164 make_zip_of_tests(options, test_parameters, build_graph, build_inputs) 165 166 167@register_make_test_function() 168def make_unsorted_segment_prod_tests(options): 169 make_unsorted_segment_tests(options, tf.math.unsorted_segment_prod) 170 171 172@register_make_test_function() 173def make_unsorted_segment_max_tests(options): 174 make_unsorted_segment_tests(options, tf.math.unsorted_segment_max) 175 176 177@register_make_test_function() 178def make_unsorted_segment_min_tests(options): 179 make_unsorted_segment_tests(options, tf.math.unsorted_segment_min) 180 181 182@register_make_test_function() 183def make_unsorted_segment_sum_tests(options): 184 make_unsorted_segment_tests(options, tf.math.unsorted_segment_sum) 185