#include #include #if defined(USE_CUDA) || defined(USE_ROCM) #include #endif #include "aoti_custom_class.h" namespace torch::aot_inductor { static auto registerMyAOTIClass = torch::class_("aoti", "MyAOTIClass") .def(torch::init()) .def("forward", &MyAOTIClass::forward) .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { std::vector v; v.push_back(self->lib_path()); v.push_back(self->device()); return v; }, [](std::vector params) { return c10::make_intrusive(params[0], params[1]); }); MyAOTIClass::MyAOTIClass( const std::string& model_path, const std::string& device) : lib_path_(model_path), device_(device) { if (device_ == "cpu") { runner_ = std::make_unique( model_path.c_str()); #if defined(USE_CUDA) || defined(USE_ROCM) } else if (device_ == "cuda") { runner_ = std::make_unique( model_path.c_str()); #endif } else { throw std::runtime_error("invalid device: " + device); } } std::vector MyAOTIClass::forward( std::vector inputs) { return runner_->run(inputs); } } // namespace torch::aot_inductor