• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: dynamo"]
2
3"""Test examples for NEP 50."""
4
5import itertools
6from unittest import skipIf as skipif, SkipTest
7
8
9try:
10    import numpy as _np
11
12    v = _np.__version__.split(".")
13    HAVE_NUMPY = int(v[0]) >= 1 and int(v[1]) >= 24
14except ImportError:
15    HAVE_NUMPY = False
16
17import torch._numpy as tnp
18from torch._numpy import (  # noqa: F401
19    array,
20    bool_,
21    complex128,
22    complex64,
23    float32,
24    float64,
25    inf,
26    int16,
27    int32,
28    int64,
29    uint8,
30)
31from torch._numpy.testing import assert_allclose
32from torch.testing._internal.common_utils import (
33    instantiate_parametrized_tests,
34    parametrize,
35    run_tests,
36    TestCase,
37)
38
39
40uint16 = uint8  # can be anything here, see below
41
42
43# from numpy import array, uint8, uint16, int64, float32, float64, inf
44# from numpy.testing import assert_allclose
45# import numpy as np
46# np._set_promotion_state('weak')
47
48from pytest import raises as assert_raises
49
50
51unchanged = None
52
53# expression    old result   new_result
54examples = {
55    "uint8(1) + 2": (int64(3), uint8(3)),
56    "array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)),
57    "array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)),
58    "array([1.], float32) + float64(1.)": (
59        array([2.0], float32),
60        array([2.0], float64),
61    ),
62    "array([1.], float32) + array(1., float64)": (
63        array([2.0], float32),
64        array([2.0], float64),
65    ),
66    "array([1], uint8) + 1": (array([2], uint8), unchanged),
67    "array([1], uint8) + 200": (array([201], uint8), unchanged),
68    "array([100], uint8) + 200": (array([44], uint8), unchanged),
69    "array([1], uint8) + 300": (array([301], uint16), Exception),
70    "uint8(1) + 300": (int64(301), Exception),
71    "uint8(100) + 200": (int64(301), uint8(44)),  # and RuntimeWarning
72    "float32(1) + 3e100": (float64(3e100), float32(inf)),  # and RuntimeWarning [T7]
73    "array([1.0], float32) + 1e-14 == 1.0": (array([True]), unchanged),
74    "array([0.1], float32) == float64(0.1)": (array([True]), array([False])),
75    "array(1.0, float32) + 1e-14 == 1.0": (array(False), array(True)),
76    "array([1.], float32) + 3": (array([4.0], float32), unchanged),
77    "array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)),
78    "3j + array(3, complex64)": (array(3 + 3j, complex128), array(3 + 3j, complex64)),
79    "float32(1) + 1j": (array(1 + 1j, complex128), array(1 + 1j, complex64)),
80    "int32(1) + 5j": (array(1 + 5j, complex128), unchanged),
81    # additional examples from the NEP text
82    "int16(2) + 2": (int64(4), int16(4)),
83    "int16(4) + 4j": (complex128(4 + 4j), unchanged),
84    "float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)),
85    "bool_(True) + 1": (int64(2), unchanged),
86    "True + uint8(2)": (uint8(3), unchanged),
87}
88
89
90@skipif(not HAVE_NUMPY, reason="NumPy not found")
91@instantiate_parametrized_tests
92class TestNEP50Table(TestCase):
93    @parametrize("example", examples)
94    def test_nep50_exceptions(self, example):
95        old, new = examples[example]
96
97        if new == Exception:
98            with assert_raises(OverflowError):
99                eval(example)
100
101        else:
102            result = eval(example)
103
104            if new is unchanged:
105                new = old
106
107            assert_allclose(result, new, atol=1e-16)
108            assert result.dtype == new.dtype
109
110
111# ### Directly compare to numpy ###
112
113weaks = (True, 1, 2.0, 3j)
114non_weaks = (
115    tnp.asarray(True),
116    tnp.uint8(1),
117    tnp.int8(1),
118    tnp.int32(1),
119    tnp.int64(1),
120    tnp.float32(1),
121    tnp.float64(1),
122    tnp.complex64(1),
123    tnp.complex128(1),
124)
125if HAVE_NUMPY:
126    dtypes = (
127        None,
128        _np.bool_,
129        _np.uint8,
130        _np.int8,
131        _np.int32,
132        _np.int64,
133        _np.float32,
134        _np.float64,
135        _np.complex64,
136        _np.complex128,
137    )
138else:
139    dtypes = (None,)
140
141
142# ufunc name: [array.dtype]
143corners = {
144    "true_divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
145    "divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
146    "arctan2": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
147    "copysign": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
148    "heaviside": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
149    "ldexp": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
150    "power": ["uint8"],
151    "nextafter": ["float32"],
152}
153
154
155@skipif(not HAVE_NUMPY, reason="NumPy not found")
156@instantiate_parametrized_tests
157class TestCompareToNumpy(TestCase):
158    @parametrize("scalar, array, dtype", itertools.product(weaks, non_weaks, dtypes))
159    def test_direct_compare(self, scalar, array, dtype):
160        # compare to NumPy w/ NEP 50.
161        try:
162            state = _np._get_promotion_state()
163            _np._set_promotion_state("weak")
164
165            if dtype is not None:
166                kwargs = {"dtype": dtype}
167            try:
168                result_numpy = _np.add(scalar, array.tensor.numpy(), **kwargs)
169            except Exception:
170                return
171
172            kwargs = {}
173            if dtype is not None:
174                kwargs = {"dtype": getattr(tnp, dtype.__name__)}
175            result = tnp.add(scalar, array, **kwargs).tensor.numpy()
176            assert result.dtype == result_numpy.dtype
177            assert result == result_numpy
178
179        finally:
180            _np._set_promotion_state(state)
181
182    @parametrize("name", tnp._ufuncs._binary)
183    @parametrize("scalar, array", itertools.product(weaks, non_weaks))
184    def test_compare_ufuncs(self, name, scalar, array):
185        if name in corners and (
186            array.dtype.name in corners[name]
187            or tnp.asarray(scalar).dtype.name in corners[name]
188        ):
189            raise SkipTest(f"{name}(..., dtype=array.dtype)")
190
191        try:
192            state = _np._get_promotion_state()
193            _np._set_promotion_state("weak")
194
195            if name in ["matmul", "modf", "divmod", "ldexp"]:
196                return
197            ufunc = getattr(tnp, name)
198            ufunc_numpy = getattr(_np, name)
199
200            try:
201                result = ufunc(scalar, array)
202            except RuntimeError:
203                # RuntimeError: "bitwise_xor_cpu" not implemented for 'ComplexDouble' etc
204                result = None
205
206            try:
207                result_numpy = ufunc_numpy(scalar, array.tensor.numpy())
208            except TypeError:
209                # TypeError: ufunc 'hypot' not supported for the input types
210                result_numpy = None
211
212            if result is not None and result_numpy is not None:
213                assert result.tensor.numpy().dtype == result_numpy.dtype
214
215        finally:
216            _np._set_promotion_state(state)
217
218
219if __name__ == "__main__":
220    run_tests()
221