• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.gather."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.eager import backprop
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import indexed_slices
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gradient_checker_v2
32from tensorflow.python.ops import gradients_impl
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import test
37
38_TEST_TYPES = (dtypes.int64, dtypes.float32,
39               dtypes.complex64, dtypes.complex128)
40
41# TODO(virimia): Add a benchmark for gather_v2, with batch_dims and axis set.
42
43
44def _to_str_elements(values):
45  """Converts the inner list elements to strings."""
46  if isinstance(values, list):
47    return [_to_str_elements(value) for value in values]
48  else:
49    return str(values).encode("utf-8")
50
51
52class GatherTest(test.TestCase, parameterized.TestCase):
53
54  def _buildParams(self, data, dtype):
55    data = data.astype(dtype.as_numpy_dtype)
56    # For complex types, add an index-dependent imaginary component so we can
57    # tell we got the right value.
58    if dtype.is_complex:
59      return data + 10j * data
60    return data
61
62  def testScalar1D(self):
63    with self.cached_session():
64      data = np.array([0, 1, 2, 3, 7, 5])
65      for dtype in _TEST_TYPES:
66        for indices in 4, [1, 2, 2, 4, 5]:
67          with self.subTest(dtype=dtype, indices=indices):
68            params_np = self._buildParams(data, dtype)
69            params = constant_op.constant(params_np)
70            indices_tf = constant_op.constant(indices)
71            gather_t = array_ops.gather(params, indices_tf)
72            gather_val = self.evaluate(gather_t)
73            np_val = params_np[indices]
74            self.assertAllEqual(np_val, gather_val)
75            self.assertEqual(np_val.shape, gather_t.get_shape())
76
77  def testScalar2D(self):
78    with self.session():
79      data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
80                       [9, 10, 11], [12, 13, 14]])
81      for dtype in _TEST_TYPES:
82        for axis in range(data.ndim):
83          with self.subTest(dtype=dtype, axis=axis):
84            params_np = self._buildParams(data, dtype)
85            params = constant_op.constant(params_np)
86            indices = constant_op.constant(2)
87            gather_t = array_ops.gather(params, indices, axis=axis)
88            gather_val = self.evaluate(gather_t)
89            self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val)
90            expected_shape = data.shape[:axis] + data.shape[axis + 1:]
91            self.assertEqual(expected_shape, gather_t.get_shape())
92
93  def testSimpleTwoD32(self):
94    with self.session():
95      data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
96                       [9, 10, 11], [12, 13, 14]])
97      for dtype in _TEST_TYPES:
98        for axis in range(data.ndim):
99          with self.subTest(dtype=dtype, axis=axis):
100            params_np = self._buildParams(data, dtype)
101            params = constant_op.constant(params_np)
102            # The indices must be in bounds for any axis.
103            indices = constant_op.constant([0, 1, 0, 2])
104            gather_t = array_ops.gather(params, indices, axis=axis)
105            gather_val = self.evaluate(gather_t)
106            self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis),
107                                gather_val)
108            expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
109            self.assertEqual(expected_shape, gather_t.get_shape())
110
111  def testHigherRank(self):
112    with ops.Graph().as_default():
113      # We check that scalar and empty indices shapes work as well
114      shape = (2, 1, 3, 2)
115      for indices_shape in (), (0,), (2, 0), (2, 3):
116        for dtype in _TEST_TYPES:
117          for axis in range(len(shape)):
118            params = self._buildParams(np.random.randn(*shape), dtype)
119            indices = np.random.randint(shape[axis], size=indices_shape)
120            with self.subTest(
121                indices_shape=indices_shape,
122                dtype=dtype,
123                axis=axis,
124                indices=indices):
125              tf_params = constant_op.constant(params)
126              tf_indices = constant_op.constant(indices)
127              # Check that both positive and negative indices for axis work.
128              tf_axis = constant_op.constant(axis)
129              tf_negative_axis = constant_op.constant(-len(shape) + axis)
130              gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
131              gather_negative_axis = array_ops.gather(
132                  tf_params, tf_indices, axis=tf_negative_axis)
133              gather_value, gather_negative_axis_value = self.evaluate(
134                  [gather, gather_negative_axis])
135              gather_np = np.take(params, indices, axis)
136              self.assertAllEqual(gather_np, gather_value)
137              self.assertAllEqual(gather_np, gather_negative_axis_value)
138              expected_shape = (params.shape[:axis] + indices.shape +
139                                params.shape[axis + 1:])
140              self.assertEqual(expected_shape, gather.shape)
141              self.assertEqual(expected_shape, gather_negative_axis.shape)
142
143              # Test gradients
144              gather_grad = np.random.randn(
145                  *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
146              if dtype.is_complex:
147                gather_grad -= 1j * gather_grad
148              params_grad, indices_grad, axis_grad = gradients_impl.gradients(
149                  gather, [tf_params, tf_indices, tf_axis], gather_grad)
150              self.assertIsNone(indices_grad)
151              self.assertIsNone(axis_grad)
152              if dtype.is_integer:
153                self.assertIsNone(params_grad)
154                continue
155              # For axis 0, we are able to create an efficient IndexedSlices for
156              # the gradient.
157              if axis == 0:
158                self.assertEqual(
159                    type(params_grad), indexed_slices.IndexedSlices)
160                params_grad = ops.convert_to_tensor(params_grad)
161              correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
162              outer_dims = axis
163              inner_dims = len(shape) - axis - 1
164              gather_grad = gather_grad.reshape(
165                  shape[:axis] + (indices.size,) + shape[axis + 1:])
166              for source_index, dest_index in enumerate(indices.flat):
167                dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
168                              (slice(None),) * inner_dims)
169                source_slice = ((slice(None),) * outer_dims + (source_index,) +
170                                (slice(None),) * inner_dims)
171                correct_params_grad[dest_slice] += gather_grad[source_slice]
172              self.assertAllClose(
173                  correct_params_grad,
174                  self.evaluate(params_grad),
175                  atol=2e-6,
176                  rtol=2e-6)
177
178  def testHigherRankGradientTape(self):
179    # We check that scalar and empty indices shapes work as well
180    shape = (2, 1, 3, 2)
181    for indices_shape in (), (0,), (2, 0), (2, 3):
182      for dtype in _TEST_TYPES:
183        for axis in range(len(shape)):
184          params = self._buildParams(np.random.randn(*shape), dtype)
185          indices = np.random.randint(shape[axis], size=indices_shape)
186          with self.subTest(
187              indices_shape=indices_shape,
188              dtype=dtype,
189              axis=axis,
190              indices=indices):
191            with backprop.GradientTape() as tape:
192              tf_params = constant_op.constant(params)
193              tf_indices = constant_op.constant(indices)
194              # Check that both positive and negative indices for axis work.
195              tf_axis = constant_op.constant(axis)
196              tape.watch(tf_params)
197              tape.watch(tf_indices)
198              tape.watch(tf_axis)
199              tf_negative_axis = constant_op.constant(-len(shape) + axis)
200              gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
201              gather_negative_axis = array_ops.gather(
202                  tf_params, tf_indices, axis=tf_negative_axis)
203              gather_value, gather_negative_axis_value = self.evaluate(
204                  [gather, gather_negative_axis])
205              gather_np = np.take(params, indices, axis)
206              self.assertAllEqual(gather_np, gather_value)
207              self.assertAllEqual(gather_np, gather_negative_axis_value)
208              expected_shape = (
209                  params.shape[:axis] + indices.shape + params.shape[axis + 1:])
210              self.assertEqual(expected_shape, gather.shape)
211              self.assertEqual(expected_shape, gather_negative_axis.shape)
212
213              # Test gradients
214              gather_grad = np.random.randn(
215                  *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
216              if dtype.is_complex:
217                gather_grad -= 1j * gather_grad
218            params_grad, indices_grad, axis_grad = tape.gradient(
219                gather, [tf_params, tf_indices, tf_axis], gather_grad)
220            self.assertIsNone(indices_grad)
221            self.assertIsNone(axis_grad)
222            if dtype.is_integer:
223              self.assertIsNone(params_grad)
224              continue
225            # For axis 0, we are able to create an efficient IndexedSlices for
226            # the gradient.
227            if axis == 0:
228              self.assertEqual(type(params_grad), indexed_slices.IndexedSlices)
229              params_grad = ops.convert_to_tensor(params_grad)
230            correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
231            outer_dims = axis
232            inner_dims = len(shape) - axis - 1
233            gather_grad = gather_grad.reshape(shape[:axis] + (indices.size,) +
234                                              shape[axis + 1:])
235            for source_index, dest_index in enumerate(indices.flat):
236              dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
237                            (slice(None),) * inner_dims)
238              source_slice = ((slice(None),) * outer_dims + (source_index,) +
239                              (slice(None),) * inner_dims)
240              correct_params_grad[dest_slice] += gather_grad[source_slice]
241            self.assertAllClose(
242                correct_params_grad,
243                self.evaluate(params_grad),
244                atol=2e-6,
245                rtol=2e-6)
246
247  def testString(self):
248    params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
249    self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0))
250    self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1))
251
252  def testUInt32AndUInt64(self):
253    for unsigned_type in (dtypes.uint32, dtypes.uint64):
254      with self.subTest(unsigned_type=unsigned_type):
255        params = self._buildParams(
256            np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
257        with self.cached_session():
258          self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0))
259          self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1))
260
261  def testUnknownIndices(self):
262    # This test is purely a test for placeholder inputs which is only applicable
263    # in graph mode.
264    with ops.Graph().as_default():
265      params = constant_op.constant([[0, 1, 2]])
266      indices = array_ops.placeholder(dtypes.int32)
267      gather_t = array_ops.gather(params, indices)
268      self.assertEqual(None, gather_t.get_shape())
269
270  def testUnknownAxis(self):
271    # This test is purely a test for placeholder inputs which is only applicable
272    # in graph mode.
273    with ops.Graph().as_default():
274      params = constant_op.constant([[0, 1, 2]])
275      indices = constant_op.constant([[0, 0], [0, 0]])
276      axis = array_ops.placeholder(dtypes.int32)
277      gather_t = array_ops.gather(params, indices, axis=axis)
278      # Rank 2 params with rank 2 indices results in a rank 3 shape.
279      self.assertEqual([None, None, None], gather_t.shape.as_list())
280
281      # If indices is also unknown the result rank is unknown.
282      indices = array_ops.placeholder(dtypes.int32)
283      gather_t = array_ops.gather(params, indices, axis=axis)
284      self.assertEqual(None, gather_t.shape)
285
286  def testBadIndicesType(self):
287    with self.assertRaisesRegex(
288        (TypeError, errors.InvalidArgumentError),
289        "float.* not in.* list of allowed values: int16, int32, int64"):
290      self.evaluate(array_ops.gather([0], 0.))
291
292  @test_util.disable_xla(
293      "Assertion inside an op is not supported in XLA. Instead XLA clamps the "
294      "index to be in bounds and returns the indexed value there (Don't rely "
295      "on this behavior).")
296  def testBadIndicesCPU(self):
297    with test_util.force_cpu():
298      params = [[0, 1, 2], [3, 4, 5]]
299      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
300        self.evaluate(array_ops.gather(params, [[7]], axis=0))
301      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
302        self.evaluate(array_ops.gather(params, [[7]], axis=1))
303
304  def _disabledTestBadIndicesGPU(self):
305    # TODO disabled due to different behavior on GPU and CPU
306    # On GPU the bad indices do not raise error but fetch 0 values
307    if not test.is_gpu_available():
308      return
309    with self.session():
310      params = [[0, 1, 2], [3, 4, 5]]
311      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
312        array_ops.gather(params, [[7]], axis=0).eval()
313      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
314        array_ops.gather(params, [[7]], axis=1).eval()
315
316  def testBadAxis(self):
317
318    @def_function.function(autograph=False, jit_compile=False)
319    def gather(x, indices, axis):
320      return array_ops.gather(x, indices, axis=axis)
321
322    @def_function.function(
323        autograph=False,
324        jit_compile=False,
325        input_signature=[
326            tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
327        ] * 3)
328    def gather_shape_inf_disabled(x, indices, axis):
329      return array_ops.gather(x, indices, axis=axis)
330
331    @def_function.function(
332        autograph=False,
333        jit_compile=True,
334        input_signature=[
335            tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
336        ] * 3)
337    def xla_gather(x, indices, axis):
338      return array_ops.gather(x, indices, axis=axis)
339
340    params = [0, 1, 2]
341    indices = 0
342    functions = [("array_ops.gather", array_ops.gather), ("gather", gather),
343                 ("gather_shape_inf_disabled", gather_shape_inf_disabled),
344                 ("xla_gather", xla_gather)]
345    for bad_axis in (1, 2, -2):
346      for fn_name, fn in functions:
347        # Shape inference can validate axis for known params rank.
348        with self.subTest(bad_axis=bad_axis, msg=fn_name, fn=fn):
349          with self.assertRaisesRegex(
350              (ValueError, errors.InvalidArgumentError),
351              "Shape must be at least rank .* but is rank 1"):
352            fn(params, indices, axis=bad_axis)
353
354  def testEmptySlices(self):
355    for dtype in _TEST_TYPES:
356      for itype in np.int32, np.int64:
357        # Leading axis gather.
358        with self.subTest(dtype=dtype, itype=itype):
359          params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
360          indices = np.array([3, 4], dtype=itype)
361          gather = array_ops.gather(params, indices, axis=0)
362          self.assertAllEqual(gather, np.zeros((2, 0, 0)))
363
364          # Middle axis gather.
365          params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
366          gather = array_ops.gather(params, indices, axis=1)
367          self.assertAllEqual(gather, np.zeros((0, 2, 0)))
368
369          # Trailing axis gather.
370          params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
371          gather = array_ops.gather(params, indices, axis=2)
372          self.assertAllEqual(gather, np.zeros((0, 0, 2)))
373
374  @parameterized.parameters([
375      # batch_dims=0 (equivalent to tf.gather)
376      dict(  # 2D indices
377          batch_dims=0,
378          params=[6, 7, 8, 9],
379          indices=[[2, 1], [0, 3]],
380          expected=[[8, 7], [6, 9]]),
381      dict(  # 3D indices
382          batch_dims=0,
383          params=[6, 7, 8, 9],
384          indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]],
385          expected=[[[9, 7], [8, 6]], [[6, 9], [8, 8]]]),
386      dict(  # 4D indices
387          batch_dims=0,
388          params=[8, 9],
389          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
390                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
391          expected=[[[[8, 9], [9, 8]], [[8, 8], [9, 9]]],
392                    [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]),
393
394      # batch_dims=indices.shape.ndims - 1
395      # (equivalent to tf.compat.v1.batch_gather)
396      dict(  # 2D indices (1 batch dim)
397          batch_dims=1,
398          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
399          indices=[[2, 1], [0, 3]],
400          expected=[[12, 11], [20, 23]]),
401      dict(  # 3D indices (2 batch dims)
402          batch_dims=2,
403          params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]],
404          indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
405          expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]),
406      dict(  # 2D indices (1 batch dim)
407          batch_dims=-1,
408          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
409          indices=[[2, 1], [0, 3]],
410          expected=[[12, 11], [20, 23]]),
411      dict(  # 3D indices (2 batch dims)
412          batch_dims=-1,
413          params=[[[100, 101], [110, 111]], [[200, 201], [210, 211]]],
414          indices=[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
415          expected=[[[100, 101], [111, 110]], [[200, 200], [211, 211]]]),
416
417      # batch_dims=indices.shape.ndims
418      dict(  # 1D indices (1 batch dim)
419          batch_dims=1,
420          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
421          indices=[2, 1],
422          expected=[12, 21]),
423      dict(  # 2D indices (2 batch dim)
424          batch_dims=2,
425          params=[[[100, 101, 102, 103], [110, 111, 112, 113]],
426                  [[200, 201, 202, 203], [210, 211, 212, 213]]],
427          indices=[[2, 1], [0, 3]],
428          expected=[[102, 111], [200, 213]]),
429
430      # 0 < batch_dims < indices.shape.ndims - 1
431      dict(  # 3D indices (1 batch dim)
432          batch_dims=1,
433          params=[[10, 11, 12, 13], [20, 21, 22, 23]],
434          indices=[[[3, 1], [2, 0]], [[0, 3], [2, 2]]],
435          expected=[[[13, 11], [12, 10]], [[20, 23], [22, 22]]]),
436      dict(  # 4D indices (1 batch dim)
437          batch_dims=1,
438          params=[[6, 7], [8, 9]],
439          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
440                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
441          expected=[[[[6, 7], [7, 6]], [[6, 6], [7, 7]]],
442                    [[[9, 9], [8, 8]], [[8, 9], [9, 8]]]]),
443      dict(  # 4D indices (2 batch dims)
444          batch_dims=2,
445          params=[[[2, 3], [4, 5]], [[6, 7], [8, 9]]],
446          indices=[[[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
447                   [[[1, 1], [0, 0]], [[0, 1], [1, 0]]]],
448          expected=[[[[2, 3], [3, 2]], [[4, 4], [5, 5]]],
449                    [[[7, 7], [6, 6]], [[8, 9], [9, 8]]]]),
450
451      # axis > 0
452      dict(  # 3D indices, batch_dims=1, axis=2
453          # params.shape  = [I1, J1, J2] = [2, 2, 3]
454          # indices.shape = [I1, K1, K2] = [2, 1, 5]
455          # result.shape  = [I1, J1, K1, K2] = [2, 2, 1, 5]
456          batch_dims=1,
457          axis=2,
458          params=[[[10, 11, 12], [13, 14, 15]], [[20, 21, 22], [23, 24, 25]]],
459          indices=[[[0, 1, 2, 1, 0]], [[0, 1, 2, 1, 0]]],
460          expected=[[[[10, 11, 12, 11, 10]], [[13, 14, 15, 14, 13]]],
461                    [[[20, 21, 22, 21, 20]], [[23, 24, 25, 24, 23]]]]),
462      dict(  # 3D indices, batch_dims=None, axis=1
463          batch_dims=None,
464          axis=1,
465          params=[[10, 11, 12], [13, 14, 15]],
466          indices=[1, 0],
467          expected=[[11, 10], [14, 13]]),
468      dict(  # 3D indices, batch_dims=-3, axis=1
469          batch_dims=-3,
470          axis=1,
471          params=[[0, 1, 2], [3, 4, 5]],
472          indices=[[[0, 1], [1, 0]]],
473          expected=[[[[0, 1], [1, 0]]], [[[3, 4], [4, 3]]]]),
474  ])
475  @test_util.run_in_graph_and_eager_modes
476  def testBatchDims(self, params, indices, batch_dims, expected=None,
477                    axis=None):
478    result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
479    self.assertAllEqual(expected, result)
480
481    # Test gradients
482    f64_params = math_ops.cast(params, dtypes.float64)
483    def gather(params):
484      return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
485    theoretical, numerical = gradient_checker_v2.compute_gradient(
486        gather, [f64_params])
487    self.assertAllClose(theoretical, numerical)
488
489    # Test gradients when input shapes are unknown
490    @def_function.function(input_signature=[
491        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float64),
492        tensor_spec.TensorSpec(shape=None, dtype=dtypes.int32)
493    ])
494    def gather_unknown_shapes(params, indices):
495      return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
496    if batch_dims is None or batch_dims >= 0:
497      theoretical, numerical = gradient_checker_v2.compute_gradient(
498          lambda p: gather_unknown_shapes(p, indices), [f64_params])
499      self.assertAllClose(theoretical, numerical)
500    else:
501      with self.assertRaisesRegex(
502          ValueError,
503          "Currently, it is unsupported to take the gradient of tf.gather"):
504        gradient_checker_v2.compute_gradient(
505            lambda p: gather_unknown_shapes(p, indices), [f64_params])
506
507    # Test the gradients shape.
508    with backprop.GradientTape() as tape:
509      zeros = array_ops.zeros_like(params, dtype=dtypes.float32)
510      tape.watch(zeros)
511      values = zeros * 2 + zeros
512      result = array_ops.gather(
513          values, indices, axis=axis, batch_dims=batch_dims)
514    gradients = tape.gradient(result, zeros)
515
516    self.assertAllEqual(array_ops.shape(params), array_ops.shape(gradients))
517
518    # Run the same test for strings.
519    params = _to_str_elements(params)
520    expected = _to_str_elements(expected)
521    result = array_ops.gather(
522        params, indices, axis=axis, batch_dims=batch_dims)
523
524    self.assertAllEqual(expected, result)
525
526  @parameterized.parameters([
527      dict(
528          params_shape=[2, 3, 4, 5, 6, 7],
529          indices_shape=[2, 3, 8, 9, 10],
530          batch_dims=2,
531          axis=2,
532          output_shape=[2, 3, 8, 9, 10, 5, 6, 7]
533          # = params.shape[:2] + indices.shape[2:] + params.shape[3:]
534          ),
535      dict(
536          params_shape=[2, 3, 4, 5, 6, 7],
537          indices_shape=[2, 3, 8, 9, 10],
538          batch_dims=2,
539          axis=3,
540          output_shape=[2, 3, 4, 8, 9, 10, 6, 7]
541          # = params.shape[:3] + indices.shape[2:] + params.shape[4:]
542          ),
543      dict(
544          params_shape=[2, 3, 4, 5, 6, 7],
545          indices_shape=[2, 3, 8, 9, 10],
546          batch_dims=2,
547          axis=4,
548          output_shape=[2, 3, 4, 5, 8, 9, 10, 7]
549          # = params.shape[:4] + indices.shape[2:] + params.shape[5:]
550          ),
551      dict(
552          params_shape=[2, 3, 4, 5, 6, 7],
553          indices_shape=[2, 3, 8, 9, 10],
554          batch_dims=2,
555          axis=5,
556          output_shape=[2, 3, 4, 5, 6, 8, 9, 10]
557          # = params.shape[:5] + indices.shape[2:] + params.shape[6:]
558          ),
559      dict(
560          params_shape=[2, 3, 4, 5, 6, 7],
561          indices_shape=[2, 3, 8, 9, 10],
562          batch_dims=2,
563          axis=-4,
564          output_shape=[2, 3, 8, 9, 10, 5, 6, 7]
565          # = params.shape[:2] + indices.shape[2:] + params.shape[3:]
566          ),
567      dict(
568          params_shape=[2, 3, 4, 5, 6, 7],
569          indices_shape=[2, 3, 8, 9, 10],
570          batch_dims=2,
571          axis=-3,
572          output_shape=[2, 3, 4, 8, 9, 10, 6, 7]
573          # = params.shape[:3] + indices.shape[2:] + params.shape[4:]
574          ),
575      dict(
576          params_shape=[2, 3, 4, 5, 6, 7],
577          indices_shape=[2, 3, 8, 9, 10],
578          batch_dims=2,
579          axis=-2,
580          output_shape=[2, 3, 4, 5, 8, 9, 10, 7]
581          # = params.shape[:4] + indices.shape[2:] + params.shape[5:]
582          ),
583      dict(
584          params_shape=[2, 3, 4, 5, 6, 7],
585          indices_shape=[2, 3, 8, 9, 10],
586          batch_dims=2,
587          axis=-1,
588          output_shape=[2, 3, 4, 5, 6, 8, 9, 10]
589          # = params.shape[:5] + indices.shape[2:] + params.shape[6:]
590          ),
591  ])
592  @test_util.run_in_graph_and_eager_modes
593  def testBatchDimsMatchesPythonBatching(self, params_shape, indices_shape,
594                                         batch_dims, axis, output_shape):
595    """Checks that batch_dims matches multiple calls to tf.gather()."""
596    # Generate a `params` tensor with the indicated shape.
597    params_size = np.prod(params_shape)
598    params = np.reshape(np.arange(params_size), params_shape)
599
600    # Generate an `indices` tensor with the indicated shape, where each index
601    # is within the appropriate range.
602    indices_size = np.prod(indices_shape)
603    indices = np.reshape(np.arange(indices_size), indices_shape)
604    indices = indices % params_shape[axis]
605
606    # Perform repeated (batched) gather operations with numpy, to find the
607    # expected result.
608    expected = self._batchNumpyGather(params, indices, axis, batch_dims)
609
610    # On Windows, we get an exception if we pass in the transformed numpy
611    # arrays ("Failed to convert numpy ndarray to a Tensor (Unsupported
612    # feed type)."); so convert them back to lists before calling tf.gather.
613    params = params.tolist()
614    indices = indices.tolist()
615
616    result = array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims)
617    self.assertAllEqual(output_shape, result.shape.as_list())
618    self.assertAllEqual(expected, result)
619
620    # Run the same test for strings.
621    params = _to_str_elements(params)
622    expected = _to_str_elements(expected.tolist())
623    result = array_ops.gather(
624        params, indices, axis=axis, batch_dims=batch_dims)
625
626    self.assertAllEqual(output_shape, result.shape.as_list())
627    self.assertAllEqual(expected, result)
628
629  def _batchNumpyGather(self, params, indices, axis, batch_dims):
630    """Performs a batch gather by making recursive calls to np.take().
631
632    This is used by testBatchDims() to construct the expected value.
633
634    Args:
635      params: A numpy array
636      indices: A numpy array
637      axis: An integer
638      batch_dims: An integer
639    Returns:
640      A numpy array
641    """
642    if batch_dims == 0:
643      return np.take(params, indices, axis=axis)
644    self.assertEqual(params.shape[0], indices.shape[0])
645    if axis > 0:
646      axis -= 1
647    return np.stack([
648        self._batchNumpyGather(params[i], indices[i], axis, batch_dims - 1)
649        for i in range(params.shape[0])
650    ])
651
652  @test_util.run_v1_only("RefVariable is not supported in v2")
653  def testGatherRefVariable(self):
654    with self.cached_session():
655      v = variables.RefVariable(constant_op.constant([[1, 2], [3, 4], [5, 6]]))
656      self.evaluate(variables.global_variables_initializer())
657      gather = array_ops.gather(v, [0, 2])
658      if not context.executing_eagerly():  # .op doesn't make sense in Eager
659        self.assertEqual("GatherV2", gather.op.name)
660      self.assertAllEqual([[1, 2], [5, 6]], gather)
661
662  @test_util.run_in_graph_and_eager_modes
663  def testGatherResourceVariable(self):
664    with self.cached_session():
665      v = resource_variable_ops.ResourceVariable(
666          constant_op.constant([[1, 2], [3, 4], [5, 6]]))
667      self.evaluate(variables.global_variables_initializer())
668      gather = array_ops.gather(v, [0, 2])
669      if not context.executing_eagerly():  # .op doesn't make sense in Eager
670        self.assertEqual("ResourceGather", gather.op.inputs[0].op.type)
671      self.assertAllEqual([[1, 2], [5, 6]], gather)
672
673if __name__ == "__main__":
674  test.main()
675