• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: mtia"]
2
3import os
4import shutil
5import sys
6import tempfile
7import unittest
8
9import torch
10import torch.testing._internal.common_utils as common
11import torch.utils.cpp_extension
12from torch.testing._internal.common_utils import (
13    IS_ARM64,
14    IS_LINUX,
15    skipIfTorchDynamo,
16    TEST_CUDA,
17    TEST_MPS,
18    TEST_PRIVATEUSE1,
19    TEST_XPU,
20)
21from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
22
23
24# define TEST_ROCM before changing TEST_CUDA
25TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
26TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
27
28
29def remove_build_path():
30    if sys.platform == "win32":
31        # Not wiping extensions build folder because Windows
32        return
33    default_build_root = torch.utils.cpp_extension.get_default_build_root()
34    if os.path.exists(default_build_root):
35        shutil.rmtree(default_build_root, ignore_errors=True)
36
37
38# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
39# The test will be skipped if any of the following conditions are met:
40@unittest.skipIf(
41    IS_ARM64
42    or not IS_LINUX
43    or TEST_CUDA
44    or TEST_XPU
45    or TEST_MPS
46    or TEST_PRIVATEUSE1
47    or TEST_ROCM,
48    "Only on linux platform and mutual exclusive to other backends",
49)
50@torch.testing._internal.common_utils.markDynamoStrictTest
51class TestCppExtensionStreamAndEvent(common.TestCase):
52    """Tests Stream and Event with C++ extensions."""
53
54    module = None
55
56    def setUp(self):
57        super().setUp()
58        # cpp extensions use relative paths. Those paths are relative to
59        # this file, so we'll change the working directory temporarily
60        self.old_working_dir = os.getcwd()
61        os.chdir(os.path.dirname(os.path.abspath(__file__)))
62
63    def tearDown(self):
64        super().tearDown()
65        # return the working directory (see setUp)
66        os.chdir(self.old_working_dir)
67
68    @classmethod
69    def tearDownClass(cls):
70        remove_build_path()
71
72    @classmethod
73    def setUpClass(cls):
74        remove_build_path()
75        build_dir = tempfile.mkdtemp()
76        # Load the fake device guard impl.
77        src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp"
78        cls.module = torch.utils.cpp_extension.load(
79            name="mtia_extension",
80            sources=[src],
81            build_directory=build_dir,
82            extra_include_paths=[
83                "cpp_extensions",
84                "path / with spaces in it",
85                "path with quote'",
86            ],
87            is_python_module=False,
88            verbose=True,
89        )
90
91    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
92    def test_stream_event(self):
93        s = torch.Stream()
94        self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
95        e = torch.Event()
96        self.assertTrue(e.device.type, "mtia")
97        # Should be nullptr by default
98        self.assertTrue(e.event_id == 0)
99        s.record_event(event=e)
100        print(f"recorded event 1: {e}")
101        self.assertTrue(e.event_id != 0)
102        e2 = s.record_event()
103        print(f"recorded event 2: {e2}")
104        self.assertTrue(e2.event_id != 0)
105        self.assertTrue(e2.event_id != e.event_id)
106        e.synchronize()
107        e2.synchronize()
108        time_elapsed = e.elapsed_time(e2)
109        print(f"time elapsed between e1 and e2: {time_elapsed}")
110        old_event_id = e.event_id
111        e.record(stream=s)
112        print(f"recorded event 1: {e}")
113        self.assertTrue(e.event_id == old_event_id)
114
115
116if __name__ == "__main__":
117    common.run_tests()
118