• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
18 #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
19 
20 #include <map>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include <memory>
25 #include "utils/hash_map.h"
26 #include "abstract/abstract_value.h"
27 #include "ir/primitive.h"
28 #include "ir/signature.h"
29 #include "pybind11/pybind11.h"
30 #include "include/common/utils/convert_utils_py.h"
31 
32 namespace py = pybind11;
33 namespace mindspore {
34 
35 class PrimitivePy;
36 using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
37 using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>;
38 
39 class PrimitivePyAdapter;
40 using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>;
41 
42 class PrimitiveFunctionAdapter;
43 using PrimitiveFunctionAdapterPtr = std::shared_ptr<PrimitiveFunctionAdapter>;
44 
45 class PrimitivePy : public Primitive {
46  public:
47   explicit PrimitivePy(const std::string &name);
48   PrimitivePy(const PrimitivePy &prim_py);
49   PrimitivePy &operator=(const PrimitivePy &other);
50   explicit PrimitivePy(const py::object &python_obj);
51   ~PrimitivePy() override;
52   MS_DECLARE_PARENT(PrimitivePy, Primitive);
53   const bool parse_info_ = true;
54   py::function GetVmapRuleFunction(const bool is_side_effect = false, int axis_size = 0);
55   py::function GetBpropFunction();
56   py::function GetTaylorRuleFunction();
backward_hook_fn()57   const std::map<int, py::function> &backward_hook_fn() const { return backward_hook_fn_; }
58   void CopyHookFunction(const PrimitivePyPtr &primitive_py);
59   void AddBpropCutPrim(const PrimitivePyPtr &bprop_cut_prim);
60   void AddBackwardHookFn(const int &key, const py::function &backward_hook_fn);
61   void RemoveBackwardHookFn(const int &key);
62   BaseRef RunHookFunction(const VectorRef &args) const;
63   BaseRef RunCellCustomBpropFunction(const py::tuple &py_args) const;
64   BaseRef RunCustomOpBpropFunction(const py::tuple &py_args) const;
65   BaseRef RunCellHookFunction(const py::tuple &py_args) const;
66   BaseRef RunVariableHookFunction(const py::tuple &py_args, bool is_tensor_hook) const;
67   BaseRef RunComputeFunction(const VectorRef &args) const override;
68   py::object RunPyComputeFunction(const py::tuple &py_args) const;
69   bool HasComputeFunction() const;
70   py::dict GetAttrDict();
GetPyObj()71   const py::object &GetPyObj() const { return python_obj_; }
HasPyObj()72   bool HasPyObj() const { return python_obj_.operator bool(); }
73   void RunCheck(const py::tuple &args);
74   py::dict RunInfer(const py::tuple &args);
75   py::object RunInferValue(const py::tuple &args);
76   PrimitivePtr Clone() override;
adapter()77   PrimitivePyAdapterPtr adapter() const { return adapter_; }
set_bprop_cls_name(const std::string & name)78   void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; }
79   static void ProcessUnPairedCellHook(bool execute_hook_fn);
80   static void ClearHookRes();
IsPythonPrim()81   bool IsPythonPrim() override { return true; }
82 
83  private:
84   py::function GetComputeFunction() const;
85   py::object UnpackRetValueOfCellHook(const py::object &grad_out) const;
86   void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out, const py::object &code_obj,
87                             const py::object &co_name) const;
88   py::object python_obj_;
89   std::string bprop_cls_name_;
90   PrimitivePyAdapterPtr adapter_;
91   std::vector<Signature> signatures_;
92   std::vector<PrimitivePyWeakPtr> bprop_cut_prims_;
93   std::map<int, py::function> backward_hook_fn_;
94   static std::map<std::string, std::pair<std::map<int, py::function>, py::object>> hook_grad_;
95 };
96 
97 class PrimitivePyAdapter {
98  public:
99   explicit PrimitivePyAdapter(const py::str &name);
100   PrimitivePyAdapter(const PrimitivePyAdapter &adapter);
101   PrimitivePyAdapter &operator=(const PrimitivePyAdapter &other);
102   ~PrimitivePyAdapter() = default;
attrs()103   const mindspore::HashMap<std::string, ValuePtr> &attrs() const { return attrs_; }
104   void AddPyAttr(const py::str &name, const py::object &obj);
105   void DelPyAttr(const py::str &name);
106   py::dict GetAttrDict();
107   int AddBackwardHookFn(const py::function &backward_hook_fn);
108   void RemoveBackwardHookFn(int key);
109   void set_prim_type(const PrimType t);
110   void set_const_prim(bool is_const_prim);
111   void set_inplace_prim(bool inplace_prim);
112   void set_const_input_indexes(const std::vector<size_t> &const_input_indexes);
113   void set_signatures(const std::vector<Signature> &signatures);
114   void set_instance_name(const std::string &s);
115   void set_attached_primitive(const PrimitivePyPtr &prim);
attached_primitive()116   PrimitivePyPtr attached_primitive() const { return attached_primitive_.lock(); }
id()117   uint64_t id() const { return id_; }
name()118   std::string name() const { return name_; }
set_name(const std::string & name)119   void set_name(const std::string &name) { name_ = name; }
120 
121   struct PrimitiveUserData {
122     py::object obj;
~PrimitiveUserDataPrimitiveUserData123     ~PrimitiveUserData() {
124       // cppcheck-suppress unreadVariable
125       py::gil_scoped_acquire acquire_gil;
126       obj = py::none();
127     }
128   };
129 
130   void SetUserData(const py::str &key, const py::object &value);
131   py::object GetUserData(const py::str &key) const;
132 
133   const bool parse_info_ = true;
134 
135  private:
136   friend PrimitivePy;
137 
138   template <typename T>
set_user_data(const std::string & key,const std::shared_ptr<T> & value)139   void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
140     user_data_.set<T>(key, value);
141   }
142   template <typename T>
user_data(const std::string & key)143   std::shared_ptr<T> user_data(const std::string &key) const {
144     return user_data_.get<T>(key);
145   }
146 
147   bool const_prim_{false};
148   bool inplace_prim_{false};
149   int backward_hook_fn_key_{-1};
150   uint64_t id_;
151   std::string name_;
152   std::string instance_name_;
153   PrimType prim_type_{kPrimTypeBuiltIn};
154   PrimitivePyWeakPtr attached_primitive_;
155   mindspore::HashMap<std::string, ValuePtr> attrs_;
156   std::vector<size_t> const_input_indexes_;
157   std::vector<Signature> signatures_;
158   std::map<int, py::function> backward_hook_fn_;
159   UserData user_data_;
160 };
161 
162 /// \brief OpPrimPyRegister defines the singleton to save primitivepy which has no attrs.
163 class OpPrimPyRegister {
164  public:
165   /// \brief Destructor of OpPrimPyRegister.
~OpPrimPyRegister()166   ~OpPrimPyRegister() {}
167 
168   /// \brief Get the OpPrimPyRegister singleton.
169   ///
170   /// \return The OpPrimPyRegister singleton.
GetInstance()171   static OpPrimPyRegister &GetInstance() {
172     static OpPrimPyRegister instance{};
173     return instance;
174   }
175 
176   /// \brief Get PrimPyMap of the OpPrimPyRegister singleton.
177   ///
178   /// \return The PrimPyMap of the OpPrimPyRegister singleton.
GetPrimPyMap()179   const HashMap<std::string, ValuePtr> &GetPrimPyMap() const { return primpy_map_; }
180 
181   /// \brief Add an element into the PrimPyMap of the OpPrimPyRegister singleton.
182   ///
183   /// param[in] name The operator name.
184   /// param[in] primpy The primitivepy of the operator.
SetPrimPyMap(const std::string & name,const ValuePtr & primpy)185   void SetPrimPyMap(const std::string &name, const ValuePtr &primpy) { primpy_map_[name] = primpy; }
186 
187   /// \brief Clear the PrimPyMap before the pyobject destroyed.
Clear()188   void Clear() { primpy_map_.clear(); }
189 
190  private:
OpPrimPyRegister()191   OpPrimPyRegister() {}
192   HashMap<std::string, ValuePtr> primpy_map_;  // op_name, primpy
193 };
194 
195 class PrimitiveFunctionAdapter {
196  public:
197   PrimitiveFunctionAdapter() = default;
198   virtual ~PrimitiveFunctionAdapter() = default;
set_attached_primitive_function(const PrimitivePtr & prim_func)199   void set_attached_primitive_function(const PrimitivePtr &prim_func) { attached_primitive_function_ = prim_func; }
attached_primitive_function()200   PrimitivePtr attached_primitive_function() { return attached_primitive_function_; }
name()201   virtual std::string name() { return py::str(attached_primitive_function_->name()).cast<std::string>(); }
has_label(const std::string & label)202   py::object has_label(const std::string &label) { return py::bool_(attached_primitive_function_->HasAttr(label)); }
203   void set_label(const std::string &label, const py::object &value);
get_label(const std::string & label)204   py::object get_label(const std::string &label) { return ValueToPyData(attached_primitive_function_->GetAttr(label)); }
205   py::object clone();
206 
207   const bool parse_info_ = true;
208 
209  private:
210   std::string name_;
211   PrimitivePtr attached_primitive_function_;
212 };
213 }  // namespace mindspore
214 #endif  // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
215