• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: dynamo"]
2
3"""Light smoke test switching between numpy to pytorch random streams.
4"""
5from contextlib import contextmanager
6from functools import partial
7
8import numpy as _np
9import pytest
10
11import torch._dynamo.config as config
12import torch._numpy as tnp
13from torch._numpy.testing import assert_equal
14from torch.testing._internal.common_utils import (
15    instantiate_parametrized_tests,
16    parametrize,
17    run_tests,
18    subtest,
19    TestCase,
20)
21
22
23@contextmanager
24def control_stream(use_numpy=False):
25    with config.patch(use_numpy_random_stream=use_numpy):
26        yield
27
28
29@instantiate_parametrized_tests
30class TestScalarReturn(TestCase):
31    @parametrize("use_numpy", [True, False])
32    @parametrize(
33        "func",
34        [
35            tnp.random.normal,
36            tnp.random.rand,
37            partial(tnp.random.randint, 0, 5),
38            tnp.random.randn,
39            subtest(tnp.random.random, name="random_random"),
40            subtest(tnp.random.random_sample, name="random_sample"),
41            tnp.random.sample,
42            tnp.random.uniform,
43        ],
44    )
45    def test_rndm_scalar(self, func, use_numpy):
46        # default `size` means a python scalar return
47        with control_stream(use_numpy):
48            r = func()
49        assert isinstance(r, (int, float))
50
51    @parametrize("use_numpy", [True, False])
52    @parametrize(
53        "func",
54        [
55            tnp.random.normal,
56            tnp.random.rand,
57            partial(tnp.random.randint, 0, 5),
58            tnp.random.randn,
59            subtest(tnp.random.random, name="random_random"),
60            subtest(tnp.random.random_sample, name="random_sample"),
61            tnp.random.sample,
62            tnp.random.uniform,
63        ],
64    )
65    def test_rndm_array(self, func, use_numpy):
66        with control_stream(use_numpy):
67            if func in (tnp.random.rand, tnp.random.randn):
68                r = func(10)
69            else:
70                r = func(size=10)
71        assert isinstance(r, tnp.ndarray)
72
73
74@instantiate_parametrized_tests
75class TestShuffle(TestCase):
76    @parametrize("use_numpy", [True, False])
77    def test_1d(self, use_numpy):
78        ax = tnp.asarray([1, 2, 3, 4, 5, 6])
79        ox = ax.copy()
80
81        tnp.random.seed(1234)
82        tnp.random.shuffle(ax)
83
84        assert isinstance(ax, tnp.ndarray)
85        assert not (ax == ox).all()
86
87    @parametrize("use_numpy", [True, False])
88    def test_2d(self, use_numpy):
89        # np.shuffle only shuffles the first axis
90        ax = tnp.asarray([[1, 2, 3], [4, 5, 6]])
91        ox = ax.copy()
92
93        tnp.random.seed(1234)
94        tnp.random.shuffle(ax)
95
96        assert isinstance(ax, tnp.ndarray)
97        assert not (ax == ox).all()
98
99    @parametrize("use_numpy", [True, False])
100    def test_shuffle_list(self, use_numpy):
101        # on eager, we refuse to shuffle lists
102        # under dynamo, we always fall back to numpy
103        # NB: this means that the random stream is different for
104        # shuffling a list or an array when USE_NUMPY_STREAM == False
105        x = [1, 2, 3]
106        with pytest.raises(NotImplementedError):
107            tnp.random.shuffle(x)
108
109
110@instantiate_parametrized_tests
111class TestChoice(TestCase):
112    @parametrize("use_numpy", [True, False])
113    def test_choice(self, use_numpy):
114        kwds = dict(size=3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
115        with control_stream(use_numpy):
116            tnp.random.seed(12345)
117            x = tnp.random.choice(5, **kwds)
118            tnp.random.seed(12345)
119            x_1 = tnp.random.choice(tnp.arange(5), **kwds)
120            assert_equal(x, x_1)
121
122
123class TestNumpyGlobal(TestCase):
124    def test_numpy_global(self):
125        with control_stream(use_numpy=True):
126            tnp.random.seed(12345)
127            x = tnp.random.uniform(0, 1, size=11)
128
129        # check that the stream is identical to numpy's
130        _np.random.seed(12345)
131        x_np = _np.random.uniform(0, 1, size=11)
132        assert_equal(x, tnp.asarray(x_np))
133
134        # switch to the pytorch stream, variates differ
135        with control_stream(use_numpy=False):
136            tnp.random.seed(12345)
137            x_1 = tnp.random.uniform(0, 1, size=11)
138
139        assert not (x_1 == x).all()
140
141
142if __name__ == "__main__":
143    run_tests()
144