1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.gather.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.eager import backprop 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import indexed_slices 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gradient_checker_v2 32from tensorflow.python.ops import gradients_impl 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import test 37 38_TEST_TYPES = (dtypes.int64, dtypes.float32, 39 dtypes.complex64, dtypes.complex128) 40 41# TODO(virimia): Add a benchmark for gather_v2, with batch_dims and axis set. 42 43 44def _to_str_elements(values): 45 """Converts the inner list elements to strings.""" 46 if isinstance(values, list): 47 return [_to_str_elements(value) for value in values] 48 else: 49 return str(values).encode("utf-8") 50 51 52class GatherTest(test.TestCase, parameterized.TestCase): 53 54 def _buildParams(self, data, dtype): 55 data = data.astype(dtype.as_numpy_dtype) 56 # For complex types, add an index-dependent imaginary component so we can 57 # tell we got the right value. 58 if dtype.is_complex: 59 return data + 10j * data 60 return data 61 62 def testScalar1D(self): 63 with self.cached_session(): 64 data = np.array([0, 1, 2, 3, 7, 5]) 65 for dtype in _TEST_TYPES: 66 for indices in 4, [1, 2, 2, 4, 5]: 67 with self.subTest(dtype=dtype, indices=indices): 68 params_np = self._buildParams(data, dtype) 69 params = constant_op.constant(params_np) 70 indices_tf = constant_op.constant(indices) 71 gather_t = array_ops.gather(params, indices_tf) 72 gather_val = self.evaluate(gather_t) 73 np_val = params_np[indices] 74 self.assertAllEqual(np_val, gather_val) 75 self.assertEqual(np_val.shape, gather_t.get_shape()) 76 77 def testScalar2D(self): 78 with self.session(): 79 data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], 80 [9, 10, 11], [12, 13, 14]]) 81 for dtype in _TEST_TYPES: 82 for axis in range(data.ndim): 83 with self.subTest(dtype=dtype, axis=axis): 84 params_np = self._buildParams(data, dtype) 85 params = constant_op.constant(params_np) 86 indices = constant_op.constant(2) 87 gather_t = array_ops.gather(params, indices, axis=axis) 88 gather_val = self.evaluate(gather_t) 89 self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val) 90 expected_shape = data.shape[:axis] + data.shape[axis + 1:] 91 self.assertEqual(expected_shape, gather_t.get_shape()) 92 93 def testSimpleTwoD32(self): 94 with self.session(): 95 data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], 96 [9, 10, 11], [12, 13, 14]]) 97 for dtype in _TEST_TYPES: 98 for axis in range(data.ndim): 99 with self.subTest(dtype=dtype, axis=axis): 100 params_np = self._buildParams(data, dtype) 101 params = constant_op.constant(params_np) 102 # The indices must be in bounds for any axis. 103 indices = constant_op.constant([0, 1, 0, 2]) 104 gather_t = array_ops.gather(params, indices, axis=axis) 105 gather_val = self.evaluate(gather_t) 106 self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis), 107 gather_val) 108 expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:] 109 self.assertEqual(expected_shape, gather_t.get_shape()) 110 111 def testHigherRank(self): 112 with ops.Graph().as_default(): 113 # We check that scalar and empty indices shapes work as well 114 shape = (2, 1, 3, 2) 115 for indices_shape in (), (0,), (2, 0), (2, 3): 116 for dtype in _TEST_TYPES: 117 for axis in range(len(shape)): 118 params = self._buildParams(np.random.randn(*shape), dtype) 119 indices = np.random.randint(shape[axis], size=indices_shape) 120 with self.subTest( 121 indices_shape=indices_shape, 122 dtype=dtype, 123 axis=axis, 124 indices=indices): 125 tf_params = constant_op.constant(params) 126 tf_indices = constant_op.constant(indices) 127 # Check that both positive and negative indices for axis work. 128 tf_axis = constant_op.constant(axis) 129 tf_negative_axis = constant_op.constant(-len(shape) + axis) 130 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) 131 gather_negative_axis = array_ops.gather( 132 tf_params, tf_indices, axis=tf_negative_axis) 133 gather_value, gather_negative_axis_value = self.evaluate( 134 [gather, gather_negative_axis]) 135 gather_np = np.take(params, indices, axis) 136 self.assertAllEqual(gather_np, gather_value) 137 self.assertAllEqual(gather_np, gather_negative_axis_value) 138 expected_shape = (params.shape[:axis] + indices.shape + 139 params.shape[axis + 1:]) 140 self.assertEqual(expected_shape, gather.shape) 141 self.assertEqual(expected_shape, gather_negative_axis.shape) 142 143 # Test gradients 144 gather_grad = np.random.randn( 145 *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) 146 if dtype.is_complex: 147 gather_grad -= 1j * gather_grad 148 params_grad, indices_grad, axis_grad = gradients_impl.gradients( 149 gather, [tf_params, tf_indices, tf_axis], gather_grad) 150 self.assertIsNone(indices_grad) 151 self.assertIsNone(axis_grad) 152 if dtype.is_integer: 153 self.assertIsNone(params_grad) 154 continue 155 # For axis 0, we are able to create an efficient IndexedSlices for 156 # the gradient. 157 if axis == 0: 158 self.assertEqual( 159 type(params_grad), indexed_slices.IndexedSlices) 160 params_grad = ops.convert_to_tensor(params_grad) 161 correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) 162 outer_dims = axis 163 inner_dims = len(shape) - axis - 1 164 gather_grad = gather_grad.reshape( 165 shape[:axis] + (indices.size,) + shape[axis + 1:]) 166 for source_index, dest_index in enumerate(indices.flat): 167 dest_slice = ((slice(None),) * outer_dims + (dest_index,) + 168 (slice(None),) * inner_dims) 169 source_slice = ((slice(None),) * outer_dims + (source_index,) + 170 (slice(None),) * inner_dims) 171 correct_params_grad[dest_slice] += gather_grad[source_slice] 172 self.assertAllClose( 173 correct_params_grad, 174 self.evaluate(params_grad), 175 atol=2e-6, 176 rtol=2e-6) 177 178 def testHigherRankGradientTape(self): 179 # We check that scalar and empty indices shapes work as well 180 shape = (2, 1, 3, 2) 181 for indices_shape in (), (0,), (2, 0), (2, 3): 182 for dtype in _TEST_TYPES: 183 for axis in range(len(shape)): 184 params = self._buildParams(np.random.randn(*shape), dtype) 185 indices = np.random.randint(shape[axis], size=indices_shape) 186 with self.subTest( 187 indices_shape=indices_shape, 188 dtype=dtype, 189 axis=axis, 190 indices=indices): 191 with backprop.GradientTape() as tape: 192 tf_params = constant_op.constant(params) 193 tf_indices = constant_op.constant(indices) 194 # Check that both positive and negative indices for axis work. 195 tf_axis = constant_op.constant(axis) 196 tape.watch(tf_params) 197 tape.watch(tf_indices) 198 tape.watch(tf_axis) 199 tf_negative_axis = constant_op.constant(-len(shape) + axis) 200 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) 201 gather_negative_axis = array_ops.gather( 202 tf_params, tf_indices, axis=tf_negative_axis) 203 gather_value, gather_negative_axis_value = self.evaluate( 204 [gather, gather_negative_axis]) 205 gather_np = np.take(params, indices, axis) 206 self.assertAllEqual(gather_np, gather_value) 207 self.assertAllEqual(gather_np, gather_negative_axis_value) 208 expected_shape = ( 209 params.shape[:axis] + indices.shape + params.shape[axis + 1:]) 210 self.assertEqual(expected_shape, gather.shape) 211 self.assertEqual(expected_shape, gather_negative_axis.shape) 212 213 # Test gradients 214 gather_grad = np.random.randn( 215 *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) 216 if dtype.is_complex: 217 gather_grad -= 1j * gather_grad 218 params_grad, indices_grad, axis_grad = tape.gradient( 219 gather, [tf_params, tf_indices, tf_axis], gather_grad) 220 self.assertIsNone(indices_grad) 221 self.assertIsNone(axis_grad) 222 if dtype.is_integer: 223 self.assertIsNone(params_grad) 224 continue 225 # For axis 0, we are able to create an efficient IndexedSlices for 226 # the gradient. 227 if axis == 0: 228 self.assertEqual(type(params_grad), indexed_slices.IndexedSlices) 229 params_grad = ops.convert_to_tensor(params_grad) 230 correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) 231 outer_dims = axis 232 inner_dims = len(shape) - axis - 1 233 gather_grad = gather_grad.reshape(shape[:axis] + (indices.size,) + 234 shape[axis + 1:]) 235 for source_index, dest_index in enumerate(indices.flat): 236 dest_slice = ((slice(None),) * outer_dims + (dest_index,) + 237 (slice(None),) * inner_dims) 238 source_slice = ((slice(None),) * outer_dims + (source_index,) + 239 (slice(None),) * inner_dims) 240 correct_params_grad[dest_slice] += gather_grad[source_slice] 241 self.assertAllClose( 242 correct_params_grad, 243 self.evaluate(params_grad), 244 atol=2e-6, 245 rtol=2e-6) 246 247 def testString(self): 248 params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) 249 self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0)) 250 self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1)) 251 252 def testUInt32AndUInt64(self): 253 for unsigned_type in (dtypes.uint32, dtypes.uint64): 254 with self.subTest(unsigned_type=unsigned_type): 255 params = self._buildParams( 256 np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type) 257 with self.cached_session(): 258 self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0)) 259 self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1)) 260 261 def testUnknownIndices(self): 262 # This test is purely a test for placeholder inputs which is only applicable 263 # in graph mode. 264 with ops.Graph().as_default(): 265 params = constant_op.constant([[0, 1, 2]]) 266 indices = array_ops.placeholder(dtypes.int32) 267 gather_t = array_ops.gather(params, indices) 268 self.assertEqual(None, gather_t.get_shape()) 269 270 def testUnknownAxis(self): 271 # This test is purely a test for placeholder inputs which is only applicable 272 # in graph mode. 273 with ops.Graph().as_default(): 274 params = constant_op.constant([[0, 1, 2]]) 275 indices = constant_op.constant([[0, 0], [0, 0]]) 276 axis = array_ops.placeholder(dtypes.int32) 277 gather_t = array_ops.gather(params, indices, axis=axis) 278 # Rank 2 params with rank 2 indices results in a rank 3 shape. 279 self.assertEqual([None, None, None], gather_t.shape.as_list()) 280 281 # If indices is also unknown the result rank is unknown. 282 indices = array_ops.placeholder(dtypes.int32) 283 gather_t = array_ops.gather(params, indices, axis=axis) 284 self.assertEqual(None, gather_t.shape) 285 286 def testBadIndicesType(self): 287 with self.assertRaisesRegex( 288 (TypeError, errors.InvalidArgumentError), 289 "float.* not in.* list of allowed values: int16, int32, int64"): 290 self.evaluate(array_ops.gather([0], 0.)) 291 292 @test_util.disable_xla( 293 "Assertion inside an op is not supported in XLA. Instead XLA clamps the " 294 "index to be in bounds and returns the indexed value there (Don't rely " 295 "on this behavior).") 296 def testBadIndicesCPU(self): 297 with test_util.force_cpu(): 298 params = [[0, 1, 2], [3, 4, 5]] 299 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): 300 self.evaluate(array_ops.gather(params, [[7]], axis=0)) 301 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): 302 self.evaluate(array_ops.gather(params, [[7]], axis=1)) 303 304 def _disabledTestBadIndicesGPU(self): 305 # TODO disabled due to different behavior on GPU and CPU 306 # On GPU the bad indices do not raise error but fetch 0 values 307 if not test.is_gpu_available(): 308 return 309 with self.session(): 310 params = [[0, 1, 2], [3, 4, 5]] 311 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): 312 array_ops.gather(params, [[7]], axis=0).eval() 313 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): 314 array_ops.gather(params, [[7]], axis=1).eval() 315 316 def testBadAxis(self): 317 318 @def_function.function(autograph=False, jit_compile=False) 319 def gather(x, indices, axis): 320 return array_ops.gather(x, indices, axis=axis) 321 322 @def_function.function( 323 autograph=False, 324 jit_compile=False, 325 input_signature=[ 326 tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) 327 ] * 3) 328 def gather_shape_inf_disabled(x, indices, axis): 329 return array_ops.gather(x, indices, axis=axis) 330 331 @def_function.function( 332 autograph=False, 333 jit_compile=True, 334 input_signature=[ 335 tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) 336 ] * 3) 337 def xla_gather(x, indices, axis): 338 return array_ops.gather(x, indices, axis=axis) 339 340 params = [0, 1, 2] 341 indices = 0 342 functions = [("array_ops.gather", array_ops.gather), ("gather", gather), 343 ("gather_shape_inf_disabled", gather_shape_inf_disabled), 344 ("xla_gather", xla_gather)] 345 for bad_axis in (1, 2, -2): 346 for fn_name, fn in functions: 347 # Shape inference can validate axis for known params rank. 348 with self.subTest(bad_axis=bad_axis, msg=fn_name, fn=fn): 349 with self.assertRaisesRegex( 350 (ValueError, errors.InvalidArgumentError), 351 "Shape must be at least rank .* but is rank 1"): 352 fn(params, indices, axis=bad_axis) 353 354 def testEmptySlices(self): 355 for dtype in _TEST_TYPES: 356 for itype in np.int32, np.int64: 357 # Leading axis gather. 358 with self.subTest(dtype=dtype, itype=itype): 359 params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) 360 indices = np.array([3, 4], dtype=itype) 361 gather = array_ops.gather(params, indices, axis=0) 362 self.assertAllEqual(gather, np.zeros((2, 0, 0))) 363 364 # Middle axis gather. 365 params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype) 366 gather = array_ops.gather(params, indices, axis=1) 367 self.assertAllEqual(gather, np.zeros((0, 2, 0))) 368 369 # Trailing axis gather. 370 params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype) 371 gather = array_ops.gather(params, indices, axis=2) 372 self.assertAllEqual(gather, np.zeros((0, 0, 2))) 373 374 @parameterized.parameters([ 375 # batch_dims=0 (equivalent to tf.gather) 376 dict( # 2D indices 377 batch_dims=0, 378 params=[6, 7, 8, 9], 379 indices=[[2, 1], [0, 3]], 380 expected=[[8, 7], [6, 9]]), 381 dict( # 3D indices 382 batch_dims=0, 383 params=[6, 7, 8, 9], 384 indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]], 385 expected=[[[9, 7], [8, 6]], [[6, 9], [8, 8]]]), 386 dict( # 4D indices 387 batch_dims=0, 388 params=[8, 9], 389 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 390 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 391 expected=[[[[8, 9], [9, 8]], [[8, 8], [9, 9]]], 392 [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]), 393 394 # batch_dims=indices.shape.ndims - 1 395 # (equivalent to tf.compat.v1.batch_gather) 396 dict( # 2D indices (1 batch dim) 397 batch_dims=1, 398 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 399 indices=[[2, 1], [0, 3]], 400 expected=[[12, 11], [20, 23]]), 401 dict( # 3D indices (2 batch dims) 402 batch_dims=2, 403 params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]], 404 indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 405 expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]), 406 dict( # 2D indices (1 batch dim) 407 batch_dims=-1, 408 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 409 indices=[[2, 1], [0, 3]], 410 expected=[[12, 11], [20, 23]]), 411 dict( # 3D indices (2 batch dims) 412 batch_dims=-1, 413 params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]], 414 indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 415 expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]), 416 417 # batch_dims=indices.shape.ndims 418 dict( # 1D indices (1 batch dim) 419 batch_dims=1, 420 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 421 indices=[2, 1], 422 expected=[12, 21]), 423 dict( # 2D indices (2 batch dim) 424 batch_dims=2, 425 params=[[[100, 101, 102, 103], [110, 111, 112, 113]], 426 [[200, 201, 202, 203], [210, 211, 212, 213]]], 427 indices=[[2, 1], [0, 3]], 428 expected=[[102, 111], [200, 213]]), 429 430 # 0 < batch_dims < indices.shape.ndims - 1 431 dict( # 3D indices (1 batch dim) 432 batch_dims=1, 433 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 434 indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]], 435 expected=[[[13, 11], [12, 10]], [[20, 23], [22, 22]]]), 436 dict( # 4D indices (1 batch dim) 437 batch_dims=1, 438 params=[[6, 7], [8, 9]], 439 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 440 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 441 expected=[[[[6, 7], [7, 6]], [[6, 6], [7, 7]]], 442 [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]), 443 dict( # 4D indices (2 batch dims) 444 batch_dims=2, 445 params=[[[2, 3], [4, 5]], [[6, 7], [8, 9]]], 446 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 447 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 448 expected=[[[[2, 3], [3, 2]], [[4, 4], [5, 5]]], 449 [[[7, 7], [6, 6]], [[8, 9], [9, 8]]]]), 450 451 # axis > 0 452 dict( # 3D indices, batch_dims=1, axis=2 453 # params.shape = [I1, J1, J2] = [2, 2, 3] 454 # indices.shape = [I1, K1, K2] = [2, 1, 5] 455 # result.shape = [I1, J1, K1, K2] = [2, 2, 1, 5] 456 batch_dims=1, 457 axis=2, 458 params=[[[10, 11, 12], [13, 14, 15]], [[20, 21, 22], [23, 24, 25]]], 459 indices=[[[0, 1, 2, 1, 0]], [[0, 1, 2, 1, 0]]], 460 expected=[[[[10, 11, 12, 11, 10]], [[13, 14, 15, 14, 13]]], 461 [[[20, 21, 22, 21, 20]], [[23, 24, 25, 24, 23]]]]), 462 dict( # 3D indices, batch_dims=None, axis=1 463 batch_dims=None, 464 axis=1, 465 params=[[10, 11, 12], [13, 14, 15]], 466 indices=[1, 0], 467 expected=[[11, 10], [14, 13]]), 468 dict( # 3D indices, batch_dims=-3, axis=1 469 batch_dims=-3, 470 axis=1, 471 params=[[0, 1, 2], [3, 4, 5]], 472 indices=[[[0, 1], [1, 0]]], 473 expected=[[[[0, 1], [1, 0]]], [[[3, 4], [4, 3]]]]), 474 ]) 475 @test_util.run_in_graph_and_eager_modes 476 def testBatchDims(self, params, indices, batch_dims, expected=None, 477 axis=None): 478 result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 479 self.assertAllEqual(expected, result) 480 481 # Test gradients 482 f64_params = math_ops.cast(params, dtypes.float64) 483 def gather(params): 484 return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 485 theoretical, numerical = gradient_checker_v2.compute_gradient( 486 gather, [f64_params]) 487 self.assertAllClose(theoretical, numerical) 488 489 # Test gradients when input shapes are unknown 490 @def_function.function(input_signature=[ 491 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float64), 492 tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32) 493 ]) 494 def gather_unknown_shapes(params, indices): 495 return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 496 if batch_dims is None or batch_dims >= 0: 497 theoretical, numerical = gradient_checker_v2.compute_gradient( 498 lambda p: gather_unknown_shapes(p, indices), [f64_params]) 499 self.assertAllClose(theoretical, numerical) 500 else: 501 with self.assertRaisesRegex( 502 ValueError, 503 "Currently, it is unsupported to take the gradient of tf.gather"): 504 gradient_checker_v2.compute_gradient( 505 lambda p: gather_unknown_shapes(p, indices), [f64_params]) 506 507 # Test the gradients shape. 508 with backprop.GradientTape() as tape: 509 zeros = array_ops.zeros_like(params, dtype=dtypes.float32) 510 tape.watch(zeros) 511 values = zeros * 2 + zeros 512 result = array_ops.gather( 513 values, indices, axis=axis, batch_dims=batch_dims) 514 gradients = tape.gradient(result, zeros) 515 516 self.assertAllEqual(array_ops.shape(params), array_ops.shape(gradients)) 517 518 # Run the same test for strings. 519 params = _to_str_elements(params) 520 expected = _to_str_elements(expected) 521 result = array_ops.gather( 522 params, indices, axis=axis, batch_dims=batch_dims) 523 524 self.assertAllEqual(expected, result) 525 526 @parameterized.parameters([ 527 dict( 528 params_shape=[2, 3, 4, 5, 6, 7], 529 indices_shape=[2, 3, 8, 9, 10], 530 batch_dims=2, 531 axis=2, 532 output_shape=[2, 3, 8, 9, 10, 5, 6, 7] 533 # = params.shape[:2] + indices.shape[2:] + params.shape[3:] 534 ), 535 dict( 536 params_shape=[2, 3, 4, 5, 6, 7], 537 indices_shape=[2, 3, 8, 9, 10], 538 batch_dims=2, 539 axis=3, 540 output_shape=[2, 3, 4, 8, 9, 10, 6, 7] 541 # = params.shape[:3] + indices.shape[2:] + params.shape[4:] 542 ), 543 dict( 544 params_shape=[2, 3, 4, 5, 6, 7], 545 indices_shape=[2, 3, 8, 9, 10], 546 batch_dims=2, 547 axis=4, 548 output_shape=[2, 3, 4, 5, 8, 9, 10, 7] 549 # = params.shape[:4] + indices.shape[2:] + params.shape[5:] 550 ), 551 dict( 552 params_shape=[2, 3, 4, 5, 6, 7], 553 indices_shape=[2, 3, 8, 9, 10], 554 batch_dims=2, 555 axis=5, 556 output_shape=[2, 3, 4, 5, 6, 8, 9, 10] 557 # = params.shape[:5] + indices.shape[2:] + params.shape[6:] 558 ), 559 dict( 560 params_shape=[2, 3, 4, 5, 6, 7], 561 indices_shape=[2, 3, 8, 9, 10], 562 batch_dims=2, 563 axis=-4, 564 output_shape=[2, 3, 8, 9, 10, 5, 6, 7] 565 # = params.shape[:2] + indices.shape[2:] + params.shape[3:] 566 ), 567 dict( 568 params_shape=[2, 3, 4, 5, 6, 7], 569 indices_shape=[2, 3, 8, 9, 10], 570 batch_dims=2, 571 axis=-3, 572 output_shape=[2, 3, 4, 8, 9, 10, 6, 7] 573 # = params.shape[:3] + indices.shape[2:] + params.shape[4:] 574 ), 575 dict( 576 params_shape=[2, 3, 4, 5, 6, 7], 577 indices_shape=[2, 3, 8, 9, 10], 578 batch_dims=2, 579 axis=-2, 580 output_shape=[2, 3, 4, 5, 8, 9, 10, 7] 581 # = params.shape[:4] + indices.shape[2:] + params.shape[5:] 582 ), 583 dict( 584 params_shape=[2, 3, 4, 5, 6, 7], 585 indices_shape=[2, 3, 8, 9, 10], 586 batch_dims=2, 587 axis=-1, 588 output_shape=[2, 3, 4, 5, 6, 8, 9, 10] 589 # = params.shape[:5] + indices.shape[2:] + params.shape[6:] 590 ), 591 ]) 592 @test_util.run_in_graph_and_eager_modes 593 def testBatchDimsMatchesPythonBatching(self, params_shape, indices_shape, 594 batch_dims, axis, output_shape): 595 """Checks that batch_dims matches multiple calls to tf.gather().""" 596 # Generate a `params` tensor with the indicated shape. 597 params_size = np.prod(params_shape) 598 params = np.reshape(np.arange(params_size), params_shape) 599 600 # Generate an `indices` tensor with the indicated shape, where each index 601 # is within the appropriate range. 602 indices_size = np.prod(indices_shape) 603 indices = np.reshape(np.arange(indices_size), indices_shape) 604 indices = indices % params_shape[axis] 605 606 # Perform repeated (batched) gather operations with numpy, to find the 607 # expected result. 608 expected = self._batchNumpyGather(params, indices, axis, batch_dims) 609 610 # On Windows, we get an exception if we pass in the transformed numpy 611 # arrays ("Failed to convert numpy ndarray to a Tensor (Unsupported 612 # feed type)."); so convert them back to lists before calling tf.gather. 613 params = params.tolist() 614 indices = indices.tolist() 615 616 result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 617 self.assertAllEqual(output_shape, result.shape.as_list()) 618 self.assertAllEqual(expected, result) 619 620 # Run the same test for strings. 621 params = _to_str_elements(params) 622 expected = _to_str_elements(expected.tolist()) 623 result = array_ops.gather( 624 params, indices, axis=axis, batch_dims=batch_dims) 625 626 self.assertAllEqual(output_shape, result.shape.as_list()) 627 self.assertAllEqual(expected, result) 628 629 def _batchNumpyGather(self, params, indices, axis, batch_dims): 630 """Performs a batch gather by making recursive calls to np.take(). 631 632 This is used by testBatchDims() to construct the expected value. 633 634 Args: 635 params: A numpy array 636 indices: A numpy array 637 axis: An integer 638 batch_dims: An integer 639 Returns: 640 A numpy array 641 """ 642 if batch_dims == 0: 643 return np.take(params, indices, axis=axis) 644 self.assertEqual(params.shape[0], indices.shape[0]) 645 if axis > 0: 646 axis -= 1 647 return np.stack([ 648 self._batchNumpyGather(params[i], indices[i], axis, batch_dims - 1) 649 for i in range(params.shape[0]) 650 ]) 651 652 @test_util.run_v1_only("RefVariable is not supported in v2") 653 def testGatherRefVariable(self): 654 with self.cached_session(): 655 v = variables.RefVariable(constant_op.constant([[1, 2], [3, 4], [5, 6]])) 656 self.evaluate(variables.global_variables_initializer()) 657 gather = array_ops.gather(v, [0, 2]) 658 if not context.executing_eagerly(): # .op doesn't make sense in Eager 659 self.assertEqual("GatherV2", gather.op.name) 660 self.assertAllEqual([[1, 2], [5, 6]], gather) 661 662 @test_util.run_in_graph_and_eager_modes 663 def testGatherResourceVariable(self): 664 with self.cached_session(): 665 v = resource_variable_ops.ResourceVariable( 666 constant_op.constant([[1, 2], [3, 4], [5, 6]])) 667 self.evaluate(variables.global_variables_initializer()) 668 gather = array_ops.gather(v, [0, 2]) 669 if not context.executing_eagerly(): # .op doesn't make sense in Eager 670 self.assertEqual("ResourceGather", gather.op.inputs[0].op.type) 671 self.assertAllEqual([[1, 2], [5, 6]], gather) 672 673if __name__ == "__main__": 674 test.main() 675