1 #pragma once 2 3 #include <torch/csrc/api/include/torch/imethod.h> 4 #include <torch/csrc/jit/runtime/static/impl.h> 5 6 namespace torch::jit { 7 8 class StaticMethod : public torch::IMethod { 9 public: StaticMethod(std::shared_ptr<StaticModule> static_module,std::string method_name)10 StaticMethod( 11 std::shared_ptr<StaticModule> static_module, 12 std::string method_name) 13 : static_module_(std::move(static_module)), 14 method_name_(std::move(method_name)) { 15 TORCH_CHECK(static_module_); 16 } 17 operator()18 c10::IValue operator()( 19 std::vector<IValue> args, 20 const IValueMap& kwargs = IValueMap()) const override { 21 return (*static_module_)(std::move(args), kwargs); 22 } 23 name()24 const std::string& name() const override { 25 return method_name_; 26 } 27 28 protected: setArgumentNames(std::vector<std::string> & argument_names_out)29 void setArgumentNames( 30 std::vector<std::string>& argument_names_out) const override { 31 const auto& schema = static_module_->schema(); 32 CAFFE_ENFORCE(schema.has_value()); 33 const auto& arguments = schema->arguments(); 34 argument_names_out.clear(); 35 argument_names_out.reserve(arguments.size()); 36 std::transform( 37 arguments.begin(), 38 arguments.end(), 39 std::back_inserter(argument_names_out), 40 [](const c10::Argument& arg) -> std::string { return arg.name(); }); 41 } 42 43 private: 44 std::shared_ptr<StaticModule> static_module_; 45 std::string method_name_; 46 }; 47 48 } // namespace torch::jit 49