1# Owner(s): ["module: mtia"] 2 3import os 4import shutil 5import sys 6import tempfile 7import unittest 8 9import torch 10import torch.testing._internal.common_utils as common 11import torch.utils.cpp_extension 12from torch.testing._internal.common_utils import ( 13 IS_ARM64, 14 IS_LINUX, 15 skipIfTorchDynamo, 16 TEST_CUDA, 17 TEST_MPS, 18 TEST_PRIVATEUSE1, 19 TEST_XPU, 20) 21from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME 22 23 24# define TEST_ROCM before changing TEST_CUDA 25TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None 26TEST_CUDA = TEST_CUDA and CUDA_HOME is not None 27 28 29def remove_build_path(): 30 if sys.platform == "win32": 31 # Not wiping extensions build folder because Windows 32 return 33 default_build_root = torch.utils.cpp_extension.get_default_build_root() 34 if os.path.exists(default_build_root): 35 shutil.rmtree(default_build_root, ignore_errors=True) 36 37 38# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other. 39# The test will be skipped if any of the following conditions are met: 40@unittest.skipIf( 41 IS_ARM64 42 or not IS_LINUX 43 or TEST_CUDA 44 or TEST_XPU 45 or TEST_MPS 46 or TEST_PRIVATEUSE1 47 or TEST_ROCM, 48 "Only on linux platform and mutual exclusive to other backends", 49) 50@torch.testing._internal.common_utils.markDynamoStrictTest 51class TestCppExtensionStreamAndEvent(common.TestCase): 52 """Tests Stream and Event with C++ extensions.""" 53 54 module = None 55 56 def setUp(self): 57 super().setUp() 58 # cpp extensions use relative paths. Those paths are relative to 59 # this file, so we'll change the working directory temporarily 60 self.old_working_dir = os.getcwd() 61 os.chdir(os.path.dirname(os.path.abspath(__file__))) 62 63 def tearDown(self): 64 super().tearDown() 65 # return the working directory (see setUp) 66 os.chdir(self.old_working_dir) 67 68 @classmethod 69 def tearDownClass(cls): 70 remove_build_path() 71 72 @classmethod 73 def setUpClass(cls): 74 remove_build_path() 75 build_dir = tempfile.mkdtemp() 76 # Load the fake device guard impl. 77 src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp" 78 cls.module = torch.utils.cpp_extension.load( 79 name="mtia_extension", 80 sources=[src], 81 build_directory=build_dir, 82 extra_include_paths=[ 83 "cpp_extensions", 84 "path / with spaces in it", 85 "path with quote'", 86 ], 87 is_python_module=False, 88 verbose=True, 89 ) 90 91 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 92 def test_stream_event(self): 93 s = torch.Stream() 94 self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA)) 95 e = torch.Event() 96 self.assertTrue(e.device.type, "mtia") 97 # Should be nullptr by default 98 self.assertTrue(e.event_id == 0) 99 s.record_event(event=e) 100 print(f"recorded event 1: {e}") 101 self.assertTrue(e.event_id != 0) 102 e2 = s.record_event() 103 print(f"recorded event 2: {e2}") 104 self.assertTrue(e2.event_id != 0) 105 self.assertTrue(e2.event_id != e.event_id) 106 e.synchronize() 107 e2.synchronize() 108 time_elapsed = e.elapsed_time(e2) 109 print(f"time elapsed between e1 and e2: {time_elapsed}") 110 old_event_id = e.event_id 111 e.record(stream=s) 112 print(f"recorded event 1: {e}") 113 self.assertTrue(e.event_id == old_event_id) 114 115 116if __name__ == "__main__": 117 common.run_tests() 118