• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for ternary operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23import scipy.special as sps
24
25from tensorflow.compiler.tests import xla_test
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_math_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.platform import googletest
32
33
34class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
35
36  def _testTernary(self, op, a, b, c, expected, rtol=1e-3, atol=1e-6):
37    with self.session() as session:
38      with self.test_scope():
39        pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
40        pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
41        pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
42        output = op(pa, pb, pc)
43      result = session.run(output, {pa: a, pb: b, pc: c})
44      self.assertAllClose(result, expected, rtol=rtol, atol=atol)
45      return result
46
47  @parameterized.parameters(
48      {'start': 1, 'end': 2, 'num': 1},
49      {'start': 1, 'end': 4, 'num': 3},
50      {'start': 0, 'end': 41, 'num': 42})
51  @test_util.disable_mlir_bridge(
52      'TODO(b/156174708): Dynamic result types not supported')
53  def testLinspace(self, start, end, num):
54    expected = np.linspace(start, end, num, dtype=np.float32)
55    result = self._testTernary(
56        math_ops.linspace,
57        np.float32(start),
58        np.float32(end),
59        np.int32(num),
60        expected)
61    # According to linspace spec, start has to be the first element and end has
62    # to be last element.
63    self.assertEqual(result[-1], expected[-1])
64    self.assertEqual(result[0], expected[0])
65
66  def testRange(self):
67    self._testTernary(
68        math_ops.range,
69        np.int32(1),
70        np.int32(2),
71        np.int32(1),
72        expected=np.array([1], dtype=np.int32))
73    self._testTernary(
74        math_ops.range,
75        np.int32(1),
76        np.int32(7),
77        np.int32(2),
78        expected=np.array([1, 3, 5], dtype=np.int32))
79
80  def testSelect(self):
81    for dtype in self.numeric_types:
82      self._testTernary(
83          array_ops.where,
84          np.array(False),
85          np.array(2, dtype=dtype),
86          np.array(7, dtype=dtype),
87          expected=np.array(7, dtype=dtype))
88
89      self._testTernary(
90          array_ops.where,
91          np.array(True),
92          np.array([1, 2, 3, 4], dtype=dtype),
93          np.array([5, 6, 7, 8], dtype=dtype),
94          expected=np.array([1, 2, 3, 4], dtype=dtype))
95
96      self._testTernary(
97          array_ops.where,
98          np.array(False),
99          np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
100          np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
101          expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
102
103      self._testTernary(
104          array_ops.where,
105          np.array([0, 1, 1, 0], dtype=np.bool_),
106          np.array([1, 2, 3, 4], dtype=dtype),
107          np.array([5, 6, 7, 8], dtype=dtype),
108          expected=np.array([5, 2, 3, 8], dtype=dtype))
109
110      self._testTernary(
111          array_ops.where,
112          np.array([0, 1, 0], dtype=np.bool_),
113          np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
114          np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
115          expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
116
117  def testSelectV2(self):
118    for dtype in self.numeric_types:
119      self._testTernary(
120          array_ops.where_v2,
121          np.array(False),
122          np.array(2, dtype=dtype),
123          np.array(7, dtype=dtype),
124          expected=np.array(7, dtype=dtype))
125
126      self._testTernary(
127          array_ops.where_v2,
128          np.array(True),
129          np.array([1, 2, 3, 4], dtype=dtype),
130          np.array([5, 6, 7, 8], dtype=dtype),
131          expected=np.array([1, 2, 3, 4], dtype=dtype))
132
133      self._testTernary(
134          array_ops.where_v2,
135          np.array(False),
136          np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
137          np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
138          expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
139
140      self._testTernary(
141          array_ops.where_v2,
142          np.array([0, 1, 1, 0], dtype=np.bool_),
143          np.array([1, 2, 3, 4], dtype=dtype),
144          np.array([5, 6, 7, 8], dtype=dtype),
145          expected=np.array([5, 2, 3, 8], dtype=dtype))
146
147      # Broadcast the condition
148      self._testTernary(
149          array_ops.where_v2,
150          np.array([0, 1], dtype=np.bool_),
151          np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
152          np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
153          expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype))
154
155      # Broadcast the then branch to the else
156      self._testTernary(
157          array_ops.where_v2,
158          np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool_),
159          np.array([[1, 2]], dtype=dtype),
160          np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
161          expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
162
163      # Broadcast the else branch to the then
164      self._testTernary(
165          array_ops.where_v2,
166          np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool_),
167          np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
168          np.array([[1, 2]], dtype=dtype),
169          expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype))
170
171      # Broadcast the then/else branches to the condition
172      self._testTernary(
173          array_ops.where_v2,
174          np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool_),
175          np.array(7, dtype=dtype),
176          np.array(8, dtype=dtype),
177          expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype))
178      self._testTernary(
179          array_ops.where_v2,
180          np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool_),
181          np.array(7, dtype=dtype),
182          np.array([8, 9], dtype=dtype),
183          expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
184
185  def testSlice(self):
186    for dtype in self.numeric_types:
187      self._testTernary(
188          array_ops.slice,
189          np.array([[], [], []], dtype=dtype),
190          np.array([1, 0], dtype=np.int32),
191          np.array([2, 0], dtype=np.int32),
192          expected=np.array([[], []], dtype=dtype))
193
194      self._testTernary(
195          array_ops.slice,
196          np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype),
197          np.array([0, 1], dtype=np.int32),
198          np.array([2, 1], dtype=np.int32),
199          expected=np.array([[2], [5]], dtype=dtype))
200
201  def testClipByValue(self):
202    for dtype in self.numeric_types - self.complex_types:
203      test_cases = [
204          (np.array([2, 4, 5], dtype=dtype), dtype(7)),  #
205          (dtype(1), np.array([2, 4, 5], dtype=dtype)),  #
206          (np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype))
207      ]
208      x = np.array([-2, 10, 6], dtype=dtype)
209      for lower, upper in test_cases:
210        self._testTernary(
211            gen_math_ops._clip_by_value,
212            x,
213            lower,
214            upper,
215            expected=np.minimum(np.maximum(x, lower), upper))
216
217  def testBetaincSanity(self):
218    # This operation is only supported for float32 and float64.
219    for dtype in self.numeric_types & {np.float32, np.float64}:
220      # Sanity check a few identities:
221      # - betainc(a, b, 0) == 0
222      # - betainc(a, b, 1) == 1
223      # - betainc(a, 1, x) == x ** a
224      # Compare against the implementation in SciPy.
225      a = np.array([.3, .4, .2, .2], dtype=dtype)
226      b = np.array([1., 1., .4, .4], dtype=dtype)
227      x = np.array([.3, .4, .0, .1], dtype=dtype)
228      expected = sps.betainc(a, b, x)
229      self._testTernary(
230          math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6)
231
232  @parameterized.parameters(
233      {
234          'sigma': 1e15,
235          'rtol': 1e-6,
236          'atol': 1e-4
237      },
238      {
239          'sigma': 30,
240          'rtol': 1e-6,
241          'atol': 2e-3
242      },
243      {
244          'sigma': 1e-8,
245          'rtol': 5e-4,
246          'atol': 3e-4
247      },
248      {
249          'sigma': 1e-16,
250          'rtol': 1e-6,
251          'atol': 2e-4
252      },
253  )
254  def testBetainc(self, sigma, rtol, atol):
255    # This operation is only supported for float32 and float64.
256    for dtype in self.numeric_types & {np.float32, np.float64}:
257      # Randomly generate a, b, x in the numerical domain of betainc.
258      # Compare against the implementation in SciPy.
259      a = np.abs(np.random.randn(10, 10) * sigma).astype(dtype)  # in (0, infty)
260      b = np.abs(np.random.randn(10, 10) * sigma).astype(dtype)  # in (0, infty)
261      x = np.random.rand(10, 10).astype(dtype)  # in (0, 1)
262      expected = sps.betainc(a, b, x, dtype=dtype)
263      self._testTernary(
264          math_ops.betainc, a, b, x, expected, rtol=rtol, atol=atol)
265
266
267if __name__ == "__main__":
268  googletest.main()
269