1# Owner(s): ["module: dynamo"] 2 3""" 4Basic tests to assert and illustrate the behavior around the decision to use 0D 5arrays in place of array scalars. 6 7Extensive tests of this sort of functionality is in numpy_tests/core/*scalar* 8 9Also test the isscalar function (which is deliberately a bit more lax). 10""" 11 12from torch.testing._internal.common_utils import ( 13 instantiate_parametrized_tests, 14 parametrize, 15 run_tests, 16 subtest, 17 TEST_WITH_TORCHDYNAMO, 18 TestCase, 19 xfailIfTorchDynamo, 20) 21 22 23if TEST_WITH_TORCHDYNAMO: 24 import numpy as np 25 from numpy.testing import assert_equal 26else: 27 import torch._numpy as np 28 from torch._numpy.testing import assert_equal 29 30 31parametrize_value = parametrize( 32 "value", 33 [ 34 subtest(np.int64(42), name="int64"), 35 subtest(np.array(42), name="array"), 36 subtest(np.asarray(42), name="asarray"), 37 subtest(np.asarray(np.int64(42)), name="asarray_int"), 38 ], 39) 40 41 42@instantiate_parametrized_tests 43class TestArrayScalars(TestCase): 44 @parametrize_value 45 def test_array_scalar_basic(self, value): 46 assert value.ndim == 0 47 assert value.shape == () 48 assert value.size == 1 49 assert value.dtype == np.dtype("int64") 50 51 @parametrize_value 52 def test_conversion_to_int(self, value): 53 py_scalar = int(value) 54 assert py_scalar == 42 55 assert isinstance(py_scalar, int) 56 assert not isinstance(value, int) 57 58 @parametrize_value 59 def test_decay_to_py_scalar(self, value): 60 # NumPy distinguishes array scalars and 0D arrays. For instance 61 # `scalar * list` is equivalent to `int(scalar) * list`, but 62 # `0D array * list` is equivalent to `0D array * np.asarray(list)`. 63 # Our scalars follow 0D array behavior (because they are 0D arrays) 64 lst = [1, 2, 3] 65 66 product = value * lst 67 assert isinstance(product, np.ndarray) 68 assert product.shape == (3,) 69 assert_equal(product, [42, 42 * 2, 42 * 3]) 70 71 # repeat with right-mulitply 72 product = lst * value 73 assert isinstance(product, np.ndarray) 74 assert product.shape == (3,) 75 assert_equal(product, [42, 42 * 2, 42 * 3]) 76 77 def test_scalar_comparisons(self): 78 scalar = np.int64(42) 79 arr = np.array(42) 80 81 assert arr == scalar 82 assert arr >= scalar 83 assert arr <= scalar 84 85 assert scalar == 42 86 assert arr == 42 87 88 89# @xfailIfTorchDynamo 90@instantiate_parametrized_tests 91class TestIsScalar(TestCase): 92 # 93 # np.isscalar(...) checks that its argument is a numeric object with exactly one element. 94 # 95 # This differs from NumPy which also requires that shape == (). 96 # 97 scalars = [ 98 subtest(42, "literal"), 99 subtest(int(42.0), "int"), 100 subtest(np.float32(42), "float32"), 101 subtest(np.array(42), "array_0D", decorators=[xfailIfTorchDynamo]), 102 subtest([42], "list", decorators=[xfailIfTorchDynamo]), 103 subtest([[42]], "list-list", decorators=[xfailIfTorchDynamo]), 104 subtest(np.array([42]), "array_1D", decorators=[xfailIfTorchDynamo]), 105 subtest(np.array([[42]]), "array_2D", decorators=[xfailIfTorchDynamo]), 106 ] 107 108 import math 109 110 not_scalars = [ 111 int, 112 np.float32, 113 subtest("s", decorators=[xfailIfTorchDynamo]), 114 subtest("string", decorators=[xfailIfTorchDynamo]), 115 (), 116 [], 117 math.sin, 118 np, 119 np.transpose, 120 [1, 2], 121 np.asarray([1, 2]), 122 np.float32([1, 2]), 123 ] 124 125 @parametrize("value", scalars) 126 def test_is_scalar(self, value): 127 assert np.isscalar(value) 128 129 @parametrize("value", not_scalars) 130 def test_is_not_scalar(self, value): 131 assert not np.isscalar(value) 132 133 134if __name__ == "__main__": 135 run_tests() 136