• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
18 #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <tuple>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "abstract/abstract_value.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 #include "ir/primitive.h"
30 #include "ir/signature.h"
31 #include "pybind11/pybind11.h"
32 #include "utils/log_adapter.h"
33 #include "utils/misc.h"
34 
35 namespace py = pybind11;
36 namespace mindspore {
37 
38 class PrimitivePy;
39 using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
40 using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>;
41 
42 class PrimitivePyAdapter;
43 using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>;
44 
45 class PrimitivePy : public Primitive {
46  public:
47   explicit PrimitivePy(const std::string &name);
48   PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter);
49   ~PrimitivePy() override;
50   MS_DECLARE_PARENT(PrimitivePy, Primitive);
51   py::function GetBpropFunction();
52 
53   void set_signatures(const std::vector<Signature> &signatures);
54 
signatures()55   const std::vector<Signature> &signatures() const { return signatures_; }
56 
57   void CopyHookFunction(const PrimitivePtr &primitive) override;
58 
59   py::dict GetAttrDict();
set_hook(const py::function & hook)60   void set_hook(const py::function &hook) { hook_ = hook; }
hook()61   py::function hook() const { return hook_; }
62   BaseRef RunHookFunction(const VectorRef &args) const override;
63   BaseRef RunCellBpropFunction(const py::tuple &py_args) const;
64   BaseRef RunCellHookFunction(const py::tuple &py_args) const;
65   BaseRef RunVariableHookFunction(const py::tuple &py_args) const;
66   BaseRef RunComputeFunction(const VectorRef &args) const override;
67   py::object RunPyComputeFunction(const py::tuple &py_args) const;
68   bool HasComputeFunction() const;
69   const bool parse_info_ = true;
GetPyObj()70   const py::object &GetPyObj() const { return python_obj_; }
71   py::dict RunInfer(const py::tuple &args);
72   void RunCheck(const py::tuple &args);
73   py::object RunInferValue(const py::tuple &args);
HasPyObj()74   bool HasPyObj() { return python_obj_.operator bool(); }
75   PrimitivePtr Clone() override;
adapter()76   PrimitivePyAdapterPtr adapter() const { return adapter_; }
77 
78  private:
79   py::function GetComputeFunction() const;
80   void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const;
81   void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const;
82   py::object python_obj_;
83   PrimitivePyAdapterPtr adapter_;
84   py::function hook_;
85   std::vector<Signature> signatures_;
86   static std::map<std::string, py::object> hook_grad_;
87 };
88 
89 class PrimitivePyAdapter {
90  public:
91   explicit PrimitivePyAdapter(const py::str &name);
92   ~PrimitivePyAdapter() = default;
93   void AddPyAttr(const py::str &name, const py::object &obj);
94   void DelPyAttr(const py::str &name);
95   py::dict GetAttrDict();
96   void set_prim_type(const PrimType t);
97   void set_const_prim(bool is_const_prim);
98   void set_const_input_indexes(const std::vector<size_t> &const_input_indexes);
99   void set_signatures(const std::vector<Signature> &signatures);
100   void set_hook(const py::function &hook);
101   void set_instance_name(const std::string &s);
102   void set_attached_primitive(const PrimitivePyPtr &prim);
attached_primitive()103   PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); }
name()104   std::string name() const { return name_; }
set_name(const std::string & name)105   void set_name(const std::string &name) { name_ = name; }
106   const bool parse_info_ = true;
107 
108  private:
109   friend PrimitivePy;
110   std::string name_;
111   PrimitivePyWeakPtr attached_primitive_;
112   std::unordered_map<std::string, ValuePtr> attrs_;
113   PrimType prim_type_{kPrimTypeBuiltIn};
114   bool is_const_prim_{false};
115   std::vector<size_t> const_input_indexes_;
116   std::vector<Signature> signatures_;
117   py::function hook_;
118   std::string instance_name_;
119 };
120 }  // namespace mindspore
121 #endif  // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
122