• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.gather."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.platform import test
30
31_TEST_TYPES = (dtypes.int64, dtypes.float32,
32               dtypes.complex64, dtypes.complex128)
33
34
35class GatherTest(test.TestCase, parameterized.TestCase):
36
37  def _buildParams(self, data, dtype):
38    data = data.astype(dtype.as_numpy_dtype)
39    # For complex types, add an index-dependent imaginary component so we can
40    # tell we got the right value.
41    if dtype.is_complex:
42      return data + 10j * data
43    return data
44
45  @parameterized.parameters(dtypes.int32, dtypes.int64)
46  def testSimpleGather(self, indices_dtype):
47    data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13])
48    indices = [3, 4]
49    with self.session():
50      for dtype in _TEST_TYPES:
51        params_np = self._buildParams(data, dtype)
52        params = constant_op.constant(params_np)
53        indices_tf = constant_op.constant(indices, dtype=indices_dtype)
54        gather_t = array_ops.batch_gather(params, indices_tf)
55        expected_result = np.array([3, 7])
56        np_val = self._buildParams(expected_result, dtype)
57        gather_val = self.evaluate(gather_t)
58        self.assertAllEqual(np_val, gather_val)
59        self.assertEqual(np_val.shape, gather_t.get_shape())
60
61  @parameterized.parameters(dtypes.int32, dtypes.int64)
62  def test2DArray(self, indices_dtype):
63    data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]])
64    indices = [[3], [4]]
65    with self.session():
66      for dtype in _TEST_TYPES:
67        params_np = self._buildParams(data, dtype)
68        params = constant_op.constant(params_np)
69        indices_tf = constant_op.constant(indices, dtype=indices_dtype)
70        gather_t = array_ops.batch_gather(params, indices_tf)
71        expected_result = np.array([[3], [15]])
72        np_val = self._buildParams(expected_result, dtype)
73        gather_val = self.evaluate(gather_t)
74        self.assertAllEqual(np_val, gather_val)
75        self.assertEqual(np_val.shape, gather_t.get_shape())
76
77  def testHigherRank(self):
78    data = np.array([[[0, 1, 2], [3, 7, 5]], [[8, 9, 10], [11, 15, 13]]])
79    indices = [[[2, 0], [1, 2]], [[2, 0], [0, 1]]]
80    with self.session():
81      for dtype in _TEST_TYPES:
82        params_np = self._buildParams(data, dtype)
83        params = constant_op.constant(params_np)
84        indices_tf = constant_op.constant(indices)
85        gather_t = array_ops.batch_gather(params, indices_tf)
86        gather_val = self.evaluate(gather_t)
87        expected_result = np.array([[[2, 0], [7, 5]], [[10, 8], [11, 15]]])
88        np_val = self._buildParams(expected_result, dtype)
89        self.assertAllEqual(np_val, gather_val)
90        self.assertEqual(np_val.shape, gather_t.get_shape())
91
92  def testString(self):
93    params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
94    with self.cached_session():
95      indices_tf = constant_op.constant([1])
96      self.assertAllEqual(
97          [[b"qwer", b"uiop"]],
98          self.evaluate(array_ops.batch_gather(params, indices_tf)))
99
100  def testUnknownIndices(self):
101    # This test needs a placeholder which means we need to construct a graph.
102    with ops.Graph().as_default():
103      params = constant_op.constant([[0, 1, 2]])
104      indices = array_ops.placeholder(dtypes.int32, shape=[None, None])
105      gather_t = array_ops.batch_gather(params, indices)
106      self.assertEqual([1, None], gather_t.get_shape().as_list())
107
108  @test_util.disable_xla("Cannot force cpu placement for xla_gpu test")
109  def testBadIndicesCPU(self):
110    with ops.device_v2("cpu:0"):
111      params = [[0, 1, 2], [3, 4, 5]]
112      with self.assertRaisesOpError(r"indices\[0\] = 7 is not in \[0, 2\)"):
113        self.evaluate(array_ops.batch_gather(params, [7]))
114
115  def testEmptySlices(self):
116    with self.session():
117      for dtype in _TEST_TYPES:
118        for itype in np.int32, np.int64:
119          params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
120          indices = np.array([3, 4], dtype=itype)
121          self.assertAllEqual(
122              self.evaluate(array_ops.batch_gather(params, indices)),
123              np.zeros((2, 0, 0)))
124
125if __name__ == "__main__":
126  test.main()
127