from __future__ import annotations import time from torch._dynamo import device_interface # noqa: PLC2701 import-private-name class DeviceProperties: def __init__(self) -> None: self.major = 8 # TODO: bypass check for H100 in triton_heuristics.py self.max_threads_per_multi_processor = 1 self.multi_processor_count = 80 class DeviceInterface(device_interface.DeviceInterface): class Event( device_interface._EventBase ): # pyright: ignore [reportPrivateImportUsage] def __init__( self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False, ) -> None: self.enable_timing = enable_timing self.recorded_time: int | None = None def record(self, stream) -> None: if not self.enable_timing: return assert self.recorded_time is None self.recorded_time = time.perf_counter_ns() def elapsed_time(self, end_event: DeviceInterface.Event) -> float: assert self.recorded_time assert end_event.recorded_time # convert to ms return (end_event.recorded_time - self.recorded_time) / 1000000 def wait(self, stream) -> None: pass def query(self) -> None: pass def synchronize(self) -> None: pass class device: # noqa: N801 invalid-class-name # pyright: ignore [reportIncompatibleVariableOverride] def __init__(self, device) -> None: self.device = device class Worker(device_interface.DeviceInterface.Worker): @staticmethod def set_device(device: int) -> None: # No device index for our backend pass @staticmethod def current_device() -> int: # No device index for our backend return 0 @staticmethod def get_device_properties( device=None, ) -> DeviceProperties: return DeviceProperties() @staticmethod def current_device() -> int: return 0 @staticmethod def set_device(device) -> None: pass @staticmethod def device_count() -> int: raise NotImplementedError @staticmethod def maybe_exchange_device(device: int) -> int: assert ( device == 0 ), f"Only device index 0 is supported, tried to set index to {device}" return 0 # previous device is always 0 @staticmethod def exchange_device(device: int) -> int: assert ( device == 0 ), f"Only device index 0 is supported, tried to set index to {device}" return 0 # previous device is always 0 @staticmethod def current_stream(): raise NotImplementedError @staticmethod def set_stream(stream) -> None: raise NotImplementedError @staticmethod def get_raw_stream(device_index: int): return None @staticmethod def synchronize(device) -> None: pass @staticmethod def get_device_properties(device) -> DeviceProperties: raise NotImplementedError # Can be mock patched by @patch decorator. @staticmethod def is_available() -> bool: return True @staticmethod def get_compute_capability(device) -> int: return 0 @staticmethod def triton_supported() -> bool: return True