• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import annotations
2
3import time
4
5from torch._dynamo import device_interface  # noqa: PLC2701 import-private-name
6
7
8class DeviceProperties:
9    def __init__(self) -> None:
10        self.major = 8  # TODO: bypass check for H100 in triton_heuristics.py
11        self.max_threads_per_multi_processor = 1
12        self.multi_processor_count = 80
13
14
15class DeviceInterface(device_interface.DeviceInterface):
16    class Event(
17        device_interface._EventBase
18    ):  # pyright: ignore [reportPrivateImportUsage]
19        def __init__(
20            self,
21            enable_timing: bool = False,
22            blocking: bool = False,
23            interprocess: bool = False,
24        ) -> None:
25            self.enable_timing = enable_timing
26            self.recorded_time: int | None = None
27
28        def record(self, stream) -> None:
29            if not self.enable_timing:
30                return
31            assert self.recorded_time is None
32            self.recorded_time = time.perf_counter_ns()
33
34        def elapsed_time(self, end_event: DeviceInterface.Event) -> float:
35            assert self.recorded_time
36            assert end_event.recorded_time
37            # convert to ms
38            return (end_event.recorded_time - self.recorded_time) / 1000000
39
40        def wait(self, stream) -> None:
41            pass
42
43        def query(self) -> None:
44            pass
45
46        def synchronize(self) -> None:
47            pass
48
49    class device:  # noqa: N801 invalid-class-name # pyright: ignore [reportIncompatibleVariableOverride]
50        def __init__(self, device) -> None:
51            self.device = device
52
53    class Worker(device_interface.DeviceInterface.Worker):
54        @staticmethod
55        def set_device(device: int) -> None:
56            # No device index for our backend
57            pass
58
59        @staticmethod
60        def current_device() -> int:
61            # No device index for our backend
62            return 0
63
64        @staticmethod
65        def get_device_properties(
66            device=None,
67        ) -> DeviceProperties:
68            return DeviceProperties()
69
70    @staticmethod
71    def current_device() -> int:
72        return 0
73
74    @staticmethod
75    def set_device(device) -> None:
76        pass
77
78    @staticmethod
79    def device_count() -> int:
80        raise NotImplementedError
81
82    @staticmethod
83    def maybe_exchange_device(device: int) -> int:
84        assert (
85            device == 0
86        ), f"Only device index 0 is supported, tried to set index to {device}"
87        return 0  # previous device is always 0
88
89    @staticmethod
90    def exchange_device(device: int) -> int:
91        assert (
92            device == 0
93        ), f"Only device index 0 is supported, tried to set index to {device}"
94        return 0  # previous device is always 0
95
96    @staticmethod
97    def current_stream():
98        raise NotImplementedError
99
100    @staticmethod
101    def set_stream(stream) -> None:
102        raise NotImplementedError
103
104    @staticmethod
105    def get_raw_stream(device_index: int):
106        return None
107
108    @staticmethod
109    def synchronize(device) -> None:
110        pass
111
112    @staticmethod
113    def get_device_properties(device) -> DeviceProperties:
114        raise NotImplementedError
115
116    # Can be mock patched by @patch decorator.
117    @staticmethod
118    def is_available() -> bool:
119        return True
120
121    @staticmethod
122    def get_compute_capability(device) -> int:
123        return 0
124
125    @staticmethod
126    def triton_supported() -> bool:
127        return True
128