• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for ternary operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests.xla_test import XLATestCase
24from tensorflow.python.framework import dtypes
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import googletest
28
29
30class TernaryOpsTest(XLATestCase):
31
32  def _testTernary(self, op, a, b, c, expected):
33    with self.test_session() as session:
34      with self.test_scope():
35        pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
36        pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
37        pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
38        output = op(pa, pb, pc)
39      result = session.run(output, {pa: a, pb: b, pc: c})
40      self.assertAllClose(result, expected, rtol=1e-3)
41
42  def testLinspace(self):
43    self._testTernary(
44        math_ops.linspace,
45        np.float32(1),
46        np.float32(2),
47        np.int32(1),
48        expected=np.array([1], dtype=np.float32))
49    self._testTernary(
50        math_ops.linspace,
51        np.float32(1),
52        np.float32(4),
53        np.int32(3),
54        expected=np.array([1, 2.5, 4], dtype=np.float32))
55
56  def testRange(self):
57    self._testTernary(
58        math_ops.range,
59        np.int32(1),
60        np.int32(2),
61        np.int32(1),
62        expected=np.array([1], dtype=np.int32))
63    self._testTernary(
64        math_ops.range,
65        np.int32(1),
66        np.int32(7),
67        np.int32(2),
68        expected=np.array([1, 3, 5], dtype=np.int32))
69
70  def testSelect(self):
71    self._testTernary(
72        array_ops.where,
73        np.array(0, dtype=np.bool),
74        np.array(2, dtype=np.float32),
75        np.array(7, dtype=np.float32),
76        expected=np.array(7, dtype=np.float32))
77
78    self._testTernary(
79        array_ops.where,
80        np.array(1, dtype=np.bool),
81        np.array([1, 2, 3, 4], dtype=np.float32),
82        np.array([5, 6, 7, 8], dtype=np.float32),
83        expected=np.array([1, 2, 3, 4], dtype=np.float32))
84
85    self._testTernary(
86        array_ops.where,
87        np.array(0, dtype=np.bool),
88        np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
89        np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
90        expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32))
91
92    self._testTernary(
93        array_ops.where,
94        np.array([0, 1, 1, 0], dtype=np.bool),
95        np.array([1, 2, 3, 4], dtype=np.float32),
96        np.array([5, 6, 7, 8], dtype=np.float32),
97        expected=np.array([5, 2, 3, 8], dtype=np.float32))
98
99    self._testTernary(
100        array_ops.where,
101        np.array([0, 1, 0], dtype=np.bool),
102        np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
103        np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
104        expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
105
106  def testSlice(self):
107    for dtype in self.numeric_types:
108      self._testTernary(
109          array_ops.slice,
110          np.array([[], [], []], dtype=dtype),
111          np.array([1, 0], dtype=np.int32),
112          np.array([2, 0], dtype=np.int32),
113          expected=np.array([[], []], dtype=dtype))
114
115      self._testTernary(
116          array_ops.slice,
117          np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype),
118          np.array([0, 1], dtype=np.int32),
119          np.array([2, 1], dtype=np.int32),
120          expected=np.array([[2], [5]], dtype=dtype))
121
122
123if __name__ == "__main__":
124  googletest.main()
125