1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.gather.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.platform import test 30 31_TEST_TYPES = (dtypes.int64, dtypes.float32, 32 dtypes.complex64, dtypes.complex128) 33 34 35class GatherTest(test.TestCase, parameterized.TestCase): 36 37 def _buildParams(self, data, dtype): 38 data = data.astype(dtype.as_numpy_dtype) 39 # For complex types, add an index-dependent imaginary component so we can 40 # tell we got the right value. 41 if dtype.is_complex: 42 return data + 10j * data 43 return data 44 45 @parameterized.parameters(dtypes.int32, dtypes.int64) 46 def testSimpleGather(self, indices_dtype): 47 data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13]) 48 indices = [3, 4] 49 with self.session(): 50 for dtype in _TEST_TYPES: 51 params_np = self._buildParams(data, dtype) 52 params = constant_op.constant(params_np) 53 indices_tf = constant_op.constant(indices, dtype=indices_dtype) 54 gather_t = array_ops.batch_gather(params, indices_tf) 55 expected_result = np.array([3, 7]) 56 np_val = self._buildParams(expected_result, dtype) 57 gather_val = self.evaluate(gather_t) 58 self.assertAllEqual(np_val, gather_val) 59 self.assertEqual(np_val.shape, gather_t.get_shape()) 60 61 @parameterized.parameters(dtypes.int32, dtypes.int64) 62 def test2DArray(self, indices_dtype): 63 data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]]) 64 indices = [[3], [4]] 65 with self.session(): 66 for dtype in _TEST_TYPES: 67 params_np = self._buildParams(data, dtype) 68 params = constant_op.constant(params_np) 69 indices_tf = constant_op.constant(indices, dtype=indices_dtype) 70 gather_t = array_ops.batch_gather(params, indices_tf) 71 expected_result = np.array([[3], [15]]) 72 np_val = self._buildParams(expected_result, dtype) 73 gather_val = self.evaluate(gather_t) 74 self.assertAllEqual(np_val, gather_val) 75 self.assertEqual(np_val.shape, gather_t.get_shape()) 76 77 def testHigherRank(self): 78 data = np.array([[[0, 1, 2], [3, 7, 5]], [[8, 9, 10], [11, 15, 13]]]) 79 indices = [[[2, 0], [1, 2]], [[2, 0], [0, 1]]] 80 with self.session(): 81 for dtype in _TEST_TYPES: 82 params_np = self._buildParams(data, dtype) 83 params = constant_op.constant(params_np) 84 indices_tf = constant_op.constant(indices) 85 gather_t = array_ops.batch_gather(params, indices_tf) 86 gather_val = self.evaluate(gather_t) 87 expected_result = np.array([[[2, 0], [7, 5]], [[10, 8], [11, 15]]]) 88 np_val = self._buildParams(expected_result, dtype) 89 self.assertAllEqual(np_val, gather_val) 90 self.assertEqual(np_val.shape, gather_t.get_shape()) 91 92 def testString(self): 93 params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) 94 with self.cached_session(): 95 indices_tf = constant_op.constant([1]) 96 self.assertAllEqual( 97 [[b"qwer", b"uiop"]], 98 self.evaluate(array_ops.batch_gather(params, indices_tf))) 99 100 def testUnknownIndices(self): 101 # This test needs a placeholder which means we need to construct a graph. 102 with ops.Graph().as_default(): 103 params = constant_op.constant([[0, 1, 2]]) 104 indices = array_ops.placeholder(dtypes.int32, shape=[None, None]) 105 gather_t = array_ops.batch_gather(params, indices) 106 self.assertEqual([1, None], gather_t.get_shape().as_list()) 107 108 @test_util.disable_xla("Cannot force cpu placement for xla_gpu test") 109 def testBadIndicesCPU(self): 110 with ops.device_v2("cpu:0"): 111 params = [[0, 1, 2], [3, 4, 5]] 112 with self.assertRaisesOpError(r"indices\[0\] = 7 is not in \[0, 2\)"): 113 self.evaluate(array_ops.batch_gather(params, [7])) 114 115 def testEmptySlices(self): 116 with self.session(): 117 for dtype in _TEST_TYPES: 118 for itype in np.int32, np.int64: 119 params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) 120 indices = np.array([3, 4], dtype=itype) 121 self.assertAllEqual( 122 self.evaluate(array_ops.batch_gather(params, indices)), 123 np.zeros((2, 0, 0))) 124 125if __name__ == "__main__": 126 test.main() 127