• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pybind_api/ir/primitive_py.h"
18 
19 #include <map>
20 #include "ir/signature.h"
21 #include "pipeline/jit/ps/parse/data_converter.h"
22 #include "include/common/utils/python_adapter.h"
23 #include "pybind11/pytypes.h"
24 #include "include/common/pybind_api/api_register.h"
25 #include "pybind_api/export_flags.h"
26 #include "pybind_api/ir/base_ref_py.h"
27 #include "utils/convert_utils_base.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "utils/ms_context.h"
30 #include "include/common/utils/primitive_utils.h"
31 #include "utils/check_convert_utils.h"
32 #include "pipeline/pynative/pynative_execute.h"
33 #include "include/common/profiler.h"
34 
35 namespace mindspore {
36 namespace {
37 constexpr auto kBpropAttrName = "bprop";
38 constexpr auto kCellHookAttrName = "cell_hook";
39 constexpr auto kCellIDAttrName = "cell_id";
40 constexpr auto kCustomOpBpropAttrName = "custom_op_bprop";
41 constexpr auto kIsRecomputeAttr = "is_recompute";
42 
MakeId()43 static uint64_t MakeId() {
44   // Use atomic to make id generator thread safe.
45   static std::atomic<uint64_t> last_id{1};
46   return last_id.fetch_add(1, std::memory_order_relaxed);
47 }
48 std::map<std::string, std::string> kOpAttrNameReplaceMap = {
49   {"data_format", "format"},
50 };
51 
SyncData(const py::object & arg)52 void SyncData(const py::object &arg) {
53   if (py::isinstance<py::tuple>(arg)) {
54     py::tuple arg_list = py::cast<py::tuple>(arg);
55     for (size_t i = 0; i < arg_list.size(); i++) {
56       SyncData(arg_list[i]);
57     }
58   }
59   if (py::isinstance<tensor::Tensor>(arg)) {
60     auto tensor = py::cast<tensor::TensorPtr>(arg);
61     tensor->data_sync();
62   }
63   if (IsStubTensor(arg)) {
64     auto tensor = ConvertStubTensor(arg);
65     tensor->data_sync();
66   }
67 }
68 
ConstructCellHookFnArgs(const std::string & cell_id,const py::object & grad_input,const py::object & grad_output)69 py::tuple ConstructCellHookFnArgs(const std::string &cell_id, const py::object &grad_input,
70                                   const py::object &grad_output) {
71   constexpr size_t grad_input_index = 1;
72   constexpr size_t grad_output_index = 2;
73   constexpr size_t input_args_nums = 3;
74   // Convert c++ object to python object.
75   py::tuple c_grad_args(input_args_nums - 1);
76   c_grad_args[0] = grad_input;
77   c_grad_args[1] = grad_output;
78   py::tuple py_grad_args(input_args_nums - 1);
79   ConvertCTensorToPyTensor(c_grad_args, &py_grad_args);
80   // Get tuple args of cell hook function.
81   py::tuple hook_fn_args(input_args_nums);
82   hook_fn_args[0] = cell_id;
83   // Set grad in
84   if (!py::isinstance<py::tuple>(py_grad_args[0])) {
85     hook_fn_args[grad_input_index] = py::make_tuple(py_grad_args[0]);
86   } else {
87     hook_fn_args[grad_input_index] = py_grad_args[0];
88   }
89   // Set grad out
90   if (!py::isinstance<py::tuple>(py_grad_args[1])) {
91     hook_fn_args[grad_output_index] = py::make_tuple(py_grad_args[1]);
92   } else {
93     hook_fn_args[grad_output_index] = py_grad_args[1];
94   }
95   return hook_fn_args;
96 }
97 
ContainsWeights(const py::tuple & grads)98 bool ContainsWeights(const py::tuple &grads) {
99   if (grads.size() < kSizeTwo) {
100     return false;
101   }
102   if (!py::isinstance<py::tuple>(grads[0]) && !py::isinstance<py::dict>(grads[1])) {
103     return false;
104   }
105   return true;
106 }
107 
108 struct RunPrimitivePyHookFunctionRegister {
RunPrimitivePyHookFunctionRegistermindspore::__anond61628c30111::RunPrimitivePyHookFunctionRegister109   RunPrimitivePyHookFunctionRegister() {
110     python_adapter::PyAdapterCallback::SetRunPrimitivePyHookFunctionHandler(
111       [](const PrimitivePtr &prim, const VectorRef &args) -> BaseRef {
112         auto py_prim = prim->cast<PrimitivePyPtr>();
113         MS_EXCEPTION_IF_NULL(py_prim);
114         return py_prim->RunHookFunction(args);
115       });
116   }
117 } callback_register;
118 struct ProcessUnPairedCellHookRegister {
ProcessUnPairedCellHookRegistermindspore::__anond61628c30111::ProcessUnPairedCellHookRegister119   ProcessUnPairedCellHookRegister() {
120     python_adapter::PyAdapterCallback::SetProcessUnPairedCellHookHandler(
121       [](bool execute_hook_fn) -> void { PrimitivePy::ProcessUnPairedCellHook(execute_hook_fn); });
122   }
123 } cell_hook_callback_register;
124 }  // namespace
125 std::map<std::string, std::pair<std::map<int, py::function>, py::object>> PrimitivePy::hook_grad_;
126 
PrimitivePy(const std::string & name)127 PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
128 
PrimitivePy(const PrimitivePy & prim_py)129 PrimitivePy::PrimitivePy(const PrimitivePy &prim_py)
130     : Primitive(prim_py),
131       python_obj_(prim_py.python_obj_),
132       bprop_cls_name_(prim_py.bprop_cls_name_),
133       adapter_(prim_py.adapter_),
134       signatures_(prim_py.signatures_),
135       bprop_cut_prims_(prim_py.bprop_cut_prims_),
136       backward_hook_fn_(prim_py.backward_hook_fn_) {}
137 
operator =(const PrimitivePy & other)138 PrimitivePy &PrimitivePy::operator=(const PrimitivePy &other) {
139   if (this == &other) {
140     return *this;
141   }
142   Primitive::operator=(other);
143   python_obj_ = other.python_obj_;
144   bprop_cls_name_ = other.bprop_cls_name_;
145   adapter_ = other.adapter_;
146   signatures_ = other.signatures_;
147   bprop_cut_prims_ = other.bprop_cut_prims_;
148   backward_hook_fn_ = other.backward_hook_fn_;
149   return *this;
150 }
151 
PrimitivePy(const py::object & python_obj)152 PrimitivePy::PrimitivePy(const py::object &python_obj)
153     : Primitive(python_obj.cast<PrimitivePyAdapterPtr>()->name_, false),
154       python_obj_(python_obj),
155       adapter_(python_obj.cast<PrimitivePyAdapterPtr>()) {
156   MS_LOG(DEBUG) << "New primitive:" << adapter_->name_;
157   set_signatures(adapter_->signatures_);
158   (void)Primitive::SetAttrs(adapter_->attrs_);
159   Primitive::set_prim_type(adapter_->prim_type_);
160   Primitive::set_const_prim(adapter_->const_prim_);
161   Primitive::set_inplace_prim(adapter_->inplace_prim_);
162   Primitive::set_const_input_indexes(adapter_->const_input_indexes_);
163   for (const auto &elem : adapter_->backward_hook_fn_) {
164     AddBackwardHookFn(elem.first, elem.second);
165   }
166   set_instance_name(adapter_->instance_name_);
167   CloneUserData(adapter_->user_data_);
168 }
169 
~PrimitivePy()170 PrimitivePy::~PrimitivePy() {
171   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kDefault, name(),
172                                      false);
173   py::gil_scoped_acquire acquire_gil;
174   python_obj_ = py::none();
175   backward_hook_fn_.clear();
176 }
177 
GetVmapRuleFunction(const bool,int axis_size)178 py::function PrimitivePy::GetVmapRuleFunction(const bool, int axis_size) {
179   constexpr char get_vmap_rule_func_name[] = "get_vmap_rule";
180   if (py::hasattr(python_obj_, get_vmap_rule_func_name)) {
181     return python_obj_.attr(get_vmap_rule_func_name)().cast<py::function>();
182   }
183   return GetVmapRuleFunctionByObj(python_obj_, axis_size);
184 }
185 
GetBpropFunction()186 py::function PrimitivePy::GetBpropFunction() {
187   static const char *const get_bprop_func_name = "get_bprop";
188   if (py::hasattr(python_obj_, get_bprop_func_name)) {
189     py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
190     return fn;
191   }
192 
193   auto fn = GetBpropFunctionByObj(python_obj_);
194   return fn;
195 }
196 
GetTaylorRuleFunction()197 py::function PrimitivePy::GetTaylorRuleFunction() {
198   static const char *const get_taylor_rule_func_name = "get_taylor_rule";
199   if (py::hasattr(python_obj_, get_taylor_rule_func_name)) {
200     py::function fn = python_obj_.attr(get_taylor_rule_func_name)().cast<py::function>();
201     return fn;
202   }
203   auto fn = GetTaylorRuleFunctionByObj(python_obj_);
204   return fn;
205 }
206 
check_bprop_input_grads(const py::tuple & py_args,const py::tuple & grads,const std::string & bprop_cls_name,int filter_args_size)207 void check_bprop_input_grads(const py::tuple &py_args, const py::tuple &grads, const std::string &bprop_cls_name,
208                              int filter_args_size) {
209   if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
210     return;
211   }
212   if (grads.size() != py_args.size() - filter_args_size) {
213     MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name
214                             << "', the number of return values(gradients) should be equal to the number of input "
215                                "arguments except 'out' and 'dout', which is: "
216                             << (py_args.size() - filter_args_size) << ", but got:" << grads.size() << ".";
217   }
218   for (size_t i = 0; i < grads.size(); i++) {
219     if (py::isinstance<tensor::Tensor>(py_args[i]) || IsStubTensor(py_args[i])) {
220       if (!py::isinstance<tensor::Tensor>(grads[i]) && !IsStubTensor(grads[i])) {
221         MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
222                                 << "th return value(gradient of the " << i << "th argument) should be Tensor, but got "
223                                 << py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
224                                 << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
225       }
226 
227       py::object arg_dtype = py_args[i].attr("dtype");
228       py::object grad_dtype = grads[i].attr("dtype");
229       py::tuple arg_shape = py_args[i].attr("shape");
230       py::tuple grad_shape = grads[i].attr("shape");
231       if (!grad_dtype.equal(arg_dtype)) {
232         MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
233                                 << "th return value(gradient of the " << i
234                                 << "th argument) should have the same dtype as the " << i
235                                 << "th argument, which is:" << py::cast<py::str>(arg_dtype)
236                                 << ", but got: " << py::cast<py::str>(grad_dtype) << ".";
237       }
238       if (!grad_shape.equal(arg_shape)) {
239         MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
240                                  << "th return value(gradient of the " << i
241                                  << "th argument) should have the same shape as the " << i
242                                  << "th argument, which is:" << py::cast<py::str>(arg_shape)
243                                  << ", but got: " << py::cast<py::str>(grad_shape) << ".";
244       }
245     }
246   }
247 }
248 
check_bprop_out(const py::object & grads_obj,const py::tuple & py_args,const std::string & bprop_cls_name)249 py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
250   py::tuple grads;
251   if (py::isinstance<py::none>(grads_obj)) {
252     MS_EXCEPTION(TypeError) << "The python function output is none.";
253   } else if (!py::isinstance<py::tuple>(grads_obj)) {
254     MS_LOG(DEBUG) << "Wrap a tuple";
255     grads = py::make_tuple(grads_obj);
256   } else {
257     grads = py::cast<py::tuple>(grads_obj);
258   }
259   if (ContainsWeights(grads)) {
260     MS_LOG(DEBUG) << "Contain weights";
261     py::tuple input_grads = py::cast<py::tuple>(grads[0]);
262     py::dict weight_grads = py::cast<py::dict>(grads[1]);
263     check_bprop_input_grads(py_args, input_grads, bprop_cls_name, weight_grads.size() + 1);
264     if (weight_grads.empty()) {
265       return input_grads;
266     }
267     py::tuple all_grads(input_grads.size() + weight_grads.size());
268     for (size_t i = 0; i < input_grads.size(); ++i) {
269       all_grads[i] = input_grads[i];
270     }
271     size_t i = 0;
272     for (auto weight_grad : weight_grads) {
273       all_grads[i + input_grads.size()] = weight_grad.second;
274       ++i;
275     }
276     return all_grads;
277   } else {
278     MS_LOG(DEBUG) << "Not contain weights";
279     check_bprop_input_grads(py_args, grads, bprop_cls_name, kSizeTwo);
280     return grads;
281   }
282 }
283 
AddBpropCutPrim(const PrimitivePyPtr & bprop_cut_prim)284 void PrimitivePy::AddBpropCutPrim(const PrimitivePyPtr &bprop_cut_prim) {
285   MS_EXCEPTION_IF_NULL(bprop_cut_prim);
286   (void)bprop_cut_prims_.emplace_back(bprop_cut_prim);
287 }
288 
AddBackwardHookFn(const int & key,const py::function & backward_hook_fn)289 void PrimitivePy::AddBackwardHookFn(const int &key, const py::function &backward_hook_fn) {
290   backward_hook_fn_[key] = backward_hook_fn;
291   for (const auto &elem : bprop_cut_prims_) {
292     PrimitivePyPtr bprop_cut_prim = elem.lock();
293     if (bprop_cut_prim != nullptr) {
294       bprop_cut_prim->AddBackwardHookFn(key, backward_hook_fn);
295     }
296   }
297 }
298 
RemoveBackwardHookFn(const int & key)299 void PrimitivePy::RemoveBackwardHookFn(const int &key) {
300   auto iter = backward_hook_fn_.find(key);
301   if (iter != backward_hook_fn_.end()) {
302     (void)backward_hook_fn_.erase(key);
303   }
304   // Remove hook_fn for bprop cut prim on grad graph.
305   for (const auto &elem : bprop_cut_prims_) {
306     PrimitivePyPtr bprop_cut_prim = elem.lock();
307     if (bprop_cut_prim != nullptr) {
308       bprop_cut_prim->RemoveBackwardHookFn(key);
309     }
310   }
311 }
312 
UnpackRetValueOfCellHook(const py::object & grad_out) const313 py::object PrimitivePy::UnpackRetValueOfCellHook(const py::object &grad_out) const {
314   if (!py::isinstance<py::tuple>(grad_out)) {
315     hook_grad_.clear();
316     MS_EXCEPTION(TypeError) << "The return gradient of cell backward hook function should be a tuple!";
317   }
318   auto out_tuple = py::cast<py::tuple>(grad_out);
319   if (out_tuple.size() == 1) {
320     // The input number of current cell is 1.
321     return out_tuple[0];
322   }
323   return grad_out;
324 }
325 
CheckHookConsistency(const py::object & grad_out,const py::object & expected_grad_out,const py::object & code_obj,const py::object & co_name) const326 void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out,
327                                        const py::object &code_obj, const py::object &co_name) const {
328   if (py::isinstance<py::tuple>(expected_grad_out)) {
329     if (!py::isinstance<py::tuple>(grad_out)) {
330       hook_grad_.clear();
331       MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
332     }
333     auto actual_out_tuple = py::cast<py::tuple>(grad_out);
334     auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
335     if (actual_out_tuple.size() != expected_out_tuple.size()) {
336       hook_grad_.clear();
337       MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
338                                << ", but it is " << actual_out_tuple.size();
339     }
340     for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
341       CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i], code_obj, co_name);
342     }
343   }
344 
345   if (py::isinstance<tensor::Tensor>(expected_grad_out) || IsStubTensor(expected_grad_out)) {
346     if (!py::isinstance<tensor::Tensor>(grad_out) && !IsStubTensor(grad_out)) {
347       hook_grad_.clear();
348       MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
349                               << py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
350     }
351     tensor::TensorPtr actual_out_tensor =
352       IsStubTensor(grad_out) ? ConvertStubTensor(grad_out) : py::cast<tensor::TensorPtr>(grad_out);
353     tensor::TensorPtr expected_out_tensor = IsStubTensor(expected_grad_out)
354                                               ? ConvertStubTensor(expected_grad_out)
355                                               : py::cast<tensor::TensorPtr>(expected_grad_out);
356     MS_EXCEPTION_IF_NULL(actual_out_tensor);
357     MS_EXCEPTION_IF_NULL(expected_out_tensor);
358     if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
359       hook_grad_.clear();
360       MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name)
361                                << " is not consistent with the expected, it should be "
362                                << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got "
363                                << actual_out_tensor->GetShapeAndDataTypeInfo();
364     }
365   }
366 }
367 
RunCellCustomBpropFunction(const py::tuple & py_args) const368 BaseRef PrimitivePy::RunCellCustomBpropFunction(const py::tuple &py_args) const {
369   if (backward_hook_fn_.size() > 1) {
370     MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
371   }
372   py::tuple converted_args(py_args.size());
373   ConvertCTensorToPyTensor(py_args, &converted_args);
374   MS_LOG(DEBUG) << "Get convert args size " << converted_args.size() << ", args are "
375                 << ConvertPyObjToString(converted_args);
376   // If recompute, just discard dout; Otherwise, discat out and dout
377   bool is_recompute = HasAttr(kIsRecomputeAttr);
378   size_t non_inp_args_size = is_recompute ? kSizeOne : kSizeTwo;
379 
380   auto inp_args_size = py_args.size() - non_inp_args_size;
381   py::tuple input_args(inp_args_size);
382   for (size_t i = 0; i < inp_args_size; ++i) {
383     input_args[i] = py_args[i];
384   }
385   MS_LOG(DEBUG) << "Get cell input arg size " << inp_args_size;
386   // Run bprop function.
387   auto inst = pynative::PyNativeExecutor::GetInstance();
388   MS_EXCEPTION_IF_NULL(inst);
389   try {
390     MS_LOG(DEBUG) << "Run cell custom bprop function start.";
391     py::tuple grads;
392     MS_LOG(DEBUG) << "Get num of backward hook fn is " << backward_hook_fn_.size();
393     for (const auto &elem : backward_hook_fn_) {
394       inst->NewGraph(elem.second, input_args.cast<py::args>());
395       py::object grads_obj = elem.second(*converted_args);
396       MS_LOG(DEBUG) << "Get cell hook output " << ConvertPyObjToString(grads_obj);
397       grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
398       py::object out = grads_obj;
399       // If grads.size() > inp_args_size, that means exist weights.
400       if (grads.size() > inp_args_size) {
401         MS_LOG(DEBUG) << "Get grads size " << grads.size();
402         out = py::cast<py::tuple>(grads_obj)[0];
403       }
404       inst->EndGraph(elem.second, out, input_args.cast<py::args>());
405     }
406     MS_LOG(DEBUG) << "Run cell custom bprop function end.";
407     return std::make_shared<PyObjectRef>(grads);
408   } catch (std::exception &bt) {
409     inst->ClearRes();
410     std::rethrow_exception(std::current_exception());
411   }
412 }
413 
RunCustomOpBpropFunction(const py::tuple & py_args) const414 BaseRef PrimitivePy::RunCustomOpBpropFunction(const py::tuple &py_args) const {
415   if (backward_hook_fn_.size() > 1) {
416     MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
417   }
418   py::tuple grads;
419   SyncData(py_args);
420   py::tuple converted_args(py_args.size());
421   ConvertCTensorToPyTensor(py_args, &converted_args);
422   MS_LOG(DEBUG) << "Get convert args size " << converted_args.size() << ", args are "
423                 << ConvertPyObjToString(converted_args);
424   try {
425     MS_LOG(DEBUG) << "start execute custom op bprop";
426     for (const auto &elem : backward_hook_fn_) {
427       py::object grads_obj = elem.second(*converted_args);
428       grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
429     }
430     MS_LOG(DEBUG) << "end execute custom op bprop";
431     return std::make_shared<PyObjectRef>(grads);
432   } catch (std::exception &bt) {
433     std::rethrow_exception(std::current_exception());
434   }
435 }
436 
RunCellHookFunction(const py::tuple & py_args) const437 BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
438   const auto args_size = py_args.size();
439   // Get the din passed to current bprop cut op.
440   py::object grad_output = py_args[args_size - 1];
441   // Get the cell id.
442   auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
443   auto iter = hook_grad_.find(cell_id);
444   if (iter != hook_grad_.end()) {
445     // The second bprop_cut used to hook output gradient of cell.
446     for (const auto &elem : backward_hook_fn_) {
447       MS_LOG(DEBUG) << "Run cell hook function start.";
448       py::object code_obj = py::getattr(elem.second, "__code__");
449       py::object co_name = py::getattr(code_obj, "co_name");
450       if (std::string(py::str(co_name)) == "staging_specialize") {
451         py::object name_obj = py::getattr(elem.second, "__name__");
452         MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
453       }
454       MS_LOG(DEBUG) << "Get cell dout " << ConvertPyObjToString(grad_output);
455       SyncData(grad_output);
456       const py::object grad_input = iter->second.second;
457       py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, grad_input, grad_output);
458       py::object ret = elem.second(*hook_fn_args);
459       if (!py::isinstance<py::none>(ret)) {
460         MS_LOG(DEBUG) << "Get hook output " << ConvertPyObjToString(ret);
461         grad_output = UnpackRetValueOfCellHook(ret);
462       }
463       CheckHookConsistency(grad_output, py_args[args_size - 1], code_obj, co_name);
464       MS_LOG(DEBUG) << "Run cell hook function end.";
465     }
466     (void)hook_grad_.erase(cell_id);
467   } else {
468     // The first bprop_cut used to hook input gradient of cell.
469     MS_LOG(DEBUG) << "Get cell din " << ConvertPyObjToString(grad_output);
470     SyncData(grad_output);
471     hook_grad_[cell_id] = {backward_hook_fn_, grad_output};
472   }
473   if (!py::isinstance<py::tuple>(grad_output)) {
474     grad_output = py::make_tuple(grad_output);
475   }
476   return std::make_shared<PyObjectRef>(grad_output);
477 }
478 
RunVariableHookFunction(const py::tuple & py_args,bool is_tensor_hook) const479 BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args, bool is_tensor_hook) const {
480   py::tuple converted_args(py_args.size());
481   ConvertCTensorToPyTensor(py_args, &converted_args);
482   MS_LOG(DEBUG) << "Get convert args size " << converted_args.size() << ", args are "
483                 << ConvertPyObjToString(converted_args);
484   constexpr size_t grad_output_index = 2;
485   if (converted_args.size() != kSizeThree) {
486     MS_LOG(EXCEPTION) << "Bprop cut run must in the following format: input, output and dout";
487   }
488   py::object grad_output = converted_args[grad_output_index];
489   MS_LOG(DEBUG) << "Get grad output " << ConvertPyObjToString(grad_output);
490   for (const auto &elem : backward_hook_fn_) {
491     if (is_tensor_hook) {
492       MS_LOG(DEBUG) << "Run tensor hook function begin";
493       grad_output = elem.second(grad_output);
494       if (py::isinstance<py::none>(grad_output)) {
495         MS_EXCEPTION(ValueError) << "The bprop function output is None";
496       }
497       MS_LOG(DEBUG) << "Run tensor hook function end";
498     } else {
499       MS_LOG(DEBUG) << "Run hook function begin";
500       py::object code_obj = py::getattr(elem.second, "__code__");
501       py::object co_name = py::getattr(code_obj, "co_name");
502       if (std::string(py::str(co_name)) == "staging_specialize") {
503         py::object name_obj = py::getattr(elem.second, "__name__");
504         MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
505       }
506 
507       py::object ret = elem.second(py::make_tuple(grad_output));
508       if (!py::isinstance<py::none>(ret)) {
509         MS_LOG(DEBUG) << "Get hook output " << ConvertPyObjToString(ret);
510         grad_output = ret;
511       }
512       CheckHookConsistency(grad_output, py_args[grad_output_index], code_obj, co_name);
513       MS_LOG(DEBUG) << "Run hook function end";
514     }
515   }
516   grad_output = py::make_tuple(grad_output);
517   return std::make_shared<PyObjectRef>(grad_output);
518 }
519 
RunHookFunction(const VectorRef & args) const520 BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
521   py::tuple py_args = ConvertDatatoPyTuple(args);
522   MS_LOG(DEBUG) << "Get input args size " << py_args.size() << ", args are " << ConvertPyObjToString(py_args);
523   // For cell has custom bprop function
524   bool is_bprop = this->HasAttr(kBpropAttrName);
525   if (is_bprop) {
526     MS_LOG(DEBUG) << "Run cell custom bprop";
527     return RunCellCustomBpropFunction(py_args);
528   }
529 
530   // For cell register hook
531   bool is_cell = this->HasAttr(kCellHookAttrName);
532   if (is_cell) {
533     MS_LOG(DEBUG) << "Run cell hook";
534     return RunCellHookFunction(py_args);
535   }
536 
537   // For custom op, which define custrcut and bprop
538   bool is_custom_op_bprop = this->HasAttr(kCustomOpBpropAttrName);
539   if (is_custom_op_bprop) {
540     MS_LOG(DEBUG) << "Run custom op";
541     return RunCustomOpBpropFunction(py_args);
542   }
543 
544   // For hook use, include hook op and tensor register hook
545   return RunVariableHookFunction(py_args, this->HasAttr("tensor_hook"));
546 }
547 
GetComputeFunction() const548 py::function PrimitivePy::GetComputeFunction() const {
549   static const char *const compute_func_name = "vm_impl";
550 
551   if (py::hasattr(python_obj_, compute_func_name)) {
552     MS_LOG(DEBUG) << name() << " compute_func_name";
553     py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
554     return fn;
555   }
556 
557   static const std::string vm_module = "mindspore.ops.vm_impl_registry";
558   static const std::string get_vm_impl_fn = "get_vm_impl_fn";
559   MS_LOG(DEBUG) << name() << ": get_vm_impl_fn";
560   py::function get_fn = python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
561   py::function vm_fn = get_fn(python_obj_);
562   if (py::isinstance<py::none>(vm_fn)) {
563     vm_fn = get_fn(name());
564   }
565   if (py::isinstance<py::none>(vm_fn)) {
566     MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
567     vm_fn = mindspore::GetComputeFunction(Primitive::name());
568   }
569   return vm_fn;
570 }
571 
GetAttrDict()572 py::dict PrimitivePy::GetAttrDict() {
573   py::dict attr_dict;
574   for (auto &attr : attrs_) {
575     attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
576   }
577   return attr_dict;
578 }
579 
CopyHookFunction(const PrimitivePyPtr & primitive_py)580 void PrimitivePy::CopyHookFunction(const PrimitivePyPtr &primitive_py) {
581   MS_EXCEPTION_IF_NULL(primitive_py);
582   const auto &backward_hook_fn = primitive_py->backward_hook_fn();
583   for (const auto &elem : backward_hook_fn) {
584     AddBackwardHookFn(elem.first, elem.second);
585   }
586   if (primitive_py->HasAttr(kBpropAttrName)) {
587     set_bprop_cls_name(primitive_py->bprop_cls_name_);
588     (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
589   }
590 }
591 
RunComputeFunction(const VectorRef & args) const592 BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
593   auto py_args = ConvertDatatoPyTuple(args);
594   auto result = this->RunPyComputeFunction(py_args);
595   if (py::isinstance<py::none>(result)) {
596     return std::make_shared<BaseRef>(nullptr);
597   }
598   return std::make_shared<PyObjectRef>(result);
599 }
600 
RunPyComputeFunction(const py::tuple & py_args) const601 py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
602   auto func = this->GetComputeFunction();
603   if (py::isinstance<py::none>(func)) {
604     return py::none();
605   }
606   auto result = func(*py_args);
607   return result;
608 }
609 
HasComputeFunction() const610 bool PrimitivePy::HasComputeFunction() const {
611   auto func = GetComputeFunction();
612   return !py::isinstance<py::none>(func);
613 }
614 
Clone()615 PrimitivePtr PrimitivePy::Clone() {
616   auto clone_fn = python_obj_.attr("_clone");
617   py::object obj_adapter = clone_fn();
618   auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
619   auto prim = std::make_shared<PrimitivePy>(obj_adapter);
620   prim_adapter->set_attached_primitive(prim);
621   return prim;
622 }
623 
RunInfer(const py::tuple & args)624 py::dict PrimitivePy::RunInfer(const py::tuple &args) {
625   if (!HasPyObj()) {
626     MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
627   }
628   // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
629   if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER)) {
630     MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER;
631   }
632   auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
633   return infer_fuc(*args);
634 }
635 
RunCheck(const py::tuple & args)636 void PrimitivePy::RunCheck(const py::tuple &args) {
637   if (!HasPyObj()) {
638     MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
639   }
640   // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
641   if (!py::hasattr(python_obj_, PY_PRIM_METHOD_CHECK)) {
642     MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_CHECK;
643   }
644   auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
645   (void)check_func(*args);
646 }
647 
RunInferValue(const py::tuple & args)648 py::object PrimitivePy::RunInferValue(const py::tuple &args) {
649   if (!HasPyObj()) {
650     MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
651   }
652   // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
653   if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER_VALUE)) {
654     MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER_VALUE;
655   }
656   auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
657   return infer_value(*args);
658 }
659 
ProcessUnPairedCellHook(bool execute_hook_fn)660 void PrimitivePy::ProcessUnPairedCellHook(bool execute_hook_fn) {
661   if (execute_hook_fn) {
662     for (const auto &[cell_id, pair] : hook_grad_) {
663       const auto &hook_fn = pair.first;
664       const auto &grad_input = pair.second;
665       for (const auto &elem : hook_fn) {
666         SyncData(grad_input);
667         py::object grad_output = py::none();
668         py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, grad_input, grad_output);
669         (void)elem.second(*hook_fn_args);
670       }
671     }
672   }
673   hook_grad_.clear();
674 }
675 
ClearHookRes()676 void PrimitivePy::ClearHookRes() { hook_grad_.clear(); }
677 
PrimitivePyAdapter(const py::str & name)678 PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : id_(MakeId()), name_(name) {}
679 
PrimitivePyAdapter(const PrimitivePyAdapter & adapter)680 PrimitivePyAdapter::PrimitivePyAdapter(const PrimitivePyAdapter &adapter)
681     : const_prim_(adapter.const_prim_),
682       inplace_prim_(adapter.inplace_prim_),
683       backward_hook_fn_key_(adapter.backward_hook_fn_key_),
684       id_(adapter.id_),
685       name_(adapter.name_),
686       instance_name_(adapter.instance_name_),
687       prim_type_(adapter.prim_type_),
688       attrs_(adapter.attrs_),
689       const_input_indexes_(adapter.const_input_indexes_),
690       signatures_(adapter.signatures_),
691       backward_hook_fn_(adapter.backward_hook_fn_) {}
692 
operator =(const PrimitivePyAdapter & other)693 PrimitivePyAdapter &PrimitivePyAdapter::operator=(const PrimitivePyAdapter &other) {
694   if (this == &other) {
695     return *this;
696   }
697   const_prim_ = other.const_prim_;
698   inplace_prim_ = other.inplace_prim_;
699   backward_hook_fn_key_ = other.backward_hook_fn_key_;
700   id_ = other.id_;
701   name_ = other.name_;
702   instance_name_ = other.instance_name_;
703   prim_type_ = other.prim_type_;
704   attrs_ = other.attrs_;
705   const_input_indexes_ = other.const_input_indexes_;
706   signatures_ = other.signatures_;
707   backward_hook_fn_ = other.backward_hook_fn_;
708   return *this;
709 }
710 
AddPyAttr(const py::str & name,const py::object & obj)711 void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
712   std::string attr_name = name;
713   ValuePtr converted_res = nullptr;
714   if (py::isinstance<py::module>(obj)) {
715     MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
716                       << " not support py::module to be attribute value; primitive name: " << this->name_
717                       << ", attribute name: " << attr_name << " attribute value: " << py::str(obj);
718   }
719   bool converted = parse::ConvertData(obj, &converted_res);
720   if (!converted) {
721     MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
722                       << " convert python obj to MindSpore obj failed; primitive name: " << this->name_
723                       << ", attribute name:" << attr_name << ", attribute value:" << py::str(obj)
724                       << ", attribute type:" << py::cast<std::string>(obj.attr("__class__").attr("__name__"));
725   }
726   if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
727     attr_name = kOpAttrNameReplaceMap[attr_name];
728   }
729   (void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_res);
730   if (attr_name == "primitive_target") {
731     MS_EXCEPTION_IF_NULL(converted_res);
732     if (!converted_res->isa<StringImm>()) {
733       MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_
734                         << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
735                         << py::str(obj);
736     }
737     auto target = GetValue<std::string>(converted_res);
738     if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice &&
739         target != "Device") {
740       MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_
741                         << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
742                         << py::str(obj);
743     }
744   }
745 
746   // If it's func graph, to reserve all used func graphs.
747   if (converted_res->isa<FuncGraph>()) {
748     const auto &fg = dyn_cast<FuncGraph>(converted_res);
749     MS_EXCEPTION_IF_NULL(fg);
750     fg->set_reserved(true);
751     auto manager = Manage({fg}, false);
752     const auto &total_used_fg = manager->func_graphs_used_total(fg);
753     for (const auto &used_fg : total_used_fg) {
754       used_fg->set_reserved(true);
755     }
756   }
757 
758   attrs_[attr_name] = converted_res;
759   auto prim = attached_primitive_.lock();
760   if (prim != nullptr) {
761     (void)prim->AddAttr(attr_name, converted_res);
762   }
763 }
764 
DelPyAttr(const py::str & name)765 void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
766   (void)attrs_.erase(name);
767   auto prim = attached_primitive_.lock();
768   if (prim != nullptr) {
769     (void)prim->DelAttr(name);
770   }
771 }
772 
GetAttrDict()773 py::dict PrimitivePyAdapter::GetAttrDict() {
774   auto prim = attached_primitive_.lock();
775   if (prim != nullptr) {
776     return prim->GetAttrDict();
777   }
778 
779   py::dict attr_dict;
780   for (auto &attr : attrs_) {
781     attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
782   }
783   return attr_dict;
784 }
785 
set_prim_type(const PrimType t)786 void PrimitivePyAdapter::set_prim_type(const PrimType t) {
787   prim_type_ = t;
788   auto prim = attached_primitive_.lock();
789   if (prim != nullptr) {
790     prim->set_prim_type(t);
791   }
792 }
793 
set_const_prim(bool is_const_prim)794 void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
795   const_prim_ = is_const_prim;
796   auto prim = attached_primitive_.lock();
797   if (prim != nullptr) {
798     prim->set_const_prim(is_const_prim);
799   }
800 }
801 
set_inplace_prim(bool is_inplace_prim)802 void PrimitivePyAdapter::set_inplace_prim(bool is_inplace_prim) {
803   inplace_prim_ = is_inplace_prim;
804   auto prim = attached_primitive_.lock();
805   if (prim != nullptr) {
806     prim->set_inplace_prim(is_inplace_prim);
807   }
808 }
809 
set_const_input_indexes(const std::vector<size_t> & const_input_indexes)810 void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
811   const_input_indexes_ = const_input_indexes;
812   auto prim = attached_primitive_.lock();
813   if (prim != nullptr) {
814     prim->set_const_input_indexes(const_input_indexes);
815   }
816 }
817 
set_signatures(const std::vector<Signature> & signatures)818 void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
819   signatures_ = signatures;
820   auto prim = attached_primitive_.lock();
821   if (prim != nullptr) {
822     prim->set_signatures(signatures);
823   }
824 }
825 
AddBackwardHookFn(const py::function & backward_hook_fn)826 int PrimitivePyAdapter::AddBackwardHookFn(const py::function &backward_hook_fn) {
827   ++backward_hook_fn_key_;
828   backward_hook_fn_[backward_hook_fn_key_] = backward_hook_fn;
829   auto prim = attached_primitive_.lock();
830   if (prim != nullptr) {
831     prim->AddBackwardHookFn(backward_hook_fn_key_, backward_hook_fn);
832   }
833   return backward_hook_fn_key_;
834 }
835 
RemoveBackwardHookFn(int key)836 void PrimitivePyAdapter::RemoveBackwardHookFn(int key) {
837   const auto iter = backward_hook_fn_.find(key);
838   if (iter != backward_hook_fn_.end()) {
839     (void)backward_hook_fn_.erase(iter);
840   }
841   auto prim = attached_primitive_.lock();
842   if (prim != nullptr) {
843     prim->RemoveBackwardHookFn(key);
844   }
845 }
846 
set_instance_name(const std::string & s)847 void PrimitivePyAdapter::set_instance_name(const std::string &s) {
848   instance_name_ = s;
849   auto prim = attached_primitive_.lock();
850   if (prim != nullptr) {
851     prim->set_instance_name(s);
852   }
853 }
854 
set_attached_primitive(const PrimitivePyPtr & prim)855 void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
856   if (attached_primitive_.lock() != nullptr) {
857     MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
858   }
859   MS_EXCEPTION_IF_NULL(prim);
860   attached_primitive_ = prim;
861 }
862 
SetUserData(const py::str & key,const py::object & value)863 void PrimitivePyAdapter::SetUserData(const py::str &key, const py::object &value) {
864   const std::string name = std::string("__primitive_user_data_") + key.cast<std::string>();
865   const auto &primitive_data = std::make_shared<PrimitiveUserData>();
866   primitive_data->obj = value;
867   // Set into primitive adapter.
868   set_user_data<PrimitiveUserData>(name, primitive_data);
869   // Set in primitive.
870   auto prim = attached_primitive_.lock();
871   if (prim != nullptr) {
872     prim->set_user_data<PrimitiveUserData>(name, primitive_data);
873   }
874 }
875 
GetUserData(const py::str & key) const876 py::object PrimitivePyAdapter::GetUserData(const py::str &key) const {
877   const std::string name = std::string("__primitive_user_data_") + key.cast<std::string>();
878   // Get from primitive.
879   auto prim = attached_primitive_.lock();
880   if (prim != nullptr) {
881     const auto primitive_data = prim->user_data<PrimitiveUserData>(name);
882     return primitive_data->obj;
883   }
884   // Get from primtive adapter.
885   const auto primitive_data = user_data<PrimitiveUserData>(name);
886   return primitive_data->obj;
887 }
888 
set_label(const std::string & label,const py::object & value)889 void PrimitiveFunctionAdapter::set_label(const std::string &label, const py::object &value) {
890   ValuePtr converted_value = nullptr;
891   if (!parse::ConvertData(value, &converted_value)) {
892     MS_LOG(INTERNAL_EXCEPTION) << "For '" << PrimitiveFunctionAdapter::name() << "', Convert data failed.";
893   }
894   attached_primitive_function_->AddAttr(label, converted_value);
895 }
896 
clone()897 py::object PrimitiveFunctionAdapter::clone() {
898   const auto op_path = "mindspore.ops.primitive";
899   const auto func = "_create_primitive_function_obj";
900   py::object prim_func_adapter_obj = python_adapter::CallPyFn(op_path, func);
901   prim_func_adapter_obj.cast<PrimitiveFunctionAdapterPtr>()->set_attached_primitive_function(
902     attached_primitive_function_->Clone());
903   return prim_func_adapter_obj;
904 }
905 
RegPrimitive(const py::module * m)906 void RegPrimitive(const py::module *m) {
907   (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
908     .value("unknown", PrimType::kPrimTypeUnknown)
909     .value("builtin", PrimType::kPrimTypeBuiltIn)
910     .value("py_infer_shape", PrimType::kPrimTypePyInfer)
911     .value("user_custom", PrimType::kPrimTypeUserCustom)
912     .value("py_infer_check", PrimType::kPrimTypePyCheck);
913   (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
914     .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
915     .def(py::init<py::str &>())
916     .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
917     .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
918     .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
919     .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
920     .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
921     .def("set_inplace_prim", &PrimitivePyAdapter::set_inplace_prim, "Set primitive is inplace primitive.")
922     .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes, "Set primitive const input indexes.")
923     .def("set_signatures", &PrimitivePyAdapter::set_signatures, "Set primitive inputs signature.")
924     .def("add_backward_hook_fn", &PrimitivePyAdapter::AddBackwardHookFn, "Add primitive backward hook function.")
925     .def("remove_backward_hook_fn", &PrimitivePyAdapter::RemoveBackwardHookFn,
926          "Remove primitive backward hook function.")
927     .def("set_instance_name", &PrimitivePyAdapter::set_instance_name, "Set primitive instance name.")
928     .def("set_user_data", &PrimitivePyAdapter::SetUserData, "Set primitive user data.")
929     .def("get_user_data", &PrimitivePyAdapter::GetUserData, "Get primitive user data.");
930 }
931 
RegPrimitiveFunction(const py::module * m)932 void RegPrimitiveFunction(const py::module *m) {
933   (void)py::class_<PrimitiveFunctionAdapter, std::shared_ptr<PrimitiveFunctionAdapter>>(*m, "PrimitiveFunction_")
934     .def_readonly(PYTHON_PRIMITIVE_FUNCTION_FLAG, &PrimitiveFunctionAdapter::parse_info_)
935     .def(py::init<>())
936     .def_property_readonly("name", &PrimitiveFunctionAdapter::name, "Get function name.")
937     .def("has_label", &PrimitiveFunctionAdapter::has_label, "Has function attr.")
938     .def("set_label", &PrimitiveFunctionAdapter::set_label, "Set function attr.")
939     .def("get_label", &PrimitiveFunctionAdapter::get_label, "Get function attr.")
940     .def("clone", &PrimitiveFunctionAdapter::clone, "Clone a Primitive and create a PrimitiveFunctionAdapter.");
941 }
942 }  // namespace mindspore
943