• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for strided_slice_np_style."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow.compat.v1 as tf
21from tensorflow.lite.testing.zip_test_utils import create_tensor_data
22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
23from tensorflow.lite.testing.zip_test_utils import register_make_test_function
24
25
26# TODO(b/137615945): Expand the test coverage of this one and remove the old
27# ones.
28@register_make_test_function()
29def make_strided_slice_np_style_tests(options):
30  """Make a set of tests to test strided_slice in np style."""
31
32  test_parameters = [
33      {
34          "dtype": [tf.float32],
35          "shape": [[12, 7], [33, 1]],
36          "spec": [[slice(3, 7, 2), slice(None)],
37                   [tf.newaxis,
38                    slice(3, 7, 1), tf.newaxis,
39                    slice(None)], [slice(1, 5, 1), slice(None)]],
40      },
41      # 1-D case
42      {
43          "dtype": [tf.float32],
44          "shape": [[44]],
45          "spec": [[slice(3, 7, 2)], [tf.newaxis, slice(None)]],
46      },
47      # Shrink mask.
48      {
49          "dtype": [tf.float32],
50          "shape": [[21, 15, 7]],
51          "spec": [[slice(3, 7, 2), slice(None), 2]],
52      },
53      # Ellipsis 3d.
54      {
55          "dtype": [tf.float32],
56          "shape": [[21, 15, 7]],
57          "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
58                                                slice(3, 7, 2)],
59                   [slice(1, 11, 3), Ellipsis,
60                    slice(3, 7, 2)]],
61      },
62      # Ellipsis 4d.
63      {
64          "dtype": [tf.float32],
65          "shape": [[21, 15, 7, 9]],
66          "spec": [[slice(3, 7, 2), Ellipsis], [Ellipsis,
67                                                slice(3, 7, 2)],
68                   [slice(1, 11, 3), Ellipsis,
69                    slice(3, 7, 2)]],
70      },
71      # Ellipsis 5d.
72      {
73          "dtype": [tf.float32],
74          "shape": [[11, 21, 15, 7, 9]],
75          "spec": [[
76              slice(3, 7, 2),
77              slice(None),
78              slice(None),
79              slice(None),
80              slice(None)
81          ], [Ellipsis, slice(3, 7, 2)]],
82      },
83      # Ellipsis + Shrink Mask
84      {
85          "dtype": [tf.float32],
86          "shape": [[22, 15, 7]],
87          "spec": [[2,  # shrink before ellipsis
88                    Ellipsis],
89                   [Ellipsis,  # shrink after ellipsis
90                    2]],
91      },
92      # Ellipsis + New Axis Mask
93      {
94          "dtype": [tf.float32],
95          "shape": [[23, 15, 7]],
96          "spec": [[tf.newaxis,  # new_axis before ellipsis
97                    slice(3, 7, 2),
98                    slice(None), Ellipsis],
99                   [tf.newaxis,  # new_axis after (and before) ellipsis
100                    slice(3, 7, 2),
101                    slice(None), Ellipsis, tf.newaxis]],
102      },
103  ]
104
105  def build_graph(parameters):
106    """Build a simple graph with np style strided_slice."""
107    input_value = tf.compat.v1.placeholder(
108        dtype=parameters["dtype"], shape=parameters["shape"])
109    out = input_value.__getitem__(parameters["spec"])
110    return [input_value], [out]
111
112  def build_inputs(parameters, sess, inputs, outputs):
113    input_value = create_tensor_data(parameters["dtype"], parameters["shape"])
114    return [input_value], sess.run(
115        outputs, feed_dict=dict(zip(inputs, [input_value])))
116
117  make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
118