# Owner(s): ["module: mtia"] import os import shutil import sys import tempfile import unittest import torch import torch.testing._internal.common_utils as common import torch.utils.cpp_extension from torch.testing._internal.common_utils import ( IS_ARM64, IS_LINUX, skipIfTorchDynamo, TEST_CUDA, TEST_MPS, TEST_PRIVATEUSE1, TEST_XPU, ) from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME # define TEST_ROCM before changing TEST_CUDA TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None TEST_CUDA = TEST_CUDA and CUDA_HOME is not None def remove_build_path(): if sys.platform == "win32": # Not wiping extensions build folder because Windows return default_build_root = torch.utils.cpp_extension.get_default_build_root() if os.path.exists(default_build_root): shutil.rmtree(default_build_root, ignore_errors=True) # Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other. # The test will be skipped if any of the following conditions are met: @unittest.skipIf( IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_XPU or TEST_MPS or TEST_PRIVATEUSE1 or TEST_ROCM, "Only on linux platform and mutual exclusive to other backends", ) @torch.testing._internal.common_utils.markDynamoStrictTest class TestCppExtensionStreamAndEvent(common.TestCase): """Tests Stream and Event with C++ extensions.""" module = None def setUp(self): super().setUp() # cpp extensions use relative paths. Those paths are relative to # this file, so we'll change the working directory temporarily self.old_working_dir = os.getcwd() os.chdir(os.path.dirname(os.path.abspath(__file__))) def tearDown(self): super().tearDown() # return the working directory (see setUp) os.chdir(self.old_working_dir) @classmethod def tearDownClass(cls): remove_build_path() @classmethod def setUpClass(cls): remove_build_path() build_dir = tempfile.mkdtemp() # Load the fake device guard impl. src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp" cls.module = torch.utils.cpp_extension.load( name="mtia_extension", sources=[src], build_directory=build_dir, extra_include_paths=[ "cpp_extensions", "path / with spaces in it", "path with quote'", ], is_python_module=False, verbose=True, ) @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_stream_event(self): s = torch.Stream() self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA)) e = torch.Event() self.assertTrue(e.device.type, "mtia") # Should be nullptr by default self.assertTrue(e.event_id == 0) s.record_event(event=e) print(f"recorded event 1: {e}") self.assertTrue(e.event_id != 0) e2 = s.record_event() print(f"recorded event 2: {e2}") self.assertTrue(e2.event_id != 0) self.assertTrue(e2.event_id != e.event_id) e.synchronize() e2.synchronize() time_elapsed = e.elapsed_time(e2) print(f"time elapsed between e1 and e2: {time_elapsed}") old_event_id = e.event_id e.record(stream=s) print(f"recorded event 1: {e}") self.assertTrue(e.event_id == old_event_id) if __name__ == "__main__": common.run_tests()