• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for strided_slice operators."""
16import numpy as np
17import tensorflow.compat.v1 as tf
18
19from tensorflow.lite.testing.zip_test_utils import create_tensor_data
20from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
21from tensorflow.lite.testing.zip_test_utils import MAP_TF_TO_NUMPY_TYPE
22from tensorflow.lite.testing.zip_test_utils import register_make_test_function
23
24
25def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
26  """Utility function to make strided_slice_tests based on parameters."""
27
28  def build_graph(parameters):
29    """Build graph for stride_slice test."""
30    input_tensor = tf.compat.v1.placeholder(
31        dtype=parameters["dtype"],
32        name="input",
33        shape=parameters["input_shape"])
34    if parameters["constant_indices"]:
35      begin = parameters["begin"]
36      end = parameters["end"]
37      strides = parameters["strides"]
38      tensors = [input_tensor]
39    else:
40      begin = tf.compat.v1.placeholder(
41          dtype=parameters["index_type"],
42          name="begin",
43          shape=[len(parameters["begin"])])
44      end = tf.compat.v1.placeholder(
45          dtype=parameters["index_type"],
46          name="end",
47          shape=[len(parameters["end"])])
48      strides = None
49      if parameters["strides"] is not None:
50        strides = tf.compat.v1.placeholder(
51            dtype=parameters["index_type"],
52            name="strides",
53            shape=[len(parameters["strides"])])
54      tensors = [input_tensor, begin, end]
55      if strides is not None:
56        tensors.append(strides)
57
58    kwargs = {}
59    if parameters.get("ellipsis_mask", None):
60      kwargs.update({"ellipsis_mask": parameters["ellipsis_mask"]})
61    if parameters.get("new_axis_mask", None):
62      kwargs.update({"new_axis_mask": parameters["new_axis_mask"]})
63
64    out = tf.strided_slice(
65        input_tensor,
66        begin,
67        end,
68        strides,
69        begin_mask=parameters["begin_mask"],
70        end_mask=parameters["end_mask"],
71        shrink_axis_mask=parameters["shrink_axis_mask"],
72        **kwargs)
73    return tensors, [out]
74
75  def build_inputs(parameters, sess, inputs, outputs):
76    """Build inputs for stride_slice test."""
77    input_values = create_tensor_data(
78        parameters["dtype"],
79        parameters["input_shape"],
80        min_value=-1,
81        max_value=1)
82    index_type = MAP_TF_TO_NUMPY_TYPE[parameters["index_type"]]
83    values = [input_values]
84    if not parameters["constant_indices"]:
85      begin_values = np.array(parameters["begin"]).astype(index_type)
86      end_values = np.array(parameters["end"]).astype(index_type)
87      stride_values = (
88          np.array(parameters["strides"]).astype(index_type)
89          if parameters["strides"] is not None else None)
90      values.append(begin_values)
91      values.append(end_values)
92      if stride_values is not None:
93        values.append(stride_values)
94
95    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
96
97  make_zip_of_tests(
98      options,
99      test_parameters,
100      build_graph,
101      build_inputs,
102      expected_tf_failures=expected_tf_failures)
103
104
105@register_make_test_function()
106def make_strided_slice_tests(options):
107  """Make a set of tests to do strided_slice."""
108
109  # TODO(soroosh): add test/support for uint8.
110  test_parameters = [
111      # 4-D (basic cases with const/non-const indices).
112      {
113          "dtype": [tf.float32, tf.int32, tf.int64, tf.bool],
114          "index_type": [tf.int32],
115          "input_shape": [[12, 2, 2, 5]],
116          "strides": [None, [2, 1, 3, 1]],
117          "begin": [[0, 0, 0, 0]],
118          "end": [[12, 2, 2, 5]],
119          "begin_mask": [None],
120          "end_mask": [None],
121          "shrink_axis_mask": [None],
122          "constant_indices": [False, True],
123          "fully_quantize": [False],
124      },
125      # 4-D with non-trivial begin & end.
126      {
127          "dtype": [tf.float32],
128          "index_type": [tf.int32],
129          "input_shape": [[12, 2, 2, 5]],
130          "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
131          "end": [[8, 2, 2, 3], [12, 2, 2, 5]],
132          "strides": [None, [2, 1, 3, 1]],
133          "begin_mask": [None, 8],
134          "end_mask": [None, 3],
135          "shrink_axis_mask": [None, 15, -1],
136          "constant_indices": [True],
137          "fully_quantize": [False],
138      },
139      # Begin, end, strides dim are different from input shape
140      {
141          "dtype": [tf.float32],
142          "index_type": [tf.int32],
143          "input_shape": [[12, 2, 2, 5]],
144          "begin": [[0]],
145          "end": [[1]],
146          "strides": [None, [1]],
147          "begin_mask": [0],
148          "end_mask": [0],
149          "shrink_axis_mask": [1],
150          "constant_indices": [True, False],
151          "fully_quantize": [False],
152      },
153      # 2-D
154      {
155          "dtype": [tf.float32],
156          "index_type": [tf.int32],
157          "input_shape": [[2, 3]],
158          "begin": [[0, 0]],
159          "end": [[2, 2]],
160          "strides": [None, [2, 2]],
161          "begin_mask": [None, 1, 2],
162          "end_mask": [None, 1, 2],
163          "shrink_axis_mask": [None, 1, 2, 3, -1],
164          "constant_indices": [False, True],
165          "fully_quantize": [False],
166      },
167      # Negative strides
168      {
169          "dtype": [tf.float32],
170          "index_type": [tf.int32],
171          "input_shape": [[2, 3]],
172          "begin": [[0, -1]],
173          "end": [[2, -3]],
174          "strides": [[1, -1]],
175          "begin_mask": [None, 1, 2],
176          "end_mask": [None, 1, 2],
177          "shrink_axis_mask": [None, 1, 2, 3, -1],
178          "constant_indices": [False],
179          "fully_quantize": [False],
180      },
181      # 4-D (cases with const indices and batchsize of 1).
182      {
183          "dtype": [tf.float32],
184          "index_type": [tf.int32],
185          "input_shape": [[1, 2, 2, 5]],
186          "strides": [None, [1, 1, 1, 1]],
187          "begin": [[0, 0, 0, 0], [0, 1, 1, 3]],
188          "end": [[1, 2, 2, 5], [1, 2, 2, 4]],
189          "begin_mask": [None],
190          "end_mask": [None],
191          "shrink_axis_mask": [None],
192          "constant_indices": [True],
193          "fully_quantize": [True],
194      },
195      # Begin, end, strides dim are different from input shape
196      {
197          "dtype": [tf.float32],
198          "index_type": [tf.int32],
199          "input_shape": [[12, 2, 2, 5]],
200          "begin": [[0]],
201          "end": [[1]],
202          "strides": [None, [1]],
203          "begin_mask": [0],
204          "end_mask": [0],
205          "shrink_axis_mask": [1],
206          "constant_indices": [True],
207          "fully_quantize": [True],
208      },
209      {
210          "dtype": [tf.float32],
211          "index_type": [tf.int32],
212          "input_shape": [[1, 1, 2]],
213          "begin": [[1]],
214          "end": [[0]],
215          "strides": [[1]],
216          "begin_mask": [0],
217          "end_mask": [1],
218          "shrink_axis_mask": [0],
219          "constant_indices": [True, False],
220          "fully_quantize": [False],
221      },
222      {
223          "dtype": [tf.float32],
224          "index_type": [tf.int32],
225          "input_shape": [[1, 1, 2]],
226          "begin": [[1, 0, 0]],
227          "end": [[0, -1, -1]],
228          "strides": [[1, 1, 1]],
229          "begin_mask": [6],
230          "end_mask": [7],
231          "shrink_axis_mask": [0],
232          "constant_indices": [True, False],
233          "fully_quantize": [False],
234      },
235      # String input.
236      {
237          "dtype": [tf.string],
238          "index_type": [tf.int32],
239          "input_shape": [[12, 2, 2, 5]],
240          "begin": [[0, 0, 0, 0]],
241          "end": [[8, 2, 2, 3]],
242          "strides": [[2, 1, 3, 1]],
243          "begin_mask": [8],
244          "end_mask": [3],
245          "shrink_axis_mask": [None],
246          "constant_indices": [True, False],
247          "fully_quantize": [False],
248      },
249      # ellipsis_mask and new_axis_mask.
250      {
251          "dtype": [tf.float32],
252          "index_type": [tf.int32],
253          "input_shape": [[5, 5, 7, 7]],
254          "begin": [[0, 0, 0, 0]],
255          "end": [[2, 3, 4, 5]],
256          "strides": [[1, 1, 1, 1]],
257          "begin_mask": [0, 8],
258          "end_mask": [0, 2],
259          "shrink_axis_mask": [0, 4],
260          "ellipsis_mask": [2, 4],
261          "new_axis_mask": [1, 6],
262          "constant_indices": [True],
263          "fully_quantize": [False],
264      },
265      {
266          "dtype": [tf.float32],
267          "index_type": [tf.int32],
268          "input_shape": [[5, 6, 7]],
269          "begin": [[0, 0, 0]],
270          "end": [[2, 3, 4]],
271          "strides": [[1, 1, 1]],
272          "begin_mask": [0],
273          "end_mask": [0],
274          "shrink_axis_mask": [0, 2],
275          "ellipsis_mask": [2],
276          "new_axis_mask": [1, 2, 3, 4, 5],
277          "constant_indices": [False],
278          "fully_quantize": [False],
279      },
280      # Shrink_axis and add_axis mask both set
281      {
282          "dtype": [tf.float32],
283          "index_type": [tf.int32],
284          "input_shape": [[6, 7, 8]],
285          "begin": [[0, 0, 0, 0]],
286          "end": [[2, 3, 4, 5]],
287          "strides": [[1, 1, 1, 1]],
288          "begin_mask": [0],
289          "end_mask": [0],
290          "new_axis_mask": [10],
291          "shrink_axis_mask": [1],
292          "constant_indices": [True],
293          "fully_quantize": [False],
294      },
295  ]
296  _make_strided_slice_tests(options, test_parameters, expected_tf_failures=29)
297
298
299@register_make_test_function()
300def make_strided_slice_1d_exhaustive_tests(options):
301  """Make a set of exhaustive tests for 1D strided_slice."""
302  test_parameters = [
303      # 1-D Exhaustive
304      {
305          "dtype": [tf.float32],
306          "index_type": [tf.int32],
307          "input_shape": [[3]],
308          "begin": [[-2], [-1], [0], [1], [2]],
309          "end": [[-2], [-1], [0], [1], [2]],
310          "strides": [[-2], [-1], [1], [2]],
311          "begin_mask": [0, 1],
312          "end_mask": [0, 1],
313          "shrink_axis_mask": [0],
314          "constant_indices": [False],
315      },
316  ]
317  _make_strided_slice_tests(options, test_parameters)
318