"""File invoked through subprocess to actually carry out measurements. `worker/main.py` is deliberately isolated from the rest of the benchmark infrastructure. Other parts of the benchmark rely on this file, but `worker/` has only one Python file and does not import ANYTHING from the rest of the benchmark suite. The reason that this is important is that we can't rely on paths to access the other files (namely `core.api`) since a source command might change the CWD. It also helps keep startup time down by limiting spurious definition work. The life of a worker is very simple: It receives a file containing a `WorkerTimerArgs` telling it what to run, and writes a `WorkerOutput` result back to the same file. Because this file only expects to run in a child context, error handling means plumbing failures up to the caller, not raising in this process. """ import argparse import dataclasses import io import os import pickle import sys import timeit import traceback from typing import Any, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: # Benchmark utils are only partially strict compliant, so MyPy won't follow # imports using the public namespace. (Due to an exclusion rule in # mypy-strict.ini) from torch.utils.benchmark.utils.timer import Language, Timer from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import ( CallgrindStats, ) else: from torch.utils.benchmark import CallgrindStats, Language, Timer WORKER_PATH = os.path.abspath(__file__) # ============================================================================= # == Interface ================================================================ # ============================================================================= # While the point of this is mainly to collect instruction counts, we're going # to have to compile C++ timers anyway (as they're used as a check before # calling Valgrind), so we may as well grab wall times for reference. They # are comparatively inexpensive. MIN_RUN_TIME = 5 # Repeats are inexpensive as long as they are all run in the same process. This # also lets us filter outliers (e.g. malloc arena reorganization), so we don't # need a high CALLGRIND_NUMBER to get good data. CALLGRIND_NUMBER = 100 CALLGRIND_REPEATS = 5 @dataclasses.dataclass(frozen=True) class WorkerTimerArgs: """Container for Timer constructor arguments. This dataclass serves two roles. First, it is a simple interface for defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules for the advanced interfaces.) Second, it provides serialization for controlling workers. `Timer` is not pickleable, so instead the main process will pass `WorkerTimerArgs` instances to workers for processing. """ stmt: str setup: str = "pass" global_setup: str = "" num_threads: int = 1 language: Language = Language.PYTHON @dataclasses.dataclass(frozen=True) class WorkerOutput: # Only return values to reduce communication between main process and workers. wall_times: Tuple[float, ...] instructions: Tuple[int, ...] @dataclasses.dataclass(frozen=True) class WorkerFailure: # If a worker fails, we attach the string contents of the Exception # rather than the Exception object itself. This is done for two reasons: # 1) Depending on the type thrown, `e` may or may not be pickleable # 2) If we re-throw in the main process, we lose the true stack trace. failure_trace: str class WorkerUnpickler(pickle.Unpickler): def find_class(self, module: str, name: str) -> Any: """Resolve import for pickle. When the main runner uses a symbol `foo` from this file, it sees it as `worker.main.foo`. However the worker (called as a standalone file) sees the same symbol as `__main__.foo`. We have to help pickle understand that they refer to the same symbols. """ symbol_map = { # Only blessed interface Enums and dataclasses need to be mapped. "WorkerTimerArgs": WorkerTimerArgs, "WorkerOutput": WorkerOutput, "WorkerFailure": WorkerFailure, } if name in symbol_map: return symbol_map[name] return super().find_class(module, name) def load_input(self) -> WorkerTimerArgs: result = self.load() assert isinstance(result, WorkerTimerArgs) return result def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]: """Convenience method for type safe loading.""" result = self.load() assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure)) return result # ============================================================================= # == Execution ================================================================ # ============================================================================= def _run(timer_args: WorkerTimerArgs) -> WorkerOutput: timer = Timer( stmt=timer_args.stmt, setup=timer_args.setup or "pass", global_setup=timer_args.global_setup, # Prevent NotImplementedError on GPU builds and C++ snippets. timer=timeit.default_timer, num_threads=timer_args.num_threads, language=timer_args.language, ) m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME) stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind( number=CALLGRIND_NUMBER, collect_baseline=False, repeats=CALLGRIND_REPEATS, retain_out_file=False, ) return WorkerOutput( wall_times=tuple(m.times), instructions=tuple(s.counts(denoise=True) for s in stats), ) def main(communication_file: str) -> None: result: Union[WorkerOutput, WorkerFailure] try: with open(communication_file, "rb") as f: timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input() assert isinstance(timer_args, WorkerTimerArgs) result = _run(timer_args) except KeyboardInterrupt: # Runner process sent SIGINT. sys.exit() except BaseException: trace_f = io.StringIO() traceback.print_exc(file=trace_f) result = WorkerFailure(failure_trace=trace_f.getvalue()) if not os.path.exists(os.path.split(communication_file)[0]): # This worker is an orphan, and the parent has already cleaned up the # working directory. In that case we can simply exit. print(f"Orphaned worker {os.getpid()} exiting.") return with open(communication_file, "wb") as f: pickle.dump(result, f) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--communication-file", "--communication_file", type=str) communication_file = parser.parse_args().communication_file main(communication_file)