• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for strided_slice_np_style."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow.compat.v1 as tf
21from tensorflow.lite.testing.zip_test_utils import create_tensor_data
22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
23from tensorflow.lite.testing.zip_test_utils import register_make_test_function
24
25
26# TODO(b/137615945): Expand the test coverage of this one and remove the old
27# ones.
28@register_make_test_function()
29def make_strided_slice_np_style_tests(options):
30  """Make a set of tests to test strided_slice in np style."""
31
32  test_parameters = [
33      {
34          "dtype": [tf.float32],
35          "shape": [[12, 7], [33, 1]],
36          "spec": [[slice(3, 7, 2), slice(None)],
37                   [tf.newaxis,
38                    slice(3, 7, 1), tf.newaxis,
39                    slice(None)], [slice(1, 5, 1), slice(None)]],
40      },
41      # 1-D case
42      {
43          "dtype": [tf.float32],
44          "shape": [[44]],
45          "spec": [[slice(3, 7, 2)], [tf.newaxis, slice(None)]],
46      },
47      # Shrink mask.
48      {
49          "dtype": [tf.float32],
50          "shape": [[21, 15, 7]],
51          "spec": [[slice(3, 7, 2), slice(None), 2]],
52      },
53      # Ellipsis 3d.
54      {
55          "dtype": [tf.float32],
56          "shape": [[21, 15, 7]],
57          "spec": [[slice(3, 7, 2), Ellipsis],
58                   [slice(1, 11, 3), Ellipsis,
59                    slice(3, 7, 2)]],
60      },
61      # Ellipsis 4d.
62      {
63          "dtype": [tf.float32],
64          "shape": [[21, 15, 7, 9]],
65          "spec": [[slice(3, 7, 2), Ellipsis]],
66      },
67      # Ellipsis 5d.
68      {
69          "dtype": [tf.float32],
70          "shape": [[11, 21, 15, 7, 9]],
71          "spec": [[
72              slice(3, 7, 2),
73              slice(None),
74              slice(None),
75              slice(None),
76              slice(None)
77          ]],
78      },
79      # Ellipsis + Shrink Mask
80      {
81          "dtype": [tf.float32],
82          "shape": [[22, 15, 7]],
83          "spec": [
84              [
85                  2,  # shrink before ellipsis
86                  Ellipsis
87              ],
88          ],
89      },
90      # Ellipsis + New Axis Mask
91      {
92          "dtype": [tf.float32],
93          "shape": [[23, 15, 7]],
94          "spec": [
95              [
96                  tf.newaxis,  # new_axis before ellipsis
97                  slice(3, 7, 2),
98                  slice(None),
99                  Ellipsis
100              ],
101              [
102                  tf.newaxis,  # new_axis after (and before) ellipsis
103                  slice(3, 7, 2),
104                  slice(None),
105                  Ellipsis,
106                  tf.newaxis
107              ]
108          ],
109      },
110  ]
111
112  if options.use_experimental_converter:
113    # The case when Ellipsis is expanded to multiple dimension is only supported
114    # by MLIR converter (b/183902491).
115    test_parameters = test_parameters + [
116        # Ellipsis 3d.
117        {
118            "dtype": [tf.float32],
119            "shape": [[21, 15, 7]],
120            "spec": [[Ellipsis, slice(3, 7, 2)]],
121        },
122        # Ellipsis 4d.
123        {
124            "dtype": [tf.float32],
125            "shape": [[21, 15, 7, 9]],
126            "spec": [[Ellipsis, slice(3, 7, 2)],
127                     [slice(1, 11, 3), Ellipsis,
128                      slice(3, 7, 2)]],
129        },
130        # Ellipsis 5d.
131        {
132            "dtype": [tf.float32],
133            "shape": [[11, 21, 15, 7, 9]],
134            "spec": [[Ellipsis, slice(3, 7, 2)]],
135        },
136        # Ellipsis + Shrink Mask
137        {
138            "dtype": [tf.float32],
139            "shape": [[22, 15, 7]],
140            "spec": [[
141                Ellipsis,  # shrink after ellipsis
142                2
143            ]],
144        },
145    ]
146
147  def build_graph(parameters):
148    """Build a simple graph with np style strided_slice."""
149    input_value = tf.compat.v1.placeholder(
150        dtype=parameters["dtype"], shape=parameters["shape"])
151    out = input_value.__getitem__(parameters["spec"])
152    return [input_value], [out]
153
154  def build_inputs(parameters, sess, inputs, outputs):
155    input_value = create_tensor_data(parameters["dtype"], parameters["shape"])
156    return [input_value], sess.run(
157        outputs, feed_dict=dict(zip(inputs, [input_value])))
158
159  make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
160