• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import sys
2import unittest
3from pathlib import Path
4
5
6REPO_ROOT = Path(__file__).resolve().parent.parent.parent
7try:
8    # using tools/ to optimize test run.
9    sys.path.append(str(REPO_ROOT))
10    from tools.testing.test_run import ShardedTest, TestRun
11except ModuleNotFoundError:
12    print("Can't import required modules, exiting")
13    sys.exit(1)
14
15
16class TestTestRun(unittest.TestCase):
17    def test_union_with_full_run(self) -> None:
18        run1 = TestRun("foo")
19        run2 = TestRun("foo::bar")
20
21        self.assertEqual(run1 | run2, run1)
22        self.assertEqual(run2 | run1, run1)
23
24    def test_union_with_inclusions(self) -> None:
25        run1 = TestRun("foo::bar")
26        run2 = TestRun("foo::baz")
27
28        expected = TestRun("foo", included=["bar", "baz"])
29
30        self.assertEqual(run1 | run2, expected)
31        self.assertEqual(run2 | run1, expected)
32
33    def test_union_with_non_overlapping_exclusions(self) -> None:
34        run1 = TestRun("foo", excluded=["bar"])
35        run2 = TestRun("foo", excluded=["baz"])
36
37        expected = TestRun("foo")
38
39        self.assertEqual(run1 | run2, expected)
40        self.assertEqual(run2 | run1, expected)
41
42    def test_union_with_overlapping_exclusions(self) -> None:
43        run1 = TestRun("foo", excluded=["bar", "car"])
44        run2 = TestRun("foo", excluded=["bar", "caz"])
45
46        expected = TestRun("foo", excluded=["bar"])
47
48        self.assertEqual(run1 | run2, expected)
49        self.assertEqual(run2 | run1, expected)
50
51    def test_union_with_mixed_inclusion_exclusions(self) -> None:
52        run1 = TestRun("foo", excluded=["baz", "car"])
53        run2 = TestRun("foo", included=["baz"])
54
55        expected = TestRun("foo", excluded=["car"])
56
57        self.assertEqual(run1 | run2, expected)
58        self.assertEqual(run2 | run1, expected)
59
60    def test_union_with_mixed_files_fails(self) -> None:
61        run1 = TestRun("foo")
62        run2 = TestRun("bar")
63
64        with self.assertRaises(AssertionError):
65            run1 | run2
66
67    def test_union_with_empty_file_yields_orig_file(self) -> None:
68        run1 = TestRun("foo")
69        run2 = TestRun.empty()
70
71        self.assertEqual(run1 | run2, run1)
72        self.assertEqual(run2 | run1, run1)
73
74    def test_subtracting_full_run_fails(self) -> None:
75        run1 = TestRun("foo::bar")
76        run2 = TestRun("foo")
77
78        self.assertEqual(run1 - run2, TestRun.empty())
79
80    def test_subtracting_empty_file_yields_orig_file(self) -> None:
81        run1 = TestRun("foo")
82        run2 = TestRun.empty()
83
84        self.assertEqual(run1 - run2, run1)
85        self.assertEqual(run2 - run1, TestRun.empty())
86
87    def test_empty_is_falsey(self) -> None:
88        self.assertFalse(TestRun.empty())
89
90    def test_subtracting_inclusion_from_full_run(self) -> None:
91        run1 = TestRun("foo")
92        run2 = TestRun("foo::bar")
93
94        expected = TestRun("foo", excluded=["bar"])
95
96        self.assertEqual(run1 - run2, expected)
97
98    def test_subtracting_inclusion_from_overlapping_inclusion(self) -> None:
99        run1 = TestRun("foo", included=["bar", "baz"])
100        run2 = TestRun("foo::baz")
101
102        self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
103
104    def test_subtracting_inclusion_from_nonoverlapping_inclusion(self) -> None:
105        run1 = TestRun("foo", included=["bar", "baz"])
106        run2 = TestRun("foo", included=["car"])
107
108        self.assertEqual(run1 - run2, TestRun("foo", included=["bar", "baz"]))
109
110    def test_subtracting_exclusion_from_full_run(self) -> None:
111        run1 = TestRun("foo")
112        run2 = TestRun("foo", excluded=["bar"])
113
114        self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
115
116    def test_subtracting_exclusion_from_superset_exclusion(self) -> None:
117        run1 = TestRun("foo", excluded=["bar", "baz"])
118        run2 = TestRun("foo", excluded=["baz"])
119
120        self.assertEqual(run1 - run2, TestRun.empty())
121        self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
122
123    def test_subtracting_exclusion_from_nonoverlapping_exclusion(self) -> None:
124        run1 = TestRun("foo", excluded=["bar", "baz"])
125        run2 = TestRun("foo", excluded=["car"])
126
127        self.assertEqual(run1 - run2, TestRun("foo", included=["car"]))
128        self.assertEqual(run2 - run1, TestRun("foo", included=["bar", "baz"]))
129
130    def test_subtracting_inclusion_from_exclusion_without_overlaps(self) -> None:
131        run1 = TestRun("foo", excluded=["bar", "baz"])
132        run2 = TestRun("foo", included=["bar"])
133
134        self.assertEqual(run1 - run2, run1)
135        self.assertEqual(run2 - run1, run2)
136
137    def test_subtracting_inclusion_from_exclusion_with_overlaps(self) -> None:
138        run1 = TestRun("foo", excluded=["bar", "baz"])
139        run2 = TestRun("foo", included=["bar", "car"])
140
141        self.assertEqual(run1 - run2, TestRun("foo", excluded=["bar", "baz", "car"]))
142        self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
143
144    def test_and(self) -> None:
145        run1 = TestRun("foo", included=["bar", "baz"])
146        run2 = TestRun("foo", included=["bar", "car"])
147
148        self.assertEqual(run1 & run2, TestRun("foo", included=["bar"]))
149
150    def test_and_exclusions(self) -> None:
151        run1 = TestRun("foo", excluded=["bar", "baz"])
152        run2 = TestRun("foo", excluded=["bar", "car"])
153
154        self.assertEqual(run1 & run2, TestRun("foo", excluded=["bar", "baz", "car"]))
155
156
157class TestShardedTest(unittest.TestCase):
158    def test_get_pytest_args(self) -> None:
159        test = TestRun("foo", included=["bar", "baz"])
160        sharded_test = ShardedTest(test, 1, 1)
161
162        expected_args = ["-k", "bar or baz"]
163
164        self.assertListEqual(sharded_test.get_pytest_args(), expected_args)
165
166
167if __name__ == "__main__":
168    unittest.main()
169