1# Owner(s): ["module: dynamo"] 2 3"""Light smoke test switching between numpy to pytorch random streams. 4""" 5from contextlib import contextmanager 6from functools import partial 7 8import numpy as _np 9import pytest 10 11import torch._dynamo.config as config 12import torch._numpy as tnp 13from torch._numpy.testing import assert_equal 14from torch.testing._internal.common_utils import ( 15 instantiate_parametrized_tests, 16 parametrize, 17 run_tests, 18 subtest, 19 TestCase, 20) 21 22 23@contextmanager 24def control_stream(use_numpy=False): 25 with config.patch(use_numpy_random_stream=use_numpy): 26 yield 27 28 29@instantiate_parametrized_tests 30class TestScalarReturn(TestCase): 31 @parametrize("use_numpy", [True, False]) 32 @parametrize( 33 "func", 34 [ 35 tnp.random.normal, 36 tnp.random.rand, 37 partial(tnp.random.randint, 0, 5), 38 tnp.random.randn, 39 subtest(tnp.random.random, name="random_random"), 40 subtest(tnp.random.random_sample, name="random_sample"), 41 tnp.random.sample, 42 tnp.random.uniform, 43 ], 44 ) 45 def test_rndm_scalar(self, func, use_numpy): 46 # default `size` means a python scalar return 47 with control_stream(use_numpy): 48 r = func() 49 assert isinstance(r, (int, float)) 50 51 @parametrize("use_numpy", [True, False]) 52 @parametrize( 53 "func", 54 [ 55 tnp.random.normal, 56 tnp.random.rand, 57 partial(tnp.random.randint, 0, 5), 58 tnp.random.randn, 59 subtest(tnp.random.random, name="random_random"), 60 subtest(tnp.random.random_sample, name="random_sample"), 61 tnp.random.sample, 62 tnp.random.uniform, 63 ], 64 ) 65 def test_rndm_array(self, func, use_numpy): 66 with control_stream(use_numpy): 67 if func in (tnp.random.rand, tnp.random.randn): 68 r = func(10) 69 else: 70 r = func(size=10) 71 assert isinstance(r, tnp.ndarray) 72 73 74@instantiate_parametrized_tests 75class TestShuffle(TestCase): 76 @parametrize("use_numpy", [True, False]) 77 def test_1d(self, use_numpy): 78 ax = tnp.asarray([1, 2, 3, 4, 5, 6]) 79 ox = ax.copy() 80 81 tnp.random.seed(1234) 82 tnp.random.shuffle(ax) 83 84 assert isinstance(ax, tnp.ndarray) 85 assert not (ax == ox).all() 86 87 @parametrize("use_numpy", [True, False]) 88 def test_2d(self, use_numpy): 89 # np.shuffle only shuffles the first axis 90 ax = tnp.asarray([[1, 2, 3], [4, 5, 6]]) 91 ox = ax.copy() 92 93 tnp.random.seed(1234) 94 tnp.random.shuffle(ax) 95 96 assert isinstance(ax, tnp.ndarray) 97 assert not (ax == ox).all() 98 99 @parametrize("use_numpy", [True, False]) 100 def test_shuffle_list(self, use_numpy): 101 # on eager, we refuse to shuffle lists 102 # under dynamo, we always fall back to numpy 103 # NB: this means that the random stream is different for 104 # shuffling a list or an array when USE_NUMPY_STREAM == False 105 x = [1, 2, 3] 106 with pytest.raises(NotImplementedError): 107 tnp.random.shuffle(x) 108 109 110@instantiate_parametrized_tests 111class TestChoice(TestCase): 112 @parametrize("use_numpy", [True, False]) 113 def test_choice(self, use_numpy): 114 kwds = dict(size=3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) 115 with control_stream(use_numpy): 116 tnp.random.seed(12345) 117 x = tnp.random.choice(5, **kwds) 118 tnp.random.seed(12345) 119 x_1 = tnp.random.choice(tnp.arange(5), **kwds) 120 assert_equal(x, x_1) 121 122 123class TestNumpyGlobal(TestCase): 124 def test_numpy_global(self): 125 with control_stream(use_numpy=True): 126 tnp.random.seed(12345) 127 x = tnp.random.uniform(0, 1, size=11) 128 129 # check that the stream is identical to numpy's 130 _np.random.seed(12345) 131 x_np = _np.random.uniform(0, 1, size=11) 132 assert_equal(x, tnp.asarray(x_np)) 133 134 # switch to the pytorch stream, variates differ 135 with control_stream(use_numpy=False): 136 tnp.random.seed(12345) 137 x_1 = tnp.random.uniform(0, 1, size=11) 138 139 assert not (x_1 == x).all() 140 141 142if __name__ == "__main__": 143 run_tests() 144