• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for manip_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import errors_impl
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import gradient_checker
26from tensorflow.python.ops import manip_ops
27from tensorflow.python.platform import test as test_lib
28
29# pylint: disable=g-import-not-at-top
30try:
31  from distutils.version import StrictVersion as Version
32  # numpy.roll for multiple shifts was introduced in numpy version 1.12.0
33  NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0")
34except ImportError:
35  NP_ROLL_CAN_MULTISHIFT = False
36# pylint: enable=g-import-not-at-top
37
38
39class RollTest(test_util.TensorFlowTestCase):
40
41  def _testRoll(self, np_input, shift, axis):
42    expected_roll = np.roll(np_input, shift, axis)
43    with self.test_session():
44      roll = manip_ops.roll(np_input, shift, axis)
45      self.assertAllEqual(roll.eval(), expected_roll)
46
47  def _testGradient(self, np_input, shift, axis):
48    with self.test_session():
49      inx = constant_op.constant(np_input.tolist())
50      xs = list(np_input.shape)
51      y = manip_ops.roll(inx, shift, axis)
52      # Expected y's shape to be the same
53      ys = xs
54      jacob_t, jacob_n = gradient_checker.compute_gradient(
55          inx, xs, y, ys, x_init_value=np_input)
56      self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
57
58  def _testAll(self, np_input, shift, axis):
59    self._testRoll(np_input, shift, axis)
60    if np_input.dtype == np.float32:
61      self._testGradient(np_input, shift, axis)
62
63  def testIntTypes(self):
64    for t in [np.int32, np.int64]:
65      self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
66      if NP_ROLL_CAN_MULTISHIFT:
67        self._testAll(
68            np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3],
69            [0, 1, 2])
70        self._testAll(
71            np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
72            [1, 2, 3])
73
74  def testFloatTypes(self):
75    for t in [np.float32, np.float64]:
76      self._testAll(np.random.rand(5).astype(t), 2, 0)
77      if NP_ROLL_CAN_MULTISHIFT:
78        self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
79        self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
80
81  def testComplexTypes(self):
82    for t in [np.complex64, np.complex128]:
83      x = np.random.rand(4, 4).astype(t)
84      self._testAll(x + 1j * x, 2, 0)
85      if NP_ROLL_CAN_MULTISHIFT:
86        x = np.random.rand(2, 5).astype(t)
87        self._testAll(x + 1j * x, [1, 2], [1, 0])
88        x = np.random.rand(3, 2, 1, 1).astype(t)
89        self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
90
91  def testRollInputMustVectorHigherRaises(self):
92    tensor = 7
93    shift = 1
94    axis = 0
95    with self.test_session():
96      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
97                                   "input must be 1-D or higher"):
98        manip_ops.roll(tensor, shift, axis).eval()
99
100  def testRollAxisMustBeScalarOrVectorRaises(self):
101    tensor = [[1, 2], [3, 4]]
102    shift = 1
103    axis = [[0, 1]]
104    with self.test_session():
105      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
106                                   "axis must be a scalar or a 1-D vector"):
107        manip_ops.roll(tensor, shift, axis).eval()
108
109  def testRollShiftMustBeScalarOrVectorRaises(self):
110    tensor = [[1, 2], [3, 4]]
111    shift = [[0, 1]]
112    axis = 1
113    with self.test_session():
114      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
115                                   "shift must be a scalar or a 1-D vector"):
116        manip_ops.roll(tensor, shift, axis).eval()
117
118  def testRollShiftAndAxisMustBeSameSizeRaises(self):
119    tensor = [[1, 2], [3, 4]]
120    shift = [1]
121    axis = [0, 1]
122    with self.test_session():
123      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
124                                   "shift and axis must have the same size"):
125        manip_ops.roll(tensor, shift, axis).eval()
126
127  def testRollAxisOutOfRangeRaises(self):
128    tensor = [1, 2]
129    shift = 1
130    axis = 1
131    with self.test_session():
132      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
133                                   "is out of range"):
134        manip_ops.roll(tensor, shift, axis).eval()
135
136
137if __name__ == "__main__":
138  test_lib.main()
139