1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for operators with > 3 or arbitrary numbers of arguments.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import unittest 22 23import numpy as np 24 25from tensorflow.compiler.tests.xla_test import XLATestCase 26from tensorflow.python.framework import dtypes 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import googletest 30 31 32class NAryOpsTest(XLATestCase): 33 34 def _testNAry(self, op, args, expected, equality_fn=None): 35 with self.test_session() as session: 36 with self.test_scope(): 37 placeholders = [ 38 array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) 39 for arg in args 40 ] 41 feeds = {placeholders[i]: args[i] for i in range(0, len(args))} 42 output = op(placeholders) 43 result = session.run(output, feeds) 44 if not equality_fn: 45 equality_fn = self.assertAllClose 46 equality_fn(result, expected, rtol=1e-3) 47 48 def _nAryListCheck(self, results, expected, **kwargs): 49 self.assertEqual(len(results), len(expected)) 50 for (r, e) in zip(results, expected): 51 self.assertAllClose(r, e, **kwargs) 52 53 def _testNAryLists(self, op, args, expected): 54 self._testNAry(op, args, expected, equality_fn=self._nAryListCheck) 55 56 def testFloat(self): 57 self._testNAry(math_ops.add_n, 58 [np.array([[1, 2, 3]], dtype=np.float32)], 59 expected=np.array([[1, 2, 3]], dtype=np.float32)) 60 61 self._testNAry(math_ops.add_n, 62 [np.array([1, 2], dtype=np.float32), 63 np.array([10, 20], dtype=np.float32)], 64 expected=np.array([11, 22], dtype=np.float32)) 65 self._testNAry(math_ops.add_n, 66 [np.array([-4], dtype=np.float32), 67 np.array([10], dtype=np.float32), 68 np.array([42], dtype=np.float32)], 69 expected=np.array([48], dtype=np.float32)) 70 71 def testComplex(self): 72 for dtype in self.complex_types: 73 self._testNAry( 74 math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)], 75 expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)) 76 77 self._testNAry( 78 math_ops.add_n, [ 79 np.array([1 + 2j, 2 - 3j], dtype=dtype), 80 np.array([10j, 20], dtype=dtype) 81 ], 82 expected=np.array([1 + 12j, 22 - 3j], dtype=dtype)) 83 self._testNAry( 84 math_ops.add_n, [ 85 np.array([-4, 5j], dtype=dtype), 86 np.array([2 + 10j, -2], dtype=dtype), 87 np.array([42j, 3 + 3j], dtype=dtype) 88 ], 89 expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype)) 90 91 @unittest.skip("IdentityN is temporarily CompilationOnly as workaround") 92 def testIdentityN(self): 93 self._testNAryLists(array_ops.identity_n, 94 [np.array([[1, 2, 3]], dtype=np.float32)], 95 expected=[np.array([[1, 2, 3]], dtype=np.float32)]) 96 self._testNAryLists(array_ops.identity_n, 97 [np.array([[1, 2], [3, 4]], dtype=np.float32), 98 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], 99 expected=[ 100 np.array([[1, 2], [3, 4]], dtype=np.float32), 101 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) 102 self._testNAryLists(array_ops.identity_n, 103 [np.array([[1], [2], [3], [4]], dtype=np.int32), 104 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], 105 expected=[ 106 np.array([[1], [2], [3], [4]], dtype=np.int32), 107 np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) 108 109 def testConcat(self): 110 self._testNAry( 111 lambda x: array_ops.concat(x, 0), [ 112 np.array( 113 [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array( 114 [[7, 8, 9], [10, 11, 12]], dtype=np.float32) 115 ], 116 expected=np.array( 117 [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32)) 118 119 self._testNAry( 120 lambda x: array_ops.concat(x, 1), [ 121 np.array( 122 [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array( 123 [[7, 8, 9], [10, 11, 12]], dtype=np.float32) 124 ], 125 expected=np.array( 126 [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) 127 128 def testOneHot(self): 129 with self.test_session() as session, self.test_scope(): 130 indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) 131 op = array_ops.one_hot(indices, 132 np.int32(4), 133 on_value=np.float32(7), off_value=np.float32(3)) 134 output = session.run(op) 135 expected = np.array([[[3, 3, 7, 3], [3, 3, 3, 7]], 136 [[7, 3, 3, 3], [3, 7, 3, 3]]], 137 dtype=np.float32) 138 self.assertAllEqual(output, expected) 139 140 op = array_ops.one_hot(indices, 141 np.int32(4), 142 on_value=np.int32(2), off_value=np.int32(1), 143 axis=1) 144 output = session.run(op) 145 expected = np.array([[[1, 1], [1, 1], [2, 1], [1, 2]], 146 [[2, 1], [1, 2], [1, 1], [1, 1]]], 147 dtype=np.int32) 148 self.assertAllEqual(output, expected) 149 150 def testSplitV(self): 151 with self.test_session() as session: 152 with self.test_scope(): 153 output = session.run( 154 array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], 155 dtype=np.float32), 156 [2, 2], 1)) 157 expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32), 158 np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)] 159 self.assertAllEqual(output, expected) 160 161 def testStridedSlice(self): 162 self._testNAry(lambda x: array_ops.strided_slice(*x), 163 [np.array([[], [], []], dtype=np.float32), 164 np.array([1, 0], dtype=np.int32), 165 np.array([3, 0], dtype=np.int32), 166 np.array([1, 1], dtype=np.int32)], 167 expected=np.array([[], []], dtype=np.float32)) 168 169 if np.int64 in self.int_types: 170 self._testNAry( 171 lambda x: array_ops.strided_slice(*x), [ 172 np.array([[], [], []], dtype=np.float32), np.array( 173 [1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64), 174 np.array([1, 1], dtype=np.int64) 175 ], 176 expected=np.array([[], []], dtype=np.float32)) 177 178 self._testNAry(lambda x: array_ops.strided_slice(*x), 179 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 180 dtype=np.float32), 181 np.array([1, 1], dtype=np.int32), 182 np.array([3, 3], dtype=np.int32), 183 np.array([1, 1], dtype=np.int32)], 184 expected=np.array([[5, 6], [8, 9]], dtype=np.float32)) 185 186 self._testNAry(lambda x: array_ops.strided_slice(*x), 187 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 188 dtype=np.float32), 189 np.array([0, 2], dtype=np.int32), 190 np.array([2, 0], dtype=np.int32), 191 np.array([1, -1], dtype=np.int32)], 192 expected=np.array([[3, 2], [6, 5]], dtype=np.float32)) 193 194 self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1], 195 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 196 dtype=np.float32)], 197 expected=np.array([[[3, 2, 1]], [[6, 5, 4]]], 198 dtype=np.float32)) 199 200 self._testNAry(lambda x: x[0][1, :, array_ops.newaxis], 201 [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 202 dtype=np.float32)], 203 expected=np.array([[4], [5], [6]], dtype=np.float32)) 204 205 def testStridedSliceGrad(self): 206 # Tests cases where input shape is empty. 207 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 208 [np.array([], dtype=np.int32), 209 np.array([], dtype=np.int32), 210 np.array([], dtype=np.int32), 211 np.array([], dtype=np.int32), 212 np.float32(0.5)], 213 expected=np.array(np.float32(0.5), dtype=np.float32)) 214 215 # Tests case where input shape is non-empty, but gradients are empty. 216 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 217 [np.array([3], dtype=np.int32), 218 np.array([0], dtype=np.int32), 219 np.array([0], dtype=np.int32), 220 np.array([1], dtype=np.int32), 221 np.array([], dtype=np.float32)], 222 expected=np.array([0, 0, 0], dtype=np.float32)) 223 224 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 225 [np.array([3, 0], dtype=np.int32), 226 np.array([1, 0], dtype=np.int32), 227 np.array([3, 0], dtype=np.int32), 228 np.array([1, 1], dtype=np.int32), 229 np.array([[], []], dtype=np.float32)], 230 expected=np.array([[], [], []], dtype=np.float32)) 231 232 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 233 [np.array([3, 3], dtype=np.int32), 234 np.array([1, 1], dtype=np.int32), 235 np.array([3, 3], dtype=np.int32), 236 np.array([1, 1], dtype=np.int32), 237 np.array([[5, 6], [8, 9]], dtype=np.float32)], 238 expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]], 239 dtype=np.float32)) 240 241 def ssg_test(x): 242 return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4, 243 new_axis_mask=0x1) 244 245 self._testNAry(ssg_test, 246 [np.array([3, 1, 3], dtype=np.int32), 247 np.array([0, 0, 0, 2], dtype=np.int32), 248 np.array([0, 3, 1, -4], dtype=np.int32), 249 np.array([1, 2, 1, -3], dtype=np.int32), 250 np.array([[[1], [2]]], dtype=np.float32)], 251 expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]], 252 dtype=np.float32)) 253 254 ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15) 255 self._testNAry(ssg_test2, 256 [np.array([4, 4], dtype=np.int32), 257 np.array([0, 0, 0, 1, 0], dtype=np.int32), 258 np.array([0, 3, 0, 4, 0], dtype=np.int32), 259 np.array([1, 2, 1, 2, 1], dtype=np.int32), 260 np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)], 261 expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4], 262 [0, 0, 0, 0]], dtype=np.float32)) 263 264 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 265 [np.array([3, 3], dtype=np.int32), 266 np.array([0, 2], dtype=np.int32), 267 np.array([2, 0], dtype=np.int32), 268 np.array([1, -1], dtype=np.int32), 269 np.array([[1, 2], [3, 4]], dtype=np.float32)], 270 expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]], 271 dtype=np.float32)) 272 273 self._testNAry(lambda x: array_ops.strided_slice_grad(*x), 274 [np.array([3, 3], dtype=np.int32), 275 np.array([2, 2], dtype=np.int32), 276 np.array([0, 1], dtype=np.int32), 277 np.array([-1, -2], dtype=np.int32), 278 np.array([[1], [2]], dtype=np.float32)], 279 expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]], 280 dtype=np.float32)) 281 282if __name__ == "__main__": 283 googletest.main() 284