• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: dynamo"]
2
3import sys
4
5import pytest
6
7import torch._numpy as tnp
8
9
10def pytest_configure(config):
11    config.addinivalue_line("markers", "slow: very slow tests")
12
13
14def pytest_addoption(parser):
15    parser.addoption("--runslow", action="store_true", help="run slow tests")
16    parser.addoption("--nonp", action="store_true", help="error when NumPy is accessed")
17
18
19class Inaccessible:
20    def __getattribute__(self, attr):
21        raise RuntimeError(f"Using --nonp but accessed np.{attr}")
22
23
24def pytest_sessionstart(session):
25    if session.config.getoption("--nonp"):
26        sys.modules["numpy"] = Inaccessible()
27
28
29def pytest_generate_tests(metafunc):
30    """
31    Hook to parametrize test cases
32    See https://docs.pytest.org/en/6.2.x/parametrize.html#pytest-generate-tests
33
34    The logic here allows us to test with both NumPy-proper and torch._numpy.
35    Normally we'd just test torch._numpy, e.g.
36
37        import torch._numpy as np
38        ...
39        def test_foo():
40            np.array([42])
41            ...
42
43    but this hook allows us to test NumPy-proper as well, e.g.
44
45        def test_foo(np):
46            np.array([42])
47            ...
48
49    np is a pytest parameter, which is either NumPy-proper or torch._numpy. This
50    allows us to sanity check our own tests, so that tested behaviour is
51    consistent with NumPy-proper.
52
53    pytest will have test names respective to the library being tested, e.g.
54
55        $ pytest --collect-only
56        test_foo[torch._numpy]
57        test_foo[numpy]
58
59    """
60    np_params = [tnp]
61
62    try:
63        import numpy as np
64    except ImportError:
65        pass
66    else:
67        if not isinstance(np, Inaccessible):  # i.e. --nonp was used
68            np_params.append(np)
69
70    if "np" in metafunc.fixturenames:
71        metafunc.parametrize("np", np_params)
72
73
74def pytest_collection_modifyitems(config, items):
75    if not config.getoption("--runslow"):
76        skip_slow = pytest.mark.skip(reason="slow test, use --runslow to run")
77        for item in items:
78            if "slow" in item.keywords:
79                item.add_marker(skip_slow)
80