# Owner(s): ["module: inductor"] import operator import os from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, SubprocException, SubprocPool, ) from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import HAS_CPU class TestCompileWorker(TestCase): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_basic_jobs(self): pool = SubprocPool(2) try: a = pool.submit(operator.add, 100, 1) b = pool.submit(operator.sub, 100, 1) self.assertEqual(a.result(), 101) self.assertEqual(b.result(), 99) finally: pool.shutdown() @skipIfWindows(msg="pass_fds not supported on Windows.") def test_exception(self): pool = SubprocPool(2) try: a = pool.submit(raise_testexc) with self.assertRaisesRegex( SubprocException, "torch._inductor.compile_worker.subproc_pool.TestException", ): a.result() finally: pool.shutdown() @skipIfWindows(msg="pass_fds not supported on Windows.") def test_crash(self): pool = SubprocPool(2) try: with self.assertRaises(Exception): a = pool.submit(os._exit, 1) a.result() # Pool should still be usable after a crash b = pool.submit(operator.add, 100, 1) c = pool.submit(operator.sub, 100, 1) self.assertEqual(b.result(), 101) self.assertEqual(c.result(), 99) finally: pool.shutdown() if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU: run_tests()