• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: dynamo"]
2
3"""
4Basic tests to assert and illustrate the  behavior around the decision to use 0D
5arrays in place of array scalars.
6
7Extensive tests of this sort of functionality is in numpy_tests/core/*scalar*
8
9Also test the isscalar function (which is deliberately a bit more lax).
10"""
11
12from torch.testing._internal.common_utils import (
13    instantiate_parametrized_tests,
14    parametrize,
15    run_tests,
16    subtest,
17    TEST_WITH_TORCHDYNAMO,
18    TestCase,
19    xfailIfTorchDynamo,
20)
21
22
23if TEST_WITH_TORCHDYNAMO:
24    import numpy as np
25    from numpy.testing import assert_equal
26else:
27    import torch._numpy as np
28    from torch._numpy.testing import assert_equal
29
30
31parametrize_value = parametrize(
32    "value",
33    [
34        subtest(np.int64(42), name="int64"),
35        subtest(np.array(42), name="array"),
36        subtest(np.asarray(42), name="asarray"),
37        subtest(np.asarray(np.int64(42)), name="asarray_int"),
38    ],
39)
40
41
42@instantiate_parametrized_tests
43class TestArrayScalars(TestCase):
44    @parametrize_value
45    def test_array_scalar_basic(self, value):
46        assert value.ndim == 0
47        assert value.shape == ()
48        assert value.size == 1
49        assert value.dtype == np.dtype("int64")
50
51    @parametrize_value
52    def test_conversion_to_int(self, value):
53        py_scalar = int(value)
54        assert py_scalar == 42
55        assert isinstance(py_scalar, int)
56        assert not isinstance(value, int)
57
58    @parametrize_value
59    def test_decay_to_py_scalar(self, value):
60        # NumPy distinguishes array scalars and 0D arrays. For instance
61        # `scalar * list` is equivalent to `int(scalar) * list`, but
62        # `0D array * list` is equivalent to `0D array * np.asarray(list)`.
63        # Our scalars follow 0D array behavior (because they are 0D arrays)
64        lst = [1, 2, 3]
65
66        product = value * lst
67        assert isinstance(product, np.ndarray)
68        assert product.shape == (3,)
69        assert_equal(product, [42, 42 * 2, 42 * 3])
70
71        # repeat with right-mulitply
72        product = lst * value
73        assert isinstance(product, np.ndarray)
74        assert product.shape == (3,)
75        assert_equal(product, [42, 42 * 2, 42 * 3])
76
77    def test_scalar_comparisons(self):
78        scalar = np.int64(42)
79        arr = np.array(42)
80
81        assert arr == scalar
82        assert arr >= scalar
83        assert arr <= scalar
84
85        assert scalar == 42
86        assert arr == 42
87
88
89# @xfailIfTorchDynamo
90@instantiate_parametrized_tests
91class TestIsScalar(TestCase):
92    #
93    # np.isscalar(...) checks that its argument is a numeric object with exactly one element.
94    #
95    # This differs from NumPy which also requires that shape == ().
96    #
97    scalars = [
98        subtest(42, "literal"),
99        subtest(int(42.0), "int"),
100        subtest(np.float32(42), "float32"),
101        subtest(np.array(42), "array_0D", decorators=[xfailIfTorchDynamo]),
102        subtest([42], "list", decorators=[xfailIfTorchDynamo]),
103        subtest([[42]], "list-list", decorators=[xfailIfTorchDynamo]),
104        subtest(np.array([42]), "array_1D", decorators=[xfailIfTorchDynamo]),
105        subtest(np.array([[42]]), "array_2D", decorators=[xfailIfTorchDynamo]),
106    ]
107
108    import math
109
110    not_scalars = [
111        int,
112        np.float32,
113        subtest("s", decorators=[xfailIfTorchDynamo]),
114        subtest("string", decorators=[xfailIfTorchDynamo]),
115        (),
116        [],
117        math.sin,
118        np,
119        np.transpose,
120        [1, 2],
121        np.asarray([1, 2]),
122        np.float32([1, 2]),
123    ]
124
125    @parametrize("value", scalars)
126    def test_is_scalar(self, value):
127        assert np.isscalar(value)
128
129    @parametrize("value", not_scalars)
130    def test_is_not_scalar(self, value):
131        assert not np.isscalar(value)
132
133
134if __name__ == "__main__":
135    run_tests()
136