1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for manip_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import errors_impl 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import gradient_checker 26from tensorflow.python.ops import manip_ops 27from tensorflow.python.platform import test as test_lib 28 29# pylint: disable=g-import-not-at-top 30try: 31 from distutils.version import StrictVersion as Version 32 # numpy.roll for multiple shifts was introduced in numpy version 1.12.0 33 NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0") 34except ImportError: 35 NP_ROLL_CAN_MULTISHIFT = False 36# pylint: enable=g-import-not-at-top 37 38 39class RollTest(test_util.TensorFlowTestCase): 40 41 def _testRoll(self, np_input, shift, axis): 42 expected_roll = np.roll(np_input, shift, axis) 43 with self.test_session(): 44 roll = manip_ops.roll(np_input, shift, axis) 45 self.assertAllEqual(roll.eval(), expected_roll) 46 47 def _testGradient(self, np_input, shift, axis): 48 with self.test_session(): 49 inx = constant_op.constant(np_input.tolist()) 50 xs = list(np_input.shape) 51 y = manip_ops.roll(inx, shift, axis) 52 # Expected y's shape to be the same 53 ys = xs 54 jacob_t, jacob_n = gradient_checker.compute_gradient( 55 inx, xs, y, ys, x_init_value=np_input) 56 self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) 57 58 def _testAll(self, np_input, shift, axis): 59 self._testRoll(np_input, shift, axis) 60 if np_input.dtype == np.float32: 61 self._testGradient(np_input, shift, axis) 62 63 def testIntTypes(self): 64 for t in [np.int32, np.int64]: 65 self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0) 66 if NP_ROLL_CAN_MULTISHIFT: 67 self._testAll( 68 np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3], 69 [0, 1, 2]) 70 self._testAll( 71 np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2], 72 [1, 2, 3]) 73 74 def testFloatTypes(self): 75 for t in [np.float32, np.float64]: 76 self._testAll(np.random.rand(5).astype(t), 2, 0) 77 if NP_ROLL_CAN_MULTISHIFT: 78 self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0]) 79 self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2]) 80 81 def testComplexTypes(self): 82 for t in [np.complex64, np.complex128]: 83 x = np.random.rand(4, 4).astype(t) 84 self._testAll(x + 1j * x, 2, 0) 85 if NP_ROLL_CAN_MULTISHIFT: 86 x = np.random.rand(2, 5).astype(t) 87 self._testAll(x + 1j * x, [1, 2], [1, 0]) 88 x = np.random.rand(3, 2, 1, 1).astype(t) 89 self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2]) 90 91 def testRollInputMustVectorHigherRaises(self): 92 tensor = 7 93 shift = 1 94 axis = 0 95 with self.test_session(): 96 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 97 "input must be 1-D or higher"): 98 manip_ops.roll(tensor, shift, axis).eval() 99 100 def testRollAxisMustBeScalarOrVectorRaises(self): 101 tensor = [[1, 2], [3, 4]] 102 shift = 1 103 axis = [[0, 1]] 104 with self.test_session(): 105 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 106 "axis must be a scalar or a 1-D vector"): 107 manip_ops.roll(tensor, shift, axis).eval() 108 109 def testRollShiftMustBeScalarOrVectorRaises(self): 110 tensor = [[1, 2], [3, 4]] 111 shift = [[0, 1]] 112 axis = 1 113 with self.test_session(): 114 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 115 "shift must be a scalar or a 1-D vector"): 116 manip_ops.roll(tensor, shift, axis).eval() 117 118 def testRollShiftAndAxisMustBeSameSizeRaises(self): 119 tensor = [[1, 2], [3, 4]] 120 shift = [1] 121 axis = [0, 1] 122 with self.test_session(): 123 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 124 "shift and axis must have the same size"): 125 manip_ops.roll(tensor, shift, axis).eval() 126 127 def testRollAxisOutOfRangeRaises(self): 128 tensor = [1, 2] 129 shift = 1 130 axis = 1 131 with self.test_session(): 132 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 133 "is out of range"): 134 manip_ops.roll(tensor, shift, axis).eval() 135 136 137if __name__ == "__main__": 138 test_lib.main() 139