1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for ternary operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests.xla_test import XLATestCase 24from tensorflow.python.framework import dtypes 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.platform import googletest 28 29 30class TernaryOpsTest(XLATestCase): 31 32 def _testTernary(self, op, a, b, c, expected): 33 with self.test_session() as session: 34 with self.test_scope(): 35 pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") 36 pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") 37 pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c") 38 output = op(pa, pb, pc) 39 result = session.run(output, {pa: a, pb: b, pc: c}) 40 self.assertAllClose(result, expected, rtol=1e-3) 41 42 def testLinspace(self): 43 self._testTernary( 44 math_ops.linspace, 45 np.float32(1), 46 np.float32(2), 47 np.int32(1), 48 expected=np.array([1], dtype=np.float32)) 49 self._testTernary( 50 math_ops.linspace, 51 np.float32(1), 52 np.float32(4), 53 np.int32(3), 54 expected=np.array([1, 2.5, 4], dtype=np.float32)) 55 56 def testRange(self): 57 self._testTernary( 58 math_ops.range, 59 np.int32(1), 60 np.int32(2), 61 np.int32(1), 62 expected=np.array([1], dtype=np.int32)) 63 self._testTernary( 64 math_ops.range, 65 np.int32(1), 66 np.int32(7), 67 np.int32(2), 68 expected=np.array([1, 3, 5], dtype=np.int32)) 69 70 def testSelect(self): 71 self._testTernary( 72 array_ops.where, 73 np.array(0, dtype=np.bool), 74 np.array(2, dtype=np.float32), 75 np.array(7, dtype=np.float32), 76 expected=np.array(7, dtype=np.float32)) 77 78 self._testTernary( 79 array_ops.where, 80 np.array(1, dtype=np.bool), 81 np.array([1, 2, 3, 4], dtype=np.float32), 82 np.array([5, 6, 7, 8], dtype=np.float32), 83 expected=np.array([1, 2, 3, 4], dtype=np.float32)) 84 85 self._testTernary( 86 array_ops.where, 87 np.array(0, dtype=np.bool), 88 np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), 89 np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), 90 expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32)) 91 92 self._testTernary( 93 array_ops.where, 94 np.array([0, 1, 1, 0], dtype=np.bool), 95 np.array([1, 2, 3, 4], dtype=np.float32), 96 np.array([5, 6, 7, 8], dtype=np.float32), 97 expected=np.array([5, 2, 3, 8], dtype=np.float32)) 98 99 self._testTernary( 100 array_ops.where, 101 np.array([0, 1, 0], dtype=np.bool), 102 np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), 103 np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), 104 expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32)) 105 106 def testSlice(self): 107 for dtype in self.numeric_types: 108 self._testTernary( 109 array_ops.slice, 110 np.array([[], [], []], dtype=dtype), 111 np.array([1, 0], dtype=np.int32), 112 np.array([2, 0], dtype=np.int32), 113 expected=np.array([[], []], dtype=dtype)) 114 115 self._testTernary( 116 array_ops.slice, 117 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), 118 np.array([0, 1], dtype=np.int32), 119 np.array([2, 1], dtype=np.int32), 120 expected=np.array([[2], [5]], dtype=dtype)) 121 122 123if __name__ == "__main__": 124 googletest.main() 125