1# Owner(s): ["module: dynamo"] 2 3"""Test examples for NEP 50.""" 4 5import itertools 6from unittest import skipIf as skipif, SkipTest 7 8 9try: 10 import numpy as _np 11 12 v = _np.__version__.split(".") 13 HAVE_NUMPY = int(v[0]) >= 1 and int(v[1]) >= 24 14except ImportError: 15 HAVE_NUMPY = False 16 17import torch._numpy as tnp 18from torch._numpy import ( # noqa: F401 19 array, 20 bool_, 21 complex128, 22 complex64, 23 float32, 24 float64, 25 inf, 26 int16, 27 int32, 28 int64, 29 uint8, 30) 31from torch._numpy.testing import assert_allclose 32from torch.testing._internal.common_utils import ( 33 instantiate_parametrized_tests, 34 parametrize, 35 run_tests, 36 TestCase, 37) 38 39 40uint16 = uint8 # can be anything here, see below 41 42 43# from numpy import array, uint8, uint16, int64, float32, float64, inf 44# from numpy.testing import assert_allclose 45# import numpy as np 46# np._set_promotion_state('weak') 47 48from pytest import raises as assert_raises 49 50 51unchanged = None 52 53# expression old result new_result 54examples = { 55 "uint8(1) + 2": (int64(3), uint8(3)), 56 "array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)), 57 "array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)), 58 "array([1.], float32) + float64(1.)": ( 59 array([2.0], float32), 60 array([2.0], float64), 61 ), 62 "array([1.], float32) + array(1., float64)": ( 63 array([2.0], float32), 64 array([2.0], float64), 65 ), 66 "array([1], uint8) + 1": (array([2], uint8), unchanged), 67 "array([1], uint8) + 200": (array([201], uint8), unchanged), 68 "array([100], uint8) + 200": (array([44], uint8), unchanged), 69 "array([1], uint8) + 300": (array([301], uint16), Exception), 70 "uint8(1) + 300": (int64(301), Exception), 71 "uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning 72 "float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7] 73 "array([1.0], float32) + 1e-14 == 1.0": (array([True]), unchanged), 74 "array([0.1], float32) == float64(0.1)": (array([True]), array([False])), 75 "array(1.0, float32) + 1e-14 == 1.0": (array(False), array(True)), 76 "array([1.], float32) + 3": (array([4.0], float32), unchanged), 77 "array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)), 78 "3j + array(3, complex64)": (array(3 + 3j, complex128), array(3 + 3j, complex64)), 79 "float32(1) + 1j": (array(1 + 1j, complex128), array(1 + 1j, complex64)), 80 "int32(1) + 5j": (array(1 + 5j, complex128), unchanged), 81 # additional examples from the NEP text 82 "int16(2) + 2": (int64(4), int16(4)), 83 "int16(4) + 4j": (complex128(4 + 4j), unchanged), 84 "float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)), 85 "bool_(True) + 1": (int64(2), unchanged), 86 "True + uint8(2)": (uint8(3), unchanged), 87} 88 89 90@skipif(not HAVE_NUMPY, reason="NumPy not found") 91@instantiate_parametrized_tests 92class TestNEP50Table(TestCase): 93 @parametrize("example", examples) 94 def test_nep50_exceptions(self, example): 95 old, new = examples[example] 96 97 if new == Exception: 98 with assert_raises(OverflowError): 99 eval(example) 100 101 else: 102 result = eval(example) 103 104 if new is unchanged: 105 new = old 106 107 assert_allclose(result, new, atol=1e-16) 108 assert result.dtype == new.dtype 109 110 111# ### Directly compare to numpy ### 112 113weaks = (True, 1, 2.0, 3j) 114non_weaks = ( 115 tnp.asarray(True), 116 tnp.uint8(1), 117 tnp.int8(1), 118 tnp.int32(1), 119 tnp.int64(1), 120 tnp.float32(1), 121 tnp.float64(1), 122 tnp.complex64(1), 123 tnp.complex128(1), 124) 125if HAVE_NUMPY: 126 dtypes = ( 127 None, 128 _np.bool_, 129 _np.uint8, 130 _np.int8, 131 _np.int32, 132 _np.int64, 133 _np.float32, 134 _np.float64, 135 _np.complex64, 136 _np.complex128, 137 ) 138else: 139 dtypes = (None,) 140 141 142# ufunc name: [array.dtype] 143corners = { 144 "true_divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"], 145 "divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"], 146 "arctan2": ["bool_", "uint8", "int8", "int16", "int32", "int64"], 147 "copysign": ["bool_", "uint8", "int8", "int16", "int32", "int64"], 148 "heaviside": ["bool_", "uint8", "int8", "int16", "int32", "int64"], 149 "ldexp": ["bool_", "uint8", "int8", "int16", "int32", "int64"], 150 "power": ["uint8"], 151 "nextafter": ["float32"], 152} 153 154 155@skipif(not HAVE_NUMPY, reason="NumPy not found") 156@instantiate_parametrized_tests 157class TestCompareToNumpy(TestCase): 158 @parametrize("scalar, array, dtype", itertools.product(weaks, non_weaks, dtypes)) 159 def test_direct_compare(self, scalar, array, dtype): 160 # compare to NumPy w/ NEP 50. 161 try: 162 state = _np._get_promotion_state() 163 _np._set_promotion_state("weak") 164 165 if dtype is not None: 166 kwargs = {"dtype": dtype} 167 try: 168 result_numpy = _np.add(scalar, array.tensor.numpy(), **kwargs) 169 except Exception: 170 return 171 172 kwargs = {} 173 if dtype is not None: 174 kwargs = {"dtype": getattr(tnp, dtype.__name__)} 175 result = tnp.add(scalar, array, **kwargs).tensor.numpy() 176 assert result.dtype == result_numpy.dtype 177 assert result == result_numpy 178 179 finally: 180 _np._set_promotion_state(state) 181 182 @parametrize("name", tnp._ufuncs._binary) 183 @parametrize("scalar, array", itertools.product(weaks, non_weaks)) 184 def test_compare_ufuncs(self, name, scalar, array): 185 if name in corners and ( 186 array.dtype.name in corners[name] 187 or tnp.asarray(scalar).dtype.name in corners[name] 188 ): 189 raise SkipTest(f"{name}(..., dtype=array.dtype)") 190 191 try: 192 state = _np._get_promotion_state() 193 _np._set_promotion_state("weak") 194 195 if name in ["matmul", "modf", "divmod", "ldexp"]: 196 return 197 ufunc = getattr(tnp, name) 198 ufunc_numpy = getattr(_np, name) 199 200 try: 201 result = ufunc(scalar, array) 202 except RuntimeError: 203 # RuntimeError: "bitwise_xor_cpu" not implemented for 'ComplexDouble' etc 204 result = None 205 206 try: 207 result_numpy = ufunc_numpy(scalar, array.tensor.numpy()) 208 except TypeError: 209 # TypeError: ufunc 'hypot' not supported for the input types 210 result_numpy = None 211 212 if result is not None and result_numpy is not None: 213 assert result.tensor.numpy().dtype == result_numpy.dtype 214 215 finally: 216 _np._set_promotion_state(state) 217 218 219if __name__ == "__main__": 220 run_tests() 221