• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for Unstack Op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gradient_checker_v2
29from tensorflow.python.platform import test
30
31
32def np_split_squeeze(array, axis):
33  axis_len = array.shape[axis]
34  return [
35      np.squeeze(
36          arr, axis=(axis,)) for arr in np.split(
37              array, axis_len, axis=axis)
38  ]
39
40
41class UnstackOpTest(test.TestCase):
42
43  def randn(self, shape, dtype):
44    data = np.random.randn(*shape)
45    if dtype == np.bool_:
46      return data < 0  # Naive casting yields True with P(1)!
47    else:
48      return data.astype(dtype)
49
50  def unstackReference(self, data, axis):
51    """Use numpy primitives to implement unstack equivalent."""
52    result = []
53    rank = len(data.shape)
54    axis = axis + rank if axis < 0 else axis
55    for k in range(data.shape[axis]):
56      axis = rank + axis if axis < 0 else axis
57      # Slice in axis dimension of k'th slice.
58      # e.g. if rank=4 k=2, axis=2 then equivalent of data[:,:,2,:]
59      # Give error with loop context
60      slice_spec = tuple(
61          slice(None) if i != axis else k for i in range(rank))
62      result.append(data.__getitem__(slice_spec))
63    return result
64
65  def testSimple(self):
66    np.random.seed(7)
67    for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
68      rank = len(shape)
69      for axis in range(-rank, rank):
70        for dtype in [
71            np.bool_, np.float16, np.float32, np.float64, np.uint8, np.int32,
72            np.int64
73        ]:
74          data = self.randn(shape, dtype)
75          # Convert data to a single tensorflow tensor
76          x = constant_op.constant(data)
77
78          # Unstack into a list of tensors
79          ref = self.unstackReference(data, axis)
80          cs = array_ops.unstack(x, axis=axis)
81          self.assertEqual(type(cs), list)
82          self.assertEqual(len(cs), shape[axis])
83          for k, c in enumerate(cs):
84            with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype):
85              self.assertAllEqual(ref[k], self.evaluate(c))
86
87  def testSimpleGpu(self):
88    if not test_util.is_gpu_available():
89      self.skipTest('No GPU available')
90
91    np.random.seed(7)
92    with test_util.force_gpu():
93      for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
94        rank = len(shape)
95        for axis in range(-rank, rank):
96          for dtype in [
97              np.bool_, np.float16, np.float32, np.float64, np.uint8, np.int32,
98              np.int64
99          ]:
100            data = self.randn(shape, dtype)
101            # Convert data to a single tensorflow tensor
102            x = constant_op.constant(data)
103            # Unstack into a list of tensors
104            ref = self.unstackReference(data, axis)
105            cs = array_ops.unstack(x, axis=axis)
106            self.assertEqual(type(cs), list)
107            self.assertEqual(len(cs), shape[axis])
108            for k, c in enumerate(cs):
109              # Give error with loop context
110              with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype):
111                self.assertAllEqual(ref[k], self.evaluate(c))
112
113  def testGradientsAxis0(self):
114    for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
115      data = np.random.randn(*shape)
116      x = constant_op.constant(data)
117
118      for i in xrange(shape[0]):
119        def func(x, shape=shape, i=i):
120          return array_ops.unstack(x, num=shape[0])[i]
121
122        with self.cached_session():
123          err = gradient_checker_v2.max_error(
124              *gradient_checker_v2.compute_gradient(func, [x]))
125          self.assertLess(err, 1e-6)
126
127  def testGradientsAxis1(self):
128    for shape in (2, 3), (3, 2), (4, 3, 2):
129      data = np.random.randn(*shape)
130      x = constant_op.constant(data)
131
132      for i in xrange(shape[1]):
133        def func(x, shape=shape, i=i):
134          return array_ops.unstack(x, num=shape[1], axis=1)[i]
135
136        with self.cached_session():
137          err = gradient_checker_v2.max_error(
138              *gradient_checker_v2.compute_gradient(func, [x]))
139          self.assertLess(err, 1e-6)
140
141  def testInferNum(self):
142    for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
143      x = array_ops.ones(shape, dtype=np.float32)
144      cs = array_ops.unstack(x)
145      self.assertEqual(type(cs), list)
146      self.assertEqual(len(cs), shape[0])
147
148  def testCannotInferNumFromUnknownShape(self):
149    # Testing unknown shape in graph mode.
150    with ops.Graph().as_default():
151      x = array_ops.placeholder(np.float32)
152      with self.assertRaisesRegex(ValueError,
153                                  r'Cannot infer num from shape <unknown>'):
154        array_ops.unstack(x)
155
156  def testUnknownShapeOkWithNum(self):
157    # Testing unknown shape in graph mode.
158    with ops.Graph().as_default():
159      x = array_ops.placeholder(np.float32)
160      array_ops.unstack(x, num=2)
161
162  def testCannotInferNumFromNoneShape(self):
163    # Testing unknown shape in graph mode.
164    with ops.Graph().as_default():
165      x = array_ops.placeholder(np.float32, shape=(None,))
166      with self.assertRaisesRegex(
167          ValueError, r'Cannot infer num from shape \((\?|None),\)'):
168        array_ops.unstack(x)
169
170  def testAgainstNumpy(self):
171    # For 1 to 5 dimensions.
172    for i in range(1, 6):
173      a = np.random.random(np.random.permutation(i) + 1)
174
175      # For all the possible axis to split it, including negative indices.
176      for j in range(-i, i):
177        expected = np_split_squeeze(a, j)
178
179        actual_unstack = self.evaluate(array_ops.unstack(a, axis=j))
180
181        self.assertAllEqual(expected, actual_unstack)
182
183  def testAxis0Default(self):
184    a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
185    unstacked = self.evaluate(array_ops.unstack(a))
186
187    self.assertEqual(len(unstacked), 2)
188    self.assertAllEqual(unstacked[0], [1, 2, 3])
189    self.assertAllEqual(unstacked[1], [4, 5, 6])
190
191  def testAxisOutOfRange(self):
192    a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
193    with self.assertRaisesRegex(ValueError, r'axis = 2 not in \[-2, 2\)'):
194      array_ops.unstack(a, axis=2)
195
196  def testAxisOutOfNegativeRange(self):
197    a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
198    with self.assertRaisesRegex(ValueError, r'axis = -3 not in \[-2, 2\)'):
199      array_ops.unstack(a, axis=-3)
200
201  def testZeroLengthDim(self):
202    x = array_ops.zeros(shape=(0, 1, 2))
203    y = self.evaluate(array_ops.unstack(x, axis=1)[0])
204    self.assertEqual(y.shape, (0, 2))
205
206  def testComplexGpu(self):
207    if not test_util.is_gpu_available():
208      self.skipTest('No GPU available')
209
210    np.random.seed(7)
211    with test_util.force_gpu():
212      for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
213        for dtype in [np.complex64, np.complex128]:
214          data = np.random.randn(*shape).astype(dtype)
215          # Convert data to a single tensorflow tensor
216          x = constant_op.constant(data)
217          # Unstack into a list of tensors
218          cs = array_ops.unstack(x, num=shape[0])
219          self.assertEqual(type(cs), list)
220          self.assertEqual(len(cs), shape[0])
221          cs = [self.evaluate(c) for c in cs]
222          self.assertAllEqual(cs, data)
223
224
225if __name__ == '__main__':
226  test.main()
227