• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import sys
2import textwrap
3from pathlib import Path
4
5
6def check(path):
7    """Check a test file for common issues with pytest->pytorch conversion."""
8    print(path.name)
9    print("=" * len(path.name), "\n")
10
11    src = path.read_text().split("\n")
12    for num, line in enumerate(src):
13        if is_comment(line):
14            continue
15
16        # module level test functions
17        if line.startswith("def test"):
18            report_violation(line, num, header="Module-level test function")
19
20        # test classes must inherit from TestCase
21        if line.startswith("class Test") and "TestCase" not in line:
22            report_violation(
23                line, num, header="Test class does not inherit from TestCase"
24            )
25
26        # last vestiges of pytest-specific stuff
27        if "pytest.mark" in line:
28            report_violation(line, num, header="pytest.mark.something")
29
30        for part in ["pytest.xfail", "pytest.skip", "pytest.param"]:
31            if part in line:
32                report_violation(line, num, header=f"stray {part}")
33
34        if textwrap.dedent(line).startswith("@parametrize"):
35            # backtrack to check
36            nn = num
37            for nn in range(num, -1, -1):
38                ln = src[nn]
39                if "class Test" in ln:
40                    # hack: large indent => likely an inner class
41                    if len(ln) - len(ln.lstrip()) < 8:
42                        break
43            else:
44                report_violation(line, num, "off-class parametrize")
45            if not src[nn - 1].startswith("@instantiate_parametrized_tests"):
46                report_violation(
47                    line, num, f"missing instantiation of parametrized tests in {ln}?"
48                )
49
50
51def is_comment(line):
52    return textwrap.dedent(line).startswith("#")
53
54
55def report_violation(line, lineno, header):
56    print(f">>>> line {lineno} : {header}\n {line}\n")
57
58
59if __name__ == "__main__":
60    argv = sys.argv
61    if len(argv) != 2:
62        raise ValueError("Usage : python check_tests_conform path/to/file/or/dir")
63
64    path = Path(argv[1])
65
66    if path.is_dir():
67        # run for all files in the directory (no subdirs)
68        for this_path in path.glob("test*.py"):
69            #   breakpoint()
70            check(this_path)
71    else:
72        check(path)
73