• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/pynative_execute.h"
18 
19 #include <typeinfo>
20 #include <set>
21 #include <memory>
22 #include <sstream>
23 #include <unordered_set>
24 #include <algorithm>
25 
26 #include "debug/trace.h"
27 #include "debug/anf_ir_dump.h"
28 #include "pybind_api/api_register.h"
29 #include "pybind_api/pybind_patch.h"
30 #include "pybind_api/ir/tensor_py.h"
31 #include "ir/param_info.h"
32 #include "ir/anf.h"
33 #include "ir/cell.h"
34 #include "ir/tensor.h"
35 #include "utils/any.h"
36 #include "utils/utils.h"
37 #include "utils/ms_context.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/context/context_extends.h"
40 #include "utils/config_manager.h"
41 #include "utils/convert_utils_py.h"
42 #include "utils/scoped_long_running.h"
43 #include "frontend/optimizer/ad/grad.h"
44 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
45 #include "frontend/operator/ops.h"
46 #include "frontend/operator/composite/do_signature.h"
47 #include "frontend/parallel/context.h"
48 #include "pipeline/jit/action.h"
49 #include "pipeline/jit/pass.h"
50 #include "pipeline/jit/parse/data_converter.h"
51 #include "pipeline/jit/parse/parse_dynamic.h"
52 #include "pipeline/jit/static_analysis/prim.h"
53 #include "pipeline/jit/static_analysis/auto_monad.h"
54 #include "pipeline/jit/pipeline.h"
55 #include "pipeline/jit/resource.h"
56 #include "pipeline/pynative/base.h"
57 #include "backend/session/session_factory.h"
58 #include "backend/optimizer/common/const_input_to_attr_registry.h"
59 #include "backend/optimizer/common/helper.h"
60 #include "runtime/hardware/device_context_manager.h"
61 #include "vm/transform.h"
62 
63 #ifdef ENABLE_GE
64 #include "pipeline/pynative/pynative_execute_ge.h"
65 #endif
66 
67 using mindspore::tensor::TensorPy;
68 
69 namespace mindspore::pynative {
70 PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
71 ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr;
72 GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr;
73 std::mutex PynativeExecutor::instance_lock_;
74 
75 namespace {
76 const size_t PTR_LEN = 15;
77 const size_t ARG_SIZE = 2;
78 const size_t MAX_TOP_CELL_COUNTS = 20;
79 
80 // primitive unable to infer value for constant input in PyNative mode
81 const std::set<std::string> kVmOperators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
82                                             "mixed_precision_cast"};
83 const char kOpsFunctionModelName[] = "mindspore.ops.functional";
84 std::shared_ptr<session::SessionBasic> kSession = nullptr;
85 std::shared_ptr<compile::MindRTBackend> mind_rt_backend = nullptr;
86 PyObjectIdCache g_pyobj_id_cache;
87 
88 template <typename T, typename... Args>
PynativeExecutorTry(const std::function<void (T * ret,const Args &...)> & method,T * ret,const Args &...args)89 void PynativeExecutorTry(const std::function<void(T *ret, const Args &...)> &method, T *ret, const Args &... args) {
90   const auto inst = PynativeExecutor::GetInstance();
91   MS_EXCEPTION_IF_NULL(inst);
92   MS_EXCEPTION_IF_NULL(method);
93   try {
94     method(ret, args...);
95   } catch (const py::error_already_set &ex) {
96     // print function call stack info before release
97     std::ostringstream oss;
98     trace::TraceGraphEval();
99     trace::GetEvalStackInfo(oss);
100     // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
101     // these info from screen, no need to open log file to find these info
102     py::print(oss.str());
103     MS_LOG(ERROR) << oss.str();
104     inst->ClearRes();
105     // re-throw this exception to Python interpreter to handle it
106     throw(py::error_already_set(ex));
107   } catch (const py::type_error &ex) {
108     inst->ClearRes();
109     throw py::type_error(ex);
110   } catch (const py::value_error &ex) {
111     inst->ClearRes();
112     throw py::value_error(ex);
113   } catch (const py::index_error &ex) {
114     inst->ClearRes();
115     throw py::index_error(ex);
116   } catch (const py::name_error &ex) {
117     inst->ClearRes();
118     throw py::name_error(ex);
119   } catch (const std::exception &ex) {
120     inst->ClearRes();
121     // re-throw this exception to Python interpreter to handle it
122     throw(std::runtime_error(ex.what()));
123   } catch (...) {
124     inst->ClearRes();
125     auto exception_type = abi::__cxa_current_exception_type();
126     MS_EXCEPTION_IF_NULL(exception_type);
127     std::string ex_name(exception_type->name());
128     MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << ex_name;
129   }
130 }
131 
PyObjToValue(const py::object & obj)132 inline ValuePtr PyObjToValue(const py::object &obj) {
133   ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
134   if (!converted_ret) {
135     MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
136   }
137   return converted_ret;
138 }
139 
GetPyObjId(const py::handle & obj)140 std::string GetPyObjId(const py::handle &obj) {
141   py::object out = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
142   if (py::isinstance<py::none>(out)) {
143     MS_LOG(EXCEPTION) << "Get pyobj failed";
144   }
145   return out.cast<std::string>();
146 }
147 
GetId(const py::handle & obj)148 std::string GetId(const py::handle &obj) {
149   if (py::isinstance<tensor::Tensor>(obj)) {
150     auto tensor_ptr = py::cast<tensor::TensorPtr>(obj);
151     return tensor_ptr->id();
152   } else if (py::isinstance<mindspore::Type>(obj)) {
153     auto type_ptr = py::cast<mindspore::TypePtr>(obj);
154     return "type" + type_ptr->ToString();
155   } else if (py::isinstance<py::str>(obj) || py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj)) {
156     return std::string(py::str(obj));
157   } else if (py::isinstance<py::none>(obj)) {
158     return "none";
159   } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
160     auto p_list = py::cast<py::tuple>(obj);
161     string prefix = py::isinstance<py::tuple>(obj) ? "tuple:" : "list";
162     if (p_list.empty()) {
163       prefix = "empty";
164     } else {
165       std::string key;
166       for (size_t i = 0; i < p_list.size(); ++i) {
167         key += std::string(py::str(GetId(p_list[i]))) + ":";
168       }
169       prefix += key;
170     }
171     return prefix;
172   }
173 
174   if (py::isinstance<Cell>(obj) || py::isinstance<py::function>(obj)) {
175     const auto &it = g_pyobj_id_cache.find(obj);
176     if (it == g_pyobj_id_cache.end()) {
177       auto &&id = GetPyObjId(obj);
178       g_pyobj_id_cache[obj] = id;
179       return std::move(id);
180     } else {
181       return it->second;
182     }
183   } else {
184     return GetPyObjId(obj);
185   }
186 }
187 
GetTypeIndex(const std::vector<SignatureEnumDType> & dtypes,std::unordered_map<SignatureEnumDType,std::vector<size_t>> * type_indexes)188 void GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes,
189                   std::unordered_map<SignatureEnumDType, std::vector<size_t>> *type_indexes) {
190   MS_EXCEPTION_IF_NULL(type_indexes);
191   for (size_t i = 0; i < dtypes.size(); ++i) {
192     auto it = type_indexes->find(dtypes[i]);
193     if (it == type_indexes->end()) {
194       (void)type_indexes->emplace(std::make_pair(dtypes[i], std::vector<size_t>{i}));
195     } else {
196       it->second.emplace_back(i);
197     }
198   }
199 }
200 
JudgeMaxType(TypeId max_type,bool has_scalar_float32,bool has_scalar_int64,bool has_tensor_int8)201 TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
202   if (max_type == TypeId::kNumberTypeBool) {
203     if (has_scalar_int64) {
204       max_type = TypeId::kNumberTypeInt64;
205     }
206     if (has_scalar_float32) {
207       max_type = TypeId::kNumberTypeFloat32;
208     }
209   }
210   if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
211       max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
212     max_type = TypeId::kNumberTypeFloat32;
213   }
214   if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
215     max_type = TypeId::kNumberTypeInt16;
216   }
217   return max_type;
218 }
219 
GetDstType(const py::tuple & py_args,const std::unordered_map<SignatureEnumDType,std::vector<size_t>> & type_indexes,std::unordered_map<SignatureEnumDType,TypeId> * dst_type)220 void GetDstType(const py::tuple &py_args,
221                 const std::unordered_map<SignatureEnumDType, std::vector<size_t>> &type_indexes,
222                 std::unordered_map<SignatureEnumDType, TypeId> *dst_type) {
223   for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
224     const auto &type = it->first;
225     const auto &indexes = it->second;
226     if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < ARG_SIZE) {
227       continue;
228     }
229     size_t priority = 0;
230     TypeId max_type = TypeId::kTypeUnknown;
231     bool has_scalar_float32 = false;
232     bool has_scalar_int64 = false;
233     bool has_tensor_int8 = false;
234     // Find the maximum priority of the same dtype
235     for (size_t index : indexes) {
236       if (index >= py_args.size()) {
237         MS_LOG(EXCEPTION) << "The index " << index << " exceeds the size of py_args " << py_args.size();
238       }
239       const auto &obj = py_args[index];
240       if (py::isinstance<py::float_>(obj)) {
241         has_scalar_float32 = true;
242       }
243       if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
244         has_scalar_int64 = true;
245       }
246       if (py::isinstance<tensor::Tensor>(obj)) {
247         auto arg = py::cast<tensor::TensorPtr>(obj);
248         TypeId arg_type_id = arg->data_type();
249         auto type_priority = prim::type_map.find(arg_type_id);
250         if (type_priority == prim::type_map.end()) {
251           continue;
252         }
253         if (arg_type_id == kNumberTypeInt8) {
254           has_tensor_int8 = true;
255         }
256         if (type_priority->second > priority) {
257           max_type = type_priority->first;
258           priority = type_priority->second;
259         }
260       }
261     }
262     max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
263     MS_EXCEPTION_IF_NULL(dst_type);
264     (void)dst_type->emplace(std::make_pair(type, max_type));
265   }
266 }
267 
TypeIdToMsTypeStr(const TypeId & type_id)268 const std::string &TypeIdToMsTypeStr(const TypeId &type_id) {
269   const auto &type_name = type_name_map.find(type_id);
270   if (type_name == type_name_map.end()) {
271     MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
272   }
273   return type_name->second;
274 }
275 
GetSignatureType(const PrimitivePyPtr & prim,std::vector<SignatureEnumDType> * dtypes)276 bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
277   MS_EXCEPTION_IF_NULL(prim);
278   MS_EXCEPTION_IF_NULL(dtypes);
279   const auto &signature = prim->signatures();
280   bool has_sig_dtype = false;
281   (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
282                        [&has_sig_dtype](const Signature &sig) {
283                          auto dtype = sig.dtype;
284                          if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
285                            has_sig_dtype = true;
286                          }
287                          return dtype;
288                        });
289   return has_sig_dtype;
290 }
291 
PynativeInfer(const PrimitivePyPtr & prim,OpExecInfo * const op_exec_info,const abstract::AbstractBasePtrList & args_spec_list)292 void PynativeInfer(const PrimitivePyPtr &prim, OpExecInfo *const op_exec_info,
293                    const abstract::AbstractBasePtrList &args_spec_list) {
294   MS_EXCEPTION_IF_NULL(prim);
295   MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
296   prim->BeginRecordAddAttr();
297   auto eval_ret = EvalOnePrim(prim, args_spec_list);
298   MS_EXCEPTION_IF_NULL(eval_ret);
299   AbstractBasePtr infer_res = eval_ret->abstract();
300   MS_EXCEPTION_IF_NULL(infer_res);
301   prim->EndRecordAddAttr();
302   MS_EXCEPTION_IF_NULL(op_exec_info);
303   op_exec_info->abstract = infer_res;
304   MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
305   MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
306 }
307 
GetSingleOpGraphInfo(const OpExecInfoPtr & op_exec_info,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<int64_t> & tensors_mask,std::string * graph_info_key)308 void GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<tensor::TensorPtr> &input_tensors,
309                           const std::vector<int64_t> &tensors_mask, std::string *graph_info_key) {
310   MS_EXCEPTION_IF_NULL(op_exec_info);
311   MS_EXCEPTION_IF_NULL(graph_info_key);
312   auto &graph_info = *graph_info_key;
313   if (input_tensors.size() != tensors_mask.size()) {
314     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
315                       << tensors_mask.size();
316   }
317   std::ostringstream buf;
318   buf << op_exec_info->op_name;
319   bool has_const_input = false;
320   for (size_t index = 0; index < input_tensors.size(); ++index) {
321     MS_EXCEPTION_IF_NULL(input_tensors[index]);
322     buf << input_tensors[index]->shape();
323     buf << input_tensors[index]->data_type();
324     buf << input_tensors[index]->padding_type();
325     // In the case of the same shape, but dtype and format are inconsistent
326     auto tensor_addr = input_tensors[index]->device_address();
327     if (tensor_addr != nullptr) {
328       auto p_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr);
329       MS_EXCEPTION_IF_NULL(p_address);
330       buf << p_address->type_id();
331       buf << p_address->format();
332     }
333     // For constant input
334     if (tensors_mask[index] == kValueNodeTensorMask) {
335       has_const_input = true;
336       auto dtype = input_tensors[index]->Dtype();
337       MS_EXCEPTION_IF_NULL(dtype);
338       if (dtype->type_id() == kNumberTypeInt64) {
339         buf << *reinterpret_cast<int *>(input_tensors[index]->data_c());
340       } else if (dtype->type_id() == kNumberTypeFloat32 || dtype->type_id() == kNumberTypeFloat16) {
341         buf << *reinterpret_cast<float *>(input_tensors[index]->data_c());
342       } else {
343         MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!";
344       }
345     }
346     buf << "_";
347   }
348   // The value of the attribute affects the operator selection
349   const auto &op_prim = op_exec_info->py_primitive;
350   MS_EXCEPTION_IF_NULL(op_prim);
351   const auto &attr_map = op_prim->attrs();
352   (void)std::for_each(attr_map.begin(), attr_map.end(),
353                       [&buf](const auto &element) { buf << element.second->ToString(); });
354 
355   // Constant input affects output, operators like DropoutGenMask whose output is related to values of input when input
356   // shapes are the same but values are different
357   if (has_const_input) {
358     buf << "_";
359     auto abstr = op_exec_info->abstract;
360     MS_EXCEPTION_IF_NULL(abstr);
361     auto build_shape = abstr->BuildShape();
362     MS_EXCEPTION_IF_NULL(build_shape);
363     buf << build_shape->ToString();
364     auto build_type = abstr->BuildType();
365     MS_EXCEPTION_IF_NULL(build_type);
366     buf << build_type->type_id();
367   }
368   graph_info = buf.str();
369 }
370 
FilterTensorArgs(const py::args & args,bool has_sens=false)371 py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
372   size_t size = args.size();
373   if (size == 0 && has_sens) {
374     MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
375   }
376   py::list only_tensors;
377   size_t forward_args_size = has_sens ? size - 1 : size;
378   for (size_t i = 0; i < forward_args_size; ++i) {
379     if (py::isinstance<tensor::Tensor>(args[i])) {
380       only_tensors.append(args[i]);
381     }
382   }
383   if (has_sens) {
384     only_tensors.append(args[forward_args_size]);
385   }
386   return only_tensors;
387 }
388 
RunOpConvertConstInputToAttr(const py::object & input_object,size_t input_index,const PrimitivePtr & op_prim,const std::unordered_set<size_t> & input_attrs)389 bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
390                                   const std::unordered_set<size_t> &input_attrs) {
391   MS_EXCEPTION_IF_NULL(op_prim);
392   const auto &input_names_value = op_prim->GetAttr(kAttrInputNames);
393   if (input_names_value == nullptr) {
394     return false;
395   }
396   const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
397   if (input_index >= input_names_vec.size()) {
398     MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
399   }
400 
401   if (input_attrs.find(input_index) != input_attrs.end()) {
402     const auto &value = PyObjToValue(input_object);
403     auto input_name = input_names_vec[input_index];
404     op_prim->AddAttr(input_name, value);
405     return true;
406   }
407   return false;
408 }
409 
PlantTensorTupleToVector(const py::tuple & tuple_inputs,const PrimitivePtr & op_prim,std::vector<tensor::TensorPtr> * input_tensors)410 void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
411                               std::vector<tensor::TensorPtr> *input_tensors) {
412   MS_EXCEPTION_IF_NULL(op_prim);
413   MS_EXCEPTION_IF_NULL(input_tensors);
414   for (const auto &input_object : tuple_inputs) {
415     if (!py::isinstance<tensor::Tensor>(input_object)) {
416       MS_LOG(EXCEPTION) << "The input object is not a tensor!";
417     }
418     auto tensor = py::cast<tensor::TensorPtr>(input_object);
419     MS_EXCEPTION_IF_NULL(tensor);
420     input_tensors->emplace_back(tensor);
421   }
422   op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(tuple_inputs.size())}));
423 }
424 
ConvertValueTupleToTensor(const py::object & input_object,std::vector<tensor::TensorPtr> * input_tensors)425 void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
426   MS_EXCEPTION_IF_NULL(input_tensors);
427   const auto &input_value = PyObjToValue(input_object);
428   MS_EXCEPTION_IF_NULL(input_value);
429   if (!input_value->isa<ValueTuple>()) {
430     MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
431   }
432   auto value_tuple = input_value->cast<ValueTuplePtr>();
433   MS_EXCEPTION_IF_NULL(value_tuple);
434   tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
435   MS_EXCEPTION_IF_NULL(tensor_ptr);
436   input_tensors->emplace_back(tensor_ptr);
437 }
438 
ConvertMultiPyObjectToTensor(const py::object & input_object,const PrimitivePtr & op_prim,std::vector<tensor::TensorPtr> * input_tensors,int64_t * const tensor_mask)439 void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
440                                   std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
441   MS_EXCEPTION_IF_NULL(op_prim);
442   MS_EXCEPTION_IF_NULL(input_tensors);
443   MS_EXCEPTION_IF_NULL(tensor_mask);
444 
445   if (!py::isinstance<py::tuple>(input_object)) {
446     MS_LOG(EXCEPTION) << "The input should be a tuple!";
447   }
448   auto tuple_inputs = py::cast<py::tuple>(input_object);
449   if (tuple_inputs.empty()) {
450     MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
451   }
452   if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
453     PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
454   } else {
455     ConvertValueTupleToTensor(input_object, input_tensors);
456     *tensor_mask = kValueNodeTensorMask;
457   }
458 }
459 
ConvertPyObjectToTensor(const py::object & input_object,const PrimitivePtr & op_prim,std::vector<tensor::TensorPtr> * input_tensors,int64_t * const tensor_mask)460 void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
461                              std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
462   MS_EXCEPTION_IF_NULL(op_prim);
463   MS_EXCEPTION_IF_NULL(input_tensors);
464   MS_EXCEPTION_IF_NULL(tensor_mask);
465   tensor::TensorPtr tensor_ptr = nullptr;
466   if (py::isinstance<tensor::Tensor>(input_object)) {
467     tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
468   } else if (py::isinstance<py::float_>(input_object)) {
469     double input_value = py::cast<py::float_>(input_object);
470     tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
471     *tensor_mask = kValueNodeTensorMask;
472   } else if (py::isinstance<py::int_>(input_object)) {
473     tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
474     *tensor_mask = kValueNodeTensorMask;
475   } else if (py::isinstance<py::array>(input_object)) {
476     tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
477   } else if (py::isinstance<py::list>(input_object)) {
478     auto list_inputs = py::cast<py::list>(input_object);
479     py::tuple tuple_inputs(list_inputs.size());
480     for (size_t i = 0; i < tuple_inputs.size(); ++i) {
481       tuple_inputs[i] = list_inputs[i];
482     }
483     ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
484     return;
485   } else if (py::isinstance<py::tuple>(input_object)) {
486     ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
487     return;
488   } else if (py::isinstance<py::none>(input_object)) {
489     return;
490   } else {
491     MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
492   }
493   MS_EXCEPTION_IF_NULL(tensor_ptr);
494   input_tensors->emplace_back(tensor_ptr);
495 }
496 
ConstructInputTensor(const OpExecInfoPtr & op_run_info,std::vector<int64_t> * tensors_mask,std::vector<tensor::TensorPtr> * input_tensors)497 void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> *tensors_mask,
498                           std::vector<tensor::TensorPtr> *input_tensors) {
499   MS_EXCEPTION_IF_NULL(op_run_info);
500   MS_EXCEPTION_IF_NULL(tensors_mask);
501   MS_EXCEPTION_IF_NULL(input_tensors);
502   PrimitivePtr op_prim = op_run_info->py_primitive;
503   // Checking whether attr conversion is needed.
504   opt::ConstInputToAttrInfoRegister reg;
505   bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
506   if (op_run_info->is_dynamic_shape &&
507       dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
508     MS_LOG(DEBUG) << "current node is dynamic shape: " << op_run_info->op_name;
509     reg_exist = false;
510   }
511   auto ms_context = MsContext::GetInstance();
512   MS_EXCEPTION_IF_NULL(ms_context);
513   const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
514   if (device_target != kCPUDevice && op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
515     reg_exist = false;
516   }
517   // Gather op needs converting const input to attr on GPU device
518   if (device_target != kGPUDevice && op_run_info->op_name == prim::kPrimGatherD->name()) {
519     reg_exist = false;
520   }
521   // Get input tensors
522   MS_EXCEPTION_IF_NULL(op_prim);
523   op_prim->BeginRecordAddAttr();
524   size_t input_num = op_run_info->op_inputs.size();
525   if (input_num != op_run_info->inputs_mask.size()) {
526     MS_LOG(EXCEPTION) << "The op input size " << input_num << ", but the size of input mask "
527                       << op_run_info->inputs_mask.size();
528   }
529   for (size_t index = 0; index < input_num; ++index) {
530     // convert const input to attr
531     if (reg_exist &&
532         RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
533       continue;
534     }
535     // convert const and tuple input to tensor
536     int64_t tensor_mask = op_run_info->inputs_mask[index];
537     ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
538     // Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
539     op_run_info->inputs_mask[index] = tensor_mask;
540     std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
541     tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
542   }
543   op_prim->EndRecordAddAttr();
544 }
545 
ConvertAttrToUnifyMindIR(const OpExecInfoPtr & op_run_info)546 void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
547   MS_EXCEPTION_IF_NULL(op_run_info);
548   const auto &op_prim = op_run_info->py_primitive;
549   MS_EXCEPTION_IF_NULL(op_prim);
550 
551   const auto &op_name = op_run_info->op_name;
552   auto attrs = op_prim->attrs();
553   for (auto attr : attrs) {
554     bool converted = CheckAndConvertUtils::ConvertAttrValueToString(op_name, attr.first, &attr.second);
555     if (converted) {
556       op_prim->set_attr(attr.first, attr.second);
557     }
558     bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(op_name, attr.first, &attr.second);
559     if (converted_ir_attr) {
560       op_prim->set_attr(attr.first, attr.second);
561     }
562   }
563 }
564 
GetTupleSize(const py::tuple & args)565 size_t GetTupleSize(const py::tuple &args) {
566   size_t count = 0;
567   for (size_t i = 0; i < args.size(); i++) {
568     if (py::isinstance<py::tuple>(args[i])) {
569       count += GetTupleSize(args[i]);
570     } else {
571       count += 1;
572     }
573   }
574   return count;
575 }
576 
ConvertTupleArg(py::tuple * res,size_t * const index,const py::tuple & arg)577 void ConvertTupleArg(py::tuple *res, size_t *const index, const py::tuple &arg) {
578   MS_EXCEPTION_IF_NULL(res);
579   MS_EXCEPTION_IF_NULL(index);
580   auto res_size = res->size();
581   for (size_t i = 0; i < arg.size(); i++) {
582     if (py::isinstance<py::tuple>(arg[i])) {
583       ConvertTupleArg(res, index, arg[i]);
584     } else {
585       if (*index >= res_size) {
586         MS_LOG(EXCEPTION) << "Convert tuple error, index is greater than tuple size, index " << (*index)
587                           << ", tuple size " << res_size;
588       }
589       (*res)[(*index)++] = arg[i];
590     }
591   }
592 }
593 
ConvertArgs(const py::tuple & args)594 py::tuple ConvertArgs(const py::tuple &args) {
595   size_t tuple_size = GetTupleSize(args);
596   py::tuple res(tuple_size);
597   size_t index = 0;
598   for (size_t i = 0; i < args.size(); i++) {
599     if (py::isinstance<py::tuple>(args[i])) {
600       ConvertTupleArg(&res, &index, args[i]);
601     } else {
602       if (index >= tuple_size) {
603         MS_LOG(EXCEPTION) << "Convert error, index is greater than tuple size, index " << index << ", tuple size "
604                           << tuple_size;
605       }
606       res[index++] = args[i];
607     }
608   }
609   return res;
610 }
611 
ResetTopCellInfo(const TopCellInfoPtr & top_cell,const py::args & args)612 void ResetTopCellInfo(const TopCellInfoPtr &top_cell, const py::args &args) {
613   MS_EXCEPTION_IF_NULL(top_cell);
614   top_cell->set_op_num(0);
615   top_cell->all_op_info().clear();
616   top_cell->set_forward_already_run(true);
617   std::string input_args_id;
618   for (size_t i = 0; i < args.size(); ++i) {
619     input_args_id += GetId(args[i]) + "_";
620   }
621   top_cell->set_input_args_id(input_args_id);
622 }
623 
RunReplace(const CNodePtr & added_make_tuple,const std::vector<tensor::TensorPtr> & total_output_tensors,const FuncGraphPtr & grad_graph)624 void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
625                 const FuncGraphPtr &grad_graph) {
626   MS_EXCEPTION_IF_NULL(grad_graph);
627   MS_EXCEPTION_IF_NULL(added_make_tuple);
628   size_t index = 0;
629   for (size_t i = 1; i < added_make_tuple->size(); ++i) {
630     const auto &input_i = added_make_tuple->input(i);
631     MS_EXCEPTION_IF_NULL(input_i);
632     auto cnode = input_i->cast<CNodePtr>();
633     MS_EXCEPTION_IF_NULL(cnode);
634     MS_LOG(DEBUG) << "Replace new output tensors for cnode: " << cnode->DebugString();
635     auto output_vnode = cnode->forward().first;
636     MS_EXCEPTION_IF_NULL(output_vnode);
637     grad_graph->AddValueNode(output_vnode);
638     MS_LOG(DEBUG) << "Original output value node: " << output_vnode << " info: " << output_vnode->ToString();
639     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
640     if (index + output_num > total_output_tensors.size()) {
641       MS_LOG(EXCEPTION) << "The size of total_output_tensors: " << total_output_tensors.size()
642                         << ", but the current index: " << index << ", output num: " << output_num;
643     }
644     // Get new tensors.
645     std::vector<ValuePtr> new_values;
646     for (size_t j = index; j < index + output_num; ++j) {
647       new_values.push_back(total_output_tensors[j]);
648     }
649     index = index + output_num;
650     // Replace new tensors.
651     if (output_num == 1) {
652       output_vnode->set_value(new_values[0]);
653     } else if (output_num > 1) {
654       output_vnode->set_value(std::make_shared<ValueTuple>(new_values));
655     } else {
656       MS_LOG(EXCEPTION) << "The output value of forward cnode is empty, forward cnode info: " << cnode->ToString();
657     }
658     MS_LOG(DEBUG) << "New output value node: " << output_vnode << " info: " << output_vnode->ToString();
659   }
660   // Save op info with new tensors for current running ms_function func graph.
661   if (index != total_output_tensors.size()) {
662     MS_LOG(EXCEPTION) << "The index: " << index
663                       << " should be equal to the size of total_output_tensors: " << total_output_tensors.size();
664   }
665 }
666 
ReplaceNewTensorsInGradGraph(const TopCellInfoPtr & top_cell,const OpExecInfoPtr & op_exec_info,const ValuePtr & added_out,const FuncGraphPtr & ms_func_graph,const FuncGraphPtr & grad_graph)667 void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecInfoPtr &op_exec_info,
668                                   const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
669                                   const FuncGraphPtr &grad_graph) {
670   MS_EXCEPTION_IF_NULL(top_cell);
671   MS_EXCEPTION_IF_NULL(grad_graph);
672   MS_EXCEPTION_IF_NULL(op_exec_info);
673   MS_EXCEPTION_IF_NULL(ms_func_graph);
674   // Get added forward nodes.
675   auto merge_node = ms_func_graph->output();
676   MS_EXCEPTION_IF_NULL(merge_node);
677   auto merge_make_tuple = merge_node->cast<CNodePtr>();
678   MS_EXCEPTION_IF_NULL(merge_make_tuple);
679   constexpr size_t merge_output_size = 3;
680   if (merge_make_tuple->size() != merge_output_size) {
681     MS_LOG(EXCEPTION) << "The input size of merge make tuple node should be 3, but it is: " << merge_make_tuple->size();
682   }
683   constexpr size_t added_output_index = 2;
684   const auto &added_forward_node = merge_make_tuple->input(added_output_index);
685   MS_EXCEPTION_IF_NULL(added_forward_node);
686   if (added_forward_node->isa<ValueNode>()) {
687     MS_LOG(DEBUG) << "The added forward output node is value node: " << added_forward_node->DebugString();
688     std::vector<tensor::TensorPtr> total_output_tensors;
689     TensorValueToTensor(added_out, &total_output_tensors);
690     top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
691     return;
692   }
693   // Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
694   auto added_make_tuple = added_forward_node->cast<CNodePtr>();
695   MS_EXCEPTION_IF_NULL(added_make_tuple);
696   MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_make_tuple->DebugString();
697   std::vector<tensor::TensorPtr> total_output_tensors;
698   TensorValueToTensor(added_out, &total_output_tensors);
699   RunReplace(added_make_tuple, total_output_tensors, grad_graph);
700   top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
701 }
702 
SaveOpInfo(const TopCellInfoPtr & top_cell,const std::string & op_info,const std::vector<tensor::TensorPtr> & op_out_tensors)703 void SaveOpInfo(const TopCellInfoPtr &top_cell, const std::string &op_info,
704                 const std::vector<tensor::TensorPtr> &op_out_tensors) {
705   MS_EXCEPTION_IF_NULL(top_cell);
706   auto &op_info_with_tensor_id = top_cell->op_info_with_tensor_id();
707   if (op_info_with_tensor_id.find(op_info) != op_info_with_tensor_id.end()) {
708     MS_LOG(EXCEPTION) << "Top cell: " << top_cell.get() << " records op info with tensor id, but get op info "
709                       << op_info << " in op_info_with_tensor_id map";
710   }
711   // Record the relationship between the forward op and its output tensor id
712   std::for_each(op_out_tensors.begin(), op_out_tensors.end(),
713                 [&op_info_with_tensor_id, &op_info](const tensor::TensorPtr &tensor) {
714                   op_info_with_tensor_id[op_info].emplace_back(tensor->id());
715                 });
716 }
717 
UpdateTensorInfo(const tensor::TensorPtr & new_tensor,const std::vector<tensor::TensorPtr> & pre_tensors)718 void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<tensor::TensorPtr> &pre_tensors) {
719   MS_EXCEPTION_IF_NULL(new_tensor);
720   if (pre_tensors.empty()) {
721     MS_LOG(EXCEPTION) << "The size of pre tensors is empty.";
722   }
723 
724   const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
725   for (auto &pre_tensor : pre_tensors) {
726     MS_EXCEPTION_IF_NULL(pre_tensor);
727     MS_LOG(DEBUG) << "Replace Old tensor " << pre_tensor.get() << " id " << pre_tensor->id()
728                   << " device_address: " << pre_tensor->device_address() << " shape and type "
729                   << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor " << new_tensor.get() << " id "
730                   << new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
731                   << new_tensor->GetShapeAndDataTypeInfo();
732     pre_tensor->set_shape(new_tensor->shape());
733     pre_tensor->set_data_type(new_tensor->data_type());
734     if (device_target != kCPUDevice) {
735       pre_tensor->set_device_address(new_tensor->device_address());
736       continue;
737     }
738     // Replace data in device address when run in CPU device.
739     if (pre_tensor->device_address() != nullptr) {
740       auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address());
741       auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
742       MS_EXCEPTION_IF_NULL(old_device_address);
743       auto old_ptr = old_device_address->GetMutablePtr();
744       MS_EXCEPTION_IF_NULL(old_ptr);
745       MS_EXCEPTION_IF_NULL(new_device_address);
746       auto new_ptr = new_device_address->GetPtr();
747       MS_EXCEPTION_IF_NULL(new_ptr);
748       auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
749       if (ret != EOK) {
750         MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
751       }
752     }
753   }
754 }
755 
CheckPyNativeContext()756 void CheckPyNativeContext() {
757   const auto &parallel_context = parallel::ParallelContext::GetInstance();
758   MS_EXCEPTION_IF_NULL(parallel_context);
759   const auto &ms_context = MsContext::GetInstance();
760   MS_EXCEPTION_IF_NULL(ms_context);
761   const auto &parallel_mode = parallel_context->parallel_mode();
762   if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
763       ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
764     MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
765   }
766 }
767 
GetDstType(const TypeId & type_id)768 py::object GetDstType(const TypeId &type_id) {
769   ValuePtr value = nullptr;
770   if (type_id == kNumberTypeFloat16) {
771     value = std::make_shared<Float>(16);
772   } else if (type_id == kNumberTypeFloat32) {
773     value = std::make_shared<Float>(32);
774   } else if (type_id == kNumberTypeFloat64) {
775     value = std::make_shared<Float>(64);
776   } else if (type_id == kNumberTypeBool) {
777     value = std::make_shared<Bool>();
778   } else if (type_id == kNumberTypeInt8) {
779     value = std::make_shared<Int>(8);
780   } else if (type_id == kNumberTypeUInt8) {
781     value = std::make_shared<UInt>(8);
782   } else if (type_id == kNumberTypeInt16) {
783     value = std::make_shared<Int>(16);
784   } else if (type_id == kNumberTypeInt32) {
785     value = std::make_shared<Int>(32);
786   } else if (type_id == kNumberTypeInt64) {
787     value = std::make_shared<Int>(64);
788   } else {
789     MS_LOG(EXCEPTION) << "Not support dst type";
790   }
791   MS_EXCEPTION_IF_NULL(value);
792   return py::cast(value);
793 }
794 }  // namespace
795 
RealRunOp(const py::args & args)796 py::object RealRunOp(const py::args &args) {
797   CheckPyNativeContext();
798   const auto &executor = PynativeExecutor::GetInstance();
799   MS_EXCEPTION_IF_NULL(executor);
800   OpExecInfoPtr op_exec_info = executor->forward_executor()->GenerateOpExecInfo(args);
801   MS_EXCEPTION_IF_NULL(op_exec_info);
802   py::object ret = py::none();
803   PynativeExecutorTry(executor->forward_executor()->RunOpS, &ret, op_exec_info);
804   return ret;
805 }
806 
grad() const807 GradExecutorPtr ForwardExecutor::grad() const {
808   auto grad_executor = grad_executor_.lock();
809   MS_EXCEPTION_IF_NULL(grad_executor);
810   return grad_executor;
811 }
812 
IsSubCell(const std::string & cell_id) const813 bool TopCellInfo::IsSubCell(const std::string &cell_id) const {
814   if (sub_cell_list_.empty()) {
815     MS_LOG(DEBUG) << "The sub cell list is empty, there is no sub cell";
816     return false;
817   }
818   if (sub_cell_list_.find(cell_id) != sub_cell_list_.end()) {
819     return true;
820   }
821   return false;
822 }
823 
ClearDeviceMemory()824 void TopCellInfo::ClearDeviceMemory() {
825   MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
826   auto ms_context = MsContext::GetInstance();
827   MS_EXCEPTION_IF_NULL(ms_context);
828   const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
829   if (device_target == kCPUDevice) {
830     MS_LOG(DEBUG) << "No need to clear device address when run in CPU device.";
831     return;
832   }
833   k_pynative_cell_ptr_ = nullptr;
834   // Get all tensors obj in value node of running graph
835   std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
836   MS_EXCEPTION_IF_NULL(resource_);
837   const auto &bprop_graph = resource_->func_graph();
838   MS_EXCEPTION_IF_NULL(bprop_graph);
839   const auto &value_node_list = bprop_graph->value_nodes();
840   for (const auto &elem : value_node_list) {
841     auto &node = elem.first;
842     MS_EXCEPTION_IF_NULL(node);
843     auto value_node = node->cast<ValueNodePtr>();
844     MS_EXCEPTION_IF_NULL(value_node);
845     TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
846   }
847   for (const auto &tensor : tensors_in_bprop_graph) {
848     MS_EXCEPTION_IF_NULL(tensor);
849     MS_LOG(DEBUG) << "Clear device address for tensor: " << tensor->ToString();
850     tensor->set_device_address(nullptr);
851   }
852 }
853 
Clear()854 void TopCellInfo::Clear() {
855   MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
856   op_num_ = 0;
857   is_dynamic_ = false;
858   vm_compiled_ = false;
859   ms_function_flag_ = false;
860   is_init_kpynative_ = false;
861   need_compile_graph_ = false;
862   forward_already_run_ = false;
863   input_args_id_.clear();
864   all_op_info_.clear();
865   resource_ = nullptr;
866   df_builder_ = nullptr;
867   k_pynative_cell_ptr_ = nullptr;
868   graph_info_map_.clear();
869   sub_cell_list_.clear();
870   op_info_with_tensor_id_.clear();
871   tensor_id_with_tensor_object_.clear();
872   op_info_with_ms_func_forward_tensors_.clear();
873 }
874 
RunOpInner(py::object * ret,const OpExecInfoPtr & op_exec_info)875 void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info) {
876   MS_EXCEPTION_IF_NULL(ret);
877   MS_EXCEPTION_IF_NULL(op_exec_info);
878   MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name;
879   if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
880     RunMixedPrecisionCastOp(op_exec_info, ret);
881     return;
882   }
883 
884   // 1.Set cast for inputs
885   SetCastForInputs(op_exec_info);
886   // 2.Construct graph, first step abs will update by node
887   auto cnode = ConstructForwardGraph(op_exec_info);
888   // 3.Get inputs abstract
889   abstract::AbstractBasePtrList args_spec_list;
890   GetInputsArgsSpec(op_exec_info, &args_spec_list);
891   // 4.Get output abstract
892   bool prim_cache_hit = false;
893   GetOpOutputAbstract(op_exec_info, args_spec_list, &prim_cache_hit);
894   // 5.Get output
895   GetOpOutput(op_exec_info, args_spec_list, cnode, prim_cache_hit, ret);
896 }
897 
GenerateOpExecInfo(const py::args & args)898 OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
899   if (args.size() != PY_ARGS_NUM) {
900     MS_LOG(EXCEPTION) << "Three args are needed by RunOp";
901   }
902   const auto &op_exec_info = std::make_shared<OpExecInfo>();
903   const auto &op_name = py::cast<std::string>(args[PY_NAME]);
904   op_exec_info->op_name = op_name;
905 
906   const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
907   MS_EXCEPTION_IF_NULL(adapter);
908   auto prim = adapter->attached_primitive();
909   if (prim == nullptr) {
910     prim = std::make_shared<PrimitivePy>(args[PY_PRIM], adapter);
911     adapter->set_attached_primitive(prim);
912   }
913 
914   if (!prim->HasPyObj()) {
915     MS_LOG(EXCEPTION) << "Pyobj is empty";
916   }
917   op_exec_info->py_primitive = prim;
918   op_exec_info->op_inputs = args[PY_INPUTS];
919   op_exec_info->lazy_build = lazy_build_;
920   return op_exec_info;
921 }
922 
SetCastForInputs(const OpExecInfoPtr & op_exec_info)923 void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) {
924   MS_EXCEPTION_IF_NULL(op_exec_info);
925   // No need cast self
926   if (op_exec_info->op_name == prim::kPrimCast->name()) {
927     return;
928   }
929 
930   // Mixed precision conversion tensors which has cast dtype
931   SetTensorMixPrecisionCast(op_exec_info);
932   // Implicit transform
933   SetImplicitCast(op_exec_info);
934 }
935 
RunMixedPrecisionCastOp(const OpExecInfoPtr & op_exec_info,py::object * ret)936 void ForwardExecutor::RunMixedPrecisionCastOp(const OpExecInfoPtr &op_exec_info, py::object *ret) {
937   py::tuple res = RunOpWithInitBackendPolicy(op_exec_info);
938   MS_EXCEPTION_IF_NULL(ret);
939   if (res.size() == 1) {
940     *ret = res[0];
941     return;
942   }
943   *ret = std::move(res);
944 }
945 
SetNonCostantValueAbs(const AbstractBasePtr & abs,size_t i,const std::string & id)946 void ForwardExecutor::SetNonCostantValueAbs(const AbstractBasePtr &abs, size_t i, const std::string &id) {
947   MS_EXCEPTION_IF_NULL(abs);
948   if (abs->isa<abstract::AbstractTensor>()) {
949     abs->set_value(kAnyValue);
950   } else if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>()) {
951     const auto &abs_seq = abs->cast<abstract::AbstractSequeuePtr>();
952     MS_EXCEPTION_IF_NULL(abs_seq);
953     for (auto &item : abs_seq->elements()) {
954       MS_EXCEPTION_IF_NULL(item);
955       if (item->isa<abstract::AbstractTensor>()) {
956         item->set_value(kAnyValue);
957       }
958     }
959   }
960   MS_LOG(DEBUG) << "Set " << i << "th abs " << abs->ToString();
961   node_abs_map_[id] = abs;
962 }
963 
GetInputsArgsSpec(const OpExecInfoPtr & op_exec_info,abstract::AbstractBasePtrList * args_spec_list)964 void ForwardExecutor::GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info,
965                                         abstract::AbstractBasePtrList *args_spec_list) {
966   MS_EXCEPTION_IF_NULL(op_exec_info);
967   MS_EXCEPTION_IF_NULL(args_spec_list);
968   auto prim = op_exec_info->py_primitive;
969   MS_EXCEPTION_IF_NULL(prim);
970   for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
971     abstract::AbstractBasePtr abs = nullptr;
972     const auto &obj = op_exec_info->op_inputs[i];
973     const auto &id = GetId(obj);
974     MS_LOG(DEBUG) << "Set input abs " << id;
975     auto it = node_abs_map_.find(id);
976     if (it != node_abs_map_.end()) {
977       abs = it->second;
978     }
979     const auto const_input_index = prim->get_const_input_indexes();
980     bool have_const_input = !const_input_index.empty();
981     bool is_const_prim = prim->is_const_prim();
982     MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
983                   << prim->is_const_prim();
984     bool is_const_input =
985       have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
986     if (abs == nullptr || is_const_prim || is_const_input) {
987       abs = PyObjToValue(obj)->ToAbstract();
988       if (!is_const_prim && !is_const_input) {
989         SetNonCostantValueAbs(abs, i, id);
990       }
991     }
992     args_spec_list->emplace_back(abs);
993   }
994 }
995 
ConstructForwardGraph(const OpExecInfoPtr & op_exec_info)996 CNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_info) {
997   MS_EXCEPTION_IF_NULL(op_exec_info);
998   auto prim = op_exec_info->py_primitive;
999   std::vector<AnfNodePtr> inputs;
1000   std::vector<int64_t> op_masks;
1001   inputs.emplace_back(NewValueNode(prim));
1002   for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
1003     const auto &obj = op_exec_info->op_inputs[i];
1004     bool op_mask = false;
1005     tensor::MetaTensorPtr meta_tensor = nullptr;
1006     if (py::isinstance<tensor::MetaTensor>(obj)) {
1007       meta_tensor = obj.cast<tensor::MetaTensorPtr>();
1008       if (meta_tensor) {
1009         op_mask = meta_tensor->is_parameter();
1010       }
1011     }
1012     MS_LOG(DEBUG) << "Args i " << i << ", op mask " << op_mask;
1013     op_masks.emplace_back(static_cast<int64_t>(op_mask));
1014 
1015     // Construct grad graph
1016     if (grad()->need_construct_graph()) {
1017       const auto &id = GetId(obj);
1018       AnfNodePtr input_node = nullptr;
1019       input_node = grad()->GetInput(obj, op_mask);
1020       // update abstract
1021       if (input_node != nullptr) {
1022         if (input_node->abstract() != nullptr) {
1023           abstract::AbstractBasePtr abs = input_node->abstract();
1024           node_abs_map_[id] = abs;
1025         }
1026         inputs.emplace_back(input_node);
1027       }
1028     }
1029   }
1030   op_exec_info->inputs_mask = std::move(op_masks);
1031   CNodePtr cnode = nullptr;
1032   if (grad()->need_construct_graph()) {
1033     cnode = grad()->curr_g()->NewCNodeInOrder(inputs);
1034     MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << ", new cnode is " << cnode->DebugString();
1035   }
1036   return cnode;
1037 }
1038 
GetOpOutputAbstract(const OpExecInfoPtr & op_exec_info,const abstract::AbstractBasePtrList & args_spec_list,bool * prim_cache_hit)1039 void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
1040                                           const abstract::AbstractBasePtrList &args_spec_list, bool *prim_cache_hit) {
1041   MS_EXCEPTION_IF_NULL(op_exec_info);
1042   MS_EXCEPTION_IF_NULL(prim_cache_hit);
1043   auto op_name = op_exec_info->op_name;
1044   auto prim = op_exec_info->py_primitive;
1045   MS_EXCEPTION_IF_NULL(prim);
1046 
1047   AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
1048   auto temp = prim_abs_list_.find(key);
1049   if (temp != prim_abs_list_.end()) {
1050     MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
1051     auto iter = temp->second.find(args_spec_list);
1052     if (iter != temp->second.end()) {
1053       MS_LOG(DEBUG) << "Match prim ok " << op_name;
1054       op_exec_info->abstract = iter->second.abs;
1055       prim->set_evaluate_added_attrs(iter->second.attrs);
1056       *prim_cache_hit = true;
1057     }
1058   }
1059 
1060   if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
1061     // Use python infer method
1062     if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
1063       PynativeInfer(prim, op_exec_info.get(), args_spec_list);
1064     }
1065   }
1066   // Get output dynamic shape info
1067   auto abstract = op_exec_info->abstract;
1068   MS_EXCEPTION_IF_NULL(abstract);
1069   auto shape = abstract->BuildShape();
1070   MS_EXCEPTION_IF_NULL(shape);
1071 
1072   if (shape->IsDynamic()) {
1073     op_exec_info->is_dynamic_shape = true;
1074     // Dynamic shape operator in the current top cell, disable backend cache
1075     grad()->EnableOpGraphCache(false);
1076   }
1077 }
1078 
GetOpOutput(const OpExecInfoPtr & op_exec_info,const abstract::AbstractBasePtrList & args_spec_list,const CNodePtr & cnode,bool prim_cache_hit,py::object * ret)1079 void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
1080                                   const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
1081                                   bool prim_cache_hit, py::object *ret) {
1082   MS_EXCEPTION_IF_NULL(op_exec_info);
1083   auto prim = op_exec_info->py_primitive;
1084   MS_EXCEPTION_IF_NULL(prim);
1085   // Infer output value by constant folding
1086   MS_EXCEPTION_IF_NULL(ret);
1087   py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
1088   if (!output["value"].is_none()) {
1089     *ret = output["value"];
1090     grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
1091     return;
1092   }
1093   if (prim->is_const_prim()) {
1094     *ret = py::cast("");
1095     grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
1096     return;
1097   }
1098 
1099   // Add output abstract info into cache, the const value needs to infer evert step
1100   if (grad()->enable_op_cache() && !prim_cache_hit && !op_exec_info->is_dynamic_shape) {
1101     AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
1102     auto &out = prim_abs_list_[key];
1103     out[args_spec_list].abs = op_exec_info->abstract;
1104     out[args_spec_list].attrs = prim->evaluate_added_attrs();
1105   }
1106   // run op with selected backend
1107   auto result = RunOpWithInitBackendPolicy(op_exec_info);
1108   py::object out_real = result;
1109   if (result.size() == 1 && op_exec_info->abstract != nullptr &&
1110       !op_exec_info->abstract->isa<abstract::AbstractSequeue>()) {
1111     out_real = result[0];
1112   }
1113   // get output value
1114   ValuePtr out_real_value = nullptr;
1115   if (grad()->grad_flag()) {
1116     out_real_value = PyObjToValue(out_real);
1117   }
1118   // Save cnode info and build grad graph
1119   if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
1120     MS_EXCEPTION_IF_NULL(cnode);
1121     const auto &obj_id = GetId(out_real);
1122     cnode->set_abstract(op_exec_info->abstract);
1123     node_abs_map_[obj_id] = op_exec_info->abstract;
1124     grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
1125     grad()->DoOpGrad(op_exec_info, cnode, out_real_value);
1126   } else {
1127     node_abs_map_.clear();
1128   }
1129   // Record op info for judge whether the construct of cell has been changed
1130   grad()->RecordGradOpInfo(op_exec_info, out_real_value);
1131   grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
1132   *ret = out_real;
1133 }
1134 
DoAutoCast(const py::object & arg,const TypeId & type_id,const std::string & op_name,size_t index)1135 py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
1136                                        size_t index) {
1137   static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
1138   const auto &op_exec_info = std::make_shared<OpExecInfo>();
1139   op_exec_info->op_name = prim::kPrimCast->name();
1140   const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
1141   MS_EXCEPTION_IF_NULL(adapter);
1142   auto prim = adapter->attached_primitive();
1143   if (prim == nullptr) {
1144     prim = std::make_shared<PrimitivePy>(cast_prim, adapter);
1145     adapter->set_attached_primitive(prim);
1146   }
1147   op_exec_info->py_primitive = prim;
1148   op_exec_info->is_mixed_precision_cast = true;
1149   op_exec_info->next_op_name = op_name;
1150   op_exec_info->next_input_index = index;
1151   py::object dst_type = GetDstType(type_id);
1152   py::tuple inputs(ARG_SIZE);
1153   inputs[0] = arg;
1154   inputs[1] = dst_type;
1155   op_exec_info->op_inputs = inputs;
1156   op_exec_info->lazy_build = lazy_build_;
1157   py::object ret = py::none();
1158   RunOpInner(&ret, op_exec_info);
1159   return ret;
1160 }
1161 
DoAutoCastTuple(const py::tuple & tuple,const TypeId & type_id,const std::string & op_name,size_t index)1162 py::object ForwardExecutor::DoAutoCastTuple(const py::tuple &tuple, const TypeId &type_id, const std::string &op_name,
1163                                             size_t index) {
1164   auto tuple_size = tuple.size();
1165   py::tuple result(tuple_size);
1166   for (size_t i = 0; i < tuple_size; i++) {
1167     if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
1168       result[i] = DoAutoCastTuple(tuple[i], type_id, op_name, index);
1169     } else {
1170       result[i] = DoAutoCast(tuple[i], type_id, op_name, index);
1171     }
1172   }
1173   return std::move(result);
1174 }
1175 
DoParamMixPrecisionCast(bool * is_cast,const py::object & obj,const std::string & op_name,size_t index)1176 py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name,
1177                                                     size_t index) {
1178   MS_EXCEPTION_IF_NULL(is_cast);
1179   const auto &tensor = py::cast<tensor::TensorPtr>(obj);
1180   MS_EXCEPTION_IF_NULL(tensor);
1181   const auto &cast_type = tensor->cast_dtype();
1182   if (cast_type != nullptr) {
1183     auto source_element = tensor->Dtype();
1184     if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
1185       MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
1186       *is_cast = true;
1187       return DoAutoCast(obj, cast_type->type_id(), op_name, index);
1188     }
1189   }
1190   return obj;
1191 }
1192 
DoParamMixPrecisionCastTuple(bool * is_cast,const py::tuple & tuple,const std::string & op_name,size_t index)1193 py::object ForwardExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple,
1194                                                          const std::string &op_name, size_t index) {
1195   MS_EXCEPTION_IF_NULL(is_cast);
1196   auto tuple_size = tuple.size();
1197   py::tuple result(tuple_size);
1198   for (size_t i = 0; i < tuple_size; i++) {
1199     if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
1200       MS_LOG(DEBUG) << "Call cast for item " << i;
1201       result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index);
1202     } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
1203       result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index);
1204     } else {
1205       result[i] = tuple[i];
1206     }
1207   }
1208   return std::move(result);
1209 }
1210 
DoSignatrueCast(const PrimitivePyPtr & prim,const std::unordered_map<SignatureEnumDType,TypeId> & dst_type,const std::vector<SignatureEnumDType> & dtypes,const OpExecInfoPtr & op_exec_info)1211 void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim,
1212                                       const std::unordered_map<SignatureEnumDType, TypeId> &dst_type,
1213                                       const std::vector<SignatureEnumDType> &dtypes,
1214                                       const OpExecInfoPtr &op_exec_info) {
1215   MS_EXCEPTION_IF_NULL(prim);
1216   MS_EXCEPTION_IF_NULL(op_exec_info);
1217   const auto &signature = prim->signatures();
1218   auto &input_args = op_exec_info->op_inputs;
1219   size_t input_args_size = input_args.size();
1220   if (!dtypes.empty() && input_args_size > dtypes.size()) {
1221     MS_LOG(EXCEPTION) << "The input args size " << input_args_size << " exceeds the size of dtypes " << dtypes.size();
1222   }
1223   for (size_t i = 0; i < input_args_size; ++i) {
1224     // No need to implicit cast if no dtype.
1225     if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
1226       continue;
1227     }
1228     auto it = dst_type.find(dtypes[i]);
1229     if (it == dst_type.end() || it->second == kTypeUnknown) {
1230       continue;
1231     }
1232     MS_LOG(DEBUG) << "Check inputs " << i;
1233     const auto &obj = input_args[i];
1234     auto sig = SignatureEnumRW::kRWDefault;
1235     if (!signature.empty()) {
1236       if (i >= signature.size()) {
1237         MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
1238                                  << ", index " << i;
1239       }
1240       sig = signature[i].rw;
1241     }
1242     TypeId arg_type_id = kTypeUnknown;
1243     if (py::isinstance<tensor::MetaTensor>(obj)) {
1244       const auto &arg = py::cast<tensor::MetaTensorPtr>(obj);
1245       arg_type_id = arg->data_type();
1246     }
1247     // Implicit cast
1248     bool is_same_type = false;
1249     if (arg_type_id != kTypeUnknown) {
1250       is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
1251     }
1252     if (sig == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) {
1253       prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
1254                                              TypeIdToMsTypeStr(it->second));
1255     }
1256     if (is_same_type) {
1257       continue;
1258     }
1259 
1260     if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
1261       MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
1262                               << "th input is a not support implicit conversion type: "
1263                               << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
1264                               << py::cast<py::str>(obj) << ".";
1265     }
1266     py::object cast_output = DoAutoCast(input_args[i], it->second, op_exec_info->op_name, i);
1267     input_args[i] = cast_output;
1268   }
1269 }
1270 
SetTensorMixPrecisionCast(const OpExecInfoPtr & op_exec_info)1271 void ForwardExecutor::SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
1272   MS_EXCEPTION_IF_NULL(op_exec_info);
1273   const auto &prim = op_exec_info->py_primitive;
1274   MS_EXCEPTION_IF_NULL(prim);
1275   const auto &signature = prim->signatures();
1276   for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
1277     const auto &obj = op_exec_info->op_inputs[i];
1278     auto sig = SignatureEnumRW::kRWDefault;
1279     if (!signature.empty()) {
1280       if (i >= signature.size()) {
1281         MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
1282                                  << ", index " << i;
1283       }
1284       sig = signature[i].rw;
1285     }
1286     MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i;
1287     // mix precision for non param
1288     bool is_cast = false;
1289     py::object cast_output;
1290     if (py::isinstance<tensor::MetaTensor>(obj)) {
1291       auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
1292       if (meta_tensor && meta_tensor->is_parameter()) {
1293         // If parameter write(not kRWRead), no need cast
1294         if (sig != SignatureEnumRW::kRWRead) {
1295           continue;
1296         }
1297       }
1298       cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i);
1299     } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
1300       // mix precision for tuple inputs
1301       cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i);
1302     }
1303     if (is_cast) {
1304       op_exec_info->op_inputs[i] = cast_output;
1305     }
1306   }
1307 }
1308 
SetImplicitCast(const OpExecInfoPtr & op_exec_info)1309 void ForwardExecutor::SetImplicitCast(const OpExecInfoPtr &op_exec_info) {
1310   MS_EXCEPTION_IF_NULL(op_exec_info);
1311   const auto &prim = op_exec_info->py_primitive;
1312   MS_EXCEPTION_IF_NULL(prim);
1313   const auto &it = implicit_cast_map_.find(prim->name());
1314   if (it == implicit_cast_map_.end()) {
1315     MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name << " first";
1316     const auto &signature = prim->signatures();
1317     auto sig_size = signature.size();
1318     // Ignore monad signature
1319     for (const auto &sig : signature) {
1320       if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
1321         --sig_size;
1322       }
1323     }
1324     auto size = op_exec_info->op_inputs.size();
1325     if (sig_size > 0 && sig_size != size) {
1326       MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
1327                                << "signature size " << sig_size;
1328     }
1329     std::vector<SignatureEnumDType> dtypes;
1330     std::unordered_map<SignatureEnumDType, std::vector<size_t>> type_indexes;
1331     bool has_dtype_sig = GetSignatureType(op_exec_info->py_primitive, &dtypes);
1332     if (has_dtype_sig) {
1333       std::unordered_map<SignatureEnumDType, TypeId> dst_type;
1334       GetTypeIndex(dtypes, &type_indexes);
1335       GetDstType(op_exec_info->op_inputs, type_indexes, &dst_type);
1336       DoSignatrueCast(op_exec_info->py_primitive, dst_type, dtypes, op_exec_info);
1337     }
1338     PrimSignature sig_value{has_dtype_sig, dtypes, type_indexes};
1339     implicit_cast_map_[prim->name()] = sig_value;
1340   } else {
1341     if (!it->second.has_dtype_sig) {
1342       MS_LOG(DEBUG) << op_exec_info->op_name << " have no dtype sig";
1343       return;
1344     }
1345     MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name << " with cache";
1346     std::unordered_map<SignatureEnumDType, TypeId> dst_type;
1347     GetDstType(op_exec_info->op_inputs, it->second.type_indexes, &dst_type);
1348     DoSignatrueCast(op_exec_info->py_primitive, dst_type, it->second.dtypes, op_exec_info);
1349   }
1350 }
1351 
GetInput(const py::object & obj,bool op_mask)1352 AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) {
1353   AnfNodePtr node = nullptr;
1354   const auto &obj_id = GetId(obj);
1355 
1356   if (op_mask) {
1357     MS_LOG(DEBUG) << "Cell parameters(weights)";
1358     // get the parameter name from parameter object
1359     auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
1360     if (py::isinstance<py::none>(name_attr)) {
1361       MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
1362     }
1363     const auto &param_name = py::cast<std::string>(name_attr);
1364     auto df_builder = top_cell()->df_builder();
1365     MS_EXCEPTION_IF_NULL(df_builder);
1366     auto graph_info = top_cell()->graph_info_map().at(df_builder);
1367     MS_EXCEPTION_IF_NULL(graph_info);
1368     if (graph_info->params.find(obj_id) == graph_info->params.end()) {
1369       auto free_param = df_builder->add_parameter();
1370       free_param->set_name(param_name);
1371       free_param->debug_info()->set_name(param_name);
1372       auto value = py::cast<tensor::TensorPtr>(obj);
1373       free_param->set_default_param(value);
1374       MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
1375       SetParamNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
1376       SetParamNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
1377       SetNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
1378       SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
1379       return free_param;
1380     }
1381     node = graph_info->params.at(obj_id);
1382     MS_EXCEPTION_IF_NULL(node);
1383     MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
1384     return node;
1385   }
1386 
1387   auto curr_graph_info = top_cell()->graph_info_map().at(curr_g_);
1388   MS_EXCEPTION_IF_NULL(curr_graph_info);
1389   if (curr_graph_info->node_map.find(obj_id) != curr_graph_info->node_map.end()) {
1390     // op(x, y)
1391     // out = op(op1(x, y))
1392     // out = op(cell1(x, y))
1393     // out = op(cell1(x, y)[0])
1394     node = GetObjNode(obj, obj_id);
1395   } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
1396     // out = op((x, y))
1397     // out = cell((x, y))
1398     auto tuple = obj.cast<py::tuple>();
1399     // cell((1,2)): support not mix (scalar, tensor)
1400     if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
1401       return MakeValueNode(obj, obj_id);
1402     }
1403     std::vector<AnfNodePtr> args;
1404     args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1405     auto tuple_size = tuple.size();
1406     for (size_t i = 0; i < tuple_size; i++) {
1407       args.emplace_back(GetInput(tuple[i], false));
1408     }
1409     auto cnode = curr_g_->NewCNode(args);
1410     SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode);
1411     node = cnode;
1412   } else {
1413     node = MakeValueNode(obj, obj_id);
1414   }
1415   node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
1416                   : MS_LOG(DEBUG) << "Get input node " << node->ToString() << ", id " << obj_id;
1417   return node;
1418 }
1419 
GetObjNode(const py::object & obj,const std::string & obj_id)1420 AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
1421   auto graph_info = top_cell()->graph_info_map().at(curr_g_);
1422   MS_EXCEPTION_IF_NULL(graph_info);
1423   const auto &out = graph_info->node_map.at(obj_id);
1424   if (out.second.size() == 1 && out.second[0] == -1) {
1425     return out.first;
1426   }
1427   MS_LOG(DEBUG) << "Output size " << out.second.size();
1428 
1429   // Params node
1430   if (graph_info->params.find(obj_id) != graph_info->params.end()) {
1431     auto para_node = out.first;
1432     for (auto &v : out.second) {
1433       std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(v)};
1434       MS_EXCEPTION_IF_NULL(curr_g_);
1435       para_node = curr_g_->NewCNode(tuple_get_item_inputs);
1436     }
1437     return para_node;
1438   }
1439 
1440   // Normal node
1441   auto node = out.first->cast<CNodePtr>();
1442   MS_EXCEPTION_IF_NULL(node);
1443   auto abs = node->abstract();
1444   ValuePtr out_obj = nullptr;
1445   if (node->forward().first != nullptr) {
1446     out_obj = node->forward().first->value();
1447   } else {
1448     out_obj = PyObjToValue(obj);
1449   }
1450   for (const auto idx : out.second) {
1451     std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
1452     node = curr_g_->NewCNode(tuple_get_item_inputs);
1453     if (out_obj->isa<ValueTuple>()) {
1454       node->add_input_value(out_obj, "");
1455       node->add_input_value(MakeValue(idx), "");
1456       auto out_tuple = out_obj->cast<ValueTuplePtr>();
1457       MS_EXCEPTION_IF_NULL(out_tuple);
1458       if (static_cast<size_t>(idx) >= out_tuple->size()) {
1459         MS_LOG(EXCEPTION) << "Index exceeds the size of tuple. Index " << idx << ", tuple size " << out_tuple->size();
1460       }
1461       out_obj = (*out_tuple)[static_cast<size_t>(idx)];
1462       node->set_forward(NewValueNode(out_obj), "");
1463     }
1464     if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
1465       auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
1466       MS_EXCEPTION_IF_NULL(abs_tuple);
1467       const auto &elements = abs_tuple->elements();
1468       if (static_cast<size_t>(idx) >= elements.size()) {
1469         MS_LOG(EXCEPTION) << "Index exceeds the size of elements. Index " << idx << ", elements size "
1470                           << elements.size();
1471       }
1472       auto prim_abs = elements[static_cast<size_t>(idx)];
1473       MS_EXCEPTION_IF_NULL(prim_abs);
1474       MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
1475       node->set_abstract(prim_abs);
1476     }
1477   }
1478   if (node->abstract() != nullptr) {
1479     forward()->node_abs_map()[obj_id] = node->abstract();
1480   }
1481   MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString();
1482   return node;
1483 }
1484 
MakeValueNode(const py::object & obj,const std::string & obj_id)1485 AnfNodePtr GradExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
1486   ValuePtr converted_ret = nullptr;
1487   parse::ConvertData(obj, &converted_ret);
1488   auto node = NewValueNode(converted_ret);
1489   SetNodeMapInGraphInfoMap(curr_g_, obj_id, node);
1490   return node;
1491 }
1492 
GetTopCell(const std::string & already_run_cell_id)1493 TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id) {
1494   TopCellInfoPtr find_top_cell = nullptr;
1495   for (const auto &top_cell : top_cell_list_) {
1496     MS_EXCEPTION_IF_NULL(top_cell);
1497     // Complete match, means run grad operation first
1498     if (top_cell->already_run_cell_id() == already_run_cell_id) {
1499       return top_cell;
1500     }
1501     // Partial match, means run forward first
1502     if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
1503         top_cell->already_run_cell_id().back() == '_') {
1504       find_top_cell = top_cell;
1505       break;
1506     }
1507   }
1508   // Same topcell info, but grad operation is not the same, construct backward graph again
1509   if (find_top_cell != nullptr) {
1510     if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
1511       MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
1512                     << grad_operation_;
1513       EraseTopCellFromTopCellList(find_top_cell);
1514       (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
1515       return nullptr;
1516     } else {
1517       return find_top_cell;
1518     }
1519   }
1520   return nullptr;
1521 }
1522 
EnableOpGraphCache(bool is_enable)1523 void GradExecutor::EnableOpGraphCache(bool is_enable) {
1524   MS_LOG(DEBUG) << "Op cache is enable: " << is_enable;
1525   enable_op_cache_ = is_enable;
1526   const auto inst = MsContext::GetInstance();
1527   MS_EXCEPTION_IF_NULL(inst);
1528   inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
1529 }
1530 
RecordGradOpInfo(const OpExecInfoPtr & op_exec_info,const ValuePtr & op_out)1531 void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
1532   if (!grad_flag_) {
1533     MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
1534     return;
1535   }
1536   MS_EXCEPTION_IF_NULL(op_exec_info);
1537   MS_EXCEPTION_IF_NULL(op_out);
1538   std::string input_args_info;
1539   // Record input args info (weight or data)
1540   for (const auto mask : op_exec_info->inputs_mask) {
1541     if (mask) {
1542       input_args_info += "w";
1543       continue;
1544     }
1545     input_args_info += "d";
1546   }
1547   // Record op name and index
1548   op_exec_info->op_info.clear();
1549   const auto &curr_op_num = top_cell()->op_num();
1550   op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
1551   // The out shape is added to determine those ops that change the shape
1552   auto out_abs = op_out->ToAbstract();
1553   if (out_abs != nullptr) {
1554     auto out_shape = out_abs->BuildShape()->ToString();
1555     if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
1556       op_exec_info->op_info += "-" + out_shape;
1557     }
1558   }
1559   top_cell()->all_op_info() += "-" + op_exec_info->op_info;
1560   top_cell()->set_op_num(curr_op_num + 1);
1561 }
1562 
SaveOutputNodeMap(const std::string & obj_id,const py::object & out_real,const CNodePtr & cnode)1563 void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const CNodePtr &cnode) {
1564   if (cell_stack_.empty()) {
1565     MS_LOG(DEBUG) << "No need save output";
1566     return;
1567   }
1568   MS_EXCEPTION_IF_NULL(cnode);
1569   MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << " id " << obj_id;
1570   if (py::isinstance<py::tuple>(out_real)) {
1571     auto value = py::cast<py::tuple>(out_real);
1572     auto size = static_cast<int64_t>(value.size());
1573     if (size > 1) {
1574       for (int64_t i = 0; i < size; ++i) {
1575         auto value_id = GetId(value[static_cast<size_t>(i)]);
1576         SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i);
1577       }
1578     }
1579   }
1580   SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode);
1581 }
1582 
1583 // Run ad grad for curr op and connect grad graph with previous op
DoOpGrad(const OpExecInfoPtr & op_exec_info,const CNodePtr & cnode,const ValuePtr & op_out)1584 void GradExecutor::DoOpGrad(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const ValuePtr &op_out) {
1585   MS_EXCEPTION_IF_NULL(op_out);
1586   if (grad_is_running_ && !bprop_grad_stack_.top().second) {
1587     MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
1588     return;
1589   }
1590   ValuePtrList input_args;
1591   for (size_t i = 0; i < op_exec_info->op_inputs.size(); ++i) {
1592     const auto &arg = PyObjToValue(op_exec_info->op_inputs[i]);
1593     input_args.emplace_back(arg);
1594   }
1595 
1596   if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, op_out)) {
1597     MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_exec_info->op_name;
1598   }
1599 }
1600 
UpdateMsFunctionForwardTensors(const OpExecInfoPtr & op_exec_info,const ValuePtr & new_forward_value)1601 void GradExecutor::UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info,
1602                                                   const ValuePtr &new_forward_value) {
1603   MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase();
1604   MS_EXCEPTION_IF_NULL(new_forward_value);
1605   MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
1606   std::vector<tensor::TensorPtr> new_tensors;
1607   TensorValueToTensor(new_forward_value, &new_tensors);
1608   if (new_tensors.empty()) {
1609     MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
1610     return;
1611   }
1612 
1613   MS_EXCEPTION_IF_NULL(op_exec_info);
1614   const auto &old_tensors = top_cell()->op_info_with_ms_func_forward_tensors().at(op_exec_info->op_info);
1615   if (old_tensors.size() != new_tensors.size()) {
1616     MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
1617                       << ", but the size of new tensors is: " << new_tensors.size()
1618                       << ", the current op info is: " << op_exec_info->op_info;
1619   }
1620   for (size_t i = 0; i < new_tensors.size(); ++i) {
1621     UpdateTensorInfo(new_tensors[i], {old_tensors[i]});
1622     old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
1623   }
1624 }
1625 
MakeCNodeForMsFunction(const FuncGraphPtr & ms_func_graph,const py::args & args,ValuePtrList * input_values,CNodePtr * ms_function_cnode)1626 void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
1627                                           ValuePtrList *input_values, CNodePtr *ms_function_cnode) {
1628   // Get input node info of ms_function
1629   MS_EXCEPTION_IF_NULL(ms_func_graph);
1630   std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
1631   MS_EXCEPTION_IF_NULL(input_values);
1632   for (size_t i = 0; i < args.size(); ++i) {
1633     auto input_i_node = GetInput(args[i], false);
1634     MS_EXCEPTION_IF_NULL(input_i_node);
1635     MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
1636     input_nodes.emplace_back(input_i_node);
1637     const auto &inp_i_value = PyObjToValue(args[i]);
1638     MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << inp_i_value->ToString();
1639     (*input_values).emplace_back(inp_i_value);
1640   }
1641 
1642   // Get dfbuilder and graph info map
1643   auto df_builder = top_cell()->df_builder();
1644   MS_EXCEPTION_IF_NULL(df_builder);
1645   const auto &graph_info = top_cell()->graph_info_map().at(df_builder);
1646   MS_EXCEPTION_IF_NULL(graph_info);
1647   // Get weights info of ms_function
1648   std::vector<AnfNodePtr> new_params;
1649   auto manage = Manage(ms_func_graph, false);
1650   for (const auto &anf_node : ms_func_graph->parameters()) {
1651     MS_EXCEPTION_IF_NULL(anf_node);
1652     auto param = anf_node->cast<ParameterPtr>();
1653     MS_EXCEPTION_IF_NULL(param);
1654     if (!param->has_default()) {
1655       new_params.push_back(param);
1656       continue;
1657     }
1658     auto param_info = param->param_info();
1659     MS_EXCEPTION_IF_NULL(param_info);
1660     auto param_name = param_info->name();
1661     if (graph_info->params.count(param_name)) {
1662       // Share same weight parameter in different ms_function call.
1663       auto same_param = graph_info->params.at(param_name);
1664       manage->Replace(anf_node, same_param);
1665       param = same_param;
1666     } else {
1667       df_builder->add_parameter(param);
1668       param->debug_info()->set_name(param_name);
1669     }
1670     new_params.push_back(param);
1671     input_nodes.emplace_back(param);
1672     (*input_values).emplace_back(param->default_param());
1673     SetParamNodeMapInGraphInfoMap(df_builder, param_name, param);
1674     MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
1675                   << param->default_param()->ToString() << ". Its name is: " << param_name;
1676   }
1677   ms_func_graph->set_parameters(new_params);
1678   manage->Clear();
1679 
1680   // Make a CNode which includes ms_function fprop graph and inputs node
1681   MS_EXCEPTION_IF_NULL(ms_function_cnode);
1682   *ms_function_cnode = curr_g_->NewCNode(input_nodes);
1683   MS_LOG(DEBUG) << "Make ms function forward cnode: " << (*ms_function_cnode)->DebugString();
1684 }
1685 
1686 // Make adjoint for ms_function fprop graph and connect it with previous op
MakeAdjointForMsFunction(const FuncGraphPtr & ms_func_graph,const FuncGraphPtr & grad_graph,const py::object & actual_out,const py::args & args,const ValuePtr & actual_out_v)1687 void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
1688                                             const py::object &actual_out, const py::args &args,
1689                                             const ValuePtr &actual_out_v) {
1690   ValuePtrList input_values;
1691   CNodePtr ms_function_cnode = nullptr;
1692   MakeCNodeForMsFunction(ms_func_graph, args, &input_values, &ms_function_cnode);
1693   MS_EXCEPTION_IF_NULL(ms_function_cnode);
1694   SetTupleArgsToGraphInfoMap(curr_g_, actual_out, ms_function_cnode);
1695   SetNodeMapInGraphInfoMap(curr_g_, GetId(actual_out), ms_function_cnode);
1696 
1697   // Connect grad graph of ms_function to context.
1698   auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
1699   MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
1700   MS_EXCEPTION_IF_NULL(grad_graph);
1701   if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, actual_out_v, grad_graph)) {
1702     MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
1703                       << ms_function_cnode->DebugString();
1704   }
1705   top_cell()->set_ms_function_flag(true);
1706 }
1707 
UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr & op_exec_info,const ValuePtr & op_out)1708 void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
1709   if (!grad_flag_) {
1710     MS_LOG(DEBUG) << "The grad flag is false, no need to update forward op info in bprop graph";
1711     return;
1712   }
1713   MS_EXCEPTION_IF_NULL(op_exec_info);
1714   MS_EXCEPTION_IF_NULL(op_out);
1715   const auto &op_info = op_exec_info->op_info;
1716   MS_LOG(DEBUG) << "Current op info: " << op_info;
1717 
1718   std::vector<tensor::TensorPtr> all_op_tensors;
1719   // Get output tensors
1720   TensorValueToTensor(op_out, &all_op_tensors);
1721   // Save all tensors info of current op
1722   if (need_construct_graph()) {
1723     SaveOpInfo(top_cell_, op_info, all_op_tensors);
1724   }
1725 
1726   // First run top cell
1727   if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
1728     MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
1729     if (!need_construct_graph()) {
1730       MS_LOG(EXCEPTION) << "The cell stack is empty when running a new top cell " << top_cell_->cell_id();
1731     }
1732     return;
1733   }
1734   // Non-first run
1735   const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
1736   MS_EXCEPTION_IF_NULL(pre_top_cell);
1737   if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
1738     MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
1739                   << top_cell_->cell_id();
1740     return;
1741   }
1742 
1743   // Update new output tensor info in bprop graph
1744   const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_info);
1745   if (pre_op_tensor_id.size() != all_op_tensors.size()) {
1746     MS_LOG(EXCEPTION) << "The size of pre op tensor id: " << pre_op_tensor_id.size()
1747                       << " is not equal to the size of all tensors of current op " << all_op_tensors.size();
1748   }
1749   const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object();
1750   for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) {
1751     auto pre_id = pre_op_tensor_id[i];
1752     if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) {
1753       continue;
1754     }
1755     const auto &new_tensor = all_op_tensors[i];
1756     const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id);
1757     UpdateTensorInfo(new_tensor, pre_tensor_object);
1758   }
1759 }
1760 
SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr & resource) const1761 void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
1762   MS_EXCEPTION_IF_NULL(resource);
1763   // Get all tensors id of forward op
1764   std::unordered_set<std::string> forward_op_tensor_id;
1765   const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
1766   for (const auto &record : op_info_with_tensor_id) {
1767     std::for_each(record.second.begin(), record.second.end(),
1768                   [&forward_op_tensor_id](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); });
1769   }
1770   // Get all tensors obj in value node of bprop graph
1771   const auto &bprop_graph = resource->func_graph();
1772   MS_EXCEPTION_IF_NULL(bprop_graph);
1773   const auto &value_node_list = bprop_graph->value_nodes();
1774   std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
1775   for (const auto &elem : value_node_list) {
1776     auto value_node = elem.first->cast<ValueNodePtr>();
1777     MS_EXCEPTION_IF_NULL(value_node);
1778     TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
1779   }
1780 
1781   auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object();
1782   if (!tensor_id_with_tensor_object.empty()) {
1783     MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: "
1784                       << top_cell()->cell_id();
1785   }
1786   // Save tensor in value node of bprop graph
1787   for (const auto &tensor : tensors_in_bprop_graph) {
1788     MS_EXCEPTION_IF_NULL(tensor);
1789     if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) {
1790       continue;
1791     }
1792     tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
1793     MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
1794                   << " device address: " << tensor->device_address() << " shape and dtype "
1795                   << tensor->GetShapeAndDataTypeInfo();
1796   }
1797 }
1798 
RunOpWithInitBackendPolicy(const OpExecInfoPtr & op_exec_info)1799 py::tuple ForwardExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
1800   MS_EXCEPTION_IF_NULL(op_exec_info);
1801   auto backend_policy = InitEnv(op_exec_info);
1802   PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
1803   // returns a null py::tuple on error
1804   py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
1805   if (status != PYNATIVE_SUCCESS) {
1806     MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name;
1807   }
1808   MS_LOG(DEBUG) << "RunOp end";
1809   return result;
1810 }
1811 
InitEnv(const OpExecInfoPtr & op_exec_info)1812 MsBackendPolicy ForwardExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
1813   MS_EXCEPTION_IF_NULL(op_exec_info);
1814   MS_LOG(DEBUG) << "RunOp start, op name is: " << op_exec_info->op_name;
1815   parse::python_adapter::set_python_env_flag(true);
1816   MsBackendPolicy backend_policy;
1817 #if (!defined ENABLE_GE)
1818   auto ms_context = MsContext::GetInstance();
1819   MS_EXCEPTION_IF_NULL(ms_context);
1820   if (!context::IsTsdOpened(ms_context)) {
1821     if (!context::OpenTsd(ms_context)) {
1822       MS_LOG(EXCEPTION) << "Open tsd failed";
1823     }
1824   }
1825   if (ms_context->backend_policy() == "ms") {
1826     backend_policy = kMsBackendMsPrior;
1827   } else {
1828     backend_policy = kMsBackendVmOnly;
1829   }
1830 #else
1831   auto ms_context = MsContext::GetInstance();
1832   MS_EXCEPTION_IF_NULL(ms_context);
1833   context::PynativeInitGe(ms_context);
1834   backend_policy = kMsBackendGeOnly;
1835 #endif
1836   if (kVmOperators.find(op_exec_info->op_name) != kVmOperators.end()) {
1837     backend_policy = kMsBackendVmOnly;
1838   }
1839   return backend_policy;
1840 }
1841 
RunOpWithBackendPolicy(MsBackendPolicy backend_policy,const OpExecInfoPtr & op_exec_info,PynativeStatusCode * status)1842 py::object ForwardExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
1843                                                    PynativeStatusCode *status) {
1844   MS_EXCEPTION_IF_NULL(status);
1845   py::object result;
1846   switch (backend_policy) {
1847     case kMsBackendVmOnly: {
1848       // use vm only
1849       MS_LOG(DEBUG) << "RunOp use VM only backend";
1850       result = RunOpInVM(op_exec_info, status);
1851       break;
1852     }
1853     case kMsBackendGePrior: {
1854 #ifdef ENABLE_GE
1855       // use GE first, use vm when GE fails
1856       MS_LOG(DEBUG) << "RunOp use GE first backend";
1857       result = RunOpInGE(op_exec_info, status);
1858       if (*status != PYNATIVE_SUCCESS) {
1859         result = RunOpInVM(op_exec_info, status);
1860       }
1861 #endif
1862       break;
1863     }
1864     case kMsBackendMsPrior: {
1865       // use Ms first,use others when ms failed
1866       MS_LOG(DEBUG) << "RunOp use Ms first backend";
1867       result = RunOpInMs(op_exec_info, status);
1868       if (*status != PYNATIVE_SUCCESS) {
1869         MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
1870       }
1871       break;
1872     }
1873     default:
1874       MS_LOG(ERROR) << "No backend configured for run op";
1875   }
1876   return result;
1877 }
1878 
RunOpInVM(const OpExecInfoPtr & op_exec_info,PynativeStatusCode * status)1879 py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
1880   MS_LOG(DEBUG) << "RunOpInVM start";
1881   MS_EXCEPTION_IF_NULL(status);
1882   MS_EXCEPTION_IF_NULL(op_exec_info);
1883   MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
1884 
1885   auto &op_inputs = op_exec_info->op_inputs;
1886   if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
1887       op_exec_info->op_name == "stop_gradient") {
1888     py::tuple result(op_inputs.size());
1889     for (size_t i = 0; i < op_inputs.size(); i++) {
1890       py::object input = op_inputs[i];
1891       auto tensor = py::cast<tensor::TensorPtr>(input);
1892       MS_EXCEPTION_IF_NULL(tensor);
1893       if (op_exec_info->op_name == "HookBackward") {
1894         // the input object is not a output of forward cnode, eg: parameter
1895         result[i] = tensor;
1896       } else {
1897         // the input object is a output of forward cnode
1898         auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
1899         new_tensor->set_device_address(tensor->device_address());
1900         new_tensor->set_sync_status(tensor->sync_status());
1901         result[i] = new_tensor;
1902       }
1903     }
1904     *status = PYNATIVE_SUCCESS;
1905     MS_LOG(DEBUG) << "RunOpInVM end";
1906     return std::move(result);
1907   }
1908 
1909   auto primitive = op_exec_info->py_primitive;
1910   MS_EXCEPTION_IF_NULL(primitive);
1911   auto result = primitive->RunPyComputeFunction(op_inputs);
1912   MS_LOG(DEBUG) << "RunOpInVM end";
1913   if (py::isinstance<py::none>(result)) {
1914     MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
1915     *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
1916     py::tuple err_ret(0);
1917     return std::move(err_ret);
1918   }
1919   *status = PYNATIVE_SUCCESS;
1920   if (py::isinstance<py::tuple>(result)) {
1921     return result;
1922   }
1923   py::tuple tuple_result = py::make_tuple(result);
1924   return std::move(tuple_result);
1925 }
1926 
RunOpInMs(const OpExecInfoPtr & op_exec_info,PynativeStatusCode * status)1927 py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
1928   MS_EXCEPTION_IF_NULL(op_exec_info);
1929   MS_EXCEPTION_IF_NULL(status);
1930   MS_LOG(DEBUG) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
1931   auto ms_context = MsContext::GetInstance();
1932   MS_EXCEPTION_IF_NULL(ms_context);
1933   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
1934   compile::SetMindRTEnable();
1935 
1936   if (kSession == nullptr && !ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1937     const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1938     kSession = session::SessionFactory::Get().Create(device_target);
1939     MS_EXCEPTION_IF_NULL(kSession);
1940     kSession->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
1941   }
1942 
1943   std::vector<tensor::TensorPtr> input_tensors;
1944   std::vector<int64_t> tensors_mask;
1945   std::string graph_info;
1946   ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
1947   ConvertAttrToUnifyMindIR(op_exec_info);
1948   // get graph info for checking it whether existing in the cache
1949   GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
1950 #if defined(__APPLE__)
1951   session::OpRunInfo op_run_info = {op_exec_info->op_name,
1952                                     op_exec_info->py_primitive,
1953                                     op_exec_info->abstract,
1954                                     op_exec_info->is_dynamic_shape,
1955                                     op_exec_info->is_mixed_precision_cast,
1956                                     op_exec_info->lazy_build,
1957                                     op_exec_info->next_op_name,
1958                                     static_cast<int>(op_exec_info->next_input_index)};
1959 #else
1960   session::OpRunInfo op_run_info = {op_exec_info->op_name,
1961                                     op_exec_info->py_primitive,
1962                                     op_exec_info->abstract,
1963                                     op_exec_info->is_dynamic_shape,
1964                                     op_exec_info->is_mixed_precision_cast,
1965                                     op_exec_info->lazy_build,
1966                                     op_exec_info->next_op_name,
1967                                     op_exec_info->next_input_index};
1968 #endif
1969   VectorRef outputs;
1970   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1971     kSession->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
1972   } else {
1973     if (mind_rt_backend == nullptr) {
1974       const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1975       uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1976       mind_rt_backend = std::make_shared<compile::MindRTBackend>("ms", device_target, device_id);
1977     }
1978 
1979     mindspore::ScopedLongRunning long_running;
1980     const compile::ActorInfo &actor_info =
1981       mind_rt_backend->CompileGraph(op_run_info, graph_info, &tensors_mask, &input_tensors);
1982     mind_rt_backend->RunGraph(actor_info, &op_run_info, &tensors_mask, &input_tensors, &outputs);
1983   }
1984 
1985   if (op_exec_info->is_dynamic_shape) {
1986     op_exec_info->abstract = op_run_info.abstract;
1987   }
1988   auto result = BaseRefToPyData(outputs);
1989   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
1990   *status = PYNATIVE_SUCCESS;
1991   MS_LOG(DEBUG) << "End run op [" << op_exec_info->op_name << "] with backend policy ms";
1992   return result;
1993 }
1994 
ClearRes()1995 void ForwardExecutor::ClearRes() {
1996   MS_LOG(DEBUG) << "Clear forward res";
1997   lazy_build_ = false;
1998   implicit_cast_map_.clear();
1999   prim_abs_list_.clear();
2000   node_abs_map_.clear();
2001 }
2002 
forward() const2003 ForwardExecutorPtr GradExecutor::forward() const {
2004   auto forward_executor = forward_executor_.lock();
2005   MS_EXCEPTION_IF_NULL(forward_executor);
2006   return forward_executor;
2007 }
2008 
top_cell() const2009 TopCellInfoPtr GradExecutor::top_cell() const {
2010   MS_EXCEPTION_IF_NULL(top_cell_);
2011   return top_cell_;
2012 }
2013 
curr_g() const2014 FuncGraphPtr GradExecutor::curr_g() const {
2015   MS_EXCEPTION_IF_NULL(curr_g_);
2016   return curr_g_;
2017 }
2018 
PushCellStack(const std::string & cell_id)2019 void GradExecutor::PushCellStack(const std::string &cell_id) { cell_stack_.push(cell_id); }
2020 
PopCellStack()2021 void GradExecutor::PopCellStack() {
2022   if (cell_stack_.empty()) {
2023     MS_LOG(EXCEPTION) << "Stack cell_statck_ is empty";
2024   }
2025   cell_stack_.pop();
2026 }
2027 
PushHighOrderGraphStack(const TopCellInfoPtr & top_cell)2028 void GradExecutor::PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) {
2029   high_order_stack_.push(std::make_pair(curr_g_, top_cell));
2030 }
2031 
PopHighOrderGraphStack()2032 TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
2033   if (high_order_stack_.empty()) {
2034     MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
2035   }
2036   high_order_stack_.pop();
2037   TopCellInfoPtr top_cell = nullptr;
2038   if (!high_order_stack_.empty()) {
2039     auto t = high_order_stack_.top();
2040     curr_g_ = t.first;
2041     top_cell = t.second;
2042   }
2043   return top_cell;
2044 }
2045 
GetCellId(const py::object & cell,const py::args & args)2046 std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args) {
2047   auto cell_id = GetId(cell);
2048   for (size_t i = 0; i < args.size(); i++) {
2049     const auto &arg_id = GetId(args[i]);
2050     auto it = forward()->node_abs_map().find(arg_id);
2051     if (it != forward()->node_abs_map().end()) {
2052       auto &abs = it->second;
2053       MS_EXCEPTION_IF_NULL(abs);
2054       auto shape = abs->BuildShape();
2055       MS_EXCEPTION_IF_NULL(shape);
2056       auto type = abs->BuildType();
2057       MS_EXCEPTION_IF_NULL(type);
2058       cell_id += "_" + shape->ToString();
2059       cell_id += type->ToString();
2060     } else {
2061       auto value = PyObjToValue(args[i]);
2062       MS_EXCEPTION_IF_NULL(value);
2063       auto abs = value->ToAbstract();
2064       MS_EXCEPTION_IF_NULL(abs);
2065       if (abs->isa<abstract::AbstractTensor>()) {
2066         abs->set_value(kAnyValue);
2067       }
2068       forward()->node_abs_map()[arg_id] = abs;
2069       auto shape = abs->BuildShape();
2070       MS_EXCEPTION_IF_NULL(shape);
2071       auto type = abs->BuildType();
2072       MS_EXCEPTION_IF_NULL(type);
2073       cell_id += "_" + shape->ToString();
2074       cell_id += type->ToString();
2075     }
2076   }
2077   return cell_id;
2078 }
2079 
DumpGraphIR(const std::string & filename,const FuncGraphPtr & graph)2080 void GradExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
2081 #ifdef ENABLE_DUMP_IR
2082   auto ms_context = MsContext::GetInstance();
2083   MS_EXCEPTION_IF_NULL(ms_context);
2084   if (ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
2085     DumpIR(filename, graph);
2086   }
2087 #endif
2088 }
2089 
IsNestedGrad() const2090 inline bool GradExecutor::IsNestedGrad() const {
2091   MS_LOG(DEBUG) << "Grad nested order is " << grad_order_;
2092   return grad_order_ > 1;
2093 }
2094 
IsCellObjIdEq(const std::string & l_cell_id,const std::string & r_cell_id) const2095 bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const {
2096   // just compare obj_id, ignore args id
2097   return l_cell_id.compare(0, PTR_LEN, r_cell_id, 0, PTR_LEN) == 0;
2098 }
2099 
IsBpropGraph(const std::string & cell_id)2100 bool GradExecutor::IsBpropGraph(const std::string &cell_id) {
2101   if (top_cell_ == nullptr) {
2102     return false;
2103   }
2104   return std::any_of(bprop_cell_list_.begin(), bprop_cell_list_.end(),
2105                      [&cell_id](const std::string &value) { return cell_id.find(value) != std::string::npos; });
2106 }
2107 
UpdateTopCellInfo(bool forward_already_run,bool need_compile_graph,bool vm_compiled)2108 void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled) {
2109   top_cell()->set_vm_compiled(vm_compiled);
2110   top_cell()->set_need_compile_graph(need_compile_graph);
2111   top_cell()->set_forward_already_run(forward_already_run);
2112 }
2113 
ClearCellRes(const std::string & cell_id)2114 void GradExecutor::ClearCellRes(const std::string &cell_id) {
2115   static bool clear_all_cell_res = false;
2116   // Grad clean
2117   if (cell_id.empty()) {
2118     MS_LOG(DEBUG) << "Clear all cell resources";
2119     clear_all_cell_res = true;
2120     for (const auto &iter : top_cell_list_) {
2121       MS_EXCEPTION_IF_NULL(iter);
2122       iter->Clear();
2123     }
2124     top_cell_list_.clear();
2125     already_run_top_cell_.clear();
2126     clear_all_cell_res = false;
2127     return;
2128   }
2129   if (clear_all_cell_res) {
2130     MS_LOG(DEBUG) << "In process of clearing all cell resources, so no need to clear single cell resource again";
2131     return;
2132   }
2133   // clear when cell destruction
2134   for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
2135     MS_EXCEPTION_IF_NULL(*it);
2136     const auto &top_cell_id = (*it)->cell_id();
2137     const auto &already_run_cell_id = (*it)->already_run_cell_id();
2138     if (IsCellObjIdEq(cell_id, top_cell_id)) {
2139       MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id;
2140       (*it)->Clear();
2141       it = top_cell_list_.erase(it);
2142       (void)already_run_top_cell_.erase(already_run_cell_id);
2143       continue;
2144     }
2145     ++it;
2146   }
2147 }
2148 
HandleInputArgsForTopCell(const py::args & args,bool is_bprop_top)2149 void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top) {
2150   if (is_bprop_top) {
2151     // Convert input args to parameters for top cell graph in bprop.
2152     for (size_t i = 0; i < args.size(); ++i) {
2153       auto param = args[i];
2154       auto new_param = curr_g_->add_parameter();
2155       const auto &param_id = GetId(param);
2156       SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
2157       SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
2158       SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
2159     }
2160     return;
2161   }
2162   // Convert input args to parameters for top cell graph in construct.
2163   std::vector<ValuePtr> input_param_values;
2164   const auto &only_tensors = FilterTensorArgs(args);
2165   for (size_t i = 0; i < only_tensors.size(); ++i) {
2166     auto new_param = curr_g_->add_parameter();
2167     auto param_i = only_tensors[i];
2168     const auto &param_i_value = PyObjToValue(param_i);
2169     input_param_values.emplace_back(param_i_value);
2170     auto param_i_abs = param_i_value->ToAbstract();
2171     MS_EXCEPTION_IF_NULL(param_i_abs);
2172     new_param->set_abstract(param_i_abs->Broaden());
2173     const auto &param_i_id = GetId(param_i);
2174     SetTupleArgsToGraphInfoMap(curr_g_, param_i, new_param, true);
2175     SetNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
2176     SetParamNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
2177     SetParamNodeMapInGraphInfoMap(top_cell_->df_builder(), param_i_id, new_param);
2178   }
2179   top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g_->parameters(), input_param_values));
2180 }
2181 
InitResourceAndDfBuilder(const std::string & cell_id,const py::args & args)2182 void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args) {
2183   if (cell_stack_.empty() || IsNestedGrad()) {
2184     if (cell_stack_.empty() && !grad_is_running_) {
2185       MS_LOG(DEBUG) << "Make new topest graph";
2186       MakeNewTopGraph(cell_id, args, true);
2187     } else if (grad_is_running_ && IsBpropGraph(cell_id)) {
2188       MS_LOG(DEBUG) << "Run bprop cell";
2189       curr_g_ = std::make_shared<FuncGraph>();
2190       auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
2191       top_cell()->graph_info_map()[curr_g_] = graph_info_cg;
2192       HandleInputArgsForTopCell(args, true);
2193       bprop_grad_stack_.push(std::make_pair(cell_id, false));
2194     } else if (grad_is_running_ && top_cell()->grad_order() != grad_order_) {
2195       MS_LOG(DEBUG) << "Nested grad graph existed in bprop";
2196       MakeNewTopGraph(cell_id, args, false);
2197       bprop_grad_stack_.push(std::make_pair(cell_id, true));
2198     } else if (!cell_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_) {
2199       MS_LOG(DEBUG) << "Nested grad graph existed in construct";
2200       auto cur_top_is_dynamic = top_cell()->is_dynamic();
2201       MakeNewTopGraph(cell_id, args, false);
2202       top_cell()->set_is_dynamic(cur_top_is_dynamic);
2203     }
2204   }
2205 
2206   PushCellStack(cell_id);
2207   // Init kPynativeCellPtr with input parameters of top cell
2208   if (!top_cell()->is_init_kpynative()) {
2209     auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
2210     top_cell()->graph_info_map()[curr_g_] = graph_info_cg;
2211     auto graph_info_df = std::make_shared<GraphInfo>(cell_id);
2212     top_cell()->graph_info_map()[top_cell_->df_builder()] = graph_info_df;
2213     HandleInputArgsForTopCell(args, false);
2214     top_cell()->set_need_compile_graph(true);
2215     top_cell()->set_init_kpynative(true);
2216   } else {
2217     // Non-top cell
2218     top_cell()->sub_cell_list().emplace(cell_id);
2219   }
2220 }
2221 
NewGraphInner(py::object * ret,const py::object & cell,const py::args & args)2222 void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
2223   MS_EXCEPTION_IF_NULL(ret);
2224   const auto &cell_id = GetCellId(cell, args);
2225   MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
2226   if (top_cell_ != nullptr && cell_stack_.empty()) {
2227     // Already run top cell need distinguish high order; high order add "0" otherwise "1"
2228     const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
2229     auto top_it = already_run_top_cell_.find(already_run_cell_id);
2230     if (top_it != already_run_top_cell_.end()) {
2231       // Top cell forward run.
2232       const auto &pre_top_cell = top_it->second;
2233       MS_EXCEPTION_IF_NULL(pre_top_cell);
2234       if (!pre_top_cell->is_dynamic()) {
2235         MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic, no need to run NewGraphInner again";
2236         ResetTopCellInfo(pre_top_cell, args);
2237         PushHighOrderGraphStack(pre_top_cell);
2238         set_top_cell(pre_top_cell);
2239         grad_order_ = pre_top_cell->grad_order();
2240         return;
2241       }
2242     } else if ((top_cell()->IsSubCell(cell_id) || GetHighOrderStackSize() >= 1) &&
2243                !IsCellObjIdEq(cell_id, check_graph_cell_id_)) {
2244       // Sub cell ( or may be a temporary cell, but must be non top) forward run in cache process.
2245       MS_LOG(DEBUG) << "Sub cell no need to run NewGraphInner again";
2246       return;
2247     }
2248   }
2249   // When the cell has custom bprop, in_custom_bprop_cell is lager than 0
2250   if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
2251     custom_bprop_cell_count_ += 1;
2252   }
2253   // Make top graph and init resource for resource and df_builder
2254   InitResourceAndDfBuilder(cell_id, args);
2255   // Check whether cell has dynamic construct
2256   if (!top_cell()->is_dynamic()) {
2257     bool is_dynamic = parse::DynamicParser::IsDynamicCell(cell);
2258     MS_LOG(DEBUG) << "Current cell dynamic " << is_dynamic;
2259     if (is_dynamic) {
2260       top_cell()->set_is_dynamic(is_dynamic);
2261     }
2262   }
2263 }
2264 
MakeNewTopGraph(const string & cell_id,const py::args & args,bool is_topest)2265 void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest) {
2266   pipeline::CheckArgsValid(args);
2267   // Record input args info
2268   std::string input_args_id;
2269   for (size_t i = 0; i < args.size(); ++i) {
2270     input_args_id += GetId(args[i]) + "_";
2271   }
2272   // Run forward first need plus 1
2273   if (grad_order_ == 0) {
2274     ++grad_order_;
2275   }
2276   // The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
2277   // disable backend cache
2278   if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
2279     EnableOpGraphCache(false);
2280     const auto last_top_cell = top_cell_list_.back();
2281     top_cell_list_.pop_back();
2282     MS_EXCEPTION_IF_NULL(last_top_cell);
2283     last_top_cell->Clear();
2284     (void)already_run_top_cell_.erase(last_top_cell->already_run_cell_id());
2285   }
2286   // Create top cell
2287   curr_g_ = std::make_shared<FuncGraph>();
2288   auto df_builder = std::make_shared<FuncGraph>();
2289   auto resource = std::make_shared<pipeline::Resource>();
2290   const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
2291   auto top_cell =
2292     std::make_shared<TopCellInfo>(is_topest, grad_order_, resource, df_builder, cell_id, already_run_cell_id);
2293   top_cell->set_forward_already_run(true);
2294   top_cell->set_input_args_id(input_args_id);
2295   top_cell_list_.emplace_back(top_cell);
2296   PushHighOrderGraphStack(top_cell);
2297   set_top_cell(top_cell);
2298   MS_LOG(DEBUG) << "New top graph, curr_g ptr " << curr_g_.get() << " resource ptr " << resource.get();
2299 }
2300 
SetTupleArgsToGraphInfoMap(const FuncGraphPtr & g,const py::object & args,const AnfNodePtr & node,bool is_param)2301 void GradExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
2302                                               bool is_param) {
2303   if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
2304     return;
2305   }
2306   auto tuple = args.cast<py::tuple>();
2307   auto tuple_size = static_cast<int64_t>(tuple.size());
2308   for (int64_t i = 0; i < tuple_size; ++i) {
2309     // tuple slice used size_t
2310     auto id = GetId(tuple[static_cast<size_t>(i)]);
2311     if (is_param && node->isa<Parameter>()) {
2312       auto param = node->cast<ParameterPtr>();
2313       MS_EXCEPTION_IF_NULL(param);
2314       SetParamNodeMapInGraphInfoMap(g, id, param);
2315     }
2316     SetNodeMapInGraphInfoMap(g, id, node, i);
2317     SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
2318   }
2319 }
2320 
SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr & g,const py::object & args,const AnfNodePtr & node,const std::vector<int64_t> & index_sequence,bool is_param)2321 void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
2322                                                   const std::vector<int64_t> &index_sequence, bool is_param) {
2323   if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
2324     return;
2325   }
2326   MS_EXCEPTION_IF_NULL(node);
2327   auto tuple = args.cast<py::tuple>();
2328   auto tuple_size = static_cast<int64_t>(tuple.size());
2329   for (int64_t i = 0; i < tuple_size; ++i) {
2330     std::vector<int64_t> tmp = index_sequence;
2331     tmp.emplace_back(i);
2332     // tuple slice used size_t
2333     auto id = GetId(tuple[static_cast<size_t>(i)]);
2334     if (is_param && node->isa<Parameter>()) {
2335       auto param = node->cast<ParameterPtr>();
2336       MS_EXCEPTION_IF_NULL(param);
2337       SetParamNodeMapInGraphInfoMap(g, id, param);
2338     }
2339     SetNodeMapInGraphInfoMap(g, id, node, tmp);
2340     SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
2341   }
2342 }
2343 
CreateMakeTupleNodeForMultiOut(const FuncGraphPtr & curr_g,const py::object & out,const std::string & out_id)2344 void GradExecutor::CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out,
2345                                                   const std::string &out_id) {
2346   MS_EXCEPTION_IF_NULL(curr_g);
2347   const auto &out_tuple = out.cast<py::tuple>();
2348   // get input node and value
2349   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
2350   ValuePtrList input_args;
2351   std::vector<size_t> value_index;
2352   for (size_t i = 0; i < out_tuple.size(); i++) {
2353     const auto &v = PyObjToValue(out_tuple[i]);
2354     // Graph have no define for grad
2355     if (v->isa<FuncGraph>()) {
2356       continue;
2357     }
2358     value_index.emplace_back(i);
2359     input_args.emplace_back(v);
2360     inputs.emplace_back(GetInput(out_tuple[i], false));
2361   }
2362   py::tuple value_outs(value_index.size());
2363   for (size_t i = 0; i < value_index.size(); ++i) {
2364     value_outs[i] = out_tuple[value_index[i]];
2365   }
2366   auto cnode = curr_g_->NewCNode(inputs);
2367   MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
2368   // record node info in graph map
2369   SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
2370   SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
2371   if (grad_is_running_ && !bprop_grad_stack_.top().second) {
2372     MS_LOG(DEBUG) << "Custom bprop, no need GradPynativeOp";
2373     return;
2374   }
2375   // run ad for maketuple node
2376   const auto &out_value = PyObjToValue(value_outs);
2377   ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value);
2378 }
2379 
EndGraphInner(py::object * ret,const py::object & cell,const py::object & out,const py::args & args)2380 void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) {
2381   MS_EXCEPTION_IF_NULL(ret);
2382   const auto &cell_id = GetCellId(cell, args);
2383   MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
2384   if (cell_stack_.empty()) {
2385     MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again";
2386     if (cell_id == top_cell()->cell_id()) {
2387       if (top_cell()->is_topest()) {
2388         set_grad_flag(false);
2389       }
2390       if (GetHighOrderStackSize() < ARG_SIZE) {
2391         auto outer_top_cell = PopHighOrderGraphStack();
2392         if (outer_top_cell != nullptr) {
2393           set_top_cell(outer_top_cell);
2394         }
2395       }
2396     }
2397     return;
2398   }
2399   // Make output node in this case: x = op1, y = op2, return (x, y)
2400   const auto &out_id = GetId(out);
2401   const auto &graph_info = top_cell()->graph_info_map().at(curr_g_);
2402   MS_EXCEPTION_IF_NULL(graph_info);
2403   if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
2404     if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
2405       CreateMakeTupleNodeForMultiOut(curr_g_, out, out_id);
2406     } else {
2407       MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
2408       MakeValueNode(out, out_id);
2409     }
2410   }
2411   DoGradForCustomBprop(cell, out, args);
2412   // Set output node for forward graph when need.
2413   PopCellStack();
2414   if (grad_is_running_ && !bprop_grad_stack_.empty()) {
2415     if (!bprop_grad_stack_.top().second) {
2416       bprop_grad_stack_.pop();
2417       MS_EXCEPTION_IF_NULL(curr_g_);
2418       curr_g_->set_output(GetObjNode(out, out_id));
2419       return;
2420     } else if (bprop_grad_stack_.top().first == cell_id) {
2421       bprop_grad_stack_.pop();
2422     }
2423   }
2424 
2425   bool is_top_cell_end = cell_id == top_cell()->cell_id();
2426   // Just only dump the last forward graph
2427   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && is_top_cell_end) {
2428     curr_g_->set_output(GetObjNode(out, out_id));
2429 #ifdef ENABLE_DUMP_IR
2430     DumpIR("fg.ir", curr_g_);
2431 #endif
2432   }
2433 
2434   // Reset grad flag and update output node of top cell
2435   if (cell_stack_.empty() && is_top_cell_end) {
2436     MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
2437     set_grad_flag(false);
2438     PopHighOrderGraphStack();
2439     // Update real output node of top cell for generating bprop graph
2440     AnfNodePtr output_node = GetObjNode(out, out_id);
2441     MS_EXCEPTION_IF_NULL(output_node);
2442     auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
2443     MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
2444     k_pynative_cell_ptr->UpdateOutputNodeOfTopCell(output_node);
2445   }
2446 
2447   // Checkout whether need to compile graph when top cell has ran finished
2448   if (is_top_cell_end) {
2449     CheckNeedCompileGraph();
2450   }
2451 }
2452 
DoGradForCustomBprop(const py::object & cell,const py::object & out,const py::args & args)2453 void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) {
2454   if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
2455     return;
2456   }
2457   custom_bprop_cell_count_ -= 1;
2458   if (custom_bprop_cell_count_ != 0) {
2459     return;
2460   }
2461   MS_LOG(DEBUG) << "Do grad for custom bprop";
2462   size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
2463   if (par_number > 0) {
2464     MS_LOG(EXCEPTION) << "When user defines the net bprop, the 'Parameter' data type is not supported in the net.";
2465   }
2466   py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
2467   auto bprop_func_cellid = GetId(bprop_func);
2468   bprop_cell_list_.emplace_back(bprop_func_cellid);
2469   auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
2470   fake_prim->set_hook(bprop_func);
2471   const auto &cell_id = GetCellId(cell, args);
2472   (void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
2473   (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
2474 
2475   py::object code_obj = py::getattr(bprop_func, "__code__");
2476   py::object co_name = py::getattr(code_obj, "co_name");
2477   if (std::string(py::str(co_name)) == "staging_specialize") {
2478     MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
2479   }
2480   // Three parameters self, out and dout need to be excluded
2481   const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
2482   if (inputs_num > args.size()) {
2483     MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
2484                             << args.size() << "]";
2485   }
2486 
2487   py::list cell_inputs;
2488   for (size_t i = 0; i < inputs_num; i += 1) {
2489     cell_inputs.append(args[i]);
2490   }
2491   OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
2492   op_exec_info->op_name = fake_prim->name();
2493   op_exec_info->py_primitive = fake_prim;
2494   op_exec_info->op_inputs = cell_inputs;
2495   auto cnode = forward()->ConstructForwardGraph(op_exec_info);
2496   const auto &v_out = PyObjToValue(out);
2497   DoOpGrad(op_exec_info, cnode, v_out);
2498   const auto &out_obj_id = GetId(out);
2499   SaveOutputNodeMap(out_obj_id, out, cnode);
2500 }
2501 
GetAlreadyRunCellId(const std::string & cell_id)2502 std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) {
2503   std::string already_run_cell_id;
2504   if (IsNestedGrad()) {
2505     already_run_cell_id = cell_id + "0";
2506   } else {
2507     already_run_cell_id = cell_id + "1";
2508   }
2509   already_run_cell_id += "_" + grad_operation_;
2510   MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id;
2511   return already_run_cell_id;
2512 }
2513 
GetGradCellId(bool has_sens,const py::object & cell,const py::args & args)2514 std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args) {
2515   size_t forward_args_size = args.size();
2516   py::args tmp = args;
2517   if (has_sens) {
2518     forward_args_size--;
2519     py::tuple f_args(forward_args_size);
2520     for (size_t i = 0; i < forward_args_size; ++i) {
2521       f_args[i] = args[i];
2522     }
2523     tmp = f_args;
2524   }
2525   const auto &cell_id = GetCellId(cell, tmp);
2526   return cell_id;
2527 }
2528 
GradNetInner(py::object * ret,const prim::GradOperationPtr & grad,const py::object & cell,const py::object & weights,const py::args & args)2529 void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
2530                                 const py::object &weights, const py::args &args) {
2531   MS_EXCEPTION_IF_NULL(ret);
2532   MS_EXCEPTION_IF_NULL(grad);
2533   auto size = args.size();
2534   const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args);
2535   MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id;
2536   if (!top_cell()->need_compile_graph()) {
2537     MS_LOG(DEBUG) << "No need compile graph";
2538     UpdateTopCellInfo(false, false, false);
2539     return;
2540   }
2541   top_cell()->set_grad_operation(grad_operation_);
2542   auto resource = top_cell()->resource();
2543   MS_EXCEPTION_IF_NULL(resource);
2544   auto df_builder = top_cell()->df_builder();
2545   MS_EXCEPTION_IF_NULL(df_builder);
2546   MS_LOG(DEBUG) << "curr_g ptr " << curr_g_.get() << " resource ptr " << resource.get();
2547 
2548   // Get params(weights) require derivative
2549   auto w_args = GetWeightsArgs(weights, df_builder);
2550   if (w_args.empty() && !df_builder->parameters().empty()) {
2551     MS_LOG(DEBUG) << "Add weights params to w_args";
2552     w_args.insert(w_args.end(), df_builder->parameters().begin(), df_builder->parameters().end());
2553   }
2554   // Get bprop graph of top cell
2555   auto bprop_graph = GetBpropGraph(grad, cell, w_args, size, args);
2556   resource->set_func_graph(bprop_graph);
2557   auto manager = resource->manager();
2558   MS_EXCEPTION_IF_NULL(manager);
2559   manager->AddFuncGraph(bprop_graph, true);
2560   DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
2561   // Launch bprop graph to backend
2562   SaveForwardTensorInfoInBpropGraph(resource);
2563   compile::SetMindRTEnable();
2564   resource->results()[pipeline::kBackend] = compile::CreateBackend();
2565   MS_LOG(DEBUG) << "Start task emit action";
2566   TaskEmitAction(resource);
2567   MS_LOG(DEBUG) << "Start execute action";
2568   ExecuteAction(resource);
2569   MS_LOG(DEBUG) << "Start update top cell info when run finish";
2570   UpdateTopCellInfo(false, false, true);
2571   resource->Clean();
2572   abstract::AnalysisContext::ClearContext();
2573 }
2574 
GetWeightsArgs(const py::object & weights,const FuncGraphPtr & df_builder)2575 std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) {
2576   MS_EXCEPTION_IF_NULL(df_builder);
2577   if (!py::hasattr(weights, "__parameter_tuple__")) {
2578     MS_LOG(DEBUG) << "No parameter tuple get";
2579     return {};
2580   }
2581 
2582   const auto &tuple = weights.cast<py::tuple>();
2583   MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size();
2584   std::vector<AnfNodePtr> w_args;
2585   for (size_t it = 0; it < tuple.size(); ++it) {
2586     auto param = tuple[it];
2587     auto param_id = GetId(param);
2588     auto &graph_info_map = top_cell()->graph_info_map();
2589     if (graph_info_map.find(df_builder) == graph_info_map.end()) {
2590       MS_LOG(EXCEPTION) << "Can not find df_builder " << df_builder.get() << " Top cell " << top_cell().get()
2591                         << " cell id " << top_cell()->cell_id();
2592     }
2593     auto graph_info = graph_info_map.at(df_builder);
2594     MS_EXCEPTION_IF_NULL(graph_info);
2595     AnfNodePtr para_node = nullptr;
2596     if (graph_info->params.find(param_id) != graph_info->params.end()) {
2597       para_node = graph_info->params.at(param_id);
2598       w_args.emplace_back(para_node);
2599       continue;
2600     }
2601     const auto &name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
2602     if (py::isinstance<py::none>(name_attr)) {
2603       MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
2604     }
2605     const auto &param_name = py::cast<std::string>(name_attr);
2606     MS_LOG(DEBUG) << "The input " << it << " parameter weight name " << param_name;
2607     if (graph_info->params.find(param_name) != graph_info->params.end()) {
2608       para_node = graph_info->params.at(param_name);
2609     } else {
2610       MS_LOG(DEBUG) << "Can not find input param in graph info map, make a new parameter";
2611       auto free_param = df_builder->add_parameter();
2612       free_param->set_name(param_name);
2613       auto value = py::cast<tensor::TensorPtr>(param);
2614       free_param->set_default_param(value);
2615       free_param->debug_info()->set_name(param_name);
2616       para_node = free_param;
2617     }
2618     w_args.emplace_back(para_node);
2619   }
2620   return w_args;
2621 }
2622 
GetArgsSpec(const py::list & args,const FuncGraphPtr & bprop_graph)2623 abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph) {
2624   MS_EXCEPTION_IF_NULL(bprop_graph);
2625   std::size_t size = args.size();
2626   abstract::AbstractBasePtrList args_spec;
2627   const auto &bprop_params = bprop_graph->parameters();
2628   // bprop_params include inputs, parameters, more than size(inputs)
2629   if (bprop_params.size() < size) {
2630     MS_LOG(EXCEPTION) << "Df parameters size " << bprop_params.size() << " less than " << size;
2631   }
2632   // Update abstract info for parameters in bprop graph
2633   size_t index = 0;
2634   for (const auto &param : bprop_params) {
2635     auto param_node = param->cast<ParameterPtr>();
2636     MS_EXCEPTION_IF_NULL(param_node);
2637     if (param_node->has_default()) {
2638       // update abstract info for weights
2639       ValuePtr value = param_node->default_param();
2640       auto ptr = value->ToAbstract();
2641       MS_EXCEPTION_IF_NULL(ptr);
2642       args_spec.emplace_back(ptr);
2643       param_node->set_abstract(ptr->Broaden());
2644     } else {
2645       // update abstract info for input params
2646       const auto &input_value = PyObjToValue(args[index]);
2647       auto input_abs = abstract::FromValue(input_value, true);
2648       if (param_node->abstract() != nullptr) {
2649         auto input_shape = input_abs->BuildShape()->ToString();
2650         auto param_tensor_abs = param_node->abstract();
2651         if (param_tensor_abs->isa<abstract::AbstractRef>()) {
2652           param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
2653         }
2654         auto ir_shape = param_tensor_abs->BuildShape()->ToString();
2655         // Exclude const input
2656         if (input_shape != "()" && ir_shape != "()") {
2657           if (input_shape != ir_shape) {
2658             MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
2659                                      << param->DebugString();
2660           }
2661           auto ir_dtype = param_tensor_abs->BuildType()->ToString();
2662           auto input_dtype = input_abs->BuildType()->ToString();
2663           if (input_dtype != ir_dtype) {
2664             MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
2665                                     << param->DebugString();
2666           }
2667         }
2668       }
2669       args_spec.emplace_back(input_abs);
2670       param_node->set_abstract(input_abs->Broaden());
2671       index++;
2672     }
2673   }
2674   MS_LOG(DEBUG) << "Args_spec size " << args_spec.size();
2675   return args_spec;
2676 }
2677 
GetBpropGraph(const prim::GradOperationPtr & grad,const py::object & cell,const std::vector<AnfNodePtr> & weights,size_t arg_size,const py::args & args)2678 FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
2679                                          const std::vector<AnfNodePtr> &weights, size_t arg_size,
2680                                          const py::args &args) {
2681   bool build_formal_param = false;
2682   if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !cell_stack_.empty() && IsNestedGrad()) {
2683     build_formal_param = true;
2684     need_renormalize_ = true;
2685   }
2686   if (top_cell()->ms_function_flag()) {
2687     need_renormalize_ = true;
2688   }
2689 
2690   auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
2691   MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
2692   MS_EXCEPTION_IF_NULL(grad);
2693   FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad->get_all_, grad->get_by_list_,
2694                                                      grad->sens_param_, build_formal_param);
2695   MS_EXCEPTION_IF_NULL(bprop_graph);
2696 
2697   MS_LOG(DEBUG) << "Top graph input params size " << arg_size;
2698   std::ostringstream ss;
2699   ss << "grad{" << arg_size << "}";
2700   bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2701   bprop_graph->debug_info()->set_name(ss.str());
2702   // Get the parameters items and add the value to args_spec
2703   (void)GetArgsSpec(FilterTensorArgs(args, grad->sens_param_), bprop_graph);
2704 
2705   // Do opt for final bprop graph
2706   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
2707   resource->set_func_graph(bprop_graph);
2708   auto manager = resource->manager();
2709   MS_EXCEPTION_IF_NULL(manager);
2710   manager->AddFuncGraph(bprop_graph);
2711   auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
2712 
2713   if (cell_stack_.empty()) {
2714     need_renormalize_ = false;
2715   }
2716   DumpGraphIR("after_final_opt.ir", optimized_bg);
2717   return optimized_bg;
2718 }
2719 
CheckGraph(const py::object & cell,const py::args & args)2720 py::object GradExecutor::CheckGraph(const py::object &cell, const py::args &args) {
2721   BaseRef ret = false;
2722   check_graph_cell_id_ = GetCellId(cell, args);
2723   if (!(top_cell_ != nullptr && check_graph_cell_id_.find(top_cell_->cell_id()) != std::string::npos &&
2724         grad_order_ >= 1)) {
2725     ++grad_order_;
2726   }
2727   if (!grad_is_running_) {
2728     MS_LOG(DEBUG) << "Grad not running yet";
2729     return BaseRefToPyData(ret);
2730   }
2731   MS_LOG(DEBUG) << "Key is " << check_graph_cell_id_;
2732   if (top_cell_ != nullptr) {
2733     for (auto it = top_cell_->sub_cell_list().begin(); it != top_cell_->sub_cell_list().end(); ++it) {
2734       MS_LOG(DEBUG) << "Cur cell id " << *it;
2735       if (!IsCellObjIdEq(*it, check_graph_cell_id_)) {
2736         continue;
2737       }
2738       MS_LOG(DEBUG) << "Delete cellid from cell graph list, top cell is " << top_cell_;
2739       top_cell_->sub_cell_list().erase(it);
2740       ret = true;
2741       break;
2742     }
2743   }
2744   return BaseRefToPyData(ret);
2745 }
2746 
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & cell,const py::args & args)2747 py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
2748                                          const py::args &args) {
2749   bool forward_run = false;
2750   // Get cell id and input args info
2751   const auto &cell_id = GetCellId(cell, args);
2752   grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_);
2753 
2754   std::string input_args_id;
2755   for (size_t i = 0; i < args.size(); ++i) {
2756     input_args_id += GetId(args[i]) + "_";
2757   }
2758   // Check whether need to run forward process
2759   const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
2760   auto find_top_cell = GetTopCell(check_already_run_cell_id);
2761   if (find_top_cell != nullptr) {
2762     forward_run = find_top_cell->forward_already_run();
2763     auto curr_top_cell = top_cell();
2764     set_top_cell(find_top_cell);
2765     bool input_args_changed =
2766       !find_top_cell->input_args_id().empty() && find_top_cell->input_args_id() != input_args_id;
2767     if (forward_run && input_args_changed && find_top_cell->is_dynamic()) {
2768       MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, "
2769                          "forward process will run again";
2770       forward_run = false;
2771     }
2772     if (forward_run && GetHighOrderStackSize() >= 1) {
2773       PushHighOrderGraphStack(curr_top_cell);
2774     }
2775   }
2776   MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << cell_id;
2777   return BaseRefToPyData(forward_run);
2778 }
2779 
CheckNeedCompileGraph()2780 void GradExecutor::CheckNeedCompileGraph() {
2781   auto new_top_cell = top_cell();
2782   const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
2783   // Update top cell by current cell op info
2784   if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
2785     MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has never been ran, need compile graph";
2786     already_run_top_cell_[already_top_cell_id] = new_top_cell;
2787     return;
2788   }
2789 
2790   MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
2791   auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
2792   MS_EXCEPTION_IF_NULL(pre_top_cell);
2793   const auto &pre_all_op_info = pre_top_cell->all_op_info();
2794   const auto &new_all_op_info = new_top_cell->all_op_info();
2795   MS_LOG(DEBUG) << "Pre all op info : " << pre_all_op_info;
2796   MS_LOG(DEBUG) << "New all op info : " << new_all_op_info;
2797   if (pre_all_op_info != new_all_op_info) {
2798     MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
2799     // The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
2800     if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
2801       EnableOpGraphCache(false);
2802     } else {
2803       // Increase top cell switches counts
2804       ++top_cell_switch_counts_;
2805     }
2806     EraseTopCellFromTopCellList(pre_top_cell);
2807     pre_top_cell->Clear();
2808     already_run_top_cell_[already_top_cell_id] = new_top_cell;
2809     g_pyobj_id_cache.clear();
2810   } else {
2811     MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
2812     pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
2813     EraseTopCellFromTopCellList(new_top_cell);
2814     new_top_cell->Clear();
2815     pre_top_cell->set_forward_already_run(true);
2816     set_top_cell(pre_top_cell);
2817   }
2818 }
2819 
RunGradGraph(py::object * ret,const py::object & cell,const py::tuple & args)2820 void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args) {
2821   MS_EXCEPTION_IF_NULL(ret);
2822   const auto &cell_id = GetCellId(cell, args);
2823   MS_LOG(DEBUG) << "Run start cell id " << cell_id;
2824   auto has_sens = std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) {
2825     return cell_id.find(value->cell_id()) != std::string::npos && cell_id != value->cell_id();
2826   });
2827   MS_LOG(DEBUG) << "Run has sens " << has_sens << " cell id " << cell_id;
2828   auto resource = top_cell()->resource();
2829   MS_EXCEPTION_IF_NULL(resource);
2830   MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
2831 
2832   VectorRef arg_list;
2833   py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens));
2834   pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
2835   if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
2836     MS_LOG(EXCEPTION) << "Can't find run graph output";
2837   }
2838   if (!resource->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
2839     MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr";
2840   }
2841   compile::VmEvalFuncPtr run = resource->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
2842   MS_EXCEPTION_IF_NULL(run);
2843 
2844   const auto &backend = MsContext::GetInstance()->backend_policy();
2845   MS_LOG(DEBUG) << "Eval run " << backend;
2846   grad_is_running_ = true;
2847   BaseRef value = (*run)(arg_list);
2848   grad_is_running_ = false;
2849   MS_LOG(DEBUG) << "Eval run end " << value.ToString();
2850   *ret = BaseRefToPyData(value);
2851   // Clear device memory resource of top cell when it has been ran.
2852   auto has_higher_order = std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
2853                                       [](const TopCellInfoPtr &value) { return !value->is_topest(); });
2854   if (top_cell()->is_topest() && !has_higher_order) {
2855     top_cell()->ClearDeviceMemory();
2856   }
2857   // High order
2858   if (top_cell()->vm_compiled()) {
2859     MakeNestedCnode(cell, converted_args, resource, *ret);
2860   } else if (GetHighOrderStackSize() >= ARG_SIZE) {
2861     SwitchTopcell();
2862   }
2863 }
2864 
SwitchTopcell()2865 void GradExecutor::SwitchTopcell() {
2866   const auto &inner_top_cell_all_op_info = top_cell()->all_op_info();
2867   bool inner_top_cell_is_dynamic = top_cell()->is_dynamic();
2868   top_cell()->set_grad_order(1);
2869 
2870   // Get outer top cell
2871   auto outer_top_cell = PopHighOrderGraphStack();
2872   MS_EXCEPTION_IF_NULL(outer_top_cell);
2873   outer_top_cell->all_op_info() += inner_top_cell_all_op_info;
2874   // If inner is dynamic, outer set dynamic too
2875   if (inner_top_cell_is_dynamic) {
2876     outer_top_cell->set_is_dynamic(inner_top_cell_is_dynamic);
2877   }
2878   set_top_cell(outer_top_cell);
2879 }
2880 
DoParameterReplace(const FuncGraphPtr & first_grad_fg,const py::tuple & forward_args,std::vector<AnfNodePtr> * inputs,ValuePtrList * weights_args)2881 void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const py::tuple &forward_args,
2882                                       std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args) {
2883   MS_EXCEPTION_IF_NULL(inputs);
2884   MS_EXCEPTION_IF_NULL(weights_args);
2885   auto first_df_builder = top_cell()->df_builder();
2886   MS_EXCEPTION_IF_NULL(first_df_builder);
2887   auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
2888   MS_EXCEPTION_IF_NULL(first_graph_info);
2889   SwitchTopcell();
2890   auto second_df_builder = top_cell()->df_builder();
2891   MS_EXCEPTION_IF_NULL(second_df_builder);
2892   auto second_graph_info = top_cell()->graph_info_map().at(second_df_builder);
2893   MS_EXCEPTION_IF_NULL(second_graph_info);
2894 
2895   std::unordered_set<std::string> params_weights_set;
2896   std::unordered_set<std::string> params_inputs_set;
2897   for (const auto &sec : second_graph_info->params) {
2898     if (sec.second->has_default()) {
2899       params_weights_set.emplace(sec.first);
2900     } else {
2901       params_inputs_set.insert(sec.first);
2902     }
2903   }
2904   auto manager = Manage({first_grad_fg}, false);
2905   // Replace inputs param
2906   for (size_t i = 0; i < forward_args.size(); ++i) {
2907     const auto &id = GetId(forward_args[i]);
2908     if (params_inputs_set.count(id)) {
2909       // Can find in second graph
2910       const auto &input_param_second = second_graph_info->params.at(id);
2911       manager->Replace(first_graph_info->params.at(id), input_param_second);
2912       inputs->emplace_back(input_param_second);
2913     } else {
2914       inputs->emplace_back(GetInput(forward_args[i], false));
2915     }
2916   }
2917 
2918   // Replace weights param
2919   for (const auto &fir : first_graph_info->params) {
2920     if (!fir.second->has_default()) {
2921       continue;
2922     }
2923     // Second graph no this weight param, need add to second graph
2924     if (!params_weights_set.count(fir.first)) {
2925       SetParamNodeMapInGraphInfoMap(second_df_builder, fir.first, fir.second);
2926       inputs->emplace_back(fir.second);
2927       weights_args->emplace_back(fir.second->default_param());
2928     } else {
2929       // Need replace
2930       for (const auto &sec : second_graph_info->params) {
2931         MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
2932         if (sec.second->has_default() && fir.second->name() == sec.second->name()) {
2933           manager->Replace(fir.second, sec.second);
2934           inputs->emplace_back(sec.second);
2935           weights_args->emplace_back(sec.second->default_param());
2936           break;
2937         }
2938       }
2939     }
2940   }
2941 }
2942 
MakeNestedCnode(const py::object & cell,const py::tuple & forward_args,const pipeline::ResourcePtr & resource,const py::object & out)2943 void GradExecutor::MakeNestedCnode(const py::object &cell, const py::tuple &forward_args,
2944                                    const pipeline::ResourcePtr &resource, const py::object &out) {
2945   if (cell_stack_.empty()) {
2946     MS_LOG(DEBUG) << "No nested grad find";
2947     return;
2948   }
2949   FuncGraphPtr first_grad_fg = nullptr;
2950   if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
2951     first_grad_fg = curr_g_;
2952     MS_LOG(DEBUG) << "Bprop nested";
2953   } else {
2954     first_grad_fg = resource->func_graph();
2955   }
2956   MS_EXCEPTION_IF_NULL(first_grad_fg);
2957   DumpGraphIR("first_grad_fg.ir", first_grad_fg);
2958 
2959   std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
2960   ValuePtrList weights_args;
2961   DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
2962 
2963   pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
2964   r->manager()->AddFuncGraph(first_grad_fg);
2965   set_eliminate_forward(false);
2966   FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
2967   set_eliminate_forward(true);
2968   DumpGraphIR("second_grad_fg.ir", second_grad_fg);
2969   r->Clean();
2970 
2971   MS_LOG(DEBUG) << "Get pre graph ptr " << curr_g().get();
2972   auto cnode = curr_g()->NewCNode(inputs);
2973   auto out_id = GetId(out);
2974   SetTupleArgsToGraphInfoMap(curr_g(), out, cnode);
2975   SetNodeMapInGraphInfoMap(curr_g(), out_id, cnode);
2976   MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
2977 
2978   // Get input values
2979   ValuePtrList input_args;
2980   for (size_t i = 0; i < forward_args.size(); ++i) {
2981     const auto &arg = PyObjToValue(forward_args[i]);
2982     input_args.emplace_back(arg);
2983   }
2984   input_args.insert(input_args.end(), weights_args.begin(), weights_args.end());
2985   // Get output values
2986   py::object new_out;
2987   if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !py::isinstance<py::tuple>(out)) {
2988     new_out = py::make_tuple(out);
2989   } else {
2990     new_out = out;
2991   }
2992   const auto &out_value = PyObjToValue(new_out);
2993   if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
2994     MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
2995   }
2996   need_renormalize_ = true;
2997 }
2998 
EraseTopCellFromTopCellList(const TopCellInfoPtr & top_cell)2999 void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
3000   MS_EXCEPTION_IF_NULL(top_cell);
3001   auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
3002                            [&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
3003   if (iter == top_cell_list_.end()) {
3004     MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
3005                     << " from top cell list";
3006   } else {
3007     (void)top_cell_list_.erase(iter);
3008   }
3009 }
3010 
GradMsFunctionInner(const std::string & phase,const py::object & out,const py::args & args,const FuncGraphPtr & ms_func_graph,const FuncGraphPtr & grad_graph)3011 void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
3012                                        const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) {
3013   // Get actual output value and added output value.
3014   if (!py::isinstance<py::tuple>(out)) {
3015     MS_LOG(EXCEPTION) << "The output value of ms_function func graph should be a tuple.";
3016   }
3017   auto tuple_out = py::cast<py::tuple>(out);
3018   constexpr size_t tuple_out_size = 2;
3019   if (tuple_out.size() != tuple_out_size) {
3020     MS_LOG(EXCEPTION) << "The tuple size of output value of ms_function func graph should be 2.";
3021   }
3022   py::object actual_out = tuple_out[0];
3023   auto actual_out_v = PyObjToValue(actual_out);
3024   auto added_out = PyObjToValue(tuple_out[1]);
3025   MS_LOG(DEBUG) << "Added output value is: " << added_out->ToString();
3026 
3027   // Identity op info for current running ms_func graph.
3028   OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
3029   op_exec_info->op_name = phase;
3030   RecordGradOpInfo(op_exec_info, actual_out_v);
3031   MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
3032 
3033   // Step 1: Update actual output tensors used in grad graph.
3034   MS_LOG(DEBUG) << "ms_function actual output value: " << actual_out_v->ToString();
3035   UpdateForwardTensorInfoInBpropGraph(op_exec_info, actual_out_v);
3036 
3037   // Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
3038   if (top_cell()->op_info_with_ms_func_forward_tensors().count(op_exec_info->op_info)) {
3039     UpdateMsFunctionForwardTensors(op_exec_info, added_out);
3040     return;
3041   }
3042   MS_LOG(DEBUG) << "Ms func graph run firstly. The graph phase is: " << graph_phase();
3043   if (!need_construct_graph()) {
3044     MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
3045   }
3046   ReplaceNewTensorsInGradGraph(top_cell(), op_exec_info, added_out, ms_func_graph, grad_graph);
3047 
3048   // Clone new ms_function func graph and grad graph.
3049   auto new_ms_func_graph = BasicClone(ms_func_graph);
3050   auto new_grad_graph = BasicClone(grad_graph, true);
3051   auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
3052   MS_EXCEPTION_IF_NULL(new_make_tuple);
3053   new_ms_func_graph->set_output(new_make_tuple->input(1));
3054 
3055   // Make Adjoint for grad graph
3056   MakeAdjointForMsFunction(new_ms_func_graph, new_grad_graph, actual_out, args, actual_out_v);
3057 }
3058 
GradMsFunction(const py::object & out,const py::args & args)3059 py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
3060   // Get actual forward output object.
3061   if (graph_phase().empty()) {
3062     MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain ms_function func graph.";
3063   }
3064   const auto &phase = graph_phase();
3065   MS_LOG(DEBUG) << "ms_function func graph phase: " << phase;
3066   auto executor = pipeline::GraphExecutorPy::GetInstance();
3067   MS_EXCEPTION_IF_NULL(executor);
3068   FuncGraphPtr ms_func_graph = executor->GetFuncGraph(phase);
3069   MS_EXCEPTION_IF_NULL(ms_func_graph);
3070   py::object ret = out;
3071   if (ms_func_graph->modify_output()) {
3072     auto tuple_out = py::cast<py::tuple>(out);
3073     ret = tuple_out[0];
3074   }
3075 
3076   // Make Adjoint for grad graph of ms_function.
3077   if (!grad_flag_) {
3078     MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
3079     set_graph_phase("");
3080     return ret;
3081   }
3082   FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
3083   MS_EXCEPTION_IF_NULL(grad_graph);
3084   GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
3085   set_graph_phase("");
3086   return ret;
3087 }
3088 
ClearGrad(const py::object & cell,const py::args & args)3089 void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
3090   MS_LOG(DEBUG) << "Clear top cell grad resource " << GetCellId(cell, args);
3091   if (grad_order_ > 0) {
3092     --grad_order_;
3093   }
3094   check_graph_cell_id_.clear();
3095   grad_operation_.clear();
3096   forward()->node_abs_map().clear();
3097   ad::CleanRes();
3098   pipeline::ReclaimOptimizer();
3099 }
3100 
ClearRes()3101 void GradExecutor::ClearRes() {
3102   MS_LOG(DEBUG) << "Clear grad res";
3103   grad_flag_ = false;
3104   enable_op_cache_ = true;
3105   grad_is_running_ = false;
3106   need_renormalize_ = false;
3107   eliminate_forward_ = true;
3108   custom_bprop_cell_count_ = 0;
3109   grad_order_ = 0;
3110   top_cell_switch_counts_ = 0;
3111 
3112   check_graph_cell_id_.clear();
3113   grad_operation_.clear();
3114   top_cell_ = nullptr;
3115   curr_g_ = nullptr;
3116   bprop_cell_list_.clear();
3117   already_run_top_cell_.clear();
3118   ClearCellRes();
3119   std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
3120   std::stack<std::string>().swap(cell_stack_);
3121   std::stack<std::pair<FuncGraphPtr, TopCellInfoPtr>>().swap(high_order_stack_);
3122 }
3123 
grad_executor() const3124 GradExecutorPtr PynativeExecutor::grad_executor() const {
3125   MS_EXCEPTION_IF_NULL(grad_executor_);
3126   return grad_executor_;
3127 }
forward_executor() const3128 ForwardExecutorPtr PynativeExecutor::forward_executor() const {
3129   MS_EXCEPTION_IF_NULL(forward_executor_);
3130   return forward_executor_;
3131 }
3132 
grad_flag() const3133 bool PynativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); }
3134 
set_grad_flag(bool flag)3135 void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); }
3136 
set_graph_phase(const std::string & graph_phase)3137 void PynativeExecutor::set_graph_phase(const std::string &graph_phase) {
3138   grad_executor()->set_graph_phase(graph_phase);
3139 }
3140 
set_py_exe_path(const py::object & py_exe_path)3141 void PynativeExecutor::set_py_exe_path(const py::object &py_exe_path) {
3142   if (!py::isinstance<py::str>(py_exe_path)) {
3143     MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";
3144   }
3145   auto py_exe_path_s = py::cast<std::string>(py_exe_path);
3146   auto ms_context = MsContext::GetInstance();
3147   ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
3148 }
3149 
set_kernel_build_server_dir(const py::object & kernel_build_server_dir)3150 void PynativeExecutor::set_kernel_build_server_dir(const py::object &kernel_build_server_dir) {
3151   if (!py::isinstance<py::str>(kernel_build_server_dir)) {
3152     MS_LOG(EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
3153   }
3154   auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
3155   auto ms_context = MsContext::GetInstance();
3156   ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
3157 }
3158 
CheckGraph(const py::object & cell,const py::args & args)3159 py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
3160   return grad_executor()->CheckGraph(cell, args);
3161 }
3162 
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & cell,const py::args & args)3163 py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
3164                                              const py::args &args) {
3165   return grad_executor()->CheckAlreadyRun(grad, cell, args);
3166 }
3167 
Run(const py::object & cell,const py::tuple & args)3168 py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args) {
3169   py::object ret;
3170   PynativeExecutorTry(grad_executor()->RunGraph, &ret, cell, args);
3171   return ret;
3172 }
3173 
ClearCell(const std::string & cell_id)3174 void PynativeExecutor::ClearCell(const std::string &cell_id) {
3175   MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
3176   grad_executor()->ClearCellRes(cell_id);
3177 }
3178 
ClearGrad(const py::object & cell,const py::args & args)3179 void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) {
3180   MS_LOG(DEBUG) << "Clear grad";
3181   return grad_executor()->ClearGrad(cell, args);
3182 }
3183 
ClearRes()3184 void PynativeExecutor::ClearRes() {
3185   MS_LOG(DEBUG) << "Clear all res";
3186   session::PynativeTaskManager::GetInstance().Reset();
3187   SetLazyBuild(false);
3188   cell_depth_ = 0;
3189 
3190   // Maybe exit in runop step
3191   auto ms_context = MsContext::GetInstance();
3192   if (ms_context != nullptr) {
3193     ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
3194   }
3195   ConfigManager::GetInstance().ResetIterNum();
3196   if (forward_executor_ != nullptr) {
3197     forward_executor_->ClearRes();
3198   }
3199   if (grad_executor_ != nullptr) {
3200     grad_executor_->ClearRes();
3201   }
3202   ad::CleanRes();
3203   pipeline::ReclaimOptimizer();
3204   kSession = nullptr;
3205   mind_rt_backend = nullptr;
3206   g_pyobj_id_cache.clear();
3207 }
3208 
NewGraph(const py::object & cell,const py::args & args)3209 void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
3210   // Make a flag for new cell
3211   if (!grad_executor()->grad_flag()) {
3212     MS_LOG(DEBUG) << "Grad flag is false";
3213     return;
3214   }
3215   py::object ret;
3216   PynativeExecutorTry(grad_executor()->InitGraph, &ret, cell, args);
3217 }
3218 
EndGraph(const py::object & cell,const py::object & out,const py::args & args)3219 void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
3220   if (!grad_executor()->grad_flag()) {
3221     MS_LOG(DEBUG) << "Grad flag is false";
3222     return;
3223   }
3224   MS_LOG(DEBUG) << "Enter end graph process.";
3225   py::object ret;
3226   PynativeExecutorTry(grad_executor()->LinkGraph, &ret, cell, out, args);
3227   MS_LOG(DEBUG) << "Leave end graph process.";
3228 }
3229 
GradMsFunction(const py::object & out,const py::args & args)3230 py::object PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
3231   return grad_executor()->GradMsFunction(out, args);
3232 }
3233 
GradNet(const prim::GradOperationPtr & grad,const py::object & cell,const py::object & weights,const py::args & args)3234 void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
3235                                const py::args &args) {
3236   py::object ret;
3237   PynativeExecutorTry(grad_executor()->GradGraph, &ret, grad, cell, weights, args);
3238 }
3239 
Sync()3240 void PynativeExecutor::Sync() {
3241   auto ms_context = MsContext::GetInstance();
3242   MS_EXCEPTION_IF_NULL(ms_context);
3243 
3244   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
3245     if (kSession == nullptr) {
3246       MS_EXCEPTION(NotExistsError) << "No session has been created!";
3247     }
3248     kSession->SyncStream();
3249   } else {
3250     std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
3251     uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
3252     const auto &device_context =
3253       device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
3254     MS_EXCEPTION_IF_NULL(device_context);
3255     (void)device_context->SyncStream();
3256   }
3257 }
3258 
SetLazyBuild(bool enable)3259 void PynativeExecutor::SetLazyBuild(bool enable) { forward_executor()->set_lazy_build(enable); }
3260 
EnterCell()3261 void PynativeExecutor::EnterCell() {
3262   if (cell_depth_ < UINT32_MAX) {
3263     ++cell_depth_;
3264   } else {
3265     MS_LOG(ERROR) << "Cell call stack too deep";
3266   }
3267 }
3268 
ExitCell()3269 void PynativeExecutor::ExitCell() {
3270   if (cell_depth_ > 0) {
3271     --cell_depth_;
3272   }
3273 }
3274 
IsTopCell() const3275 bool PynativeExecutor::IsTopCell() const { return cell_depth_ == 0; }
3276 
ExecuteAllTask()3277 void PynativeExecutor::ExecuteAllTask() { session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks(); }
3278 
__anonc481bc1a0a02(const py::module *m) 3279 REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
3280                          (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
3281                            .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
3282                            .def("enter_cell", &PynativeExecutor::EnterCell, "enter cell.")
3283                            .def("exit_cell", &PynativeExecutor::ExitCell, "exit cell.")
3284                            .def("is_top_cell", &PynativeExecutor::IsTopCell, "check top cell.")
3285                            .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
3286                            .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
3287                            .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
3288                            .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.")
3289                            .def("grad_ms_function", &PynativeExecutor::GradMsFunction, "pynative grad for ms_function.")
3290                            .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
3291                            .def("clear_cell", &PynativeExecutor::ClearCell, "pynative clear status.")
3292                            .def("clear_res", &PynativeExecutor::ClearRes, "pynative clear exception res.")
3293                            .def("clear_grad", &PynativeExecutor::ClearGrad, "pynative clear grad status.")
3294                            .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
3295                            .def("set_lazy_build", &PynativeExecutor::SetLazyBuild, "pynative build kernel async")
3296                            .def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
3297                            .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
3298                            .def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
3299                            .def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
3300                            .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
3301                                 "Executor set grad flag.")
3302                            .def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
3303                                 py::arg("py_exe_path") = py::str(""), "set python executable path.")
3304                            .def("set_kernel_build_server_dir", &PynativeExecutor::set_kernel_build_server_dir,
3305                                 py::arg("kernel_build_server_dir") = py::str(""),
3306                                 "set kernel build server directory path.");
3307                        }));
3308 }  // namespace mindspore::pynative
3309