1import sys 2import textwrap 3from pathlib import Path 4 5 6def check(path): 7 """Check a test file for common issues with pytest->pytorch conversion.""" 8 print(path.name) 9 print("=" * len(path.name), "\n") 10 11 src = path.read_text().split("\n") 12 for num, line in enumerate(src): 13 if is_comment(line): 14 continue 15 16 # module level test functions 17 if line.startswith("def test"): 18 report_violation(line, num, header="Module-level test function") 19 20 # test classes must inherit from TestCase 21 if line.startswith("class Test") and "TestCase" not in line: 22 report_violation( 23 line, num, header="Test class does not inherit from TestCase" 24 ) 25 26 # last vestiges of pytest-specific stuff 27 if "pytest.mark" in line: 28 report_violation(line, num, header="pytest.mark.something") 29 30 for part in ["pytest.xfail", "pytest.skip", "pytest.param"]: 31 if part in line: 32 report_violation(line, num, header=f"stray {part}") 33 34 if textwrap.dedent(line).startswith("@parametrize"): 35 # backtrack to check 36 nn = num 37 for nn in range(num, -1, -1): 38 ln = src[nn] 39 if "class Test" in ln: 40 # hack: large indent => likely an inner class 41 if len(ln) - len(ln.lstrip()) < 8: 42 break 43 else: 44 report_violation(line, num, "off-class parametrize") 45 if not src[nn - 1].startswith("@instantiate_parametrized_tests"): 46 report_violation( 47 line, num, f"missing instantiation of parametrized tests in {ln}?" 48 ) 49 50 51def is_comment(line): 52 return textwrap.dedent(line).startswith("#") 53 54 55def report_violation(line, lineno, header): 56 print(f">>>> line {lineno} : {header}\n {line}\n") 57 58 59if __name__ == "__main__": 60 argv = sys.argv 61 if len(argv) != 2: 62 raise ValueError("Usage : python check_tests_conform path/to/file/or/dir") 63 64 path = Path(argv[1]) 65 66 if path.is_dir(): 67 # run for all files in the directory (no subdirs) 68 for this_path in path.glob("test*.py"): 69 # breakpoint() 70 check(this_path) 71 else: 72 check(path) 73