• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for unsorted_segment ops."""
16
17import tensorflow.compat.v1 as tf
18from tensorflow.lite.testing.zip_test_utils import create_tensor_data
19from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
20from tensorflow.lite.testing.zip_test_utils import register_make_test_function
21
22
23def make_unsorted_segment_tests(options, unsorted_segment_op):
24  """Make a set of tests for given unsorted_segment op."""
25  test_parameters = [{
26      "data_shape": [[5]],
27      "segment_id": [[0, 1, 1, 0, 1]],
28      "num_segments": [2],
29      "dtype": [tf.int32, tf.float32],
30      "multi_node": [0]
31  }, {
32      "data_shape": [[2, 3, 4], [2, 5, 2]],
33      "segment_id": [[0, 1]],
34      "num_segments": [2],
35      "dtype": [tf.int32, tf.float32],
36      "multi_node": [0]
37  }, {
38      "data_shape": [[4]],
39      "segment_id": [[0, 0, 1, 8]],
40      "num_segments": [9],
41      "dtype": [tf.int32, tf.float32],
42      "multi_node": [0]
43  }, {
44      "data_shape": [[3]],
45      "segment_id": [[-1, -2, -1]],
46      "num_segments": [1],
47      "dtype": [tf.int32, tf.float32],
48      "multi_node": [0]
49  }, {
50      "data_shape": [[3]],
51      "segment_id": [[-1, 0, 1]],
52      "num_segments": [2],
53      "dtype": [tf.int32, tf.float32],
54      "multi_node": [0]
55  }, {
56      "data_shape": [[3, 2]],
57      "segment_id": [[-1, 0, 0]],
58      "num_segments": [1],
59      "dtype": [tf.int32, tf.float32],
60      "multi_node": [0]
61  }, {
62      "data_shape": [[3, 2]],
63      "segment_id": [[-1, -2, -1]],
64      "num_segments": [1],
65      "dtype": [tf.int32, tf.float32],
66      "multi_node": [0]
67  }, {
68      "data_shape": [[4]],
69      "segment_id_shape": [[4]],
70      "segment_id_min": [0],
71      "segment_id_max": [1],
72      "num_segments": [2],
73      "dtype": [tf.int32, tf.float32],
74      "segment_id_2": [[0, 0]],
75      "num_segments_2": [1],
76      "multi_node": [1]
77  }, {
78      "data_shape": [[2, 2, 3]],
79      "segment_id": [[[1, 2], [3, 4]], [4, 5],
80                     [[[1, 2, 3], [3, 4, 5]], [[1, 2, 4], [0, 0, -1]]]],
81      "num_segments": [10],
82      "dtype": [tf.int32, tf.float32],
83      "multi_node": [0]
84  }, {
85      "data_shape": [[2, 0, 3]],
86      "segment_id": [[1, 1]],
87      "num_segments": [2],
88      "dtype": [tf.int32, tf.float32],
89      "multi_node": [0]
90  }]
91
92  def build_graph_one_node(parameters):
93    data_tensor = tf.compat.v1.placeholder(
94        dtype=parameters["dtype"], name="data", shape=parameters["data_shape"])
95    segment_ids_tensor = tf.constant(
96        parameters["segment_id"], dtype=tf.int32, name="segment_ids")
97    num_segments_tensor = tf.constant(
98        parameters["num_segments"],
99        dtype=tf.int32,
100        shape=[],
101        name="num_segments")
102    output = unsorted_segment_op(data_tensor, segment_ids_tensor,
103                                 num_segments_tensor)
104    return [data_tensor], [output]
105
106
107# test cases for handling dynamically shaped input tensor
108  def build_graph_multi_node(parameters):
109    data_tensor = tf.compat.v1.placeholder(
110        dtype=parameters["dtype"], name="data", shape=parameters["data_shape"])
111    segment_ids_tensor = tf.compat.v1.placeholder(
112        dtype=tf.int32,
113        name="segment_ids",
114        shape=parameters["segment_id_shape"])
115    num_segments_tensor = tf.constant(
116        parameters["num_segments"],
117        dtype=tf.int32,
118        shape=[],
119        name="num_segments")
120    intermediate_tensor = unsorted_segment_op(data_tensor, segment_ids_tensor,
121                                              num_segments_tensor)
122    segment_ids_tensor_2 = tf.constant(
123        parameters["segment_id_2"], dtype=tf.int32, name="segment_ids_2")
124    num_segments_tensor_2 = tf.constant(
125        parameters["num_segments_2"],
126        dtype=tf.int32,
127        shape=[],
128        name="num_segments_2")
129    output = unsorted_segment_op(intermediate_tensor, segment_ids_tensor_2,
130                                 num_segments_tensor_2)
131    return [data_tensor, segment_ids_tensor], [output]
132
133  def build_graph(parameters):
134    multi_node = parameters["multi_node"]
135    if multi_node:
136      return build_graph_multi_node(parameters)
137
138    return build_graph_one_node(parameters)
139
140  def build_inputs_one_node(parameters, sess, inputs, outputs):
141    data_value = create_tensor_data(
142        parameters["dtype"], shape=parameters["data_shape"])
143    return [data_value], sess.run(
144        outputs, feed_dict=dict(zip(inputs, [data_value])))
145
146  def build_inputs_multi_node(parameters, sess, inputs, outputs):
147    data_value = create_tensor_data(
148        dtype=parameters["dtype"], shape=parameters["data_shape"])
149    segment_id_value = create_tensor_data(
150        dtype=tf.int32,
151        shape=parameters["segment_id_shape"],
152        min_value=parameters["segment_id_min"],
153        max_value=parameters["segment_id_max"])
154    return [data_value, segment_id_value], sess.run(
155        outputs, feed_dict=dict(zip(inputs, [data_value, segment_id_value])))
156
157  def build_inputs(parameters, sess, inputs, outputs):
158    multi_node = parameters["multi_node"]
159    if multi_node:
160      return build_inputs_multi_node(parameters, sess, inputs, outputs)
161
162    return build_inputs_one_node(parameters, sess, inputs, outputs)
163
164  make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
165
166
167@register_make_test_function()
168def make_unsorted_segment_prod_tests(options):
169  make_unsorted_segment_tests(options, tf.math.unsorted_segment_prod)
170
171
172@register_make_test_function()
173def make_unsorted_segment_max_tests(options):
174  make_unsorted_segment_tests(options, tf.math.unsorted_segment_max)
175
176
177@register_make_test_function()
178def make_unsorted_segment_min_tests(options):
179  make_unsorted_segment_tests(options, tf.math.unsorted_segment_min)
180
181
182@register_make_test_function()
183def make_unsorted_segment_sum_tests(options):
184  make_unsorted_segment_tests(options, tf.math.unsorted_segment_sum)
185