• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.gather."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python.eager import backprop
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import gradients_impl
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variables
35from tensorflow.python.platform import test
36
37_TEST_TYPES = (dtypes.int64, dtypes.float32,
38               dtypes.complex64, dtypes.complex128)
39
40# TODO(virimia): Add a benchmark for gather_v2, with batch_dims and axis set.
41
42
43def _to_str_elements(values):
44  """Converts the inner list elements to strings."""
45  if isinstance(values, list):
46    return [_to_str_elements(value) for value in values]
47  else:
48    return str(values).encode("utf-8")
49
50
51class GatherTest(test.TestCase, parameterized.TestCase):
52
53  def _buildParams(self, data, dtype):
54    data = data.astype(dtype.as_numpy_dtype)
55    # For complex types, add an index-dependent imaginary component so we can
56    # tell we got the right value.
57    if dtype.is_complex:
58      return data + 10j * data
59    return data
60
61  def testScalar1D(self):
62    with self.cached_session():
63      data = np.array([0, 1, 2, 3, 7, 5])
64      for dtype in _TEST_TYPES:
65        for indices in 4, [1, 2, 2, 4, 5]:
66          with self.subTest(dtype=dtype, indices=indices):
67            params_np = self._buildParams(data, dtype)
68            params = constant_op.constant(params_np)
69            indices_tf = constant_op.constant(indices)
70            gather_t = array_ops.gather(params, indices_tf)
71            gather_val = self.evaluate(gather_t)
72            np_val = params_np[indices]
73            self.assertAllEqual(np_val, gather_val)
74            self.assertEqual(np_val.shape, gather_t.get_shape())
75
76  def testScalar2D(self):
77    with self.session():
78      data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
79                       [9, 10, 11], [12, 13, 14]])
80      for dtype in _TEST_TYPES:
81        for axis in range(data.ndim):
82          with self.subTest(dtype=dtype, axis=axis):
83            params_np = self._buildParams(data, dtype)
84            params = constant_op.constant(params_np)
85            indices = constant_op.constant(2)
86            gather_t = array_ops.gather(params, indices, axis=axis)
87            gather_val = self.evaluate(gather_t)
88            self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val)
89            expected_shape = data.shape[:axis] + data.shape[axis + 1:]
90            self.assertEqual(expected_shape, gather_t.get_shape())
91
92  def testSimpleTwoD32(self):
93    with self.session():
94      data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
95                       [9, 10, 11], [12, 13, 14]])
96      for dtype in _TEST_TYPES:
97        for axis in range(data.ndim):
98          with self.subTest(dtype=dtype, axis=axis):
99            params_np = self._buildParams(data, dtype)
100            params = constant_op.constant(params_np)
101            # The indices must be in bounds for any axis.
102            indices = constant_op.constant([0, 1, 0, 2])
103            gather_t = array_ops.gather(params, indices, axis=axis)
104            gather_val = self.evaluate(gather_t)
105            self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis),
106                                gather_val)
107            expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
108            self.assertEqual(expected_shape, gather_t.get_shape())
109
110  def testHigherRank(self):
111    with ops.Graph().as_default():
112      # We check that scalar and empty indices shapes work as well
113      shape = (2, 1, 3, 2)
114      for indices_shape in (), (0,), (2, 0), (2, 3):
115        for dtype in _TEST_TYPES:
116          for axis in range(len(shape)):
117            params = self._buildParams(np.random.randn(*shape), dtype)
118            indices = np.random.randint(shape[axis], size=indices_shape)
119            with self.subTest(
120                indices_shape=indices_shape,
121                dtype=dtype,
122                axis=axis,
123                indices=indices):
124              tf_params = constant_op.constant(params)
125              tf_indices = constant_op.constant(indices)
126              # Check that both positive and negative indices for axis work.
127              tf_axis = constant_op.constant(axis)
128              tf_negative_axis = constant_op.constant(-len(shape) + axis)
129              gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
130              gather_negative_axis = array_ops.gather(
131                  tf_params, tf_indices, axis=tf_negative_axis)
132              gather_value, gather_negative_axis_value = self.evaluate(
133                  [gather, gather_negative_axis])
134              gather_np = np.take(params, indices, axis)
135              self.assertAllEqual(gather_np, gather_value)
136              self.assertAllEqual(gather_np, gather_negative_axis_value)
137              expected_shape = (params.shape[:axis] + indices.shape +
138                                params.shape[axis + 1:])
139              self.assertEqual(expected_shape, gather.shape)
140              self.assertEqual(expected_shape, gather_negative_axis.shape)
141
142              # Test gradients
143              gather_grad = np.random.randn(
144                  *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
145              if dtype.is_complex:
146                gather_grad -= 1j * gather_grad
147              params_grad, indices_grad, axis_grad = gradients_impl.gradients(
148                  gather, [tf_params, tf_indices, tf_axis], gather_grad)
149              self.assertIsNone(indices_grad)
150              self.assertIsNone(axis_grad)
151              if dtype.is_integer:
152                self.assertIsNone(params_grad)
153                continue
154              # For axis 0, we are able to create an efficient IndexedSlices for
155              # the gradient.
156              if axis == 0:
157                self.assertEqual(type(params_grad), ops.IndexedSlices)
158                params_grad = ops.convert_to_tensor(params_grad)
159              correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
160              outer_dims = axis
161              inner_dims = len(shape) - axis - 1
162              gather_grad = gather_grad.reshape(
163                  shape[:axis] + (indices.size,) + shape[axis + 1:])
164              for source_index, dest_index in enumerate(indices.flat):
165                dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
166                              (slice(None),) * inner_dims)
167                source_slice = ((slice(None),) * outer_dims + (source_index,) +
168                                (slice(None),) * inner_dims)
169                correct_params_grad[dest_slice] += gather_grad[source_slice]
170              self.assertAllClose(
171                  correct_params_grad,
172                  self.evaluate(params_grad),
173                  atol=2e-6,
174                  rtol=2e-6)
175
176  def testHigherRankGradientTape(self):
177    # We check that scalar and empty indices shapes work as well
178    shape = (2, 1, 3, 2)
179    for indices_shape in (), (0,), (2, 0), (2, 3):
180      for dtype in _TEST_TYPES:
181        for axis in range(len(shape)):
182          params = self._buildParams(np.random.randn(*shape), dtype)
183          indices = np.random.randint(shape[axis], size=indices_shape)
184          with self.subTest(
185              indices_shape=indices_shape,
186              dtype=dtype,
187              axis=axis,
188              indices=indices):
189            with backprop.GradientTape() as tape:
190              tf_params = constant_op.constant(params)
191              tf_indices = constant_op.constant(indices)
192              # Check that both positive and negative indices for axis work.
193              tf_axis = constant_op.constant(axis)
194              tape.watch(tf_params)
195              tape.watch(tf_indices)
196              tape.watch(tf_axis)
197              tf_negative_axis = constant_op.constant(-len(shape) + axis)
198              gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
199              gather_negative_axis = array_ops.gather(
200                  tf_params, tf_indices, axis=tf_negative_axis)
201              gather_value, gather_negative_axis_value = self.evaluate(
202                  [gather, gather_negative_axis])
203              gather_np = np.take(params, indices, axis)
204              self.assertAllEqual(gather_np, gather_value)
205              self.assertAllEqual(gather_np, gather_negative_axis_value)
206              expected_shape = (
207                  params.shape[:axis] + indices.shape + params.shape[axis + 1:])
208              self.assertEqual(expected_shape, gather.shape)
209              self.assertEqual(expected_shape, gather_negative_axis.shape)
210
211              # Test gradients
212              gather_grad = np.random.randn(
213                  *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
214              if dtype.is_complex:
215                gather_grad -= 1j * gather_grad
216            params_grad, indices_grad, axis_grad = tape.gradient(
217                gather, [tf_params, tf_indices, tf_axis], gather_grad)
218            self.assertIsNone(indices_grad)
219            self.assertIsNone(axis_grad)
220            if dtype.is_integer:
221              self.assertIsNone(params_grad)
222              continue
223            # For axis 0, we are able to create an efficient IndexedSlices for
224            # the gradient.
225            if axis == 0:
226              self.assertEqual(type(params_grad), ops.IndexedSlices)
227              params_grad = ops.convert_to_tensor(params_grad)
228            correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
229            outer_dims = axis
230            inner_dims = len(shape) - axis - 1
231            gather_grad = gather_grad.reshape(shape[:axis] + (indices.size,) +
232                                              shape[axis + 1:])
233            for source_index, dest_index in enumerate(indices.flat):
234              dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
235                            (slice(None),) * inner_dims)
236              source_slice = ((slice(None),) * outer_dims + (source_index,) +
237                              (slice(None),) * inner_dims)
238              correct_params_grad[dest_slice] += gather_grad[source_slice]
239            self.assertAllClose(
240                correct_params_grad,
241                self.evaluate(params_grad),
242                atol=2e-6,
243                rtol=2e-6)
244
245  def testString(self):
246    params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
247    self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0))
248    self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1))
249
250  def testUInt32AndUInt64(self):
251    for unsigned_type in (dtypes.uint32, dtypes.uint64):
252      with self.subTest(unsigned_type=unsigned_type):
253        params = self._buildParams(
254            np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
255        with self.cached_session():
256          self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0))
257          self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1))
258
259  def testUnknownIndices(self):
260    # This test is purely a test for placeholder inputs which is only applicable
261    # in graph mode.
262    with ops.Graph().as_default():
263      params = constant_op.constant([[0, 1, 2]])
264      indices = array_ops.placeholder(dtypes.int32)
265      gather_t = array_ops.gather(params, indices)
266      self.assertEqual(None, gather_t.get_shape())
267
268  def testUnknownAxis(self):
269    # This test is purely a test for placeholder inputs which is only applicable
270    # in graph mode.
271    with ops.Graph().as_default():
272      params = constant_op.constant([[0, 1, 2]])
273      indices = constant_op.constant([[0, 0], [0, 0]])
274      axis = array_ops.placeholder(dtypes.int32)
275      gather_t = array_ops.gather(params, indices, axis=axis)
276      # Rank 2 params with rank 2 indices results in a rank 3 shape.
277      self.assertEqual([None, None, None], gather_t.shape.as_list())
278
279      # If indices is also unknown the result rank is unknown.
280      indices = array_ops.placeholder(dtypes.int32)
281      gather_t = array_ops.gather(params, indices, axis=axis)
282      self.assertEqual(None, gather_t.shape)
283
284  def testBadIndicesType(self):
285    with self.assertRaisesRegex(
286        (TypeError, errors.InvalidArgumentError),
287        "float.* not in.* list of allowed values: int32, int64"):
288      self.evaluate(array_ops.gather([0], 0.))
289
290  @test_util.disable_xla(
291      "Assertion inside an op is not supported in XLA. Instead XLA clamps the "
292      "index to be in bounds and returns the indexed value there (Don't rely "
293      "on this behavior).")
294  def testBadIndicesCPU(self):
295    with test_util.force_cpu():
296      params = [[0, 1, 2], [3, 4, 5]]
297      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
298        self.evaluate(array_ops.gather(params, [[7]], axis=0))
299      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
300        self.evaluate(array_ops.gather(params, [[7]], axis=1))
301
302  def _disabledTestBadIndicesGPU(self):
303    # TODO disabled due to different behavior on GPU and CPU
304    # On GPU the bad indices do not raise error but fetch 0 values
305    if not test.is_gpu_available():
306      return
307    with self.session():
308      params = [[0, 1, 2], [3, 4, 5]]
309      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
310        array_ops.gather(params, [[7]], axis=0).eval()
311      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
312        array_ops.gather(params, [[7]], axis=1).eval()
313
314  def testBadAxis(self):
315    params = [0, 1, 2]
316    indices = 0
317    for bad_axis in (1, 2, -2):
318      # Shape inference can validate axis for known params rank.
319      with self.subTest(bad_axis=bad_axis):
320        with self.assertRaisesRegex(
321            (ValueError, errors.InvalidArgumentError),
322            "Shape must be at least rank .* but is rank 1"):
323          array_ops.gather(params, indices, axis=bad_axis)
324
325  def testEmptySlices(self):
326    for dtype in _TEST_TYPES:
327      for itype in np.int32, np.int64:
328        # Leading axis gather.
329        with self.subTest(dtype=dtype, itype=itype):
330          params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
331          indices = np.array([3, 4], dtype=itype)
332          gather = array_ops.gather(params, indices, axis=0)
333          self.assertAllEqual(gather, np.zeros((2, 0, 0)))
334
335          # Middle axis gather.
336          params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
337          gather = array_ops.gather(params, indices, axis=1)
338          self.assertAllEqual(gather, np.zeros((0, 2, 0)))
339
340          # Trailing axis gather.
341          params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
342          gather = array_ops.gather(params, indices, axis=2)
343          self.assertAllEqual(gather, np.zeros((0, 0, 2)))
344
345  @parameterized.parameters([
346      # batch_dims=0 (equivalent to tf.gather)
347      dict(  # 2D indices
348          batch_dims=0,
349          params=[6, 7, 8, 9],
350          indices=[[2, 1], [0, 3]],
351          expected=[[8, 7], [6, 9]]),
352      dict(  # 3D indices
353          batch_dims=0,
354          params=[6, 7, 8, 9],
355          indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]],
356          expected=[[[9, 7], [8, 6]], [[6, 9], [8, 8]]]),
357      dict(  # 4D indices
358          batch_dims=0,
359          params=[8, 9],
360          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
361                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
362          expected=[[[[8, 9], [9, 8]], [[8, 8], [9, 9]]],
363                    [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]),
364
365      # batch_dims=indices.shape.ndims - 1
366      # (equivalent to tf.compat.v1.batch_gather)
367      dict(  # 2D indices (1 batch dim)
368          batch_dims=1,
369          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
370          indices=[[2, 1], [0, 3]],
371          expected=[[12, 11], [20, 23]]),
372      dict(  # 3D indices (2 batch dims)
373          batch_dims=2,
374          params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]],
375          indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
376          expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]),
377      dict(  # 2D indices (1 batch dim)
378          batch_dims=-1,
379          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
380          indices=[[2, 1], [0, 3]],
381          expected=[[12, 11], [20, 23]]),
382      dict(  # 3D indices (2 batch dims)
383          batch_dims=-1,
384          params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]],
385          indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
386          expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]),
387
388      # batch_dims=indices.shape.ndims
389      dict(  # 1D indices (1 batch dim)
390          batch_dims=1,
391          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
392          indices=[2, 1],
393          expected=[12, 21]),
394      dict(  # 2D indices (2 batch dim)
395          batch_dims=2,
396          params=[[[100, 101, 102, 103], [110, 111, 112, 113]],
397                  [[200, 201, 202, 203], [210, 211, 212, 213]]],
398          indices=[[2, 1], [0, 3]],
399          expected=[[102, 111], [200, 213]]),
400
401      # 0 < batch_dims < indices.shape.ndims - 1
402      dict(  # 3D indices (1 batch dim)
403          batch_dims=1,
404          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
405          indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]],
406          expected=[[[13, 11], [12, 10]], [[20, 23], [22, 22]]]),
407      dict(  # 4D indices (1 batch dim)
408          batch_dims=1,
409          params=[[6, 7], [8, 9]],
410          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
411                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
412          expected=[[[[6, 7], [7, 6]], [[6, 6], [7, 7]]],
413                    [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]),
414      dict(  # 4D indices (2 batch dims)
415          batch_dims=2,
416          params=[[[2, 3], [4, 5]], [[6, 7], [8, 9]]],
417          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
418                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
419          expected=[[[[2, 3], [3, 2]], [[4, 4], [5, 5]]],
420                    [[[7, 7], [6, 6]], [[8, 9], [9, 8]]]]),
421
422      # axis > 0
423      dict(  # 3D indices, batch_dims=1, axis=2
424          # params.shape  = [I1, J1, J2] = [2, 2, 3]
425          # indices.shape = [I1, K1, K2] = [2, 1, 5]
426          # result.shape  = [I1, J1, K1, K2] = [2, 2, 1, 5]
427          batch_dims=1,
428          axis=2,
429          params=[[[10, 11, 12], [13, 14, 15]], [[20, 21, 22], [23, 24, 25]]],
430          indices=[[[0, 1, 2, 1, 0]], [[0, 1, 2, 1, 0]]],
431          expected=[[[[10, 11, 12, 11, 10]], [[13, 14, 15, 14, 13]]],
432                    [[[20, 21, 22, 21, 20]], [[23, 24, 25, 24, 23]]]]),
433      dict(  # 3D indices, batch_dims=None, axis=1
434          batch_dims=None,
435          axis=1,
436          params=[[10, 11, 12], [13, 14, 15]],
437          indices=[1, 0],
438          expected=[[11, 10], [14, 13]]),
439      dict(  # 3D indices, batch_dims=-3, axis=1
440          batch_dims=-3,
441          axis=1,
442          params=[[0, 1, 2], [3, 4, 5]],
443          indices=[[[0, 1], [1, 0]]],
444          expected=[[[[0, 1], [1, 0]]], [[[3, 4], [4, 3]]]]),
445  ])
446  @test_util.run_in_graph_and_eager_modes
447  def testBatchDims(self, params, indices, batch_dims, expected=None,
448                    axis=None):
449    result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
450    self.assertAllEqual(expected, result)
451
452    # Test the gradients shape.
453    with backprop.GradientTape() as tape:
454      zeros = array_ops.zeros_like(params, dtype=dtypes.float32)
455      tape.watch(zeros)
456      values = zeros * 2 + zeros
457      result = array_ops.gather(
458          values, indices, axis=axis, batch_dims=batch_dims)
459    gradients = tape.gradient(result, zeros)
460
461    self.assertAllEqual(array_ops.shape(params), array_ops.shape(gradients))
462
463    # Run the same test for strings.
464    params = _to_str_elements(params)
465    expected = _to_str_elements(expected)
466    result = array_ops.gather(
467        params, indices, axis=axis, batch_dims=batch_dims)
468
469    self.assertAllEqual(expected, result)
470
471  @parameterized.parameters([
472      dict(
473          params_shape=[2, 3, 4, 5, 6, 7],
474          indices_shape=[2, 3, 8, 9, 10],
475          batch_dims=2,
476          axis=2,
477          output_shape=[2, 3, 8, 9, 10, 5, 6, 7]
478          # = params.shape[:2] + indices.shape[2:] + params.shape[3:]
479          ),
480      dict(
481          params_shape=[2, 3, 4, 5, 6, 7],
482          indices_shape=[2, 3, 8, 9, 10],
483          batch_dims=2,
484          axis=3,
485          output_shape=[2, 3, 4, 8, 9, 10, 6, 7]
486          # = params.shape[:3] + indices.shape[2:] + params.shape[4:]
487          ),
488      dict(
489          params_shape=[2, 3, 4, 5, 6, 7],
490          indices_shape=[2, 3, 8, 9, 10],
491          batch_dims=2,
492          axis=4,
493          output_shape=[2, 3, 4, 5, 8, 9, 10, 7]
494          # = params.shape[:4] + indices.shape[2:] + params.shape[5:]
495          ),
496      dict(
497          params_shape=[2, 3, 4, 5, 6, 7],
498          indices_shape=[2, 3, 8, 9, 10],
499          batch_dims=2,
500          axis=5,
501          output_shape=[2, 3, 4, 5, 6, 8, 9, 10]
502          # = params.shape[:5] + indices.shape[2:] + params.shape[6:]
503          ),
504      dict(
505          params_shape=[2, 3, 4, 5, 6, 7],
506          indices_shape=[2, 3, 8, 9, 10],
507          batch_dims=2,
508          axis=-4,
509          output_shape=[2, 3, 8, 9, 10, 5, 6, 7]
510          # = params.shape[:2] + indices.shape[2:] + params.shape[3:]
511          ),
512      dict(
513          params_shape=[2, 3, 4, 5, 6, 7],
514          indices_shape=[2, 3, 8, 9, 10],
515          batch_dims=2,
516          axis=-3,
517          output_shape=[2, 3, 4, 8, 9, 10, 6, 7]
518          # = params.shape[:3] + indices.shape[2:] + params.shape[4:]
519          ),
520      dict(
521          params_shape=[2, 3, 4, 5, 6, 7],
522          indices_shape=[2, 3, 8, 9, 10],
523          batch_dims=2,
524          axis=-2,
525          output_shape=[2, 3, 4, 5, 8, 9, 10, 7]
526          # = params.shape[:4] + indices.shape[2:] + params.shape[5:]
527          ),
528      dict(
529          params_shape=[2, 3, 4, 5, 6, 7],
530          indices_shape=[2, 3, 8, 9, 10],
531          batch_dims=2,
532          axis=-1,
533          output_shape=[2, 3, 4, 5, 6, 8, 9, 10]
534          # = params.shape[:5] + indices.shape[2:] + params.shape[6:]
535          ),
536  ])
537  @test_util.run_in_graph_and_eager_modes
538  def testBatchDimsMatchesPythonBatching(self, params_shape, indices_shape,
539                                         batch_dims, axis, output_shape):
540    """Checks that batch_dims matches multiple calls to tf.gather()."""
541    # Generate a `params` tensor with the indicated shape.
542    params_size = np.prod(params_shape)
543    params = np.reshape(np.arange(params_size), params_shape)
544
545    # Generate an `indices` tensor with the indicated shape, where each index
546    # is within the appropriate range.
547    indices_size = np.prod(indices_shape)
548    indices = np.reshape(np.arange(indices_size), indices_shape)
549    indices = indices % params_shape[axis]
550
551    # Perform repeated (batched) gather operations with numpy, to find the
552    # expected result.
553    expected = self._batchNumpyGather(params, indices, axis, batch_dims)
554
555    # On Windows, we get an exception if we pass in the transformed numpy
556    # arrays ("Failed to convert numpy ndarray to a Tensor (Unsupported
557    # feed type)."); so convert them back to lists before calling tf.gather.
558    params = params.tolist()
559    indices = indices.tolist()
560
561    result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
562    self.assertAllEqual(output_shape, result.shape.as_list())
563    self.assertAllEqual(expected, result)
564
565    # Run the same test for strings.
566    params = _to_str_elements(params)
567    expected = _to_str_elements(expected.tolist())
568    result = array_ops.gather(
569        params, indices, axis=axis, batch_dims=batch_dims)
570
571    self.assertAllEqual(output_shape, result.shape.as_list())
572    self.assertAllEqual(expected, result)
573
574  def _batchNumpyGather(self, params, indices, axis, batch_dims):
575    """Performs a batch gather by making recursive calls to np.take().
576
577    This is used by testBatchDims() to construct the expected value.
578
579    Args:
580      params: A numpy array
581      indices: A numpy array
582      axis: An integer
583      batch_dims: An integer
584    Returns:
585      A numpy array
586    """
587    if batch_dims == 0:
588      return np.take(params, indices, axis=axis)
589    self.assertEqual(params.shape[0], indices.shape[0])
590    if axis > 0:
591      axis -= 1
592    return np.stack([
593        self._batchNumpyGather(params[i], indices[i], axis, batch_dims - 1)
594        for i in range(params.shape[0])
595    ])
596
597  @test_util.run_v1_only("RefVariable is not supported in v2")
598  def testGatherRefVariable(self):
599    with self.cached_session():
600      v = variables.RefVariable(constant_op.constant([[1, 2], [3, 4], [5, 6]]))
601      self.evaluate(variables.global_variables_initializer())
602      gather = array_ops.gather(v, [0, 2])
603      if not context.executing_eagerly():  # .op doesn't make sense in Eager
604        self.assertEqual("GatherV2", gather.op.name)
605      self.assertAllEqual([[1, 2], [5, 6]], gather)
606
607  @test_util.run_in_graph_and_eager_modes
608  def testGatherResourceVariable(self):
609    with self.cached_session():
610      v = resource_variable_ops.ResourceVariable(
611          constant_op.constant([[1, 2], [3, 4], [5, 6]]))
612      self.evaluate(variables.global_variables_initializer())
613      gather = array_ops.gather(v, [0, 2])
614      if not context.executing_eagerly():  # .op doesn't make sense in Eager
615        self.assertEqual("ResourceGather", gather.op.inputs[0].op.type)
616      self.assertAllEqual([[1, 2], [5, 6]], gather)
617
618if __name__ == "__main__":
619  test.main()
620