1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ 18 #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <tuple> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "abstract/abstract_value.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 #include "ir/primitive.h" 30 #include "ir/signature.h" 31 #include "pybind11/pybind11.h" 32 #include "utils/log_adapter.h" 33 #include "utils/misc.h" 34 35 namespace py = pybind11; 36 namespace mindspore { 37 38 class PrimitivePy; 39 using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; 40 using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>; 41 42 class PrimitivePyAdapter; 43 using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>; 44 45 class PrimitivePy : public Primitive { 46 public: 47 explicit PrimitivePy(const std::string &name); 48 PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter); 49 ~PrimitivePy() override; 50 MS_DECLARE_PARENT(PrimitivePy, Primitive); 51 py::function GetBpropFunction(); 52 53 void set_signatures(const std::vector<Signature> &signatures); 54 signatures()55 const std::vector<Signature> &signatures() const { return signatures_; } 56 57 void CopyHookFunction(const PrimitivePtr &primitive) override; 58 59 py::dict GetAttrDict(); set_hook(const py::function & hook)60 void set_hook(const py::function &hook) { hook_ = hook; } hook()61 py::function hook() const { return hook_; } 62 BaseRef RunHookFunction(const VectorRef &args) const override; 63 BaseRef RunCellBpropFunction(const py::tuple &py_args) const; 64 BaseRef RunCellHookFunction(const py::tuple &py_args) const; 65 BaseRef RunVariableHookFunction(const py::tuple &py_args) const; 66 BaseRef RunComputeFunction(const VectorRef &args) const override; 67 py::object RunPyComputeFunction(const py::tuple &py_args) const; 68 bool HasComputeFunction() const; 69 const bool parse_info_ = true; GetPyObj()70 const py::object &GetPyObj() const { return python_obj_; } 71 py::dict RunInfer(const py::tuple &args); 72 void RunCheck(const py::tuple &args); 73 py::object RunInferValue(const py::tuple &args); HasPyObj()74 bool HasPyObj() { return python_obj_.operator bool(); } 75 PrimitivePtr Clone() override; adapter()76 PrimitivePyAdapterPtr adapter() const { return adapter_; } 77 78 private: 79 py::function GetComputeFunction() const; 80 void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const; 81 void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const; 82 py::object python_obj_; 83 PrimitivePyAdapterPtr adapter_; 84 py::function hook_; 85 std::vector<Signature> signatures_; 86 static std::map<std::string, py::object> hook_grad_; 87 }; 88 89 class PrimitivePyAdapter { 90 public: 91 explicit PrimitivePyAdapter(const py::str &name); 92 ~PrimitivePyAdapter() = default; 93 void AddPyAttr(const py::str &name, const py::object &obj); 94 void DelPyAttr(const py::str &name); 95 py::dict GetAttrDict(); 96 void set_prim_type(const PrimType t); 97 void set_const_prim(bool is_const_prim); 98 void set_const_input_indexes(const std::vector<size_t> &const_input_indexes); 99 void set_signatures(const std::vector<Signature> &signatures); 100 void set_hook(const py::function &hook); 101 void set_instance_name(const std::string &s); 102 void set_attached_primitive(const PrimitivePyPtr &prim); attached_primitive()103 PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); } name()104 std::string name() const { return name_; } set_name(const std::string & name)105 void set_name(const std::string &name) { name_ = name; } 106 const bool parse_info_ = true; 107 108 private: 109 friend PrimitivePy; 110 std::string name_; 111 PrimitivePyWeakPtr attached_primitive_; 112 std::unordered_map<std::string, ValuePtr> attrs_; 113 PrimType prim_type_{kPrimTypeBuiltIn}; 114 bool is_const_prim_{false}; 115 std::vector<size_t> const_input_indexes_; 116 std::vector<Signature> signatures_; 117 py::function hook_; 118 std::string instance_name_; 119 }; 120 } // namespace mindspore 121 #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ 122