• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pybind_api/ir/primitive_py.h"
18 
19 #include <mutex>
20 #include <map>
21 #include <utility>
22 #include "ir/signature.h"
23 #include "pipeline/jit/parse/data_converter.h"
24 #include "pipeline/jit/parse/python_adapter.h"
25 #include "pybind11/pytypes.h"
26 #include "pybind_api/api_register.h"
27 #include "pybind_api/export_flags.h"
28 #include "pybind_api/ir/base_ref_py.h"
29 #include "utils/convert_utils_base.h"
30 #include "utils/convert_utils_py.h"
31 #include "utils/ms_context.h"
32 #include "utils/primitive_utils.h"
33 #include "utils/check_convert_utils.h"
34 #include "pipeline/jit/resource.h"
35 #include "pipeline/pynative/pynative_execute.h"
36 
37 namespace mindspore {
38 namespace {
39 constexpr auto kBpropAttrName = "bprop";
40 constexpr auto kCellHookAttrName = "cell_hook";
41 constexpr auto kCellIDAttrName = "cell_id";
42 std::map<std::string, std::string> kOpAttrNameReplaceMap = {
43   {"data_format", "format"},
44 };
45 
SyncData(const py::object & arg)46 void SyncData(const py::object &arg) {
47   if (py::isinstance<py::tuple>(arg)) {
48     py::tuple arg_list = py::cast<py::tuple>(arg);
49     for (size_t i = 0; i < arg_list.size(); i++) {
50       SyncData(arg_list[i]);
51     }
52   }
53   if (py::isinstance<tensor::Tensor>(arg)) {
54     auto tensor = py::cast<tensor::TensorPtr>(arg);
55     tensor->data_sync();
56   }
57 }
58 }  // namespace
59 std::map<std::string, py::object> PrimitivePy::hook_grad_;
60 
PrimitivePy(const std::string & name)61 PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
62 
PrimitivePy(const py::object & python_obj,const PrimitivePyAdapterPtr & adapter)63 PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter)
64     : Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) {
65   MS_LOG(DEBUG) << "New primitive:" << adapter->name_;
66   set_signatures(adapter->signatures_);
67   (void)Primitive::SetAttrs(adapter->attrs_);
68   Primitive::set_prim_type(adapter->prim_type_);
69   Primitive::set_const_prim(adapter->is_const_prim_);
70   Primitive::set_const_input_indexes(adapter->const_input_indexes_);
71   set_hook(adapter->hook_);
72   set_instance_name(adapter->instance_name_);
73 }
~PrimitivePy()74 PrimitivePy::~PrimitivePy() {}
75 
set_signatures(const std::vector<Signature> & signatures)76 void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
77   signatures_ = signatures;
78   set_has_signature(!signatures.empty());
79 }
80 
GetBpropFunction()81 py::function PrimitivePy::GetBpropFunction() {
82   static const char *const get_bprop_func_name = "get_bprop";
83   if (py::hasattr(python_obj_, get_bprop_func_name)) {
84     py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
85     return fn;
86   } else {
87     auto fn = GetBpropFunctionByObj(python_obj_);
88     return fn;
89   }
90 }
91 
check_bprop_out(const py::object & grads_obj,const py::tuple & py_args)92 py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) {
93   py::tuple grads;
94   if (!py::isinstance<py::tuple>(grads_obj)) {
95     grads = py::make_tuple(grads_obj);
96   } else {
97     grads = py::cast<py::tuple>(grads_obj);
98   }
99   constexpr int filter_args_size = 2;
100   if (grads.size() != py_args.size() - filter_args_size) {
101     MS_EXCEPTION(TypeError) << "For user define net bprop, the gradients number: " << grads.size()
102                             << " is not equal to the args number: " << (py_args.size() - filter_args_size) << ".";
103   }
104   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
105     for (size_t i = 0; i < grads.size(); i++) {
106       if (py::isinstance<tensor::Tensor>(py_args[i])) {
107         if (!py::isinstance<tensor::Tensor>(grads[i])) {
108           MS_EXCEPTION(ValueError) << "When user defines the net bprop,, the gradient of the " << i
109                                    << "th arg should be Tensor, but got "
110                                    << py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
111                                    << ", and the value is " << py::cast<py::str>(grads[i]) << ".";
112         }
113 
114         py::object arg_dtype = py_args[i].attr("dtype");
115         py::object grad_dtype = grads[i].attr("dtype");
116         py::tuple arg_shape = py_args[i].attr("shape");
117         py::tuple grad_shape = grads[i].attr("shape");
118         if (!grad_dtype.equal(arg_dtype)) {
119           MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
120                                   << "th arg should have the same dtype as the " << i << "th arg, but the " << i
121                                   << "th arg dtype is: " << py::cast<py::str>(arg_dtype)
122                                   << ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
123         }
124         if (!grad_shape.equal(arg_shape)) {
125           MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
126                                    << "th arg should have the same shape as the " << i << "th arg, but the " << i
127                                    << "th arg shape is: " << py::cast<py::str>(arg_shape)
128                                    << ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
129         }
130       }
131     }
132   }
133   return grads;
134 }
135 
ConvertCTensorToPyTensor(const py::tuple & input_args,py::tuple * convert_args) const136 void PrimitivePy::ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const {
137   MS_EXCEPTION_IF_NULL(convert_args);
138   if (input_args.size() != (*convert_args).size()) {
139     MS_LOG(EXCEPTION) << "The size of input_args: " << input_args.size()
140                       << " should be equal to the size of convert_args: " << (*convert_args).size();
141   }
142   for (size_t i = 0; i < input_args.size(); ++i) {
143     (*convert_args)[i] = py::isinstance<tensor::Tensor>(input_args[i])
144                            ? parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE,
145                                                              parse::PYTHON_MOD_CONVERT_TO_MS_TENSOR, input_args[i])
146                            : input_args[i];
147   }
148 }
149 
CheckHookConsistency(const py::object & grad_out,const py::object & expected_grad_out) const150 void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const {
151   if (py::isinstance<py::tuple>(expected_grad_out)) {
152     if (!py::isinstance<py::tuple>(grad_out)) {
153       hook_grad_.clear();
154       MS_EXCEPTION(TypeError) << "The output gradient should be a tuple!";
155     }
156     auto actual_out_tuple = py::cast<py::tuple>(grad_out);
157     auto expected_out_tuple = py::cast<py::tuple>(expected_grad_out);
158     if (actual_out_tuple.size() != expected_out_tuple.size()) {
159       hook_grad_.clear();
160       MS_EXCEPTION(ValueError) << "The tuple size of output gradient should be " << expected_out_tuple.size()
161                                << ", but it is " << actual_out_tuple.size();
162     }
163     for (size_t i = 0; i < expected_out_tuple.size(); ++i) {
164       CheckHookConsistency(actual_out_tuple[i], expected_out_tuple[i]);
165     }
166   }
167 
168   if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
169     if (!py::isinstance<tensor::Tensor>(grad_out)) {
170       hook_grad_.clear();
171       MS_EXCEPTION(TypeError) << "The output gradient should be a tensor!";
172     }
173     auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
174     auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
175     MS_EXCEPTION_IF_NULL(actual_out_tensor);
176     MS_EXCEPTION_IF_NULL(expected_out_tensor);
177     if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
178       hook_grad_.clear();
179       MS_EXCEPTION(ValueError) << "The output gradient is not consistent with the expected, it should be "
180                                << expected_out_tensor->GetShapeAndDataTypeInfo() << ", but it is "
181                                << actual_out_tensor->GetShapeAndDataTypeInfo();
182     }
183   }
184 }
185 
RunCellBpropFunction(const py::tuple & py_args) const186 BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
187   SyncData(py_args);
188   auto size = py_args.size();
189   constexpr size_t grad_param_nums = 2;
190   py::tuple input_args(size - grad_param_nums);
191   for (size_t i = 0; i < size - grad_param_nums; ++i) {
192     input_args[i] = py_args[i];
193   }
194   py::tuple convert_args(py_args.size());
195   ConvertCTensorToPyTensor(py_args, &convert_args);
196   auto inst = pynative::PynativeExecutor::GetInstance();
197   MS_EXCEPTION_IF_NULL(inst);
198   try {
199     MS_LOG(DEBUG) << "Run bprop function start";
200     inst->NewGraph(hook_, input_args.cast<py::args>());
201     py::object grads_obj = hook_(*convert_args);
202     py::tuple grads = check_bprop_out(grads_obj, py_args);
203     inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
204     MS_LOG(DEBUG) << "Run bprop function end";
205     return std::make_shared<PyObjectRef>(grads);
206   } catch (std::exception &bt) {
207     inst->ClearRes();
208     std::rethrow_exception(std::current_exception());
209   }
210 }
211 
RunCellHookFunction(const py::tuple & py_args) const212 BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
213   constexpr size_t grad_input_index = 1;
214   constexpr size_t grad_output_index = 2;
215   constexpr size_t input_param_nums = 3;
216   SyncData(py_args[grad_output_index]);
217 
218   py::object obj;
219   auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
220   auto iter = hook_grad_.find(cell_id);
221   if (iter != hook_grad_.end()) {
222     py::object code_obj = py::getattr(hook_, "__code__");
223     py::object co_name = py::getattr(code_obj, "co_name");
224     if (std::string(py::str(co_name)) == "staging_specialize") {
225       MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
226     }
227 
228     py::tuple convert_args(input_param_nums - 1);
229     py::tuple input_args(input_param_nums - 1);
230     input_args[0] = iter->second;
231     input_args[1] = py_args[grad_output_index];
232     ConvertCTensorToPyTensor(input_args, &convert_args);
233     auto hook_args = py::tuple(input_param_nums);
234     hook_args[0] = cell_id;
235     hook_args[grad_input_index] = py::make_tuple(convert_args[0]);
236     hook_args[grad_output_index] = py::make_tuple(convert_args[1]);
237     obj = hook_(*hook_args);
238     if (py::isinstance<py::none>(obj)) {
239       obj = py_args[grad_output_index];
240     }
241     CheckHookConsistency(obj, py_args[grad_output_index]);
242     (void)hook_grad_.erase(cell_id);
243   } else {
244     hook_grad_[cell_id] = py_args[grad_output_index];
245     obj = py_args[grad_output_index];
246   }
247   obj = py::make_tuple(obj);
248   return std::make_shared<PyObjectRef>(obj);
249 }
250 
RunVariableHookFunction(const py::tuple & py_args) const251 BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
252   py::object code_obj = py::getattr(hook_, "__code__");
253   py::object co_name = py::getattr(code_obj, "co_name");
254   if (std::string(py::str(co_name)) == "staging_specialize") {
255     MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported.";
256   }
257 
258   constexpr size_t grad_output_index = 2;
259   SyncData(py_args[grad_output_index]);
260   py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));
261   if (py::isinstance<py::none>(obj)) {
262     obj = py_args[grad_output_index];
263   }
264   CheckHookConsistency(obj, py_args[grad_output_index]);
265   obj = py::make_tuple(obj);
266   return std::make_shared<PyObjectRef>(obj);
267 }
268 
RunHookFunction(const VectorRef & args) const269 BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
270   py::tuple py_args = ConvertDatatoPyTuple(args);
271   bool is_bprop = this->HasAttr(kBpropAttrName);
272   if (is_bprop) {
273     return RunCellBpropFunction(py_args);
274   }
275   bool is_cell = this->HasAttr(kCellHookAttrName);
276   if (is_cell) {
277     return RunCellHookFunction(py_args);
278   }
279   return RunVariableHookFunction(py_args);
280 }
281 
GetComputeFunction() const282 py::function PrimitivePy::GetComputeFunction() const {
283   static const char *const compute_func_name = "vm_impl";
284 
285   if (py::hasattr(python_obj_, compute_func_name)) {
286     MS_LOG(DEBUG) << name() << " compute_func_name";
287     py::function fn = python_obj_.attr(compute_func_name).cast<py::function>();
288     return fn;
289   }
290 
291   static const std::string vm_module = "mindspore.ops.vm_impl_registry";
292   static const std::string get_vm_impl_fn = "get_vm_impl_fn";
293   MS_LOG(DEBUG) << name() << ": get_vm_impl_fn";
294   py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
295   py::function vm_fn = get_fn(python_obj_);
296   if (py::isinstance<py::none>(vm_fn)) {
297     MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>();
298     vm_fn = mindspore::GetComputeFunction(Primitive::name());
299   }
300   return vm_fn;
301 }
302 
GetAttrDict()303 py::dict PrimitivePy::GetAttrDict() {
304   py::dict attr_dict;
305   for (auto &attr : attrs_) {
306     attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
307   }
308   return attr_dict;
309 }
310 
CopyHookFunction(const PrimitivePtr & primitive)311 void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
312   MS_EXCEPTION_IF_NULL(primitive);
313   if (!primitive->isa<PrimitivePy>()) {
314     MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!";
315   }
316   auto primitive_py = primitive->cast<PrimitivePyPtr>();
317   MS_EXCEPTION_IF_NULL(primitive_py);
318   this->set_hook(primitive_py->hook());
319   if (primitive_py->HasAttr(kBpropAttrName)) {
320     (void)this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
321   }
322 }
323 
RunComputeFunction(const VectorRef & args) const324 BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
325   auto py_args = ConvertDatatoPyTuple(args);
326   auto result = this->RunPyComputeFunction(py_args);
327   if (py::isinstance<py::none>(result)) {
328     return std::make_shared<BaseRef>(nullptr);
329   }
330   return std::make_shared<PyObjectRef>(result);
331 }
332 
RunPyComputeFunction(const py::tuple & py_args) const333 py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
334   auto func = this->GetComputeFunction();
335   if (py::isinstance<py::none>(func)) {
336     return py::none();
337   }
338   auto result = func(*py_args);
339   return result;
340 }
341 
HasComputeFunction() const342 bool PrimitivePy::HasComputeFunction() const {
343   auto func = GetComputeFunction();
344   return !py::isinstance<py::none>(func);
345 }
346 
Clone()347 PrimitivePtr PrimitivePy::Clone() {
348   auto clone_fn = python_obj_.attr("_clone");
349   py::object obj_adapter = clone_fn();
350   auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
351   auto prim = std::make_shared<PrimitivePy>(obj_adapter, prim_adapter);
352   prim_adapter->set_attached_primitive(prim);
353   return prim;
354 }
355 
RunInfer(const py::tuple & args)356 py::dict PrimitivePy::RunInfer(const py::tuple &args) {
357   if (!HasPyObj()) {
358     MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
359   }
360   // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
361   if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER)) {
362     MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER;
363   }
364   auto infer_fuc = python_obj_.attr(PY_PRIM_METHOD_INFER);
365   return infer_fuc(*args);
366 }
367 
RunCheck(const py::tuple & args)368 void PrimitivePy::RunCheck(const py::tuple &args) {
369   if (!HasPyObj()) {
370     MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
371   }
372   // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
373   if (!py::hasattr(python_obj_, PY_PRIM_METHOD_CHECK)) {
374     MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_CHECK;
375   }
376   auto check_func = python_obj_.attr(PY_PRIM_METHOD_CHECK);
377   (void)check_func(*args);
378 }
379 
RunInferValue(const py::tuple & args)380 py::object PrimitivePy::RunInferValue(const py::tuple &args) {
381   if (!HasPyObj()) {
382     MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
383   }
384   // Python obj could be replaced as None, so it will losed the original info when throw exception in python.
385   if (!py::hasattr(python_obj_, PY_PRIM_METHOD_INFER_VALUE)) {
386     MS_LOG(EXCEPTION) << "prim:" << ToString() << " has no attr:" << PY_PRIM_METHOD_INFER_VALUE;
387   }
388   auto infer_value = python_obj_.attr(PY_PRIM_METHOD_INFER_VALUE);
389   return infer_value(*args);
390 }
391 
PrimitivePyAdapter(const py::str & name)392 PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {}
393 
AddPyAttr(const py::str & name,const py::object & obj)394 void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
395   std::string attr_name = name;
396   ValuePtr converted_ret = nullptr;
397   if (py::isinstance<py::module>(obj)) {
398     MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
399   }
400   bool converted = parse::ConvertData(obj, &converted_ret);
401   if (!converted) {
402     MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
403   }
404   if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
405     attr_name = kOpAttrNameReplaceMap[attr_name];
406   }
407   (void)CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret);
408   attrs_[attr_name] = converted_ret;
409   auto prim = attached_primitive_.lock();
410   if (prim != nullptr) {
411     (void)prim->AddAttr(attr_name, converted_ret);
412   }
413 
414   if (attr_name == "primitive_target") {
415     MS_EXCEPTION_IF_NULL(converted_ret);
416     if (!converted_ret->isa<StringImm>()) {
417       MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
418     }
419 
420     auto target = GetValue<std::string>(converted_ret);
421     if (target != kCPUDevice && target != kGPUDevice) {
422       auto context_ptr = MsContext::GetInstance();
423       MS_EXCEPTION_IF_NULL(context_ptr);
424       context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
425     }
426   }
427 }
428 
DelPyAttr(const py::str & name)429 void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
430   (void)attrs_.erase(name);
431   auto prim = attached_primitive_.lock();
432   if (prim != nullptr) {
433     (void)prim->DelAttr(name);
434   }
435 }
436 
GetAttrDict()437 py::dict PrimitivePyAdapter::GetAttrDict() {
438   auto prim = attached_primitive_.lock();
439   if (prim != nullptr) {
440     return prim->GetAttrDict();
441   }
442 
443   py::dict attr_dict;
444   for (auto &attr : attrs_) {
445     attr_dict[py::str(attr.first)] = ValueToPyData(attr.second);
446   }
447   return attr_dict;
448 }
449 
set_prim_type(const PrimType t)450 void PrimitivePyAdapter::set_prim_type(const PrimType t) {
451   prim_type_ = t;
452   auto prim = attached_primitive_.lock();
453   if (prim != nullptr) {
454     prim->set_prim_type(t);
455   }
456 }
set_const_prim(bool is_const_prim)457 void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
458   is_const_prim_ = is_const_prim;
459   auto prim = attached_primitive_.lock();
460   if (prim != nullptr) {
461     prim->set_const_prim(is_const_prim);
462   }
463 }
set_const_input_indexes(const std::vector<size_t> & const_input_indexes)464 void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
465   const_input_indexes_ = const_input_indexes;
466   auto prim = attached_primitive_.lock();
467   if (prim != nullptr) {
468     prim->set_const_input_indexes(const_input_indexes);
469   }
470 }
471 
set_signatures(const std::vector<Signature> & signatures)472 void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
473   signatures_ = signatures;
474   auto prim = attached_primitive_.lock();
475   if (prim != nullptr) {
476     prim->set_signatures(signatures);
477   }
478 }
479 
set_hook(const py::function & hook)480 void PrimitivePyAdapter::set_hook(const py::function &hook) {
481   hook_ = hook;
482   auto prim = attached_primitive_.lock();
483   if (prim != nullptr) {
484     prim->set_hook(hook);
485   }
486 }
487 
set_instance_name(const std::string & s)488 void PrimitivePyAdapter::set_instance_name(const std::string &s) {
489   instance_name_ = s;
490   auto prim = attached_primitive_.lock();
491   if (prim != nullptr) {
492     prim->set_instance_name(s);
493   }
494 }
495 
set_attached_primitive(const PrimitivePyPtr & prim)496 void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
497   if (attached_primitive_.lock() != nullptr) {
498     MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
499   }
500   MS_EXCEPTION_IF_NULL(prim);
501   attached_primitive_ = prim;
502 }
503 
__anonda41446c0202(const py::module *m) 504 REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
505                          (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
506                            .value("unknown", PrimType::kPrimTypeUnknown)
507                            .value("builtin", PrimType::kPrimTypeBuiltIn)
508                            .value("py_infer_shape", PrimType::kPrimTypePyInfer)
509                            .value("user_custom", PrimType::kPrimTypeUserCustom)
510                            .value("py_infer_check", PrimType::kPrimTypePyCheck);
511                          (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
512                            .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
513                            .def(py::init<py::str &>())
514                            .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
515                            .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
516                            .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
517                            .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
518                            .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
519                            .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes,
520                                 "Set primitive const input indexes.")
521                            .def("set_signatures", &PrimitivePyAdapter::set_signatures,
522                                 "Set primitive inputs signature.")
523                            .def("register_hook", &PrimitivePyAdapter::set_hook, "Set primitive hook function.")
524                            .def("set_instance_name", &PrimitivePyAdapter::set_instance_name,
525                                 "Set primitive instance name.");
526                        }));
527 }  // namespace mindspore
528