1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ 18 #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ 19 20 #include <map> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 #include <memory> 25 #include "utils/hash_map.h" 26 #include "abstract/abstract_value.h" 27 #include "ir/primitive.h" 28 #include "ir/signature.h" 29 #include "pybind11/pybind11.h" 30 #include "include/common/utils/convert_utils_py.h" 31 32 namespace py = pybind11; 33 namespace mindspore { 34 35 class PrimitivePy; 36 using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; 37 using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>; 38 39 class PrimitivePyAdapter; 40 using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>; 41 42 class PrimitiveFunctionAdapter; 43 using PrimitiveFunctionAdapterPtr = std::shared_ptr<PrimitiveFunctionAdapter>; 44 45 class PrimitivePy : public Primitive { 46 public: 47 explicit PrimitivePy(const std::string &name); 48 PrimitivePy(const PrimitivePy &prim_py); 49 PrimitivePy &operator=(const PrimitivePy &other); 50 explicit PrimitivePy(const py::object &python_obj); 51 ~PrimitivePy() override; 52 MS_DECLARE_PARENT(PrimitivePy, Primitive); 53 const bool parse_info_ = true; 54 py::function GetVmapRuleFunction(const bool is_side_effect = false, int axis_size = 0); 55 py::function GetBpropFunction(); 56 py::function GetTaylorRuleFunction(); backward_hook_fn()57 const std::map<int, py::function> &backward_hook_fn() const { return backward_hook_fn_; } 58 void CopyHookFunction(const PrimitivePyPtr &primitive_py); 59 void AddBpropCutPrim(const PrimitivePyPtr &bprop_cut_prim); 60 void AddBackwardHookFn(const int &key, const py::function &backward_hook_fn); 61 void RemoveBackwardHookFn(const int &key); 62 BaseRef RunHookFunction(const VectorRef &args) const; 63 BaseRef RunCellCustomBpropFunction(const py::tuple &py_args) const; 64 BaseRef RunCustomOpBpropFunction(const py::tuple &py_args) const; 65 BaseRef RunCellHookFunction(const py::tuple &py_args) const; 66 BaseRef RunVariableHookFunction(const py::tuple &py_args, bool is_tensor_hook) const; 67 BaseRef RunComputeFunction(const VectorRef &args) const override; 68 py::object RunPyComputeFunction(const py::tuple &py_args) const; 69 bool HasComputeFunction() const; 70 py::dict GetAttrDict(); GetPyObj()71 const py::object &GetPyObj() const { return python_obj_; } HasPyObj()72 bool HasPyObj() const { return python_obj_.operator bool(); } 73 void RunCheck(const py::tuple &args); 74 py::dict RunInfer(const py::tuple &args); 75 py::object RunInferValue(const py::tuple &args); 76 PrimitivePtr Clone() override; adapter()77 PrimitivePyAdapterPtr adapter() const { return adapter_; } set_bprop_cls_name(const std::string & name)78 void set_bprop_cls_name(const std::string &name) { bprop_cls_name_ = name; } 79 static void ProcessUnPairedCellHook(bool execute_hook_fn); 80 static void ClearHookRes(); IsPythonPrim()81 bool IsPythonPrim() override { return true; } 82 83 private: 84 py::function GetComputeFunction() const; 85 py::object UnpackRetValueOfCellHook(const py::object &grad_out) const; 86 void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out, const py::object &code_obj, 87 const py::object &co_name) const; 88 py::object python_obj_; 89 std::string bprop_cls_name_; 90 PrimitivePyAdapterPtr adapter_; 91 std::vector<Signature> signatures_; 92 std::vector<PrimitivePyWeakPtr> bprop_cut_prims_; 93 std::map<int, py::function> backward_hook_fn_; 94 static std::map<std::string, std::pair<std::map<int, py::function>, py::object>> hook_grad_; 95 }; 96 97 class PrimitivePyAdapter { 98 public: 99 explicit PrimitivePyAdapter(const py::str &name); 100 PrimitivePyAdapter(const PrimitivePyAdapter &adapter); 101 PrimitivePyAdapter &operator=(const PrimitivePyAdapter &other); 102 ~PrimitivePyAdapter() = default; attrs()103 const mindspore::HashMap<std::string, ValuePtr> &attrs() const { return attrs_; } 104 void AddPyAttr(const py::str &name, const py::object &obj); 105 void DelPyAttr(const py::str &name); 106 py::dict GetAttrDict(); 107 int AddBackwardHookFn(const py::function &backward_hook_fn); 108 void RemoveBackwardHookFn(int key); 109 void set_prim_type(const PrimType t); 110 void set_const_prim(bool is_const_prim); 111 void set_inplace_prim(bool inplace_prim); 112 void set_const_input_indexes(const std::vector<size_t> &const_input_indexes); 113 void set_signatures(const std::vector<Signature> &signatures); 114 void set_instance_name(const std::string &s); 115 void set_attached_primitive(const PrimitivePyPtr &prim); attached_primitive()116 PrimitivePyPtr attached_primitive() const { return attached_primitive_.lock(); } id()117 uint64_t id() const { return id_; } name()118 std::string name() const { return name_; } set_name(const std::string & name)119 void set_name(const std::string &name) { name_ = name; } 120 121 struct PrimitiveUserData { 122 py::object obj; ~PrimitiveUserDataPrimitiveUserData123 ~PrimitiveUserData() { 124 // cppcheck-suppress unreadVariable 125 py::gil_scoped_acquire acquire_gil; 126 obj = py::none(); 127 } 128 }; 129 130 void SetUserData(const py::str &key, const py::object &value); 131 py::object GetUserData(const py::str &key) const; 132 133 const bool parse_info_ = true; 134 135 private: 136 friend PrimitivePy; 137 138 template <typename T> set_user_data(const std::string & key,const std::shared_ptr<T> & value)139 void set_user_data(const std::string &key, const std::shared_ptr<T> &value) { 140 user_data_.set<T>(key, value); 141 } 142 template <typename T> user_data(const std::string & key)143 std::shared_ptr<T> user_data(const std::string &key) const { 144 return user_data_.get<T>(key); 145 } 146 147 bool const_prim_{false}; 148 bool inplace_prim_{false}; 149 int backward_hook_fn_key_{-1}; 150 uint64_t id_; 151 std::string name_; 152 std::string instance_name_; 153 PrimType prim_type_{kPrimTypeBuiltIn}; 154 PrimitivePyWeakPtr attached_primitive_; 155 mindspore::HashMap<std::string, ValuePtr> attrs_; 156 std::vector<size_t> const_input_indexes_; 157 std::vector<Signature> signatures_; 158 std::map<int, py::function> backward_hook_fn_; 159 UserData user_data_; 160 }; 161 162 /// \brief OpPrimPyRegister defines the singleton to save primitivepy which has no attrs. 163 class OpPrimPyRegister { 164 public: 165 /// \brief Destructor of OpPrimPyRegister. ~OpPrimPyRegister()166 ~OpPrimPyRegister() {} 167 168 /// \brief Get the OpPrimPyRegister singleton. 169 /// 170 /// \return The OpPrimPyRegister singleton. GetInstance()171 static OpPrimPyRegister &GetInstance() { 172 static OpPrimPyRegister instance{}; 173 return instance; 174 } 175 176 /// \brief Get PrimPyMap of the OpPrimPyRegister singleton. 177 /// 178 /// \return The PrimPyMap of the OpPrimPyRegister singleton. GetPrimPyMap()179 const HashMap<std::string, ValuePtr> &GetPrimPyMap() const { return primpy_map_; } 180 181 /// \brief Add an element into the PrimPyMap of the OpPrimPyRegister singleton. 182 /// 183 /// param[in] name The operator name. 184 /// param[in] primpy The primitivepy of the operator. SetPrimPyMap(const std::string & name,const ValuePtr & primpy)185 void SetPrimPyMap(const std::string &name, const ValuePtr &primpy) { primpy_map_[name] = primpy; } 186 187 /// \brief Clear the PrimPyMap before the pyobject destroyed. Clear()188 void Clear() { primpy_map_.clear(); } 189 190 private: OpPrimPyRegister()191 OpPrimPyRegister() {} 192 HashMap<std::string, ValuePtr> primpy_map_; // op_name, primpy 193 }; 194 195 class PrimitiveFunctionAdapter { 196 public: 197 PrimitiveFunctionAdapter() = default; 198 virtual ~PrimitiveFunctionAdapter() = default; set_attached_primitive_function(const PrimitivePtr & prim_func)199 void set_attached_primitive_function(const PrimitivePtr &prim_func) { attached_primitive_function_ = prim_func; } attached_primitive_function()200 PrimitivePtr attached_primitive_function() { return attached_primitive_function_; } name()201 virtual std::string name() { return py::str(attached_primitive_function_->name()).cast<std::string>(); } has_label(const std::string & label)202 py::object has_label(const std::string &label) { return py::bool_(attached_primitive_function_->HasAttr(label)); } 203 void set_label(const std::string &label, const py::object &value); get_label(const std::string & label)204 py::object get_label(const std::string &label) { return ValueToPyData(attached_primitive_function_->GetAttr(label)); } 205 py::object clone(); 206 207 const bool parse_info_ = true; 208 209 private: 210 std::string name_; 211 PrimitivePtr attached_primitive_function_; 212 }; 213 } // namespace mindspore 214 #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ 215