• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/csrc/api/include/torch/imethod.h>
4 #include <torch/csrc/jit/runtime/static/impl.h>
5 
6 namespace torch::jit {
7 
8 class StaticMethod : public torch::IMethod {
9  public:
StaticMethod(std::shared_ptr<StaticModule> static_module,std::string method_name)10   StaticMethod(
11       std::shared_ptr<StaticModule> static_module,
12       std::string method_name)
13       : static_module_(std::move(static_module)),
14         method_name_(std::move(method_name)) {
15     TORCH_CHECK(static_module_);
16   }
17 
operator()18   c10::IValue operator()(
19       std::vector<IValue> args,
20       const IValueMap& kwargs = IValueMap()) const override {
21     return (*static_module_)(std::move(args), kwargs);
22   }
23 
name()24   const std::string& name() const override {
25     return method_name_;
26   }
27 
28  protected:
setArgumentNames(std::vector<std::string> & argument_names_out)29   void setArgumentNames(
30       std::vector<std::string>& argument_names_out) const override {
31     const auto& schema = static_module_->schema();
32     CAFFE_ENFORCE(schema.has_value());
33     const auto& arguments = schema->arguments();
34     argument_names_out.clear();
35     argument_names_out.reserve(arguments.size());
36     std::transform(
37         arguments.begin(),
38         arguments.end(),
39         std::back_inserter(argument_names_out),
40         [](const c10::Argument& arg) -> std::string { return arg.name(); });
41   }
42 
43  private:
44   std::shared_ptr<StaticModule> static_module_;
45   std::string method_name_;
46 };
47 
48 } // namespace torch::jit
49