1 #include <stdexcept> 2 3 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h> 4 #if defined(USE_CUDA) || defined(USE_ROCM) 5 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h> 6 #endif 7 8 #include "aoti_custom_class.h" 9 10 namespace torch::aot_inductor { 11 12 static auto registerMyAOTIClass = 13 torch::class_<MyAOTIClass>("aoti", "MyAOTIClass") 14 .def(torch::init<std::string, std::string>()) 15 .def("forward", &MyAOTIClass::forward) 16 .def_pickle( 17 [](const c10::intrusive_ptr<MyAOTIClass>& self) __anon2ec357170102(const c10::intrusive_ptr<MyAOTIClass>& self) 18 -> std::vector<std::string> { 19 std::vector<std::string> v; 20 v.push_back(self->lib_path()); 21 v.push_back(self->device()); 22 return v; 23 }, __anon2ec357170202(std::vector<std::string> params) 24 [](std::vector<std::string> params) { 25 return c10::make_intrusive<MyAOTIClass>(params[0], params[1]); 26 }); 27 MyAOTIClass(const std::string & model_path,const std::string & device)28MyAOTIClass::MyAOTIClass( 29 const std::string& model_path, 30 const std::string& device) 31 : lib_path_(model_path), device_(device) { 32 if (device_ == "cpu") { 33 runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>( 34 model_path.c_str()); 35 #if defined(USE_CUDA) || defined(USE_ROCM) 36 } else if (device_ == "cuda") { 37 runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>( 38 model_path.c_str()); 39 #endif 40 } else { 41 throw std::runtime_error("invalid device: " + device); 42 } 43 } 44 forward(std::vector<torch::Tensor> inputs)45std::vector<torch::Tensor> MyAOTIClass::forward( 46 std::vector<torch::Tensor> inputs) { 47 return runner_->run(inputs); 48 } 49 50 } // namespace torch::aot_inductor 51