• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for operators with > 3 or arbitrary numbers of arguments."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import unittest
22
23import numpy as np
24
25from tensorflow.compiler.tests.xla_test import XLATestCase
26from tensorflow.python.framework import dtypes
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import googletest
30
31
32class NAryOpsTest(XLATestCase):
33
34  def _testNAry(self, op, args, expected, equality_fn=None):
35    with self.test_session() as session:
36      with self.test_scope():
37        placeholders = [
38            array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
39            for arg in args
40        ]
41        feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
42        output = op(placeholders)
43      result = session.run(output, feeds)
44      if not equality_fn:
45        equality_fn = self.assertAllClose
46      equality_fn(result, expected, rtol=1e-3)
47
48  def _nAryListCheck(self, results, expected, **kwargs):
49    self.assertEqual(len(results), len(expected))
50    for (r, e) in zip(results, expected):
51      self.assertAllClose(r, e, **kwargs)
52
53  def _testNAryLists(self, op, args, expected):
54    self._testNAry(op, args, expected, equality_fn=self._nAryListCheck)
55
56  def testFloat(self):
57    self._testNAry(math_ops.add_n,
58                   [np.array([[1, 2, 3]], dtype=np.float32)],
59                   expected=np.array([[1, 2, 3]], dtype=np.float32))
60
61    self._testNAry(math_ops.add_n,
62                   [np.array([1, 2], dtype=np.float32),
63                    np.array([10, 20], dtype=np.float32)],
64                   expected=np.array([11, 22], dtype=np.float32))
65    self._testNAry(math_ops.add_n,
66                   [np.array([-4], dtype=np.float32),
67                    np.array([10], dtype=np.float32),
68                    np.array([42], dtype=np.float32)],
69                   expected=np.array([48], dtype=np.float32))
70
71  def testComplex(self):
72    for dtype in self.complex_types:
73      self._testNAry(
74          math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)],
75          expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype))
76
77      self._testNAry(
78          math_ops.add_n, [
79              np.array([1 + 2j, 2 - 3j], dtype=dtype),
80              np.array([10j, 20], dtype=dtype)
81          ],
82          expected=np.array([1 + 12j, 22 - 3j], dtype=dtype))
83      self._testNAry(
84          math_ops.add_n, [
85              np.array([-4, 5j], dtype=dtype),
86              np.array([2 + 10j, -2], dtype=dtype),
87              np.array([42j, 3 + 3j], dtype=dtype)
88          ],
89          expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype))
90
91  @unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
92  def testIdentityN(self):
93    self._testNAryLists(array_ops.identity_n,
94                        [np.array([[1, 2, 3]], dtype=np.float32)],
95                        expected=[np.array([[1, 2, 3]], dtype=np.float32)])
96    self._testNAryLists(array_ops.identity_n,
97                        [np.array([[1, 2], [3, 4]], dtype=np.float32),
98                         np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)],
99                        expected=[
100                            np.array([[1, 2], [3, 4]], dtype=np.float32),
101                            np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)])
102    self._testNAryLists(array_ops.identity_n,
103                        [np.array([[1], [2], [3], [4]], dtype=np.int32),
104                         np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)],
105                        expected=[
106                            np.array([[1], [2], [3], [4]], dtype=np.int32),
107                            np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)])
108
109  def testConcat(self):
110    self._testNAry(
111        lambda x: array_ops.concat(x, 0), [
112            np.array(
113                [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
114                    [[7, 8, 9], [10, 11, 12]], dtype=np.float32)
115        ],
116        expected=np.array(
117            [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32))
118
119    self._testNAry(
120        lambda x: array_ops.concat(x, 1), [
121            np.array(
122                [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
123                    [[7, 8, 9], [10, 11, 12]], dtype=np.float32)
124        ],
125        expected=np.array(
126            [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
127
128  def testOneHot(self):
129    with self.test_session() as session, self.test_scope():
130      indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32))
131      op = array_ops.one_hot(indices,
132                             np.int32(4),
133                             on_value=np.float32(7), off_value=np.float32(3))
134      output = session.run(op)
135      expected = np.array([[[3, 3, 7, 3], [3, 3, 3, 7]],
136                           [[7, 3, 3, 3], [3, 7, 3, 3]]],
137                          dtype=np.float32)
138      self.assertAllEqual(output, expected)
139
140      op = array_ops.one_hot(indices,
141                             np.int32(4),
142                             on_value=np.int32(2), off_value=np.int32(1),
143                             axis=1)
144      output = session.run(op)
145      expected = np.array([[[1, 1], [1, 1], [2, 1], [1, 2]],
146                           [[2, 1], [1, 2], [1, 1], [1, 1]]],
147                          dtype=np.int32)
148      self.assertAllEqual(output, expected)
149
150  def testSplitV(self):
151    with self.test_session() as session:
152      with self.test_scope():
153        output = session.run(
154            array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
155                                     dtype=np.float32),
156                            [2, 2], 1))
157        expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32),
158                    np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
159        self.assertAllEqual(output, expected)
160
161  def testStridedSlice(self):
162    self._testNAry(lambda x: array_ops.strided_slice(*x),
163                   [np.array([[], [], []], dtype=np.float32),
164                    np.array([1, 0], dtype=np.int32),
165                    np.array([3, 0], dtype=np.int32),
166                    np.array([1, 1], dtype=np.int32)],
167                   expected=np.array([[], []], dtype=np.float32))
168
169    if np.int64 in self.int_types:
170      self._testNAry(
171          lambda x: array_ops.strided_slice(*x), [
172              np.array([[], [], []], dtype=np.float32), np.array(
173                  [1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64),
174              np.array([1, 1], dtype=np.int64)
175          ],
176          expected=np.array([[], []], dtype=np.float32))
177
178    self._testNAry(lambda x: array_ops.strided_slice(*x),
179                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
180                             dtype=np.float32),
181                    np.array([1, 1], dtype=np.int32),
182                    np.array([3, 3], dtype=np.int32),
183                    np.array([1, 1], dtype=np.int32)],
184                   expected=np.array([[5, 6], [8, 9]], dtype=np.float32))
185
186    self._testNAry(lambda x: array_ops.strided_slice(*x),
187                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
188                             dtype=np.float32),
189                    np.array([0, 2], dtype=np.int32),
190                    np.array([2, 0], dtype=np.int32),
191                    np.array([1, -1], dtype=np.int32)],
192                   expected=np.array([[3, 2], [6, 5]], dtype=np.float32))
193
194    self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1],
195                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
196                             dtype=np.float32)],
197                   expected=np.array([[[3, 2, 1]], [[6, 5, 4]]],
198                                     dtype=np.float32))
199
200    self._testNAry(lambda x: x[0][1, :, array_ops.newaxis],
201                   [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
202                             dtype=np.float32)],
203                   expected=np.array([[4], [5], [6]], dtype=np.float32))
204
205  def testStridedSliceGrad(self):
206    # Tests cases where input shape is empty.
207    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
208                   [np.array([], dtype=np.int32),
209                    np.array([], dtype=np.int32),
210                    np.array([], dtype=np.int32),
211                    np.array([], dtype=np.int32),
212                    np.float32(0.5)],
213                   expected=np.array(np.float32(0.5), dtype=np.float32))
214
215    # Tests case where input shape is non-empty, but gradients are empty.
216    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
217                   [np.array([3], dtype=np.int32),
218                    np.array([0], dtype=np.int32),
219                    np.array([0], dtype=np.int32),
220                    np.array([1], dtype=np.int32),
221                    np.array([], dtype=np.float32)],
222                   expected=np.array([0, 0, 0], dtype=np.float32))
223
224    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
225                   [np.array([3, 0], dtype=np.int32),
226                    np.array([1, 0], dtype=np.int32),
227                    np.array([3, 0], dtype=np.int32),
228                    np.array([1, 1], dtype=np.int32),
229                    np.array([[], []], dtype=np.float32)],
230                   expected=np.array([[], [], []], dtype=np.float32))
231
232    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
233                   [np.array([3, 3], dtype=np.int32),
234                    np.array([1, 1], dtype=np.int32),
235                    np.array([3, 3], dtype=np.int32),
236                    np.array([1, 1], dtype=np.int32),
237                    np.array([[5, 6], [8, 9]], dtype=np.float32)],
238                   expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]],
239                                     dtype=np.float32))
240
241    def ssg_test(x):
242      return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4,
243                                          new_axis_mask=0x1)
244
245    self._testNAry(ssg_test,
246                   [np.array([3, 1, 3], dtype=np.int32),
247                    np.array([0, 0, 0, 2], dtype=np.int32),
248                    np.array([0, 3, 1, -4], dtype=np.int32),
249                    np.array([1, 2, 1, -3], dtype=np.int32),
250                    np.array([[[1], [2]]], dtype=np.float32)],
251                   expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]],
252                                     dtype=np.float32))
253
254    ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15)
255    self._testNAry(ssg_test2,
256                   [np.array([4, 4], dtype=np.int32),
257                    np.array([0, 0, 0, 1, 0], dtype=np.int32),
258                    np.array([0, 3, 0, 4, 0], dtype=np.int32),
259                    np.array([1, 2, 1, 2, 1], dtype=np.int32),
260                    np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)],
261                   expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4],
262                                      [0, 0, 0, 0]], dtype=np.float32))
263
264    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
265                   [np.array([3, 3], dtype=np.int32),
266                    np.array([0, 2], dtype=np.int32),
267                    np.array([2, 0], dtype=np.int32),
268                    np.array([1, -1], dtype=np.int32),
269                    np.array([[1, 2], [3, 4]], dtype=np.float32)],
270                   expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]],
271                                     dtype=np.float32))
272
273    self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
274                   [np.array([3, 3], dtype=np.int32),
275                    np.array([2, 2], dtype=np.int32),
276                    np.array([0, 1], dtype=np.int32),
277                    np.array([-1, -2], dtype=np.int32),
278                    np.array([[1], [2]], dtype=np.float32)],
279                   expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]],
280                                     dtype=np.float32))
281
282if __name__ == "__main__":
283  googletest.main()
284