1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for strided_slice operators.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import tensorflow.compat.v1 as tf 22from tensorflow.lite.testing.zip_test_utils import create_tensor_data 23from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 24from tensorflow.lite.testing.zip_test_utils import register_make_test_function 25from tensorflow.lite.testing.zip_test_utils import TF_TYPE_INFO 26 27 28def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0): 29 """Utility function to make strided_slice_tests based on parameters.""" 30 31 def build_graph(parameters): 32 """Build graph for stride_slice test.""" 33 input_tensor = tf.compat.v1.placeholder( 34 dtype=parameters["dtype"], 35 name="input", 36 shape=parameters["input_shape"]) 37 if parameters["constant_indices"]: 38 begin = parameters["begin"] 39 end = parameters["end"] 40 strides = parameters["strides"] 41 tensors = [input_tensor] 42 else: 43 begin = tf.compat.v1.placeholder( 44 dtype=parameters["index_type"], 45 name="begin", 46 shape=[len(parameters["begin"])]) 47 end = tf.compat.v1.placeholder( 48 dtype=parameters["index_type"], 49 name="end", 50 shape=[len(parameters["end"])]) 51 strides = None 52 if parameters["strides"] is not None: 53 strides = tf.compat.v1.placeholder( 54 dtype=parameters["index_type"], 55 name="strides", 56 shape=[len(parameters["strides"])]) 57 tensors = [input_tensor, begin, end] 58 if strides is not None: 59 tensors.append(strides) 60 out = tf.strided_slice( 61 input_tensor, 62 begin, 63 end, 64 strides, 65 begin_mask=parameters["begin_mask"], 66 end_mask=parameters["end_mask"], 67 shrink_axis_mask=parameters["shrink_axis_mask"]) 68 return tensors, [out] 69 70 def build_inputs(parameters, sess, inputs, outputs): 71 """Build inputs for stride_slice test.""" 72 input_values = create_tensor_data( 73 parameters["dtype"], 74 parameters["input_shape"], 75 min_value=-1, 76 max_value=1) 77 index_type = TF_TYPE_INFO[parameters["index_type"]][0] 78 values = [input_values] 79 if not parameters["constant_indices"]: 80 begin_values = np.array(parameters["begin"]).astype(index_type) 81 end_values = np.array(parameters["end"]).astype(index_type) 82 stride_values = ( 83 np.array(parameters["strides"]).astype(index_type) 84 if parameters["strides"] is not None else None) 85 values.append(begin_values) 86 values.append(end_values) 87 if stride_values is not None: 88 values.append(stride_values) 89 90 return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) 91 92 make_zip_of_tests( 93 options, 94 test_parameters, 95 build_graph, 96 build_inputs, 97 expected_tf_failures=expected_tf_failures) 98 99 100@register_make_test_function() 101def make_strided_slice_tests(options): 102 """Make a set of tests to do strided_slice.""" 103 104 # TODO(soroosh): add test/support for uint8. 105 test_parameters = [ 106 # 4-D (basic cases with const/non-const indices). 107 { 108 "dtype": [tf.float32, tf.int32, tf.int64, tf.bool], 109 "index_type": [tf.int32], 110 "input_shape": [[12, 2, 2, 5]], 111 "strides": [None, [2, 1, 3, 1]], 112 "begin": [[0, 0, 0, 0]], 113 "end": [[12, 2, 2, 5]], 114 "begin_mask": [None], 115 "end_mask": [None], 116 "shrink_axis_mask": [None], 117 "constant_indices": [False, True], 118 "fully_quantize": [False], 119 }, 120 # 4-D with non-trivial begin & end. 121 { 122 "dtype": [tf.float32], 123 "index_type": [tf.int32], 124 "input_shape": [[12, 2, 2, 5]], 125 "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], 126 "end": [[8, 2, 2, 3], [12, 2, 2, 5]], 127 "strides": [None, [2, 1, 3, 1]], 128 "begin_mask": [None, 8], 129 "end_mask": [None, 3], 130 "shrink_axis_mask": [None, 15, -1], 131 "constant_indices": [True], 132 "fully_quantize": [False], 133 }, 134 # Begin, end, strides dim are different from input shape 135 { 136 "dtype": [tf.float32], 137 "index_type": [tf.int32], 138 "input_shape": [[12, 2, 2, 5]], 139 "begin": [[0]], 140 "end": [[1]], 141 "strides": [None, [1]], 142 "begin_mask": [0], 143 "end_mask": [0], 144 "shrink_axis_mask": [1], 145 "constant_indices": [True, False], 146 "fully_quantize": [False], 147 }, 148 # 2-D 149 { 150 "dtype": [tf.float32], 151 "index_type": [tf.int32], 152 "input_shape": [[2, 3]], 153 "begin": [[0, 0]], 154 "end": [[2, 2]], 155 "strides": [None, [2, 2]], 156 "begin_mask": [None, 1, 2], 157 "end_mask": [None, 1, 2], 158 "shrink_axis_mask": [None, 1, 2, 3, -1], 159 "constant_indices": [False, True], 160 "fully_quantize": [False], 161 }, 162 # Negative strides 163 { 164 "dtype": [tf.float32], 165 "index_type": [tf.int32], 166 "input_shape": [[2, 3]], 167 "begin": [[0, -1]], 168 "end": [[2, -3]], 169 "strides": [[1, -1]], 170 "begin_mask": [None, 1, 2], 171 "end_mask": [None, 1, 2], 172 "shrink_axis_mask": [None, 1, 2, 3, -1], 173 "constant_indices": [False], 174 "fully_quantize": [False], 175 }, 176 # 4-D (cases with const indices and batchsize of 1). 177 { 178 "dtype": [tf.float32], 179 "index_type": [tf.int32], 180 "input_shape": [[1, 2, 2, 5]], 181 "strides": [None, [1, 1, 1, 1]], 182 "begin": [[0, 0, 0, 0], [0, 1, 1, 3]], 183 "end": [[1, 2, 2, 5], [1, 2, 2, 4]], 184 "begin_mask": [None], 185 "end_mask": [None], 186 "shrink_axis_mask": [None], 187 "constant_indices": [True], 188 "fully_quantize": [True], 189 }, 190 # Begin, end, strides dim are different from input shape 191 { 192 "dtype": [tf.float32], 193 "index_type": [tf.int32], 194 "input_shape": [[12, 2, 2, 5]], 195 "begin": [[0]], 196 "end": [[1]], 197 "strides": [None, [1]], 198 "begin_mask": [0], 199 "end_mask": [0], 200 "shrink_axis_mask": [1], 201 "constant_indices": [True], 202 "fully_quantize": [True], 203 }, 204 ] 205 206 if options.use_experimental_converter: 207 test_parameters = test_parameters + [ 208 # Begin equal to input dim. 209 { 210 "dtype": [tf.float32], 211 "index_type": [tf.int32], 212 "input_shape": [[1, 1, 2]], 213 "begin": [[1]], 214 "end": [[0]], 215 "strides": [[1]], 216 "begin_mask": [0], 217 "end_mask": [1], 218 "shrink_axis_mask": [0], 219 "constant_indices": [True, False], 220 "fully_quantize": [False], 221 }, 222 { 223 "dtype": [tf.float32], 224 "index_type": [tf.int32], 225 "input_shape": [[1, 1, 2]], 226 "begin": [[1, 0, 0]], 227 "end": [[0, -1, -1]], 228 "strides": [[1, 1, 1]], 229 "begin_mask": [6], 230 "end_mask": [7], 231 "shrink_axis_mask": [0], 232 "constant_indices": [True, False], 233 "fully_quantize": [False], 234 }, 235 # String input. 236 { 237 "dtype": [tf.string], 238 "index_type": [tf.int32], 239 "input_shape": [[12, 2, 2, 5]], 240 "begin": [[0, 0, 0, 0]], 241 "end": [[8, 2, 2, 3]], 242 "strides": [[2, 1, 3, 1]], 243 "begin_mask": [8], 244 "end_mask": [3], 245 "shrink_axis_mask": [None], 246 "constant_indices": [True, False], 247 "fully_quantize": [False], 248 } 249 ] 250 _make_strided_slice_tests(options, test_parameters, expected_tf_failures=29) 251 252 253@register_make_test_function() 254def make_strided_slice_1d_exhaustive_tests(options): 255 """Make a set of exhaustive tests for 1D strided_slice.""" 256 test_parameters = [ 257 # 1-D Exhaustive 258 { 259 "dtype": [tf.float32], 260 "index_type": [tf.int32], 261 "input_shape": [[3]], 262 "begin": [[-2], [-1], [0], [1], [2]], 263 "end": [[-2], [-1], [0], [1], [2]], 264 "strides": [[-2], [-1], [1], [2]], 265 "begin_mask": [0, 1], 266 "end_mask": [0, 1], 267 "shrink_axis_mask": [0], 268 "constant_indices": [False], 269 }, 270 ] 271 _make_strided_slice_tests(options, test_parameters) 272