1import torch 2 3 4class WrapperModule: 5 """Wraps the instance of wrapped_type. 6 For graph_mode traces the instance of wrapped_type. 7 Randomaly initializes num_params tensors with single float element. 8 Args: 9 wrapped_type: 10 - Object type to be wrapped. 11 Expects the wrapped_type to: 12 - be constructed with pt_fn specified in module_config. 13 - provide forward method that takes module_config.num_params args. 14 module_config: 15 - Specified pt_fn to construct wrapped_type with, whether graph_mode 16 is enabled, and number of parameters wrapped_type's forward method 17 takes. 18 debug: 19 - Whether debug mode is enabled. 20 save: 21 - In graph mode, whether graph is to be saved. 22 """ 23 24 def __init__(self, wrapped_type, module_config, debug, save=False): 25 pt_fn = module_config.pt_fn 26 self.module = wrapped_type(pt_fn) 27 self.tensor_inputs = [] 28 self.module_name = wrapped_type.__name__ 29 for _ in range(module_config.num_params): 30 self.tensor_inputs.append(torch.randn(1)) 31 if module_config.graph_mode: 32 self.module = torch.jit.trace(self.module, self.tensor_inputs) 33 if save: 34 file_name = self.module_name + "_" + pt_fn.__name__ + ".pt" 35 torch.jit.save(self.module, file_name) 36 print(f"Generated graph is saved in {file_name}") 37 print( 38 f"Benchmarking module {self.module_name} with fn {pt_fn.__name__}: Graph mode:{module_config.graph_mode}" 39 ) 40 if debug and isinstance(self.module, torch.jit.ScriptModule): 41 print(self.module.graph) 42 print(self.module.code) 43 44 def forward(self, niters): 45 with torch.no_grad(): 46 for _ in range(niters): 47 self.module.forward(*self.tensor_inputs) 48