• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <stdexcept>
2 
3 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
4 #if defined(USE_CUDA) || defined(USE_ROCM)
5 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
6 #endif
7 
8 #include "aoti_custom_class.h"
9 
10 namespace torch::aot_inductor {
11 
12 static auto registerMyAOTIClass =
13     torch::class_<MyAOTIClass>("aoti", "MyAOTIClass")
14         .def(torch::init<std::string, std::string>())
15         .def("forward", &MyAOTIClass::forward)
16         .def_pickle(
17             [](const c10::intrusive_ptr<MyAOTIClass>& self)
__anon2ec357170102(const c10::intrusive_ptr<MyAOTIClass>& self) 18                 -> std::vector<std::string> {
19               std::vector<std::string> v;
20               v.push_back(self->lib_path());
21               v.push_back(self->device());
22               return v;
23             },
__anon2ec357170202(std::vector<std::string> params) 24             [](std::vector<std::string> params) {
25               return c10::make_intrusive<MyAOTIClass>(params[0], params[1]);
26             });
27 
MyAOTIClass(const std::string & model_path,const std::string & device)28 MyAOTIClass::MyAOTIClass(
29     const std::string& model_path,
30     const std::string& device)
31     : lib_path_(model_path), device_(device) {
32   if (device_ == "cpu") {
33     runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
34         model_path.c_str());
35 #if defined(USE_CUDA) || defined(USE_ROCM)
36   } else if (device_ == "cuda") {
37     runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
38         model_path.c_str());
39 #endif
40   } else {
41     throw std::runtime_error("invalid device: " + device);
42   }
43 }
44 
forward(std::vector<torch::Tensor> inputs)45 std::vector<torch::Tensor> MyAOTIClass::forward(
46     std::vector<torch::Tensor> inputs) {
47   return runner_->run(inputs);
48 }
49 
50 } // namespace torch::aot_inductor
51