• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <memory>
4 
5 #include <torch/torch.h>
6 
7 namespace torch::inductor {
8 
9 class AOTIModelContainerRunner;
10 
11 } // namespace torch::inductor
12 
13 namespace torch::aot_inductor {
14 
15 class MyAOTIClass : public torch::CustomClassHolder {
16  public:
17   explicit MyAOTIClass(
18       const std::string& model_path,
19       const std::string& device = "cuda");
20 
~MyAOTIClass()21   ~MyAOTIClass() {}
22 
23   MyAOTIClass(const MyAOTIClass&) = delete;
24   MyAOTIClass& operator=(const MyAOTIClass&) = delete;
25   MyAOTIClass& operator=(MyAOTIClass&&) = delete;
26 
lib_path()27   const std::string& lib_path() const {
28     return lib_path_;
29   }
30 
device()31   const std::string& device() const {
32     return device_;
33   }
34 
35   std::vector<torch::Tensor> forward(std::vector<torch::Tensor> inputs);
36 
37  private:
38   const std::string lib_path_;
39 
40   const std::string device_;
41 
42   std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner_;
43 };
44 
45 } // namespace torch::aot_inductor
46