# mypy: allow-untyped-defs from __future__ import annotations import functools import os import sys import warnings from types import ModuleType from typing import Any, Callable, Dict def _reload_triton_kernel_in_subproc(reload_module, kernel_name): return _module_to_triton_kernel(reload_module(), kernel_name) def _module_to_triton_kernel(mod, kernel_name): kernel = getattr(mod, kernel_name) kernel._reload_in_subproc = functools.partial( _reload_triton_kernel_in_subproc, mod._reload_in_subproc, kernel_name, ) return kernel def _reload_python_module_in_subproc(key, path): codecache = sys.modules.get("torch._inductor.codecache") if codecache: return codecache.PyCodeCache.load_by_key_path(key, path) else: return _reload_python_module(key, path) def _reload_python_module(key, path): with open(path) as f: try: code = compile(f.read(), path, "exec", dont_inherit=True) except Exception as e: raise RuntimeError( f"Failed to import {path}\n{type(e).__name__}: {e}" ) from None mod = ModuleType(f"{__name__}.{key}") mod.__file__ = path mod.key = key # type: ignore[attr-defined] exec(code, mod.__dict__, mod.__dict__) sys.modules[mod.__name__] = mod return mod @functools.lru_cache(None) def _set_triton_ptxas_path() -> None: if os.environ.get("TRITON_PTXAS_PATH") is not None: return ptxas_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas") ) if not os.path.exists(ptxas_path): return if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK): os.environ["TRITON_PTXAS_PATH"] = ptxas_path else: warnings.warn(f"{ptxas_path} exists but is not an executable") def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]): _set_triton_ptxas_path() os.environ.update(extra_env) load_kernel().precompile(warm_cache_only=True)