1from __future__ import annotations 2 3import time 4 5from torch._dynamo import device_interface # noqa: PLC2701 import-private-name 6 7 8class DeviceProperties: 9 def __init__(self) -> None: 10 self.major = 8 # TODO: bypass check for H100 in triton_heuristics.py 11 self.max_threads_per_multi_processor = 1 12 self.multi_processor_count = 80 13 14 15class DeviceInterface(device_interface.DeviceInterface): 16 class Event( 17 device_interface._EventBase 18 ): # pyright: ignore [reportPrivateImportUsage] 19 def __init__( 20 self, 21 enable_timing: bool = False, 22 blocking: bool = False, 23 interprocess: bool = False, 24 ) -> None: 25 self.enable_timing = enable_timing 26 self.recorded_time: int | None = None 27 28 def record(self, stream) -> None: 29 if not self.enable_timing: 30 return 31 assert self.recorded_time is None 32 self.recorded_time = time.perf_counter_ns() 33 34 def elapsed_time(self, end_event: DeviceInterface.Event) -> float: 35 assert self.recorded_time 36 assert end_event.recorded_time 37 # convert to ms 38 return (end_event.recorded_time - self.recorded_time) / 1000000 39 40 def wait(self, stream) -> None: 41 pass 42 43 def query(self) -> None: 44 pass 45 46 def synchronize(self) -> None: 47 pass 48 49 class device: # noqa: N801 invalid-class-name # pyright: ignore [reportIncompatibleVariableOverride] 50 def __init__(self, device) -> None: 51 self.device = device 52 53 class Worker(device_interface.DeviceInterface.Worker): 54 @staticmethod 55 def set_device(device: int) -> None: 56 # No device index for our backend 57 pass 58 59 @staticmethod 60 def current_device() -> int: 61 # No device index for our backend 62 return 0 63 64 @staticmethod 65 def get_device_properties( 66 device=None, 67 ) -> DeviceProperties: 68 return DeviceProperties() 69 70 @staticmethod 71 def current_device() -> int: 72 return 0 73 74 @staticmethod 75 def set_device(device) -> None: 76 pass 77 78 @staticmethod 79 def device_count() -> int: 80 raise NotImplementedError 81 82 @staticmethod 83 def maybe_exchange_device(device: int) -> int: 84 assert ( 85 device == 0 86 ), f"Only device index 0 is supported, tried to set index to {device}" 87 return 0 # previous device is always 0 88 89 @staticmethod 90 def exchange_device(device: int) -> int: 91 assert ( 92 device == 0 93 ), f"Only device index 0 is supported, tried to set index to {device}" 94 return 0 # previous device is always 0 95 96 @staticmethod 97 def current_stream(): 98 raise NotImplementedError 99 100 @staticmethod 101 def set_stream(stream) -> None: 102 raise NotImplementedError 103 104 @staticmethod 105 def get_raw_stream(device_index: int): 106 return None 107 108 @staticmethod 109 def synchronize(device) -> None: 110 pass 111 112 @staticmethod 113 def get_device_properties(device) -> DeviceProperties: 114 raise NotImplementedError 115 116 # Can be mock patched by @patch decorator. 117 @staticmethod 118 def is_available() -> bool: 119 return True 120 121 @staticmethod 122 def get_compute_capability(device) -> int: 123 return 0 124 125 @staticmethod 126 def triton_supported() -> bool: 127 return True 128