1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.gather.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python.eager import backprop 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import gradients_impl 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variables 35from tensorflow.python.platform import test 36 37_TEST_TYPES = (dtypes.int64, dtypes.float32, 38 dtypes.complex64, dtypes.complex128) 39 40# TODO(virimia): Add a benchmark for gather_v2, with batch_dims and axis set. 41 42 43def _to_str_elements(values): 44 """Converts the inner list elements to strings.""" 45 if isinstance(values, list): 46 return [_to_str_elements(value) for value in values] 47 else: 48 return str(values).encode("utf-8") 49 50 51class GatherTest(test.TestCase, parameterized.TestCase): 52 53 def _buildParams(self, data, dtype): 54 data = data.astype(dtype.as_numpy_dtype) 55 # For complex types, add an index-dependent imaginary component so we can 56 # tell we got the right value. 57 if dtype.is_complex: 58 return data + 10j * data 59 return data 60 61 def testScalar1D(self): 62 with self.cached_session(): 63 data = np.array([0, 1, 2, 3, 7, 5]) 64 for dtype in _TEST_TYPES: 65 for indices in 4, [1, 2, 2, 4, 5]: 66 with self.subTest(dtype=dtype, indices=indices): 67 params_np = self._buildParams(data, dtype) 68 params = constant_op.constant(params_np) 69 indices_tf = constant_op.constant(indices) 70 gather_t = array_ops.gather(params, indices_tf) 71 gather_val = self.evaluate(gather_t) 72 np_val = params_np[indices] 73 self.assertAllEqual(np_val, gather_val) 74 self.assertEqual(np_val.shape, gather_t.get_shape()) 75 76 def testScalar2D(self): 77 with self.session(): 78 data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], 79 [9, 10, 11], [12, 13, 14]]) 80 for dtype in _TEST_TYPES: 81 for axis in range(data.ndim): 82 with self.subTest(dtype=dtype, axis=axis): 83 params_np = self._buildParams(data, dtype) 84 params = constant_op.constant(params_np) 85 indices = constant_op.constant(2) 86 gather_t = array_ops.gather(params, indices, axis=axis) 87 gather_val = self.evaluate(gather_t) 88 self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val) 89 expected_shape = data.shape[:axis] + data.shape[axis + 1:] 90 self.assertEqual(expected_shape, gather_t.get_shape()) 91 92 def testSimpleTwoD32(self): 93 with self.session(): 94 data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], 95 [9, 10, 11], [12, 13, 14]]) 96 for dtype in _TEST_TYPES: 97 for axis in range(data.ndim): 98 with self.subTest(dtype=dtype, axis=axis): 99 params_np = self._buildParams(data, dtype) 100 params = constant_op.constant(params_np) 101 # The indices must be in bounds for any axis. 102 indices = constant_op.constant([0, 1, 0, 2]) 103 gather_t = array_ops.gather(params, indices, axis=axis) 104 gather_val = self.evaluate(gather_t) 105 self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis), 106 gather_val) 107 expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:] 108 self.assertEqual(expected_shape, gather_t.get_shape()) 109 110 def testHigherRank(self): 111 with ops.Graph().as_default(): 112 # We check that scalar and empty indices shapes work as well 113 shape = (2, 1, 3, 2) 114 for indices_shape in (), (0,), (2, 0), (2, 3): 115 for dtype in _TEST_TYPES: 116 for axis in range(len(shape)): 117 params = self._buildParams(np.random.randn(*shape), dtype) 118 indices = np.random.randint(shape[axis], size=indices_shape) 119 with self.subTest( 120 indices_shape=indices_shape, 121 dtype=dtype, 122 axis=axis, 123 indices=indices): 124 tf_params = constant_op.constant(params) 125 tf_indices = constant_op.constant(indices) 126 # Check that both positive and negative indices for axis work. 127 tf_axis = constant_op.constant(axis) 128 tf_negative_axis = constant_op.constant(-len(shape) + axis) 129 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) 130 gather_negative_axis = array_ops.gather( 131 tf_params, tf_indices, axis=tf_negative_axis) 132 gather_value, gather_negative_axis_value = self.evaluate( 133 [gather, gather_negative_axis]) 134 gather_np = np.take(params, indices, axis) 135 self.assertAllEqual(gather_np, gather_value) 136 self.assertAllEqual(gather_np, gather_negative_axis_value) 137 expected_shape = (params.shape[:axis] + indices.shape + 138 params.shape[axis + 1:]) 139 self.assertEqual(expected_shape, gather.shape) 140 self.assertEqual(expected_shape, gather_negative_axis.shape) 141 142 # Test gradients 143 gather_grad = np.random.randn( 144 *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) 145 if dtype.is_complex: 146 gather_grad -= 1j * gather_grad 147 params_grad, indices_grad, axis_grad = gradients_impl.gradients( 148 gather, [tf_params, tf_indices, tf_axis], gather_grad) 149 self.assertIsNone(indices_grad) 150 self.assertIsNone(axis_grad) 151 if dtype.is_integer: 152 self.assertIsNone(params_grad) 153 continue 154 # For axis 0, we are able to create an efficient IndexedSlices for 155 # the gradient. 156 if axis == 0: 157 self.assertEqual(type(params_grad), ops.IndexedSlices) 158 params_grad = ops.convert_to_tensor(params_grad) 159 correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) 160 outer_dims = axis 161 inner_dims = len(shape) - axis - 1 162 gather_grad = gather_grad.reshape( 163 shape[:axis] + (indices.size,) + shape[axis + 1:]) 164 for source_index, dest_index in enumerate(indices.flat): 165 dest_slice = ((slice(None),) * outer_dims + (dest_index,) + 166 (slice(None),) * inner_dims) 167 source_slice = ((slice(None),) * outer_dims + (source_index,) + 168 (slice(None),) * inner_dims) 169 correct_params_grad[dest_slice] += gather_grad[source_slice] 170 self.assertAllClose( 171 correct_params_grad, 172 self.evaluate(params_grad), 173 atol=2e-6, 174 rtol=2e-6) 175 176 def testHigherRankGradientTape(self): 177 # We check that scalar and empty indices shapes work as well 178 shape = (2, 1, 3, 2) 179 for indices_shape in (), (0,), (2, 0), (2, 3): 180 for dtype in _TEST_TYPES: 181 for axis in range(len(shape)): 182 params = self._buildParams(np.random.randn(*shape), dtype) 183 indices = np.random.randint(shape[axis], size=indices_shape) 184 with self.subTest( 185 indices_shape=indices_shape, 186 dtype=dtype, 187 axis=axis, 188 indices=indices): 189 with backprop.GradientTape() as tape: 190 tf_params = constant_op.constant(params) 191 tf_indices = constant_op.constant(indices) 192 # Check that both positive and negative indices for axis work. 193 tf_axis = constant_op.constant(axis) 194 tape.watch(tf_params) 195 tape.watch(tf_indices) 196 tape.watch(tf_axis) 197 tf_negative_axis = constant_op.constant(-len(shape) + axis) 198 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) 199 gather_negative_axis = array_ops.gather( 200 tf_params, tf_indices, axis=tf_negative_axis) 201 gather_value, gather_negative_axis_value = self.evaluate( 202 [gather, gather_negative_axis]) 203 gather_np = np.take(params, indices, axis) 204 self.assertAllEqual(gather_np, gather_value) 205 self.assertAllEqual(gather_np, gather_negative_axis_value) 206 expected_shape = ( 207 params.shape[:axis] + indices.shape + params.shape[axis + 1:]) 208 self.assertEqual(expected_shape, gather.shape) 209 self.assertEqual(expected_shape, gather_negative_axis.shape) 210 211 # Test gradients 212 gather_grad = np.random.randn( 213 *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) 214 if dtype.is_complex: 215 gather_grad -= 1j * gather_grad 216 params_grad, indices_grad, axis_grad = tape.gradient( 217 gather, [tf_params, tf_indices, tf_axis], gather_grad) 218 self.assertIsNone(indices_grad) 219 self.assertIsNone(axis_grad) 220 if dtype.is_integer: 221 self.assertIsNone(params_grad) 222 continue 223 # For axis 0, we are able to create an efficient IndexedSlices for 224 # the gradient. 225 if axis == 0: 226 self.assertEqual(type(params_grad), ops.IndexedSlices) 227 params_grad = ops.convert_to_tensor(params_grad) 228 correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) 229 outer_dims = axis 230 inner_dims = len(shape) - axis - 1 231 gather_grad = gather_grad.reshape(shape[:axis] + (indices.size,) + 232 shape[axis + 1:]) 233 for source_index, dest_index in enumerate(indices.flat): 234 dest_slice = ((slice(None),) * outer_dims + (dest_index,) + 235 (slice(None),) * inner_dims) 236 source_slice = ((slice(None),) * outer_dims + (source_index,) + 237 (slice(None),) * inner_dims) 238 correct_params_grad[dest_slice] += gather_grad[source_slice] 239 self.assertAllClose( 240 correct_params_grad, 241 self.evaluate(params_grad), 242 atol=2e-6, 243 rtol=2e-6) 244 245 def testString(self): 246 params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) 247 self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0)) 248 self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1)) 249 250 def testUInt32AndUInt64(self): 251 for unsigned_type in (dtypes.uint32, dtypes.uint64): 252 with self.subTest(unsigned_type=unsigned_type): 253 params = self._buildParams( 254 np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type) 255 with self.cached_session(): 256 self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0)) 257 self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1)) 258 259 def testUnknownIndices(self): 260 # This test is purely a test for placeholder inputs which is only applicable 261 # in graph mode. 262 with ops.Graph().as_default(): 263 params = constant_op.constant([[0, 1, 2]]) 264 indices = array_ops.placeholder(dtypes.int32) 265 gather_t = array_ops.gather(params, indices) 266 self.assertEqual(None, gather_t.get_shape()) 267 268 def testUnknownAxis(self): 269 # This test is purely a test for placeholder inputs which is only applicable 270 # in graph mode. 271 with ops.Graph().as_default(): 272 params = constant_op.constant([[0, 1, 2]]) 273 indices = constant_op.constant([[0, 0], [0, 0]]) 274 axis = array_ops.placeholder(dtypes.int32) 275 gather_t = array_ops.gather(params, indices, axis=axis) 276 # Rank 2 params with rank 2 indices results in a rank 3 shape. 277 self.assertEqual([None, None, None], gather_t.shape.as_list()) 278 279 # If indices is also unknown the result rank is unknown. 280 indices = array_ops.placeholder(dtypes.int32) 281 gather_t = array_ops.gather(params, indices, axis=axis) 282 self.assertEqual(None, gather_t.shape) 283 284 def testBadIndicesType(self): 285 with self.assertRaisesRegex( 286 (TypeError, errors.InvalidArgumentError), 287 "float.* not in.* list of allowed values: int32, int64"): 288 self.evaluate(array_ops.gather([0], 0.)) 289 290 @test_util.disable_xla( 291 "Assertion inside an op is not supported in XLA. Instead XLA clamps the " 292 "index to be in bounds and returns the indexed value there (Don't rely " 293 "on this behavior).") 294 def testBadIndicesCPU(self): 295 with test_util.force_cpu(): 296 params = [[0, 1, 2], [3, 4, 5]] 297 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): 298 self.evaluate(array_ops.gather(params, [[7]], axis=0)) 299 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): 300 self.evaluate(array_ops.gather(params, [[7]], axis=1)) 301 302 def _disabledTestBadIndicesGPU(self): 303 # TODO disabled due to different behavior on GPU and CPU 304 # On GPU the bad indices do not raise error but fetch 0 values 305 if not test.is_gpu_available(): 306 return 307 with self.session(): 308 params = [[0, 1, 2], [3, 4, 5]] 309 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): 310 array_ops.gather(params, [[7]], axis=0).eval() 311 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): 312 array_ops.gather(params, [[7]], axis=1).eval() 313 314 def testBadAxis(self): 315 params = [0, 1, 2] 316 indices = 0 317 for bad_axis in (1, 2, -2): 318 # Shape inference can validate axis for known params rank. 319 with self.subTest(bad_axis=bad_axis): 320 with self.assertRaisesRegex( 321 (ValueError, errors.InvalidArgumentError), 322 "Shape must be at least rank .* but is rank 1"): 323 array_ops.gather(params, indices, axis=bad_axis) 324 325 def testEmptySlices(self): 326 for dtype in _TEST_TYPES: 327 for itype in np.int32, np.int64: 328 # Leading axis gather. 329 with self.subTest(dtype=dtype, itype=itype): 330 params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) 331 indices = np.array([3, 4], dtype=itype) 332 gather = array_ops.gather(params, indices, axis=0) 333 self.assertAllEqual(gather, np.zeros((2, 0, 0))) 334 335 # Middle axis gather. 336 params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype) 337 gather = array_ops.gather(params, indices, axis=1) 338 self.assertAllEqual(gather, np.zeros((0, 2, 0))) 339 340 # Trailing axis gather. 341 params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype) 342 gather = array_ops.gather(params, indices, axis=2) 343 self.assertAllEqual(gather, np.zeros((0, 0, 2))) 344 345 @parameterized.parameters([ 346 # batch_dims=0 (equivalent to tf.gather) 347 dict( # 2D indices 348 batch_dims=0, 349 params=[6, 7, 8, 9], 350 indices=[[2, 1], [0, 3]], 351 expected=[[8, 7], [6, 9]]), 352 dict( # 3D indices 353 batch_dims=0, 354 params=[6, 7, 8, 9], 355 indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]], 356 expected=[[[9, 7], [8, 6]], [[6, 9], [8, 8]]]), 357 dict( # 4D indices 358 batch_dims=0, 359 params=[8, 9], 360 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 361 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 362 expected=[[[[8, 9], [9, 8]], [[8, 8], [9, 9]]], 363 [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]), 364 365 # batch_dims=indices.shape.ndims - 1 366 # (equivalent to tf.compat.v1.batch_gather) 367 dict( # 2D indices (1 batch dim) 368 batch_dims=1, 369 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 370 indices=[[2, 1], [0, 3]], 371 expected=[[12, 11], [20, 23]]), 372 dict( # 3D indices (2 batch dims) 373 batch_dims=2, 374 params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]], 375 indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 376 expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]), 377 dict( # 2D indices (1 batch dim) 378 batch_dims=-1, 379 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 380 indices=[[2, 1], [0, 3]], 381 expected=[[12, 11], [20, 23]]), 382 dict( # 3D indices (2 batch dims) 383 batch_dims=-1, 384 params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]], 385 indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 386 expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]), 387 388 # batch_dims=indices.shape.ndims 389 dict( # 1D indices (1 batch dim) 390 batch_dims=1, 391 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 392 indices=[2, 1], 393 expected=[12, 21]), 394 dict( # 2D indices (2 batch dim) 395 batch_dims=2, 396 params=[[[100, 101, 102, 103], [110, 111, 112, 113]], 397 [[200, 201, 202, 203], [210, 211, 212, 213]]], 398 indices=[[2, 1], [0, 3]], 399 expected=[[102, 111], [200, 213]]), 400 401 # 0 < batch_dims < indices.shape.ndims - 1 402 dict( # 3D indices (1 batch dim) 403 batch_dims=1, 404 params=[[10, 11, 12, 13], [20, 21, 22, 23]], 405 indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]], 406 expected=[[[13, 11], [12, 10]], [[20, 23], [22, 22]]]), 407 dict( # 4D indices (1 batch dim) 408 batch_dims=1, 409 params=[[6, 7], [8, 9]], 410 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 411 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 412 expected=[[[[6, 7], [7, 6]], [[6, 6], [7, 7]]], 413 [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]), 414 dict( # 4D indices (2 batch dims) 415 batch_dims=2, 416 params=[[[2, 3], [4, 5]], [[6, 7], [8, 9]]], 417 indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]], 418 [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]], 419 expected=[[[[2, 3], [3, 2]], [[4, 4], [5, 5]]], 420 [[[7, 7], [6, 6]], [[8, 9], [9, 8]]]]), 421 422 # axis > 0 423 dict( # 3D indices, batch_dims=1, axis=2 424 # params.shape = [I1, J1, J2] = [2, 2, 3] 425 # indices.shape = [I1, K1, K2] = [2, 1, 5] 426 # result.shape = [I1, J1, K1, K2] = [2, 2, 1, 5] 427 batch_dims=1, 428 axis=2, 429 params=[[[10, 11, 12], [13, 14, 15]], [[20, 21, 22], [23, 24, 25]]], 430 indices=[[[0, 1, 2, 1, 0]], [[0, 1, 2, 1, 0]]], 431 expected=[[[[10, 11, 12, 11, 10]], [[13, 14, 15, 14, 13]]], 432 [[[20, 21, 22, 21, 20]], [[23, 24, 25, 24, 23]]]]), 433 dict( # 3D indices, batch_dims=None, axis=1 434 batch_dims=None, 435 axis=1, 436 params=[[10, 11, 12], [13, 14, 15]], 437 indices=[1, 0], 438 expected=[[11, 10], [14, 13]]), 439 dict( # 3D indices, batch_dims=-3, axis=1 440 batch_dims=-3, 441 axis=1, 442 params=[[0, 1, 2], [3, 4, 5]], 443 indices=[[[0, 1], [1, 0]]], 444 expected=[[[[0, 1], [1, 0]]], [[[3, 4], [4, 3]]]]), 445 ]) 446 @test_util.run_in_graph_and_eager_modes 447 def testBatchDims(self, params, indices, batch_dims, expected=None, 448 axis=None): 449 result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 450 self.assertAllEqual(expected, result) 451 452 # Test the gradients shape. 453 with backprop.GradientTape() as tape: 454 zeros = array_ops.zeros_like(params, dtype=dtypes.float32) 455 tape.watch(zeros) 456 values = zeros * 2 + zeros 457 result = array_ops.gather( 458 values, indices, axis=axis, batch_dims=batch_dims) 459 gradients = tape.gradient(result, zeros) 460 461 self.assertAllEqual(array_ops.shape(params), array_ops.shape(gradients)) 462 463 # Run the same test for strings. 464 params = _to_str_elements(params) 465 expected = _to_str_elements(expected) 466 result = array_ops.gather( 467 params, indices, axis=axis, batch_dims=batch_dims) 468 469 self.assertAllEqual(expected, result) 470 471 @parameterized.parameters([ 472 dict( 473 params_shape=[2, 3, 4, 5, 6, 7], 474 indices_shape=[2, 3, 8, 9, 10], 475 batch_dims=2, 476 axis=2, 477 output_shape=[2, 3, 8, 9, 10, 5, 6, 7] 478 # = params.shape[:2] + indices.shape[2:] + params.shape[3:] 479 ), 480 dict( 481 params_shape=[2, 3, 4, 5, 6, 7], 482 indices_shape=[2, 3, 8, 9, 10], 483 batch_dims=2, 484 axis=3, 485 output_shape=[2, 3, 4, 8, 9, 10, 6, 7] 486 # = params.shape[:3] + indices.shape[2:] + params.shape[4:] 487 ), 488 dict( 489 params_shape=[2, 3, 4, 5, 6, 7], 490 indices_shape=[2, 3, 8, 9, 10], 491 batch_dims=2, 492 axis=4, 493 output_shape=[2, 3, 4, 5, 8, 9, 10, 7] 494 # = params.shape[:4] + indices.shape[2:] + params.shape[5:] 495 ), 496 dict( 497 params_shape=[2, 3, 4, 5, 6, 7], 498 indices_shape=[2, 3, 8, 9, 10], 499 batch_dims=2, 500 axis=5, 501 output_shape=[2, 3, 4, 5, 6, 8, 9, 10] 502 # = params.shape[:5] + indices.shape[2:] + params.shape[6:] 503 ), 504 dict( 505 params_shape=[2, 3, 4, 5, 6, 7], 506 indices_shape=[2, 3, 8, 9, 10], 507 batch_dims=2, 508 axis=-4, 509 output_shape=[2, 3, 8, 9, 10, 5, 6, 7] 510 # = params.shape[:2] + indices.shape[2:] + params.shape[3:] 511 ), 512 dict( 513 params_shape=[2, 3, 4, 5, 6, 7], 514 indices_shape=[2, 3, 8, 9, 10], 515 batch_dims=2, 516 axis=-3, 517 output_shape=[2, 3, 4, 8, 9, 10, 6, 7] 518 # = params.shape[:3] + indices.shape[2:] + params.shape[4:] 519 ), 520 dict( 521 params_shape=[2, 3, 4, 5, 6, 7], 522 indices_shape=[2, 3, 8, 9, 10], 523 batch_dims=2, 524 axis=-2, 525 output_shape=[2, 3, 4, 5, 8, 9, 10, 7] 526 # = params.shape[:4] + indices.shape[2:] + params.shape[5:] 527 ), 528 dict( 529 params_shape=[2, 3, 4, 5, 6, 7], 530 indices_shape=[2, 3, 8, 9, 10], 531 batch_dims=2, 532 axis=-1, 533 output_shape=[2, 3, 4, 5, 6, 8, 9, 10] 534 # = params.shape[:5] + indices.shape[2:] + params.shape[6:] 535 ), 536 ]) 537 @test_util.run_in_graph_and_eager_modes 538 def testBatchDimsMatchesPythonBatching(self, params_shape, indices_shape, 539 batch_dims, axis, output_shape): 540 """Checks that batch_dims matches multiple calls to tf.gather().""" 541 # Generate a `params` tensor with the indicated shape. 542 params_size = np.prod(params_shape) 543 params = np.reshape(np.arange(params_size), params_shape) 544 545 # Generate an `indices` tensor with the indicated shape, where each index 546 # is within the appropriate range. 547 indices_size = np.prod(indices_shape) 548 indices = np.reshape(np.arange(indices_size), indices_shape) 549 indices = indices % params_shape[axis] 550 551 # Perform repeated (batched) gather operations with numpy, to find the 552 # expected result. 553 expected = self._batchNumpyGather(params, indices, axis, batch_dims) 554 555 # On Windows, we get an exception if we pass in the transformed numpy 556 # arrays ("Failed to convert numpy ndarray to a Tensor (Unsupported 557 # feed type)."); so convert them back to lists before calling tf.gather. 558 params = params.tolist() 559 indices = indices.tolist() 560 561 result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 562 self.assertAllEqual(output_shape, result.shape.as_list()) 563 self.assertAllEqual(expected, result) 564 565 # Run the same test for strings. 566 params = _to_str_elements(params) 567 expected = _to_str_elements(expected.tolist()) 568 result = array_ops.gather( 569 params, indices, axis=axis, batch_dims=batch_dims) 570 571 self.assertAllEqual(output_shape, result.shape.as_list()) 572 self.assertAllEqual(expected, result) 573 574 def _batchNumpyGather(self, params, indices, axis, batch_dims): 575 """Performs a batch gather by making recursive calls to np.take(). 576 577 This is used by testBatchDims() to construct the expected value. 578 579 Args: 580 params: A numpy array 581 indices: A numpy array 582 axis: An integer 583 batch_dims: An integer 584 Returns: 585 A numpy array 586 """ 587 if batch_dims == 0: 588 return np.take(params, indices, axis=axis) 589 self.assertEqual(params.shape[0], indices.shape[0]) 590 if axis > 0: 591 axis -= 1 592 return np.stack([ 593 self._batchNumpyGather(params[i], indices[i], axis, batch_dims - 1) 594 for i in range(params.shape[0]) 595 ]) 596 597 @test_util.run_v1_only("RefVariable is not supported in v2") 598 def testGatherRefVariable(self): 599 with self.cached_session(): 600 v = variables.RefVariable(constant_op.constant([[1, 2], [3, 4], [5, 6]])) 601 self.evaluate(variables.global_variables_initializer()) 602 gather = array_ops.gather(v, [0, 2]) 603 if not context.executing_eagerly(): # .op doesn't make sense in Eager 604 self.assertEqual("GatherV2", gather.op.name) 605 self.assertAllEqual([[1, 2], [5, 6]], gather) 606 607 @test_util.run_in_graph_and_eager_modes 608 def testGatherResourceVariable(self): 609 with self.cached_session(): 610 v = resource_variable_ops.ResourceVariable( 611 constant_op.constant([[1, 2], [3, 4], [5, 6]])) 612 self.evaluate(variables.global_variables_initializer()) 613 gather = array_ops.gather(v, [0, 2]) 614 if not context.executing_eagerly(): # .op doesn't make sense in Eager 615 self.assertEqual("ResourceGather", gather.op.inputs[0].op.type) 616 self.assertAllEqual([[1, 2], [5, 6]], gather) 617 618if __name__ == "__main__": 619 test.main() 620