#pragma once #include #include #include #include #include #include #include #include #include namespace torch::jit { std::string typeString(py::handle h); inline std::shared_ptr toSimple(Value* v) { return std::make_shared(v); } // NB: This should be the single entry-point for instantiating a SugaredValue // from a Python object. If you are adding support for converting a new Python // type, *add it in this function's implementation*. std::shared_ptr toSugaredValue( py::object obj, GraphFunction& m, const SourceRange& loc, bool is_constant = false); std::optional as_function(const py::object& obj); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { PythonValue( py::object the_self, std::optional rcb = std::nullopt, Value* module_self = nullptr) : self(std::move(the_self)), rcb(std::move(rcb)), moduleSelf_(module_self) {} FunctionSchema getSchema( const size_t n_args, const size_t n_binders, const SourceRange& loc); // call it like a function, e.g. `outputs = this(inputs)` std::shared_ptr call( const SourceRange& loc, GraphFunction& m, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override; std::vector> asTuple( const SourceRange& loc, GraphFunction& m, const std::optional& size_hint = {}) override; std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; Value* asValue(const SourceRange& loc, GraphFunction& m) override { throw( ErrorReport(loc) << kind() << " cannot be used as a value. " << "Perhaps it is a closed over global variable? If so, please " << "consider passing it in as an argument or use a local varible " << "instead."); } protected: py::object getattr(const SourceRange& loc, const std::string& name); void checkForAddToConstantsError(std::stringstream& ss); // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) py::object self; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::optional rcb; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) Value* moduleSelf_ = nullptr; }; struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {} std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; }; // Used for desugaring uses of the torch.cuda module. All the CUDA APIs with // torch.cuda.* are resolved using CUDAPythonModuleValue. struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue { explicit CUDAPythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {} std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; }; // Represents all the parameters of a module as a List[Tensor] struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { ConstantParameterList(Value* the_list) : the_list_(the_list) {} std::string kind() const override { return "constant parameter list"; } std::shared_ptr call( const SourceRange& loc, GraphFunction& caller, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override { return toSimple(the_list_); } private: Value* the_list_; }; struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name) : iterable_(std::move(iterable)), name_(std::move(name)){}; std::string kind() const override { return name_; } std::shared_ptr call( const SourceRange& loc, GraphFunction& f, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override { if (!args.empty() || !kwargs.empty()) { throw( ErrorReport(loc) << name_ << " method does not accept any arguments"); } return iterable_; } SugaredValuePtr iterable_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string name_; }; struct SugaredDict; // defines how modules/methods behave inside the script subset. // for now this does not have any interaction with python. // in the future, we will add the ability to resolve `self.foo` to python // {functions, modules, constants} so this SugaredValue is defined here // anticipating we will eventually need to replace Module with a py::object // holding the actual nn.Module class. struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { ModuleValue(Value* self, std::shared_ptr concreteType) : self_(self), concreteType_(std::move(concreteType)) {} std::string kind() const override { return "module"; } Value* asValue(const SourceRange& loc, GraphFunction& m) override; SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m) override; // select an attribute on it, e.g. `this.field` std::shared_ptr tryGetAttr( const SourceRange& loc, GraphFunction& m, const std::string& field); // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; // select an attribute on it, e.g. `this.field` bool hasAttr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; // call module.forward with pre_hooks and hooks std::shared_ptr call( const SourceRange& loc, GraphFunction& caller, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr getSugaredDict( const SourceRange& loc, GraphFunction& m); std::shared_ptr getSugaredNamedBufferDict( const SourceRange& loc, GraphFunction& m); std::shared_ptr getSugaredNamedParameterList( const SourceRange& loc, GraphFunction& m); std::shared_ptr getSugaredNamedParameterDict( const SourceRange& loc, GraphFunction& m); void setAttr( const SourceRange& loc, GraphFunction& m, const std::string& field, Value* newValue) override; SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override; std::shared_ptr getitem( const SourceRange& loc, GraphFunction& m, Value* idx, TypePtr type_hint) override; private: // Check that the type of all submodules is a subtype of ty. If the function // returns false, more information about why it returns false (e.g. which // submodule's type is not a subtype of ty) is printed it why_not if it is not // null. bool areAllSubmodulesSubtypeOf( const TypePtr& ty, std::ostream* why_not = nullptr) const; Value* self_; std::shared_ptr concreteType_; }; bool isNamedTupleClass(const py::object& obj); TypePtr registerNamedTuple( const py::object& obj, const SourceRange& loc, const ResolutionCallback& rcb); void recurseThroughNestedModules( const SourceRange& loc, GraphFunction& m, std::vector& keys, std::vector& values, std::shared_ptr& self, const std::string& prefix, const std::string& field); // Used to support named_modules() struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue { explicit SugaredDict( std::shared_ptr self, std::shared_ptr keys, std::shared_ptr modules) : self_(std::move(self)), keys_(std::move(keys)), modules_(std::move(modules)) {} std::string kind() const override { return "ModuleDict"; } std::shared_ptr getKeys() { return keys_; } std::shared_ptr getModules() { return modules_; } std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override { return keys_; }; std::shared_ptr self_; std::shared_ptr keys_; std::shared_ptr modules_; }; struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { BooleanDispatchValue(py::dict dispatched_fn) : dispatched_fn_(std::move(dispatched_fn)) {} std::string kind() const override { return "boolean dispatch"; } std::shared_ptr call( const SourceRange& loc, GraphFunction& caller, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override; private: py::dict dispatched_fn_; }; struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue { PythonClassValue(ClassTypePtr type, py::object py_type) : ClassValue(std::move(type)), py_type_(std::move(py_type)) {} std::string kind() const override { return "Python type"; } std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; bool hasAttr( const SourceRange& loc, GraphFunction& m, const std::string& field) override; private: py::object py_type_; }; struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { explicit PythonExceptionValue(const py::object& exception_class) : ExceptionValue( py::str(py::getattr(exception_class, "__name__", py::str("")))), exception_class_qualified_name_( py::str(py::module::import("torch._jit_internal") .attr("_qualified_name")( exception_class, /*mangle_name=*/false))) {} std::string kind() const override { return "Python exception"; } std::shared_ptr call( const SourceRange& loc, GraphFunction& caller, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override; private: std::string exception_class_qualified_name_; }; // Python Slice class. struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { explicit PythonSliceClass() = default; std::string kind() const override { return "Python slice class"; } std::shared_ptr call( const SourceRange& loc, GraphFunction& caller, at::ArrayRef args, at::ArrayRef kwargs, size_t n_binders) override; }; } // namespace torch::jit