1 #pragma once 2 3 #include <memory> 4 5 #include <torch/torch.h> 6 7 namespace torch::inductor { 8 9 class AOTIModelContainerRunner; 10 11 } // namespace torch::inductor 12 13 namespace torch::aot_inductor { 14 15 class MyAOTIClass : public torch::CustomClassHolder { 16 public: 17 explicit MyAOTIClass( 18 const std::string& model_path, 19 const std::string& device = "cuda"); 20 ~MyAOTIClass()21 ~MyAOTIClass() {} 22 23 MyAOTIClass(const MyAOTIClass&) = delete; 24 MyAOTIClass& operator=(const MyAOTIClass&) = delete; 25 MyAOTIClass& operator=(MyAOTIClass&&) = delete; 26 lib_path()27 const std::string& lib_path() const { 28 return lib_path_; 29 } 30 device()31 const std::string& device() const { 32 return device_; 33 } 34 35 std::vector<torch::Tensor> forward(std::vector<torch::Tensor> inputs); 36 37 private: 38 const std::string lib_path_; 39 40 const std::string device_; 41 42 std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner_; 43 }; 44 45 } // namespace torch::aot_inductor 46