• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for strided_slice operators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import tensorflow.compat.v1 as tf
22from tensorflow.lite.testing.zip_test_utils import create_tensor_data
23from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
24from tensorflow.lite.testing.zip_test_utils import register_make_test_function
25from tensorflow.lite.testing.zip_test_utils import TF_TYPE_INFO
26
27
28def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
29  """Utility function to make strided_slice_tests based on parameters."""
30
31  def build_graph(parameters):
32    """Build graph for stride_slice test."""
33    input_tensor = tf.compat.v1.placeholder(
34        dtype=parameters["dtype"],
35        name="input",
36        shape=parameters["input_shape"])
37    if parameters["constant_indices"]:
38      begin = parameters["begin"]
39      end = parameters["end"]
40      strides = parameters["strides"]
41      tensors = [input_tensor]
42    else:
43      begin = tf.compat.v1.placeholder(
44          dtype=parameters["index_type"],
45          name="begin",
46          shape=[len(parameters["begin"])])
47      end = tf.compat.v1.placeholder(
48          dtype=parameters["index_type"],
49          name="end",
50          shape=[len(parameters["end"])])
51      strides = None
52      if parameters["strides"] is not None:
53        strides = tf.compat.v1.placeholder(
54            dtype=parameters["index_type"],
55            name="strides",
56            shape=[len(parameters["strides"])])
57      tensors = [input_tensor, begin, end]
58      if strides is not None:
59        tensors.append(strides)
60    out = tf.strided_slice(
61        input_tensor,
62        begin,
63        end,
64        strides,
65        begin_mask=parameters["begin_mask"],
66        end_mask=parameters["end_mask"],
67        shrink_axis_mask=parameters["shrink_axis_mask"])
68    return tensors, [out]
69
70  def build_inputs(parameters, sess, inputs, outputs):
71    """Build inputs for stride_slice test."""
72    input_values = create_tensor_data(
73        parameters["dtype"],
74        parameters["input_shape"],
75        min_value=-1,
76        max_value=1)
77    index_type = TF_TYPE_INFO[parameters["index_type"]][0]
78    values = [input_values]
79    if not parameters["constant_indices"]:
80      begin_values = np.array(parameters["begin"]).astype(index_type)
81      end_values = np.array(parameters["end"]).astype(index_type)
82      stride_values = (
83          np.array(parameters["strides"]).astype(index_type)
84          if parameters["strides"] is not None else None)
85      values.append(begin_values)
86      values.append(end_values)
87      if stride_values is not None:
88        values.append(stride_values)
89
90    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
91
92  make_zip_of_tests(
93      options,
94      test_parameters,
95      build_graph,
96      build_inputs,
97      expected_tf_failures=expected_tf_failures)
98
99
100@register_make_test_function()
101def make_strided_slice_tests(options):
102  """Make a set of tests to do strided_slice."""
103
104  # TODO(soroosh): add test/support for uint8.
105  test_parameters = [
106      # 4-D (basic cases with const/non-const indices).
107      {
108          "dtype": [tf.float32, tf.int32, tf.int64, tf.bool],
109          "index_type": [tf.int32],
110          "input_shape": [[12, 2, 2, 5]],
111          "strides": [None, [2, 1, 3, 1]],
112          "begin": [[0, 0, 0, 0]],
113          "end": [[12, 2, 2, 5]],
114          "begin_mask": [None],
115          "end_mask": [None],
116          "shrink_axis_mask": [None],
117          "constant_indices": [False, True],
118          "fully_quantize": [False],
119      },
120      # 4-D with non-trivial begin & end.
121      {
122          "dtype": [tf.float32],
123          "index_type": [tf.int32],
124          "input_shape": [[12, 2, 2, 5]],
125          "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
126          "end": [[8, 2, 2, 3], [12, 2, 2, 5]],
127          "strides": [None, [2, 1, 3, 1]],
128          "begin_mask": [None, 8],
129          "end_mask": [None, 3],
130          "shrink_axis_mask": [None, 15, -1],
131          "constant_indices": [True],
132          "fully_quantize": [False],
133      },
134      # Begin, end, strides dim are different from input shape
135      {
136          "dtype": [tf.float32],
137          "index_type": [tf.int32],
138          "input_shape": [[12, 2, 2, 5]],
139          "begin": [[0]],
140          "end": [[1]],
141          "strides": [None, [1]],
142          "begin_mask": [0],
143          "end_mask": [0],
144          "shrink_axis_mask": [1],
145          "constant_indices": [True, False],
146          "fully_quantize": [False],
147      },
148      # 2-D
149      {
150          "dtype": [tf.float32],
151          "index_type": [tf.int32],
152          "input_shape": [[2, 3]],
153          "begin": [[0, 0]],
154          "end": [[2, 2]],
155          "strides": [None, [2, 2]],
156          "begin_mask": [None, 1, 2],
157          "end_mask": [None, 1, 2],
158          "shrink_axis_mask": [None, 1, 2, 3, -1],
159          "constant_indices": [False, True],
160          "fully_quantize": [False],
161      },
162      # Negative strides
163      {
164          "dtype": [tf.float32],
165          "index_type": [tf.int32],
166          "input_shape": [[2, 3]],
167          "begin": [[0, -1]],
168          "end": [[2, -3]],
169          "strides": [[1, -1]],
170          "begin_mask": [None, 1, 2],
171          "end_mask": [None, 1, 2],
172          "shrink_axis_mask": [None, 1, 2, 3, -1],
173          "constant_indices": [False],
174          "fully_quantize": [False],
175      },
176      # 4-D (cases with const indices and batchsize of 1).
177      {
178          "dtype": [tf.float32],
179          "index_type": [tf.int32],
180          "input_shape": [[1, 2, 2, 5]],
181          "strides": [None, [1, 1, 1, 1]],
182          "begin": [[0, 0, 0, 0], [0, 1, 1, 3]],
183          "end": [[1, 2, 2, 5], [1, 2, 2, 4]],
184          "begin_mask": [None],
185          "end_mask": [None],
186          "shrink_axis_mask": [None],
187          "constant_indices": [True],
188          "fully_quantize": [True],
189      },
190      # Begin, end, strides dim are different from input shape
191      {
192          "dtype": [tf.float32],
193          "index_type": [tf.int32],
194          "input_shape": [[12, 2, 2, 5]],
195          "begin": [[0]],
196          "end": [[1]],
197          "strides": [None, [1]],
198          "begin_mask": [0],
199          "end_mask": [0],
200          "shrink_axis_mask": [1],
201          "constant_indices": [True],
202          "fully_quantize": [True],
203      },
204  ]
205
206  if options.use_experimental_converter:
207    test_parameters = test_parameters + [
208        # Begin equal to input dim.
209        {
210            "dtype": [tf.float32],
211            "index_type": [tf.int32],
212            "input_shape": [[1, 1, 2]],
213            "begin": [[1]],
214            "end": [[0]],
215            "strides": [[1]],
216            "begin_mask": [0],
217            "end_mask": [1],
218            "shrink_axis_mask": [0],
219            "constant_indices": [True, False],
220            "fully_quantize": [False],
221        },
222        {
223            "dtype": [tf.float32],
224            "index_type": [tf.int32],
225            "input_shape": [[1, 1, 2]],
226            "begin": [[1, 0, 0]],
227            "end": [[0, -1, -1]],
228            "strides": [[1, 1, 1]],
229            "begin_mask": [6],
230            "end_mask": [7],
231            "shrink_axis_mask": [0],
232            "constant_indices": [True, False],
233            "fully_quantize": [False],
234        },
235        # String input.
236        {
237            "dtype": [tf.string],
238            "index_type": [tf.int32],
239            "input_shape": [[12, 2, 2, 5]],
240            "begin": [[0, 0, 0, 0]],
241            "end": [[8, 2, 2, 3]],
242            "strides": [[2, 1, 3, 1]],
243            "begin_mask": [8],
244            "end_mask": [3],
245            "shrink_axis_mask": [None],
246            "constant_indices": [True, False],
247            "fully_quantize": [False],
248        }
249    ]
250  _make_strided_slice_tests(options, test_parameters, expected_tf_failures=29)
251
252
253@register_make_test_function()
254def make_strided_slice_1d_exhaustive_tests(options):
255  """Make a set of exhaustive tests for 1D strided_slice."""
256  test_parameters = [
257      # 1-D Exhaustive
258      {
259          "dtype": [tf.float32],
260          "index_type": [tf.int32],
261          "input_shape": [[3]],
262          "begin": [[-2], [-1], [0], [1], [2]],
263          "end": [[-2], [-1], [0], [1], [2]],
264          "strides": [[-2], [-1], [1], [2]],
265          "begin_mask": [0, 1],
266          "end_mask": [0, 1],
267          "shrink_axis_mask": [0],
268          "constant_indices": [False],
269      },
270  ]
271  _make_strided_slice_tests(options, test_parameters)
272