1# mypy: allow-untyped-defs 2"""Example of Timer and Compare APIs: 3 4$ python -m examples.compare 5""" 6 7import pickle 8import sys 9import time 10 11import torch 12 13import torch.utils.benchmark as benchmark_utils 14 15 16class FauxTorch: 17 """Emulate different versions of pytorch. 18 19 In normal circumstances this would be done with multiple processes 20 writing serialized measurements, but this simplifies that model to 21 make the example clearer. 22 """ 23 def __init__(self, real_torch, extra_ns_per_element): 24 self._real_torch = real_torch 25 self._extra_ns_per_element = extra_ns_per_element 26 27 def extra_overhead(self, result): 28 # time.sleep has a ~65 us overhead, so only fake a 29 # per-element overhead if numel is large enough. 30 numel = int(result.numel()) 31 if numel > 5000: 32 time.sleep(numel * self._extra_ns_per_element * 1e-9) 33 return result 34 35 def add(self, *args, **kwargs): 36 return self.extra_overhead(self._real_torch.add(*args, **kwargs)) 37 38 def mul(self, *args, **kwargs): 39 return self.extra_overhead(self._real_torch.mul(*args, **kwargs)) 40 41 def cat(self, *args, **kwargs): 42 return self.extra_overhead(self._real_torch.cat(*args, **kwargs)) 43 44 def matmul(self, *args, **kwargs): 45 return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) 46 47 48def main(): 49 tasks = [ 50 ("add", "add", "torch.add(x, y)"), 51 ("add", "add (extra +0)", "torch.add(x, y + zero)"), 52 ] 53 54 serialized_results = [] 55 repeats = 2 56 timers = [ 57 benchmark_utils.Timer( 58 stmt=stmt, 59 globals={ 60 "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), 61 "x": torch.ones((size, 4)), 62 "y": torch.ones((1, 4)), 63 "zero": torch.zeros(()), 64 }, 65 label=label, 66 sub_label=sub_label, 67 description=f"size: {size}", 68 env=branch, 69 num_threads=num_threads, 70 ) 71 for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] 72 for label, sub_label, stmt in tasks 73 for size in [1, 10, 100, 1000, 10000, 50000] 74 for num_threads in [1, 4] 75 ] 76 77 for i, timer in enumerate(timers * repeats): 78 serialized_results.append(pickle.dumps( 79 timer.blocked_autorange(min_run_time=0.05) 80 )) 81 print(f"\r{i + 1} / {len(timers) * repeats}", end="") 82 sys.stdout.flush() 83 print() 84 85 comparison = benchmark_utils.Compare([ 86 pickle.loads(i) for i in serialized_results 87 ]) 88 89 print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") 90 comparison.print() 91 92 print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") 93 comparison.trim_significant_figures() 94 comparison.colorize() 95 comparison.print() 96 97 98if __name__ == "__main__": 99 main() 100