• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for complex numbers division."""
16
17import os
18
19import numpy as np
20
21from tensorflow.compiler.tests import xla_test
22from tensorflow.python.framework import dtypes
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_math_ops
25from tensorflow.python.platform import googletest
26
27os.environ["XLA_FLAGS"] = ("--xla_cpu_fast_math_honor_nans=true "
28                           "--xla_cpu_fast_math_honor_infs=true")
29
30
31class ComplexNumbersDivisionTest(xla_test.XLATestCase):
32  """Test cases for complex numbers division operators."""
33
34  def _testBinary(self, op, a, b, expected, equality_test=None):
35    with self.session() as session:
36      with self.test_scope():
37        pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
38        pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
39        output = op(pa, pb)
40      result = session.run(output, {pa: a, pb: b})
41      if equality_test is None:
42        equality_test = self.assertAllCloseAccordingToType
43      equality_test(np.real(result), np.real(expected), rtol=1e-3)
44      equality_test(np.imag(result), np.imag(expected), rtol=1e-3)
45
46  def testComplexOps(self):
47    for dtype in self.complex_types:
48      # Test division by 0 scenarios.
49      self._testBinary(
50          gen_math_ops.real_div,
51          np.array([
52              complex(1, 1),
53              complex(1, np.inf),
54              complex(1, np.nan),
55              complex(np.inf, 1),
56              complex(np.inf, np.inf),
57              complex(np.inf, np.nan),
58              complex(np.nan, 1),
59              complex(np.nan, np.inf),
60              complex(np.nan, np.nan),
61              complex(-np.inf, np.nan),
62          ],
63                   dtype=dtype),
64          np.array([
65              0 + 0j,
66              0 + 0j,
67              0 + 0j,
68              0 + 0j,
69              0 + 0j,
70              0 + 0j,
71              0 + 0j,
72              0 + 0j,
73              0 + 0j,
74              0.0 + 0j,
75          ],
76                   dtype=dtype),
77          expected=np.array([
78              complex(np.inf, np.inf),
79              complex(np.inf, np.inf),
80              complex(np.inf, np.nan),
81              complex(np.inf, np.inf),
82              complex(np.inf, np.inf),
83              complex(np.inf, np.nan),
84              complex(np.nan, np.inf),
85              complex(np.nan, np.inf),
86              complex(np.nan, np.nan),
87              complex(-np.inf, np.nan),
88          ],
89                            dtype=dtype))
90
91      # Test division with finite numerator, inf/nan denominator.
92      self._testBinary(
93          gen_math_ops.real_div,
94          np.array([
95              1 + 1j,
96              1 + 1j,
97              1 + 1j,
98              1 + 1j,
99              1 + 1j,
100              1 + 1j,
101              1 + 1j,
102              1 + 1j,
103              1 + 1j,
104          ],
105                   dtype=dtype),
106          np.array(
107              [
108                  complex(1, np.inf),
109                  complex(1, np.nan),
110                  complex(np.inf, 1),
111                  complex(np.inf, np.inf),  # C++ and Python diverge here.
112                  complex(np.inf, np.nan),  # C++ and Python diverge here.
113                  complex(np.nan, 1),
114                  complex(np.nan, np.inf),  # C++ and Python diverge here.
115                  complex(np.nan, -np.inf),  # C++ and Python diverge here.
116                  complex(np.nan, np.nan),
117              ],
118              dtype=dtype),
119          expected=np.array(
120              [
121                  (1 + 1j) / complex(1, np.inf),
122                  (1 + 1j) / complex(1, np.nan),
123                  (1 + 1j) / complex(np.inf, 1),
124                  complex(0 + 0j),  # C++ and Python diverge here.
125                  complex(0 + 0j),  # C++ and Python diverge here.
126                  (1 + 1j) / complex(np.nan, 1),
127                  complex(0 + 0j),  # C++ and Python diverge here.
128                  complex(0 - 0j),  # C++ and Python diverge here.
129                  (1 + 1j) / complex(np.nan, np.nan),
130              ],
131              dtype=dtype))
132
133      # Test division with inf/nan numerator, infinite denominator.
134      self._testBinary(
135          gen_math_ops.real_div,
136          np.array([
137              complex(1, np.inf),
138              complex(1, np.nan),
139              complex(np.inf, 1),
140              complex(np.inf, np.inf),
141              complex(np.inf, np.nan),
142              complex(np.nan, 1),
143              complex(np.nan, np.inf),
144              complex(np.nan, np.nan),
145              complex(np.nan, -np.inf),
146          ],
147                   dtype=dtype),
148          np.array([
149              1 + 1j,
150              1 + 1j,
151              1 + 1j,
152              1 + 1j,
153              1 + 1j,
154              1 + 1j,
155              1 + 1j,
156              1 + 1j,
157              -1 - 1j,
158          ],
159                   dtype=dtype),
160          expected=np.array(
161              [
162                  complex(np.inf, np.inf),  # C++ and Python diverge here.
163                  complex(1 / np.nan) / (1 + 1j),
164                  complex(np.inf / 1) / (1 + 1j),
165                  complex(np.inf, -np.nan),  # C++ and Python diverge here.
166                  complex(np.inf, -np.inf),  # C++ and Python diverge here.
167                  complex(np.nan / 1) / (1 + 1j),
168                  complex(np.inf, np.inf),  # C++ and Python diverge here.
169                  complex(np.nan / np.nan) / (1 + 1j),
170                  complex(np.inf, np.inf),  # C++ and Python diverge here.
171              ],
172              dtype=dtype))
173
174
175if __name__ == "__main__":
176  googletest.main()
177