1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for complex numbers division.""" 16 17import os 18 19import numpy as np 20 21from tensorflow.compiler.tests import xla_test 22from tensorflow.python.framework import dtypes 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_math_ops 25from tensorflow.python.platform import googletest 26 27os.environ["XLA_FLAGS"] = ("--xla_cpu_fast_math_honor_nans=true " 28 "--xla_cpu_fast_math_honor_infs=true") 29 30 31class ComplexNumbersDivisionTest(xla_test.XLATestCase): 32 """Test cases for complex numbers division operators.""" 33 34 def _testBinary(self, op, a, b, expected, equality_test=None): 35 with self.session() as session: 36 with self.test_scope(): 37 pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") 38 pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") 39 output = op(pa, pb) 40 result = session.run(output, {pa: a, pb: b}) 41 if equality_test is None: 42 equality_test = self.assertAllCloseAccordingToType 43 equality_test(np.real(result), np.real(expected), rtol=1e-3) 44 equality_test(np.imag(result), np.imag(expected), rtol=1e-3) 45 46 def testComplexOps(self): 47 for dtype in self.complex_types: 48 # Test division by 0 scenarios. 49 self._testBinary( 50 gen_math_ops.real_div, 51 np.array([ 52 complex(1, 1), 53 complex(1, np.inf), 54 complex(1, np.nan), 55 complex(np.inf, 1), 56 complex(np.inf, np.inf), 57 complex(np.inf, np.nan), 58 complex(np.nan, 1), 59 complex(np.nan, np.inf), 60 complex(np.nan, np.nan), 61 complex(-np.inf, np.nan), 62 ], 63 dtype=dtype), 64 np.array([ 65 0 + 0j, 66 0 + 0j, 67 0 + 0j, 68 0 + 0j, 69 0 + 0j, 70 0 + 0j, 71 0 + 0j, 72 0 + 0j, 73 0 + 0j, 74 0.0 + 0j, 75 ], 76 dtype=dtype), 77 expected=np.array([ 78 complex(np.inf, np.inf), 79 complex(np.inf, np.inf), 80 complex(np.inf, np.nan), 81 complex(np.inf, np.inf), 82 complex(np.inf, np.inf), 83 complex(np.inf, np.nan), 84 complex(np.nan, np.inf), 85 complex(np.nan, np.inf), 86 complex(np.nan, np.nan), 87 complex(-np.inf, np.nan), 88 ], 89 dtype=dtype)) 90 91 # Test division with finite numerator, inf/nan denominator. 92 self._testBinary( 93 gen_math_ops.real_div, 94 np.array([ 95 1 + 1j, 96 1 + 1j, 97 1 + 1j, 98 1 + 1j, 99 1 + 1j, 100 1 + 1j, 101 1 + 1j, 102 1 + 1j, 103 1 + 1j, 104 ], 105 dtype=dtype), 106 np.array( 107 [ 108 complex(1, np.inf), 109 complex(1, np.nan), 110 complex(np.inf, 1), 111 complex(np.inf, np.inf), # C++ and Python diverge here. 112 complex(np.inf, np.nan), # C++ and Python diverge here. 113 complex(np.nan, 1), 114 complex(np.nan, np.inf), # C++ and Python diverge here. 115 complex(np.nan, -np.inf), # C++ and Python diverge here. 116 complex(np.nan, np.nan), 117 ], 118 dtype=dtype), 119 expected=np.array( 120 [ 121 (1 + 1j) / complex(1, np.inf), 122 (1 + 1j) / complex(1, np.nan), 123 (1 + 1j) / complex(np.inf, 1), 124 complex(0 + 0j), # C++ and Python diverge here. 125 complex(0 + 0j), # C++ and Python diverge here. 126 (1 + 1j) / complex(np.nan, 1), 127 complex(0 + 0j), # C++ and Python diverge here. 128 complex(0 - 0j), # C++ and Python diverge here. 129 (1 + 1j) / complex(np.nan, np.nan), 130 ], 131 dtype=dtype)) 132 133 # Test division with inf/nan numerator, infinite denominator. 134 self._testBinary( 135 gen_math_ops.real_div, 136 np.array([ 137 complex(1, np.inf), 138 complex(1, np.nan), 139 complex(np.inf, 1), 140 complex(np.inf, np.inf), 141 complex(np.inf, np.nan), 142 complex(np.nan, 1), 143 complex(np.nan, np.inf), 144 complex(np.nan, np.nan), 145 complex(np.nan, -np.inf), 146 ], 147 dtype=dtype), 148 np.array([ 149 1 + 1j, 150 1 + 1j, 151 1 + 1j, 152 1 + 1j, 153 1 + 1j, 154 1 + 1j, 155 1 + 1j, 156 1 + 1j, 157 -1 - 1j, 158 ], 159 dtype=dtype), 160 expected=np.array( 161 [ 162 complex(np.inf, np.inf), # C++ and Python diverge here. 163 complex(1 / np.nan) / (1 + 1j), 164 complex(np.inf / 1) / (1 + 1j), 165 complex(np.inf, -np.nan), # C++ and Python diverge here. 166 complex(np.inf, -np.inf), # C++ and Python diverge here. 167 complex(np.nan / 1) / (1 + 1j), 168 complex(np.inf, np.inf), # C++ and Python diverge here. 169 complex(np.nan / np.nan) / (1 + 1j), 170 complex(np.inf, np.inf), # C++ and Python diverge here. 171 ], 172 dtype=dtype)) 173 174 175if __name__ == "__main__": 176 googletest.main() 177