1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pybind_api/ir/primitive_py.h"
18
19 #include <map>
20 #include "ir/signature.h"
21 #include "pipeline/jit/ps/parse/data_converter.h"
22 #include "include/common/utils/python_adapter.h"
23 #include "pybind11/pytypes.h"
24 #include "include/common/pybind_api/api_register.h"
25 #include "pybind_api/export_flags.h"
26 #include "pybind_api/ir/base_ref_py.h"
27 #include "utils/convert_utils_base.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "utils/ms_context.h"
30 #include "include/common/utils/primitive_utils.h"
31 #include "utils/check_convert_utils.h"
32 #include "pipeline/pynative/pynative_execute.h"
33 #include "include/common/profiler.h"
34
35 namespace mindspore {
36 namespace {
37 constexpr auto kBpropAttrName = "bprop";
38 constexpr auto kCellHookAttrName = "cell_hook";
39 constexpr auto kCellIDAttrName = "cell_id";
40 constexpr auto kCustomOpBpropAttrName = "custom_op_bprop";
41 constexpr auto kIsRecomputeAttr = "is_recompute";
42
MakeId()43 static uint64_t MakeId() {
44 // Use atomic to make id generator thread safe.
45 static std::atomic<uint64_t> last_id{1};
46 return last_id.fetch_add(1, std::memory_order_relaxed);
47 }
48 std::map<std::string, std::string> kOpAttrNameReplaceMap = {
49 {"data_format", "format"},
50 };
51
SyncData(const py::object & arg)52 void SyncData(const py::object &arg) {
53 if (py::isinstance<py::tuple>(arg)) {
54 py::tuple arg_list = py::cast<py::tuple>(arg);
55 for (size_t i = 0; i < arg_list.size(); i++) {
56 SyncData(arg_list[i]);
57 }
58 }
59 if (py::isinstance<tensor::Tensor>(arg)) {
60 auto tensor = py::cast<tensor::TensorPtr>(arg);
61 tensor->data_sync();
62 }
63 if (IsStubTensor(arg)) {
64 auto tensor = ConvertStubTensor(arg);
65 tensor->data_sync();
66 }
67 }
68
ConstructCellHookFnArgs(const std::string & cell_id,const py::object & grad_input,const py::object & grad_output)69 py::tuple ConstructCellHookFnArgs(const std::string &cell_id, const py::object &grad_input,
70 const py::object &grad_output) {
71 constexpr size_t grad_input_index = 1;
72 constexpr size_t grad_output_index = 2;
73 constexpr size_t input_args_nums = 3;
74 // Convert c++ object to python object.
75 py::tuple c_grad_args(input_args_nums - 1);
76 c_grad_args[0] = grad_input;
77 c_grad_args[1] = grad_output;
78 py::tuple py_grad_args(input_args_nums - 1);
79 ConvertCTensorToPyTensor(c_grad_args, &py_grad_args);
80 // Get tuple args of cell hook function.
81 py::tuple hook_fn_args(input_args_nums);
82 hook_fn_args[0] = cell_id;
83 // Set grad in
84 if (!py::isinstance<py::tuple>(py_grad_args[0])) {
85 hook_fn_args[grad_input_index] = py::make_tuple(py_grad_args[0]);
86 } else {
87 hook_fn_args[grad_input_index] = py_grad_args[0];
88 }
89 // Set grad out
90 if (!py::isinstance<py::tuple>(py_grad_args[1])) {
91 hook_fn_args[grad_output_index] = py::make_tuple(py_grad_args[1]);
92 } else {
93 hook_fn_args[grad_output_index] = py_grad_args[1];
94 }
95 return hook_fn_args;
96 }
97
ContainsWeights(const py::tuple & grads)98 bool ContainsWeights(const py::tuple &grads) {
99 if (grads.size() < kSizeTwo) {
100 return false;
101 }
102 if (!py::isinstance<py::tuple>(grads[0]) && !py::isinstance<py::dict>(grads[1])) {
103 return false;
104 }
105 return true;
106 }
107
108 struct RunPrimitivePyHookFunctionRegister {
RunPrimitivePyHookFunctionRegistermindspore::__anond61628c30111::RunPrimitivePyHookFunctionRegister109 RunPrimitivePyHookFunctionRegister() {
110 python_adapter::PyAdapterCallback::SetRunPrimitivePyHookFunctionHandler(
111 [](const PrimitivePtr &prim, const VectorRef &args) -> BaseRef {
112 auto py_prim = prim->cast<PrimitivePyPtr>();
113 MS_EXCEPTION_IF_NULL(py_prim);
114 return py_prim->RunHookFunction(args);
115 });
116 }
117 } callback_register;
118 struct ProcessUnPairedCellHookRegister {
ProcessUnPairedCellHookRegistermindspore::__anond61628c30111::ProcessUnPairedCellHookRegister119 ProcessUnPairedCellHookRegister() {
120 python_adapter::PyAdapterCallback::SetProcessUnPairedCellHookHandler(
121 [](bool execute_hook_fn) -> void { PrimitivePy::ProcessUnPairedCellHook(execute_hook_fn); });
122 }
123 } cell_hook_callback_register;
124 } // namespace
125 std::map<std::string, std::pair<std::map<int, py::function>, py::object>> PrimitivePy::hook_grad_;
126
PrimitivePy(const std::string & name)127 PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
128
PrimitivePy(const PrimitivePy & prim_py)129 PrimitivePy::PrimitivePy(const PrimitivePy &prim_py)
130 : Primitive(prim_py),
131 python_obj_(prim_py.python_obj_),
132 bprop_cls_name_(prim_py.bprop_cls_name_),
133 adapter_(prim_py.adapter_),
134 signatures_(prim_py.signatures_),
135 bprop_cut_prims_(prim_py.bprop_cut_prims_),
136 backward_hook_fn_(prim_py.backward_hook_fn_) {}
137
operator =(const PrimitivePy & other)138 PrimitivePy &PrimitivePy::operator=(const PrimitivePy &other) {
139 if (this == &other) {
140 return *this;
141 }
142 Primitive::operator=(other);
143 python_obj_ = other.python_obj_;
144 bprop_cls_name_ = other.bprop_cls_name_;
145 adapter_ = other.adapter_;
146 signatures_ = other.signatures_;
147 bprop_cut_prims_ = other.bprop_cut_prims_;
148 backward_hook_fn_ = other.backward_hook_fn_;
149 return *this;
150 }
151
PrimitivePy(const py::object & python_obj)152 PrimitivePy::PrimitivePy(const py::object &python_obj)
153 : Primitive(python_obj.cast<PrimitivePyAdapterPtr>()->name_, false),
154 python_obj_(python_obj),
155 adapter_(python_obj.cast<PrimitivePyAdapterPtr>()) {
156 MS_LOG(DEBUG) << "New primitive:" << adapter_->name_;
157 set_signatures(adapter_->signatures_);
158 (void)Primitive::SetAttrs(adapter_->attrs_);
159 Primitive::set_prim_type(adapter_->prim_type_);
160 Primitive::set_const_prim(adapter_->const_prim_);
161 Primitive::set_inplace_prim(adapter_->inplace_prim_);
162 Primitive::set_const_input_indexes(adapter_->const_input_indexes_);
163 for (const auto &elem : adapter_->backward_hook_fn_) {
164 AddBackwardHookFn(elem.first, elem.second);
165 }
166 set_instance_name(adapter_->instance_name_);
167 CloneUserData(adapter_->user_data_);
168 }
169
~PrimitivePy()170 PrimitivePy::~PrimitivePy() {
171 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kDefault, name(),
172 false);
173 py::gil_scoped_acquire acquire_gil;
174 python_obj_ = py::none();
175 backward_hook_fn_.clear();
176 }
177
GetVmapRuleFunction(const bool,int axis_size)178 py::function PrimitivePy::GetVmapRuleFunction(const bool, int axis_size) {
179 constexpr char get_vmap_rule_func_name[] = "get_vmap_rule";
180 if (py::hasattr(python_obj_, get_vmap_rule_func_name)) {
181 return python_obj_.attr(get_vmap_rule_func_name)().cast<py::function>();
182 }
183 return GetVmapRuleFunctionByObj(python_obj_, axis_size);
184 }
185
GetBpropFunction()186 py::function PrimitivePy::GetBpropFunction() {
187 static const char *const get_bprop_func_name = "get_bprop";
188 if (py::hasattr(python_obj_, get_bprop_func_name)) {
189 py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
190 return fn;
191 }
192
193 auto fn = GetBpropFunctionByObj(python_obj_);
194 return fn;
195 }
196
GetTaylorRuleFunction()197 py::function PrimitivePy::GetTaylorRuleFunction() {
198 static const char *const get_taylor_rule_func_name = "get_taylor_rule";
199 if (py::hasattr(python_obj_, get_taylor_rule_func_name)) {
200 py::function fn = python_obj_.attr(get_taylor_rule_func_name)().cast<py::function>();
201 return fn;
202 }
203 auto fn = GetTaylorRuleFunctionByObj(python_obj_);
204 return fn;
205 }
206
check_bprop_input_grads(const py::tuple & py_args,const py::tuple & grads,const std::string & bprop_cls_name,int filter_args_size)207 void check_bprop_input_grads(const py::tuple &py_args, const py::tuple &grads, const std::string &bprop_cls_name,
208 int filter_args_size) {
209 if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
210 return;
211 }
212 if (grads.size() != py_args.size() - filter_args_size) {
213 MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name
214 << "', the number of return values(gradients) should be equal to the number of input "
215 "arguments except 'out' and 'dout', which is: "
216 << (py_args.size() - filter_args_size) << ", but got:" << grads.size() << ".";
217 }
218 for (size_t i = 0; i < grads.size(); i++) {
219 if (py::isinstance<tensor::Tensor>(py_args[i]) || IsStubTensor(py_args[i])) {
220 if (!py::isinstance<tensor::Tensor>(grads[i]) && !IsStubTensor(grads[i])) {
221 MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
222 << "th return value(gradient of the " << i << "th argument) should be Tensor, but got "
223 << py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
224 << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
225 }
226
227 py::object arg_dtype = py_args[i].attr("dtype");
228 py::object grad_dtype = grads[i].attr("dtype");
229 py::tuple arg_shape = py_args[i].attr("shape");
230 py::tuple grad_shape = grads[i].attr("shape");
231 if (!grad_dtype.equal(arg_dtype)) {
232 MS_EXCEPTION(TypeError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
233 << "th return value(gradient of the " << i
234 << "th argument) should have the same dtype as the " << i
235 << "th argument, which is:" << py::cast<py::str>(arg_dtype)
236 << ", but got: " << py::cast<py::str>(grad_dtype) << ".";
237 }
238 if (!grad_shape.equal(arg_shape)) {
239 MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
240 << "th return value(gradient of the " << i
241 << "th argument) should have the same shape as the " << i
242 << "th argument, which is:" << py::cast<py::str>(arg_shape)
243 << ", but got: " << py::cast<py::str>(grad_shape) << ".";
244 }
245 }
246 }
247 }
248
check_bprop_out(const py::object & grads_obj,const py::tuple & py_args,const std::string & bprop_cls_name)249 py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
250 py::tuple grads;
251 if (py::isinstance<py::none>(grads_obj)) {
252 MS_EXCEPTION(TypeError) << "The python function output is none.";
253 } else if (!py::isinstance<py::tuple>(grads_obj)) {
254 MS_LOG(DEBUG) << "Wrap a tuple";
255 grads = py::make_tuple(grads_obj);
256 } else {
257 grads = py::cast<py::tuple>(grads_obj);
258 }
259 if (ContainsWeights(grads)) {
260 MS_LOG(DEBUG) << "Contain weights";
261 py::tuple input_grads = py::cast<py::tuple>(grads[0]);
262 py::dict weight_grads = py::cast<py::dict>(grads[1]);
263 check_bprop_input_grads(py_args, input_grads, bprop_cls_name, weight_grads.size() + 1);
264 if (weight_grads.empty()) {
265 return input_grads;
266 }
267 py::tuple all_grads(input_grads.size() + weight_grads.size());
268 for (size_t i = 0; i < input_grads.size(); ++i) {
269 all_grads[i] = input_grads[i];
270 }
271 size_t i = 0;
272 for (auto weight_grad : weight_grads) {
273 all_grads[i + input_grads.size()] = weight_grad.second;
274 ++i;
275 }
276 return all_grads;
277 } else {
278 MS_LOG(DEBUG) << "Not contain weights";
279 check_bprop_input_grads(py_args, grads, bprop_cls_name, kSizeTwo);
280 return grads;
281 }
282 }
283
AddBpropCutPrim(const PrimitivePyPtr & bprop_cut_prim)284 void PrimitivePy::AddBpropCutPrim(const PrimitivePyPtr &bprop_cut_prim) {
285 MS_EXCEPTION_IF_NULL(bprop_cut_prim);
286 (void)bprop_cut_prims_.emplace_back(bprop_cut_prim);
287 }
288
AddBackwardHookFn(const int & key,const py::function & backward_hook_fn)289 void PrimitivePy::AddBackwardHookFn(const int &key, const py::function &backward_hook_fn) {
290 backward_hook_fn_[key] = backward_hook_fn;
291 for (const auto &elem : bprop_cut_prims_) {
292 PrimitivePyPtr bprop_cut_prim = elem.lock();
293 if (bprop_cut_prim != nullptr) {
294 bprop_cut_prim->AddBackwardHookFn(key, backward_hook_fn);
295 }
296 }
297 }
298
RemoveBackwardHookFn(const int & key)299 void PrimitivePy::RemoveBackwardHookFn(const int &key) {
300 auto iter = backward_hook_fn_.find(key);
301 if (iter != backward_hook_fn_.end()) {
302 (void)backward_hook_fn_.erase(key);
303 }
304 // Remove hook_fn for bprop cut prim on grad graph.
305 for (const auto &elem : bprop_cut_prims_) {
306 PrimitivePyPtr bprop_cut_prim = elem.lock();
307 if (bprop_cut_prim != nullptr) {
308 bprop_cut_prim->RemoveBackwardHookFn(key);
309 }
310 }
311 }
312
UnpackRetValueOfCellHook(const py::object & grad_out) const313 py::object PrimitivePy::UnpackRetValueOfCellHook(const py::object &grad_out) const {
314 if (!py::isinstance<py::tuple>(grad_out)) {
315 hook_grad_.clear();
316 MS_EXCEPTION(TypeError) << "The return gradient of cell backward hook function should be a tuple!";
317 }
318 auto out_tuple = py::cast<py::tuple>(grad_out);
319 if (out_tuple.size() == 1) {
320 // The input number of current cell is 1.
321 return out_tuple[0];
322 }
323 return grad_out;
324 }
325
CheckHookConsistency(const py::object & grad_out,const py::object & expected_grad_out,const py::object & code_obj,const py::object & co_name) const326 void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out,
327 const py::object &code_obj, const py::object &co_name) const {
328 if (py::isinstance<py::tuple>(expected_grad_out)) {
329 if (!py::isinstance<py::tuple>(grad_out)) {
330 hook_grad_.clear();
331 MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
332 }
333 auto actual_out_tuple = py::cast<py::tuple>(grad_out);
334 auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
335 if (actual_out_tuple.size() != expected_out_tuple.size()) {
336 hook_grad_.clear();
337 MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
338 << ", but it is " << actual_out_tuple.size();
339 }
340 for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
341 CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i], code_obj, co_name);
342 }
343 }
344
345 if (py::isinstance<tensor::Tensor>(expected_grad_out) || IsStubTensor(expected_grad_out)) {
346 if (!py::isinstance<tensor::Tensor>(grad_out) && !IsStubTensor(grad_out)) {
347 hook_grad_.clear();
348 MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
349 << py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
350 }
351 tensor::TensorPtr actual_out_tensor =
352 IsStubTensor(grad_out) ? ConvertStubTensor(grad_out) : py::cast<tensor::TensorPtr>(grad_out);
353 tensor::TensorPtr expected_out_tensor = IsStubTensor(expected_grad_out)
354 ? ConvertStubTensor(expected_grad_out)
355 : py::cast<tensor::TensorPtr>(expected_grad_out);
356 MS_EXCEPTION_IF_NULL(actual_out_tensor);
357 MS_EXCEPTION_IF_NULL(expected_out_tensor);
358 if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
359 hook_grad_.clear();
360 MS_EXCEPTION(ValueError) << "The output type of " << py::str(co_name)
361 << " is not consistent with the expected, it should be "
362 << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but got "
363 << actual_out_tensor->GetShapeAndDataTypeInfo();
364 }
365 }
366 }
367
RunCellCustomBpropFunction(const py::tuple & py_args) const368 BaseRef PrimitivePy::RunCellCustomBpropFunction(const py::tuple &py_args) const {
369 if (backward_hook_fn_.size() > 1) {
370 MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
371 }
372 py::tuple converted_args(py_args.size());
373 ConvertCTensorToPyTensor(py_args, &converted_args);
374 MS_LOG(DEBUG) << "Get convert args size " << converted_args.size() << ", args are "
375 << ConvertPyObjToString(converted_args);
376 // If recompute, just discard dout; Otherwise, discat out and dout
377 bool is_recompute = HasAttr(kIsRecomputeAttr);
378 size_t non_inp_args_size = is_recompute ? kSizeOne : kSizeTwo;
379
380 auto inp_args_size = py_args.size() - non_inp_args_size;
381 py::tuple input_args(inp_args_size);
382 for (size_t i = 0; i < inp_args_size; ++i) {
383 input_args[i] = py_args[i];
384 }
385 MS_LOG(DEBUG) << "Get cell input arg size " << inp_args_size;
386 // Run bprop function.
387 auto inst = pynative::PyNativeExecutor::GetInstance();
388 MS_EXCEPTION_IF_NULL(inst);
389 try {
390 MS_LOG(DEBUG) << "Run cell custom bprop function start.";
391 py::tuple grads;
392 MS_LOG(DEBUG) << "Get num of backward hook fn is " << backward_hook_fn_.size();
393 for (const auto &elem : backward_hook_fn_) {
394 inst->NewGraph(elem.second, input_args.cast<py::args>());
395 py::object grads_obj = elem.second(*converted_args);
396 MS_LOG(DEBUG) << "Get cell hook output " << ConvertPyObjToString(grads_obj);
397 grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
398 py::object out = grads_obj;
399 // If grads.size() > inp_args_size, that means exist weights.
400 if (grads.size() > inp_args_size) {
401 MS_LOG(DEBUG) << "Get grads size " << grads.size();
402 out = py::cast<py::tuple>(grads_obj)[0];
403 }
404 inst->EndGraph(elem.second, out, input_args.cast<py::args>());
405 }
406 MS_LOG(DEBUG) << "Run cell custom bprop function end.";
407 return std::make_shared<PyObjectRef>(grads);
408 } catch (std::exception &bt) {
409 inst->ClearRes();
410 std::rethrow_exception(std::current_exception());
411 }
412 }
413
RunCustomOpBpropFunction(const py::tuple & py_args) const414 BaseRef PrimitivePy::RunCustomOpBpropFunction(const py::tuple &py_args) const {
415 if (backward_hook_fn_.size() > 1) {
416 MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
417 }
418 py::tuple grads;
419 SyncData(py_args);
420 py::tuple converted_args(py_args.size());
421 ConvertCTensorToPyTensor(py_args, &converted_args);
422 MS_LOG(DEBUG) << "Get convert args size " << converted_args.size() << ", args are "
423 << ConvertPyObjToString(converted_args);
424 try {
425 MS_LOG(DEBUG) << "start execute custom op bprop";
426 for (const auto &elem : backward_hook_fn_) {
427 py::object grads_obj = elem.second(*converted_args);
428 grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
429 }
430 MS_LOG(DEBUG) << "end execute custom op bprop";
431 return std::make_shared<PyObjectRef>(grads);
432 } catch (std::exception &bt) {
433 std::rethrow_exception(std::current_exception());
434 }
435 }
436
RunCellHookFunction(const py::tuple & py_args) const437 BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
438 const auto args_size = py_args.size();
439 // Get the din passed to current bprop cut op.
440 py::object grad_output = py_args[args_size - 1];
441 // Get the cell id.
442 auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
443 auto iter = hook_grad_.find(cell_id);
444 if (iter != hook_grad_.end()) {
445 // The second bprop_cut used to hook output gradient of cell.
446 for (const auto &elem : backward_hook_fn_) {
447 MS_LOG(DEBUG) << "Run cell hook function start.";
448 py::object code_obj = py::getattr(elem.second, "__code__");
449 py::object co_name = py::getattr(code_obj, "co_name");
450 if (std::string(py::str(co_name)) == "staging_specialize") {
451 py::object name_obj = py::getattr(elem.second, "__name__");
452 MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
453 }
454 MS_LOG(DEBUG) << "Get cell dout " << ConvertPyObjToString(grad_output);
455 SyncData(grad_output);
456 const py::object grad_input = iter->second.second;
457 py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, grad_input, grad_output);
458 py::object ret = elem.second(*hook_fn_args);
459 if (!py::isinstance<py::none>(ret)) {
460 MS_LOG(DEBUG) << "Get hook output " << ConvertPyObjToString(ret);
461 grad_output = UnpackRetValueOfCellHook(ret);
462 }
463 CheckHookConsistency(grad_output, py_args[args_size - 1], code_obj, co_name);
464 MS_LOG(DEBUG) << "Run cell hook function end.";
465 }
466 (void)hook_grad_.erase(cell_id);
467 } else {
468 // The first bprop_cut used to hook input gradient of cell.
469 MS_LOG(DEBUG) << "Get cell din " << ConvertPyObjToString(grad_output);
470 SyncData(grad_output);
471 hook_grad_[cell_id] = {backward_hook_fn_, grad_output};
472 }
473 if (!py::isinstance<py::tuple>(grad_output)) {
474 grad_output = py::make_tuple(grad_output);
475 }
476 return std::make_shared<PyObjectRef>(grad_output);
477 }
478
RunVariableHookFunction(const py::tuple & py_args,bool is_tensor_hook) const479 BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args, bool is_tensor_hook) const {
480 py::tuple converted_args(py_args.size());
481 ConvertCTensorToPyTensor(py_args, &converted_args);
482 MS_LOG(DEBUG) << "Get convert args size " << converted_args.size() << ", args are "
483 << ConvertPyObjToString(converted_args);
484 constexpr size_t grad_output_index = 2;
485 if (converted_args.size() != kSizeThree) {
486 MS_LOG(EXCEPTION) << "Bprop cut run must in the following format: input, output and dout";
487 }
488 py::object grad_output = converted_args[grad_output_index];
489 MS_LOG(DEBUG) << "Get grad output " << ConvertPyObjToString(grad_output);
490 for (const auto &elem : backward_hook_fn_) {
491 if (is_tensor_hook) {
492 MS_LOG(DEBUG) << "Run tensor hook function begin";
493 grad_output = elem.second(grad_output);
494 if (py::isinstance<py::none>(grad_output)) {
495 MS_EXCEPTION(ValueError) << "The bprop function output is None";
496 }
497 MS_LOG(DEBUG) << "Run tensor hook function end";
498 } else {
499 MS_LOG(DEBUG) << "Run hook function begin";
500 py::object code_obj = py::getattr(elem.second, "__code__");
501 py::object co_name = py::getattr(code_obj, "co_name");
502 if (std::string(py::str(co_name)) == "staging_specialize") {
503 py::object name_obj = py::getattr(elem.second, "__name__");
504 MS_LOG(EXCEPTION) << "Decorating hook function " << py::str(name_obj) << " with '@jit' is not supported.";
505 }
506
507 py::object ret = elem.second(py::make_tuple(grad_output));
508 if (!py::isinstance<py::none>(ret)) {
509 MS_LOG(DEBUG) << "Get hook output " << ConvertPyObjToString(ret);
510 grad_output = ret;
511 }
512 CheckHookConsistency(grad_output, py_args[grad_output_index], code_obj, co_name);
513 MS_LOG(DEBUG) << "Run hook function end";
514 }
515 }
516 grad_output = py::make_tuple(grad_output);
517 return std::make_shared<PyObjectRef>(grad_output);
518 }
519
RunHookFunction(const VectorRef & args) const520 BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
521 py::tuple py_args = ConvertDatatoPyTuple(args);
522 MS_LOG(DEBUG) << "Get input args size " << py_args.size() << ", args are " << ConvertPyObjToString(py_args);
523 // For cell has custom bprop function
524 bool is_bprop = this->HasAttr(kBpropAttrName);
525 if (is_bprop) {
526 MS_LOG(DEBUG) << "Run cell custom bprop";
527 return RunCellCustomBpropFunction(py_args);
528 }
529
530 // For cell register hook
531 bool is_cell = this->HasAttr(kCellHookAttrName);
532 if (is_cell) {
533 MS_LOG(DEBUG) << "Run cell hook";
534 return RunCellHookFunction(py_args);
535 }
536
537 // For custom op, which define custrcut and bprop
538 bool is_custom_op_bprop = this->HasAttr(kCustomOpBpropAttrName);
539 if (is_custom_op_bprop) {
540 MS_LOG(DEBUG) << "Run custom op";
541 return RunCustomOpBpropFunction(py_args);
542 }
543
544 // For hook use, include hook op and tensor register hook
545 return RunVariableHookFunction(py_args, this->HasAttr("tensor_hook"));
546 }
547
GetComputeFunction() const548 py::function PrimitivePy::GetComputeFunction() const {
549 static const char *const compute_func_name = "vm_impl";
550
551 if (py::hasattr(python_obj_, compute_func_name)) {
552 MS_LOG(DEBUG) << name() << " compute_func_name";
553 py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
554 return fn;
555 }
556
557 static const std::string vm_module = "mindspore.ops.vm_impl_registry";
558 static const std::string get_vm_impl_fn = "get_vm_impl_fn";
559 MS_LOG(DEBUG) << name() << ": get_vm_impl_fn";
560 py::function get_fn = python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
561 py::function vm_fn = get_fn(python_obj_);
562 if (py::isinstance<py::none>(vm_fn)) {
563 vm_fn = get_fn(name());
564 }
565 if (py::isinstance<py::none>(vm_fn)) {
566 MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
567 vm_fn = mindspore::GetComputeFunction(Primitive::name());
568 }
569 return vm_fn;
570 }
571
GetAttrDict()572 py::dict PrimitivePy::GetAttrDict() {
573 py::dict attr_dict;
574 for (auto &attr : attrs_) {
575 attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
576 }
577 return attr_dict;
578 }
579
CopyHookFunction(const PrimitivePyPtr & primitive_py)580 void PrimitivePy::CopyHookFunction(const PrimitivePyPtr &primitive_py) {
581 MS_EXCEPTION_IF_NULL(primitive_py);
582 const auto &backward_hook_fn = primitive_py->backward_hook_fn();
583 for (const auto &elem : backward_hook_fn) {
584 AddBackwardHookFn(elem.first, elem.second);
585 }
586 if (primitive_py->HasAttr(kBpropAttrName)) {
587 set_bprop_cls_name(primitive_py->bprop_cls_name_);
588 (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
589 }
590 }
591
RunComputeFunction(const VectorRef & args) const592 BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
593 auto py_args = ConvertDatatoPyTuple(args);
594 auto result = this->RunPyComputeFunction(py_args);
595 if (py::isinstance<py::none>(result)) {
596 return std::make_shared<BaseRef>(nullptr);
597 }
598 return std::make_shared<PyObjectRef>(result);
599 }
600
RunPyComputeFunction(const py::tuple & py_args) const601 py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
602 auto func = this->GetComputeFunction();
603 if (py::isinstance<py::none>(func)) {
604 return py::none();
605 }
606 auto result = func(*py_args);
607 return result;
608 }
609
HasComputeFunction() const610 bool PrimitivePy::HasComputeFunction() const {
611 auto func = GetComputeFunction();
612 return !py::isinstance<py::none>(func);
613 }
614
Clone()615 PrimitivePtr PrimitivePy::Clone() {
616 auto clone_fn = python_obj_.attr("_clone");
617 py::object obj_adapter = clone_fn();
618 auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
619 auto prim = std::make_shared<PrimitivePy>(obj_adapter);
620 prim_adapter->set_attached_primitive(prim);
621 return prim;
622 }
623
RunInfer(const py::tuple & args)624 py::dict PrimitivePy::RunInfer(const py::tuple &args) {
625 if (!HasPyObj()) {
626 MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
627 }
628 // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
629 if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER)) {
630 MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER;
631 }
632 auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
633 return infer_fuc(*args);
634 }
635
RunCheck(const py::tuple & args)636 void PrimitivePy::RunCheck(const py::tuple &args) {
637 if (!HasPyObj()) {
638 MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
639 }
640 // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
641 if (!py::hasattr(python_obj_, PY_PRIM_METHOD_CHECK)) {
642 MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_CHECK;
643 }
644 auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
645 (void)check_func(*args);
646 }
647
RunInferValue(const py::tuple & args)648 py::object PrimitivePy::RunInferValue(const py::tuple &args) {
649 if (!HasPyObj()) {
650 MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
651 }
652 // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
653 if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER_VALUE)) {
654 MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER_VALUE;
655 }
656 auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
657 return infer_value(*args);
658 }
659
ProcessUnPairedCellHook(bool execute_hook_fn)660 void PrimitivePy::ProcessUnPairedCellHook(bool execute_hook_fn) {
661 if (execute_hook_fn) {
662 for (const auto &[cell_id, pair] : hook_grad_) {
663 const auto &hook_fn = pair.first;
664 const auto &grad_input = pair.second;
665 for (const auto &elem : hook_fn) {
666 SyncData(grad_input);
667 py::object grad_output = py::none();
668 py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, grad_input, grad_output);
669 (void)elem.second(*hook_fn_args);
670 }
671 }
672 }
673 hook_grad_.clear();
674 }
675
ClearHookRes()676 void PrimitivePy::ClearHookRes() { hook_grad_.clear(); }
677
PrimitivePyAdapter(const py::str & name)678 PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : id_(MakeId()), name_(name) {}
679
PrimitivePyAdapter(const PrimitivePyAdapter & adapter)680 PrimitivePyAdapter::PrimitivePyAdapter(const PrimitivePyAdapter &adapter)
681 : const_prim_(adapter.const_prim_),
682 inplace_prim_(adapter.inplace_prim_),
683 backward_hook_fn_key_(adapter.backward_hook_fn_key_),
684 id_(adapter.id_),
685 name_(adapter.name_),
686 instance_name_(adapter.instance_name_),
687 prim_type_(adapter.prim_type_),
688 attrs_(adapter.attrs_),
689 const_input_indexes_(adapter.const_input_indexes_),
690 signatures_(adapter.signatures_),
691 backward_hook_fn_(adapter.backward_hook_fn_) {}
692
operator =(const PrimitivePyAdapter & other)693 PrimitivePyAdapter &PrimitivePyAdapter::operator=(const PrimitivePyAdapter &other) {
694 if (this == &other) {
695 return *this;
696 }
697 const_prim_ = other.const_prim_;
698 inplace_prim_ = other.inplace_prim_;
699 backward_hook_fn_key_ = other.backward_hook_fn_key_;
700 id_ = other.id_;
701 name_ = other.name_;
702 instance_name_ = other.instance_name_;
703 prim_type_ = other.prim_type_;
704 attrs_ = other.attrs_;
705 const_input_indexes_ = other.const_input_indexes_;
706 signatures_ = other.signatures_;
707 backward_hook_fn_ = other.backward_hook_fn_;
708 return *this;
709 }
710
AddPyAttr(const py::str & name,const py::object & obj)711 void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
712 std::string attr_name = name;
713 ValuePtr converted_res = nullptr;
714 if (py::isinstance<py::module>(obj)) {
715 MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
716 << " not support py::module to be attribute value; primitive name: " << this->name_
717 << ", attribute name: " << attr_name << " attribute value: " << py::str(obj);
718 }
719 bool converted = parse::ConvertData(obj, &converted_res);
720 if (!converted) {
721 MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
722 << " convert python obj to MindSpore obj failed; primitive name: " << this->name_
723 << ", attribute name:" << attr_name << ", attribute value:" << py::str(obj)
724 << ", attribute type:" << py::cast<std::string>(obj.attr("__class__").attr("__name__"));
725 }
726 if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
727 attr_name = kOpAttrNameReplaceMap[attr_name];
728 }
729 (void)CheckAndConvertUtils::ConvertAttrValueToInt(this->name_, name, &converted_res);
730 if (attr_name == "primitive_target") {
731 MS_EXCEPTION_IF_NULL(converted_res);
732 if (!converted_res->isa<StringImm>()) {
733 MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_
734 << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
735 << py::str(obj);
736 }
737 auto target = GetValue<std::string>(converted_res);
738 if (!target.empty() && target != kCPUDevice && target != kGPUDevice && target != kAscendDevice &&
739 target != "Device") {
740 MS_LOG(EXCEPTION) << "Call 'add_attr' to add attribute to primitive '" << this->name_
741 << "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
742 << py::str(obj);
743 }
744 }
745
746 // If it's func graph, to reserve all used func graphs.
747 if (converted_res->isa<FuncGraph>()) {
748 const auto &fg = dyn_cast<FuncGraph>(converted_res);
749 MS_EXCEPTION_IF_NULL(fg);
750 fg->set_reserved(true);
751 auto manager = Manage({fg}, false);
752 const auto &total_used_fg = manager->func_graphs_used_total(fg);
753 for (const auto &used_fg : total_used_fg) {
754 used_fg->set_reserved(true);
755 }
756 }
757
758 attrs_[attr_name] = converted_res;
759 auto prim = attached_primitive_.lock();
760 if (prim != nullptr) {
761 (void)prim->AddAttr(attr_name, converted_res);
762 }
763 }
764
DelPyAttr(const py::str & name)765 void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
766 (void)attrs_.erase(name);
767 auto prim = attached_primitive_.lock();
768 if (prim != nullptr) {
769 (void)prim->DelAttr(name);
770 }
771 }
772
GetAttrDict()773 py::dict PrimitivePyAdapter::GetAttrDict() {
774 auto prim = attached_primitive_.lock();
775 if (prim != nullptr) {
776 return prim->GetAttrDict();
777 }
778
779 py::dict attr_dict;
780 for (auto &attr : attrs_) {
781 attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
782 }
783 return attr_dict;
784 }
785
set_prim_type(const PrimType t)786 void PrimitivePyAdapter::set_prim_type(const PrimType t) {
787 prim_type_ = t;
788 auto prim = attached_primitive_.lock();
789 if (prim != nullptr) {
790 prim->set_prim_type(t);
791 }
792 }
793
set_const_prim(bool is_const_prim)794 void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
795 const_prim_ = is_const_prim;
796 auto prim = attached_primitive_.lock();
797 if (prim != nullptr) {
798 prim->set_const_prim(is_const_prim);
799 }
800 }
801
set_inplace_prim(bool is_inplace_prim)802 void PrimitivePyAdapter::set_inplace_prim(bool is_inplace_prim) {
803 inplace_prim_ = is_inplace_prim;
804 auto prim = attached_primitive_.lock();
805 if (prim != nullptr) {
806 prim->set_inplace_prim(is_inplace_prim);
807 }
808 }
809
set_const_input_indexes(const std::vector<size_t> & const_input_indexes)810 void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
811 const_input_indexes_ = const_input_indexes;
812 auto prim = attached_primitive_.lock();
813 if (prim != nullptr) {
814 prim->set_const_input_indexes(const_input_indexes);
815 }
816 }
817
set_signatures(const std::vector<Signature> & signatures)818 void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
819 signatures_ = signatures;
820 auto prim = attached_primitive_.lock();
821 if (prim != nullptr) {
822 prim->set_signatures(signatures);
823 }
824 }
825
AddBackwardHookFn(const py::function & backward_hook_fn)826 int PrimitivePyAdapter::AddBackwardHookFn(const py::function &backward_hook_fn) {
827 ++backward_hook_fn_key_;
828 backward_hook_fn_[backward_hook_fn_key_] = backward_hook_fn;
829 auto prim = attached_primitive_.lock();
830 if (prim != nullptr) {
831 prim->AddBackwardHookFn(backward_hook_fn_key_, backward_hook_fn);
832 }
833 return backward_hook_fn_key_;
834 }
835
RemoveBackwardHookFn(int key)836 void PrimitivePyAdapter::RemoveBackwardHookFn(int key) {
837 const auto iter = backward_hook_fn_.find(key);
838 if (iter != backward_hook_fn_.end()) {
839 (void)backward_hook_fn_.erase(iter);
840 }
841 auto prim = attached_primitive_.lock();
842 if (prim != nullptr) {
843 prim->RemoveBackwardHookFn(key);
844 }
845 }
846
set_instance_name(const std::string & s)847 void PrimitivePyAdapter::set_instance_name(const std::string &s) {
848 instance_name_ = s;
849 auto prim = attached_primitive_.lock();
850 if (prim != nullptr) {
851 prim->set_instance_name(s);
852 }
853 }
854
set_attached_primitive(const PrimitivePyPtr & prim)855 void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
856 if (attached_primitive_.lock() != nullptr) {
857 MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
858 }
859 MS_EXCEPTION_IF_NULL(prim);
860 attached_primitive_ = prim;
861 }
862
SetUserData(const py::str & key,const py::object & value)863 void PrimitivePyAdapter::SetUserData(const py::str &key, const py::object &value) {
864 const std::string name = std::string("__primitive_user_data_") + key.cast<std::string>();
865 const auto &primitive_data = std::make_shared<PrimitiveUserData>();
866 primitive_data->obj = value;
867 // Set into primitive adapter.
868 set_user_data<PrimitiveUserData>(name, primitive_data);
869 // Set in primitive.
870 auto prim = attached_primitive_.lock();
871 if (prim != nullptr) {
872 prim->set_user_data<PrimitiveUserData>(name, primitive_data);
873 }
874 }
875
GetUserData(const py::str & key) const876 py::object PrimitivePyAdapter::GetUserData(const py::str &key) const {
877 const std::string name = std::string("__primitive_user_data_") + key.cast<std::string>();
878 // Get from primitive.
879 auto prim = attached_primitive_.lock();
880 if (prim != nullptr) {
881 const auto primitive_data = prim->user_data<PrimitiveUserData>(name);
882 return primitive_data->obj;
883 }
884 // Get from primtive adapter.
885 const auto primitive_data = user_data<PrimitiveUserData>(name);
886 return primitive_data->obj;
887 }
888
set_label(const std::string & label,const py::object & value)889 void PrimitiveFunctionAdapter::set_label(const std::string &label, const py::object &value) {
890 ValuePtr converted_value = nullptr;
891 if (!parse::ConvertData(value, &converted_value)) {
892 MS_LOG(INTERNAL_EXCEPTION) << "For '" << PrimitiveFunctionAdapter::name() << "', Convert data failed.";
893 }
894 attached_primitive_function_->AddAttr(label, converted_value);
895 }
896
clone()897 py::object PrimitiveFunctionAdapter::clone() {
898 const auto op_path = "mindspore.ops.primitive";
899 const auto func = "_create_primitive_function_obj";
900 py::object prim_func_adapter_obj = python_adapter::CallPyFn(op_path, func);
901 prim_func_adapter_obj.cast<PrimitiveFunctionAdapterPtr>()->set_attached_primitive_function(
902 attached_primitive_function_->Clone());
903 return prim_func_adapter_obj;
904 }
905
RegPrimitive(const py::module * m)906 void RegPrimitive(const py::module *m) {
907 (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
908 .value("unknown", PrimType::kPrimTypeUnknown)
909 .value("builtin", PrimType::kPrimTypeBuiltIn)
910 .value("py_infer_shape", PrimType::kPrimTypePyInfer)
911 .value("user_custom", PrimType::kPrimTypeUserCustom)
912 .value("py_infer_check", PrimType::kPrimTypePyCheck);
913 (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
914 .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
915 .def(py::init<py::str &>())
916 .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
917 .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
918 .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
919 .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
920 .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
921 .def("set_inplace_prim", &PrimitivePyAdapter::set_inplace_prim, "Set primitive is inplace primitive.")
922 .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes, "Set primitive const input indexes.")
923 .def("set_signatures", &PrimitivePyAdapter::set_signatures, "Set primitive inputs signature.")
924 .def("add_backward_hook_fn", &PrimitivePyAdapter::AddBackwardHookFn, "Add primitive backward hook function.")
925 .def("remove_backward_hook_fn", &PrimitivePyAdapter::RemoveBackwardHookFn,
926 "Remove primitive backward hook function.")
927 .def("set_instance_name", &PrimitivePyAdapter::set_instance_name, "Set primitive instance name.")
928 .def("set_user_data", &PrimitivePyAdapter::SetUserData, "Set primitive user data.")
929 .def("get_user_data", &PrimitivePyAdapter::GetUserData, "Get primitive user data.");
930 }
931
RegPrimitiveFunction(const py::module * m)932 void RegPrimitiveFunction(const py::module *m) {
933 (void)py::class_<PrimitiveFunctionAdapter, std::shared_ptr<PrimitiveFunctionAdapter>>(*m, "PrimitiveFunction_")
934 .def_readonly(PYTHON_PRIMITIVE_FUNCTION_FLAG, &PrimitiveFunctionAdapter::parse_info_)
935 .def(py::init<>())
936 .def_property_readonly("name", &PrimitiveFunctionAdapter::name, "Get function name.")
937 .def("has_label", &PrimitiveFunctionAdapter::has_label, "Has function attr.")
938 .def("set_label", &PrimitiveFunctionAdapter::set_label, "Set function attr.")
939 .def("get_label", &PrimitiveFunctionAdapter::get_label, "Get function attr.")
940 .def("clone", &PrimitiveFunctionAdapter::clone, "Clone a Primitive and create a PrimitiveFunctionAdapter.");
941 }
942 } // namespace mindspore
943