import torch from torch.export import Dim # custom op that loads the aot-compiled model AOTI_CUSTOM_OP_LIB = "libaoti_custom_class.so" torch.classes.load_library(AOTI_CUSTOM_OP_LIB) class TensorSerializer(torch.nn.Module): def __init__(self, data): super().__init__() for key in data: setattr(self, key, data[key]) class SimpleModule(torch.nn.Module): """ a simple module to be compiled """ def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(4, 6) self.relu = torch.nn.ReLU() def forward(self, x): a = self.fc(x) b = self.relu(a) return b class MyAOTIModule(torch.nn.Module): """ a wrapper nn.Module that instantiates its forward method on MyAOTIClass """ def __init__(self, lib_path, device): super().__init__() self.aoti_custom_op = torch.classes.aoti.MyAOTIClass( lib_path, device, ) def forward(self, *x): outputs = self.aoti_custom_op.forward(x) return tuple(outputs) def make_script_module(lib_path, device, *inputs): m = MyAOTIModule(lib_path, device) # sanity check m(*inputs) return torch.jit.trace(m, inputs) def compile_model(device, data): module = SimpleModule().to(device) x = torch.randn((4, 4), device=device) inputs = (x,) # make batch dimension batch_dim = Dim("batch", min=1, max=1024) dynamic_shapes = { "x": {0: batch_dim}, } with torch.no_grad(): # aot-compile the module into a .so pointed by lib_path lib_path = torch._export.aot_compile( module, inputs, dynamic_shapes=dynamic_shapes ) script_module = make_script_module(lib_path, device, *inputs) aoti_script_model = f"script_model_{device}.pt" script_module.save(aoti_script_model) # save sample inputs and ref output with torch.no_grad(): ref_output = module(*inputs) data.update( { f"inputs_{device}": list(inputs), f"outputs_{device}": [ref_output], } ) def main(): data = {} for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: compile_model(device, data) torch.jit.script(TensorSerializer(data)).save("script_data.pt") if __name__ == "__main__": main()