1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for Unstack Op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gradient_checker_v2 29from tensorflow.python.platform import test 30 31 32def np_split_squeeze(array, axis): 33 axis_len = array.shape[axis] 34 return [ 35 np.squeeze( 36 arr, axis=(axis,)) for arr in np.split( 37 array, axis_len, axis=axis) 38 ] 39 40 41class UnstackOpTest(test.TestCase): 42 43 def randn(self, shape, dtype): 44 data = np.random.randn(*shape) 45 if dtype == np.bool_: 46 return data < 0 # Naive casting yields True with P(1)! 47 else: 48 return data.astype(dtype) 49 50 def unstackReference(self, data, axis): 51 """Use numpy primitives to implement unstack equivalent.""" 52 result = [] 53 rank = len(data.shape) 54 axis = axis + rank if axis < 0 else axis 55 for k in range(data.shape[axis]): 56 axis = rank + axis if axis < 0 else axis 57 # Slice in axis dimension of k'th slice. 58 # e.g. if rank=4 k=2, axis=2 then equivalent of data[:,:,2,:] 59 # Give error with loop context 60 slice_spec = tuple( 61 slice(None) if i != axis else k for i in range(rank)) 62 result.append(data.__getitem__(slice_spec)) 63 return result 64 65 def testSimple(self): 66 np.random.seed(7) 67 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 68 rank = len(shape) 69 for axis in range(-rank, rank): 70 for dtype in [ 71 np.bool_, np.float16, np.float32, np.float64, np.uint8, np.int32, 72 np.int64 73 ]: 74 data = self.randn(shape, dtype) 75 # Convert data to a single tensorflow tensor 76 x = constant_op.constant(data) 77 78 # Unstack into a list of tensors 79 ref = self.unstackReference(data, axis) 80 cs = array_ops.unstack(x, axis=axis) 81 self.assertEqual(type(cs), list) 82 self.assertEqual(len(cs), shape[axis]) 83 for k, c in enumerate(cs): 84 with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype): 85 self.assertAllEqual(ref[k], self.evaluate(c)) 86 87 def testSimpleGpu(self): 88 if not test_util.is_gpu_available(): 89 self.skipTest('No GPU available') 90 91 np.random.seed(7) 92 with test_util.force_gpu(): 93 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 94 rank = len(shape) 95 for axis in range(-rank, rank): 96 for dtype in [ 97 np.bool_, np.float16, np.float32, np.float64, np.uint8, np.int32, 98 np.int64 99 ]: 100 data = self.randn(shape, dtype) 101 # Convert data to a single tensorflow tensor 102 x = constant_op.constant(data) 103 # Unstack into a list of tensors 104 ref = self.unstackReference(data, axis) 105 cs = array_ops.unstack(x, axis=axis) 106 self.assertEqual(type(cs), list) 107 self.assertEqual(len(cs), shape[axis]) 108 for k, c in enumerate(cs): 109 # Give error with loop context 110 with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype): 111 self.assertAllEqual(ref[k], self.evaluate(c)) 112 113 def testGradientsAxis0(self): 114 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 115 data = np.random.randn(*shape) 116 x = constant_op.constant(data) 117 118 for i in xrange(shape[0]): 119 def func(x, shape=shape, i=i): 120 return array_ops.unstack(x, num=shape[0])[i] 121 122 with self.cached_session(): 123 err = gradient_checker_v2.max_error( 124 *gradient_checker_v2.compute_gradient(func, [x])) 125 self.assertLess(err, 1e-6) 126 127 def testGradientsAxis1(self): 128 for shape in (2, 3), (3, 2), (4, 3, 2): 129 data = np.random.randn(*shape) 130 x = constant_op.constant(data) 131 132 for i in xrange(shape[1]): 133 def func(x, shape=shape, i=i): 134 return array_ops.unstack(x, num=shape[1], axis=1)[i] 135 136 with self.cached_session(): 137 err = gradient_checker_v2.max_error( 138 *gradient_checker_v2.compute_gradient(func, [x])) 139 self.assertLess(err, 1e-6) 140 141 def testInferNum(self): 142 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 143 x = array_ops.ones(shape, dtype=np.float32) 144 cs = array_ops.unstack(x) 145 self.assertEqual(type(cs), list) 146 self.assertEqual(len(cs), shape[0]) 147 148 def testCannotInferNumFromUnknownShape(self): 149 # Testing unknown shape in graph mode. 150 with ops.Graph().as_default(): 151 x = array_ops.placeholder(np.float32) 152 with self.assertRaisesRegex(ValueError, 153 r'Cannot infer num from shape <unknown>'): 154 array_ops.unstack(x) 155 156 def testUnknownShapeOkWithNum(self): 157 # Testing unknown shape in graph mode. 158 with ops.Graph().as_default(): 159 x = array_ops.placeholder(np.float32) 160 array_ops.unstack(x, num=2) 161 162 def testCannotInferNumFromNoneShape(self): 163 # Testing unknown shape in graph mode. 164 with ops.Graph().as_default(): 165 x = array_ops.placeholder(np.float32, shape=(None,)) 166 with self.assertRaisesRegex( 167 ValueError, r'Cannot infer num from shape \((\?|None),\)'): 168 array_ops.unstack(x) 169 170 def testAgainstNumpy(self): 171 # For 1 to 5 dimensions. 172 for i in range(1, 6): 173 a = np.random.random(np.random.permutation(i) + 1) 174 175 # For all the possible axis to split it, including negative indices. 176 for j in range(-i, i): 177 expected = np_split_squeeze(a, j) 178 179 actual_unstack = self.evaluate(array_ops.unstack(a, axis=j)) 180 181 self.assertAllEqual(expected, actual_unstack) 182 183 def testAxis0Default(self): 184 a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a') 185 unstacked = self.evaluate(array_ops.unstack(a)) 186 187 self.assertEqual(len(unstacked), 2) 188 self.assertAllEqual(unstacked[0], [1, 2, 3]) 189 self.assertAllEqual(unstacked[1], [4, 5, 6]) 190 191 def testAxisOutOfRange(self): 192 a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a') 193 with self.assertRaisesRegex(ValueError, r'axis = 2 not in \[-2, 2\)'): 194 array_ops.unstack(a, axis=2) 195 196 def testAxisOutOfNegativeRange(self): 197 a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a') 198 with self.assertRaisesRegex(ValueError, r'axis = -3 not in \[-2, 2\)'): 199 array_ops.unstack(a, axis=-3) 200 201 def testZeroLengthDim(self): 202 x = array_ops.zeros(shape=(0, 1, 2)) 203 y = self.evaluate(array_ops.unstack(x, axis=1)[0]) 204 self.assertEqual(y.shape, (0, 2)) 205 206 def testComplexGpu(self): 207 if not test_util.is_gpu_available(): 208 self.skipTest('No GPU available') 209 210 np.random.seed(7) 211 with test_util.force_gpu(): 212 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 213 for dtype in [np.complex64, np.complex128]: 214 data = np.random.randn(*shape).astype(dtype) 215 # Convert data to a single tensorflow tensor 216 x = constant_op.constant(data) 217 # Unstack into a list of tensors 218 cs = array_ops.unstack(x, num=shape[0]) 219 self.assertEqual(type(cs), list) 220 self.assertEqual(len(cs), shape[0]) 221 cs = [self.evaluate(c) for c in cs] 222 self.assertAllEqual(cs, data) 223 224 225if __name__ == '__main__': 226 test.main() 227