1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for ternary operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23import scipy.special as sps 24 25from tensorflow.compiler.tests import xla_test 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gen_math_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import googletest 32 33 34class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 35 36 def _testTernary(self, op, a, b, c, expected, rtol=1e-3, atol=1e-6): 37 with self.session() as session: 38 with self.test_scope(): 39 pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") 40 pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") 41 pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c") 42 output = op(pa, pb, pc) 43 result = session.run(output, {pa: a, pb: b, pc: c}) 44 self.assertAllClose(result, expected, rtol=rtol, atol=atol) 45 return result 46 47 @parameterized.parameters( 48 {'start': 1, 'end': 2, 'num': 1}, 49 {'start': 1, 'end': 4, 'num': 3}, 50 {'start': 0, 'end': 41, 'num': 42}) 51 @test_util.disable_mlir_bridge( 52 'TODO(b/156174708): Dynamic result types not supported') 53 def testLinspace(self, start, end, num): 54 expected = np.linspace(start, end, num, dtype=np.float32) 55 result = self._testTernary( 56 math_ops.linspace, 57 np.float32(start), 58 np.float32(end), 59 np.int32(num), 60 expected) 61 # According to linspace spec, start has to be the first element and end has 62 # to be last element. 63 self.assertEqual(result[-1], expected[-1]) 64 self.assertEqual(result[0], expected[0]) 65 66 def testRange(self): 67 self._testTernary( 68 math_ops.range, 69 np.int32(1), 70 np.int32(2), 71 np.int32(1), 72 expected=np.array([1], dtype=np.int32)) 73 self._testTernary( 74 math_ops.range, 75 np.int32(1), 76 np.int32(7), 77 np.int32(2), 78 expected=np.array([1, 3, 5], dtype=np.int32)) 79 80 def testSelect(self): 81 for dtype in self.numeric_types: 82 self._testTernary( 83 array_ops.where, 84 np.array(False), 85 np.array(2, dtype=dtype), 86 np.array(7, dtype=dtype), 87 expected=np.array(7, dtype=dtype)) 88 89 self._testTernary( 90 array_ops.where, 91 np.array(True), 92 np.array([1, 2, 3, 4], dtype=dtype), 93 np.array([5, 6, 7, 8], dtype=dtype), 94 expected=np.array([1, 2, 3, 4], dtype=dtype)) 95 96 self._testTernary( 97 array_ops.where, 98 np.array(False), 99 np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), 100 np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), 101 expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) 102 103 self._testTernary( 104 array_ops.where, 105 np.array([0, 1, 1, 0], dtype=np.bool_), 106 np.array([1, 2, 3, 4], dtype=dtype), 107 np.array([5, 6, 7, 8], dtype=dtype), 108 expected=np.array([5, 2, 3, 8], dtype=dtype)) 109 110 self._testTernary( 111 array_ops.where, 112 np.array([0, 1, 0], dtype=np.bool_), 113 np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), 114 np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), 115 expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype)) 116 117 def testSelectV2(self): 118 for dtype in self.numeric_types: 119 self._testTernary( 120 array_ops.where_v2, 121 np.array(False), 122 np.array(2, dtype=dtype), 123 np.array(7, dtype=dtype), 124 expected=np.array(7, dtype=dtype)) 125 126 self._testTernary( 127 array_ops.where_v2, 128 np.array(True), 129 np.array([1, 2, 3, 4], dtype=dtype), 130 np.array([5, 6, 7, 8], dtype=dtype), 131 expected=np.array([1, 2, 3, 4], dtype=dtype)) 132 133 self._testTernary( 134 array_ops.where_v2, 135 np.array(False), 136 np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), 137 np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), 138 expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) 139 140 self._testTernary( 141 array_ops.where_v2, 142 np.array([0, 1, 1, 0], dtype=np.bool_), 143 np.array([1, 2, 3, 4], dtype=dtype), 144 np.array([5, 6, 7, 8], dtype=dtype), 145 expected=np.array([5, 2, 3, 8], dtype=dtype)) 146 147 # Broadcast the condition 148 self._testTernary( 149 array_ops.where_v2, 150 np.array([0, 1], dtype=np.bool_), 151 np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), 152 np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), 153 expected=np.array([[7, 2], [9, 4], [11, 6]], dtype=dtype)) 154 155 # Broadcast the then branch to the else 156 self._testTernary( 157 array_ops.where_v2, 158 np.array([[0, 1], [1, 0], [1, 1]], dtype=np.bool_), 159 np.array([[1, 2]], dtype=dtype), 160 np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), 161 expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype)) 162 163 # Broadcast the else branch to the then 164 self._testTernary( 165 array_ops.where_v2, 166 np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool_), 167 np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), 168 np.array([[1, 2]], dtype=dtype), 169 expected=np.array([[7, 2], [1, 10], [1, 2]], dtype=dtype)) 170 171 # Broadcast the then/else branches to the condition 172 self._testTernary( 173 array_ops.where_v2, 174 np.array([[1, 0], [0, 1], [1, 1]], dtype=np.bool_), 175 np.array(7, dtype=dtype), 176 np.array(8, dtype=dtype), 177 expected=np.array([[7, 8], [8, 7], [7, 7]], dtype=dtype)) 178 self._testTernary( 179 array_ops.where_v2, 180 np.array([[1, 0], [0, 1], [0, 0]], dtype=np.bool_), 181 np.array(7, dtype=dtype), 182 np.array([8, 9], dtype=dtype), 183 expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype)) 184 185 def testSlice(self): 186 for dtype in self.numeric_types: 187 self._testTernary( 188 array_ops.slice, 189 np.array([[], [], []], dtype=dtype), 190 np.array([1, 0], dtype=np.int32), 191 np.array([2, 0], dtype=np.int32), 192 expected=np.array([[], []], dtype=dtype)) 193 194 self._testTernary( 195 array_ops.slice, 196 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), 197 np.array([0, 1], dtype=np.int32), 198 np.array([2, 1], dtype=np.int32), 199 expected=np.array([[2], [5]], dtype=dtype)) 200 201 def testClipByValue(self): 202 for dtype in self.numeric_types - self.complex_types: 203 test_cases = [ 204 (np.array([2, 4, 5], dtype=dtype), dtype(7)), # 205 (dtype(1), np.array([2, 4, 5], dtype=dtype)), # 206 (np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype)) 207 ] 208 x = np.array([-2, 10, 6], dtype=dtype) 209 for lower, upper in test_cases: 210 self._testTernary( 211 gen_math_ops._clip_by_value, 212 x, 213 lower, 214 upper, 215 expected=np.minimum(np.maximum(x, lower), upper)) 216 217 def testBetaincSanity(self): 218 # This operation is only supported for float32 and float64. 219 for dtype in self.numeric_types & {np.float32, np.float64}: 220 # Sanity check a few identities: 221 # - betainc(a, b, 0) == 0 222 # - betainc(a, b, 1) == 1 223 # - betainc(a, 1, x) == x ** a 224 # Compare against the implementation in SciPy. 225 a = np.array([.3, .4, .2, .2], dtype=dtype) 226 b = np.array([1., 1., .4, .4], dtype=dtype) 227 x = np.array([.3, .4, .0, .1], dtype=dtype) 228 expected = sps.betainc(a, b, x) 229 self._testTernary( 230 math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6) 231 232 @parameterized.parameters( 233 { 234 'sigma': 1e15, 235 'rtol': 1e-6, 236 'atol': 1e-4 237 }, 238 { 239 'sigma': 30, 240 'rtol': 1e-6, 241 'atol': 2e-3 242 }, 243 { 244 'sigma': 1e-8, 245 'rtol': 5e-4, 246 'atol': 3e-4 247 }, 248 { 249 'sigma': 1e-16, 250 'rtol': 1e-6, 251 'atol': 2e-4 252 }, 253 ) 254 def testBetainc(self, sigma, rtol, atol): 255 # This operation is only supported for float32 and float64. 256 for dtype in self.numeric_types & {np.float32, np.float64}: 257 # Randomly generate a, b, x in the numerical domain of betainc. 258 # Compare against the implementation in SciPy. 259 a = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty) 260 b = np.abs(np.random.randn(10, 10) * sigma).astype(dtype) # in (0, infty) 261 x = np.random.rand(10, 10).astype(dtype) # in (0, 1) 262 expected = sps.betainc(a, b, x, dtype=dtype) 263 self._testTernary( 264 math_ops.betainc, a, b, x, expected, rtol=rtol, atol=atol) 265 266 267if __name__ == "__main__": 268 googletest.main() 269