1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/pynative_execute.h"
18
19 #include <typeinfo>
20 #include <set>
21 #include <memory>
22 #include <sstream>
23 #include <unordered_set>
24 #include <algorithm>
25
26 #include "debug/trace.h"
27 #include "debug/anf_ir_dump.h"
28 #include "pybind_api/api_register.h"
29 #include "pybind_api/pybind_patch.h"
30 #include "pybind_api/ir/tensor_py.h"
31 #include "ir/param_info.h"
32 #include "ir/anf.h"
33 #include "ir/cell.h"
34 #include "ir/tensor.h"
35 #include "utils/any.h"
36 #include "utils/utils.h"
37 #include "utils/ms_context.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/context/context_extends.h"
40 #include "utils/config_manager.h"
41 #include "utils/convert_utils_py.h"
42 #include "utils/scoped_long_running.h"
43 #include "frontend/optimizer/ad/grad.h"
44 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
45 #include "frontend/operator/ops.h"
46 #include "frontend/operator/composite/do_signature.h"
47 #include "frontend/parallel/context.h"
48 #include "pipeline/jit/action.h"
49 #include "pipeline/jit/pass.h"
50 #include "pipeline/jit/parse/data_converter.h"
51 #include "pipeline/jit/parse/parse_dynamic.h"
52 #include "pipeline/jit/static_analysis/prim.h"
53 #include "pipeline/jit/static_analysis/auto_monad.h"
54 #include "pipeline/jit/pipeline.h"
55 #include "pipeline/jit/resource.h"
56 #include "pipeline/pynative/base.h"
57 #include "backend/session/session_factory.h"
58 #include "backend/optimizer/common/const_input_to_attr_registry.h"
59 #include "backend/optimizer/common/helper.h"
60 #include "runtime/hardware/device_context_manager.h"
61 #include "vm/transform.h"
62
63 #ifdef ENABLE_GE
64 #include "pipeline/pynative/pynative_execute_ge.h"
65 #endif
66
67 using mindspore::tensor::TensorPy;
68
69 namespace mindspore::pynative {
70 PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
71 ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr;
72 GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr;
73 std::mutex PynativeExecutor::instance_lock_;
74
75 namespace {
76 const size_t PTR_LEN = 15;
77 const size_t ARG_SIZE = 2;
78 const size_t MAX_TOP_CELL_COUNTS = 20;
79
80 // primitive unable to infer value for constant input in PyNative mode
81 const std::set<std::string> kVmOperators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
82 "mixed_precision_cast"};
83 const char kOpsFunctionModelName[] = "mindspore.ops.functional";
84 std::shared_ptr<session::SessionBasic> kSession = nullptr;
85 std::shared_ptr<compile::MindRTBackend> mind_rt_backend = nullptr;
86 PyObjectIdCache g_pyobj_id_cache;
87
88 template <typename T, typename... Args>
PynativeExecutorTry(const std::function<void (T * ret,const Args &...)> & method,T * ret,const Args &...args)89 void PynativeExecutorTry(const std::function<void(T *ret, const Args &...)> &method, T *ret, const Args &... args) {
90 const auto inst = PynativeExecutor::GetInstance();
91 MS_EXCEPTION_IF_NULL(inst);
92 MS_EXCEPTION_IF_NULL(method);
93 try {
94 method(ret, args...);
95 } catch (const py::error_already_set &ex) {
96 // print function call stack info before release
97 std::ostringstream oss;
98 trace::TraceGraphEval();
99 trace::GetEvalStackInfo(oss);
100 // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
101 // these info from screen, no need to open log file to find these info
102 py::print(oss.str());
103 MS_LOG(ERROR) << oss.str();
104 inst->ClearRes();
105 // re-throw this exception to Python interpreter to handle it
106 throw(py::error_already_set(ex));
107 } catch (const py::type_error &ex) {
108 inst->ClearRes();
109 throw py::type_error(ex);
110 } catch (const py::value_error &ex) {
111 inst->ClearRes();
112 throw py::value_error(ex);
113 } catch (const py::index_error &ex) {
114 inst->ClearRes();
115 throw py::index_error(ex);
116 } catch (const py::name_error &ex) {
117 inst->ClearRes();
118 throw py::name_error(ex);
119 } catch (const std::exception &ex) {
120 inst->ClearRes();
121 // re-throw this exception to Python interpreter to handle it
122 throw(std::runtime_error(ex.what()));
123 } catch (...) {
124 inst->ClearRes();
125 auto exception_type = abi::__cxa_current_exception_type();
126 MS_EXCEPTION_IF_NULL(exception_type);
127 std::string ex_name(exception_type->name());
128 MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << ex_name;
129 }
130 }
131
PyObjToValue(const py::object & obj)132 inline ValuePtr PyObjToValue(const py::object &obj) {
133 ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
134 if (!converted_ret) {
135 MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
136 }
137 return converted_ret;
138 }
139
GetPyObjId(const py::handle & obj)140 std::string GetPyObjId(const py::handle &obj) {
141 py::object out = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
142 if (py::isinstance<py::none>(out)) {
143 MS_LOG(EXCEPTION) << "Get pyobj failed";
144 }
145 return out.cast<std::string>();
146 }
147
GetId(const py::handle & obj)148 std::string GetId(const py::handle &obj) {
149 if (py::isinstance<tensor::Tensor>(obj)) {
150 auto tensor_ptr = py::cast<tensor::TensorPtr>(obj);
151 return tensor_ptr->id();
152 } else if (py::isinstance<mindspore::Type>(obj)) {
153 auto type_ptr = py::cast<mindspore::TypePtr>(obj);
154 return "type" + type_ptr->ToString();
155 } else if (py::isinstance<py::str>(obj) || py::isinstance<py::int_>(obj) || py::isinstance<py::float_>(obj)) {
156 return std::string(py::str(obj));
157 } else if (py::isinstance<py::none>(obj)) {
158 return "none";
159 } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
160 auto p_list = py::cast<py::tuple>(obj);
161 string prefix = py::isinstance<py::tuple>(obj) ? "tuple:" : "list";
162 if (p_list.empty()) {
163 prefix = "empty";
164 } else {
165 std::string key;
166 for (size_t i = 0; i < p_list.size(); ++i) {
167 key += std::string(py::str(GetId(p_list[i]))) + ":";
168 }
169 prefix += key;
170 }
171 return prefix;
172 }
173
174 if (py::isinstance<Cell>(obj) || py::isinstance<py::function>(obj)) {
175 const auto &it = g_pyobj_id_cache.find(obj);
176 if (it == g_pyobj_id_cache.end()) {
177 auto &&id = GetPyObjId(obj);
178 g_pyobj_id_cache[obj] = id;
179 return std::move(id);
180 } else {
181 return it->second;
182 }
183 } else {
184 return GetPyObjId(obj);
185 }
186 }
187
GetTypeIndex(const std::vector<SignatureEnumDType> & dtypes,std::unordered_map<SignatureEnumDType,std::vector<size_t>> * type_indexes)188 void GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes,
189 std::unordered_map<SignatureEnumDType, std::vector<size_t>> *type_indexes) {
190 MS_EXCEPTION_IF_NULL(type_indexes);
191 for (size_t i = 0; i < dtypes.size(); ++i) {
192 auto it = type_indexes->find(dtypes[i]);
193 if (it == type_indexes->end()) {
194 (void)type_indexes->emplace(std::make_pair(dtypes[i], std::vector<size_t>{i}));
195 } else {
196 it->second.emplace_back(i);
197 }
198 }
199 }
200
JudgeMaxType(TypeId max_type,bool has_scalar_float32,bool has_scalar_int64,bool has_tensor_int8)201 TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
202 if (max_type == TypeId::kNumberTypeBool) {
203 if (has_scalar_int64) {
204 max_type = TypeId::kNumberTypeInt64;
205 }
206 if (has_scalar_float32) {
207 max_type = TypeId::kNumberTypeFloat32;
208 }
209 }
210 if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
211 max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
212 max_type = TypeId::kNumberTypeFloat32;
213 }
214 if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
215 max_type = TypeId::kNumberTypeInt16;
216 }
217 return max_type;
218 }
219
GetDstType(const py::tuple & py_args,const std::unordered_map<SignatureEnumDType,std::vector<size_t>> & type_indexes,std::unordered_map<SignatureEnumDType,TypeId> * dst_type)220 void GetDstType(const py::tuple &py_args,
221 const std::unordered_map<SignatureEnumDType, std::vector<size_t>> &type_indexes,
222 std::unordered_map<SignatureEnumDType, TypeId> *dst_type) {
223 for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
224 const auto &type = it->first;
225 const auto &indexes = it->second;
226 if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < ARG_SIZE) {
227 continue;
228 }
229 size_t priority = 0;
230 TypeId max_type = TypeId::kTypeUnknown;
231 bool has_scalar_float32 = false;
232 bool has_scalar_int64 = false;
233 bool has_tensor_int8 = false;
234 // Find the maximum priority of the same dtype
235 for (size_t index : indexes) {
236 if (index >= py_args.size()) {
237 MS_LOG(EXCEPTION) << "The index " << index << " exceeds the size of py_args " << py_args.size();
238 }
239 const auto &obj = py_args[index];
240 if (py::isinstance<py::float_>(obj)) {
241 has_scalar_float32 = true;
242 }
243 if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
244 has_scalar_int64 = true;
245 }
246 if (py::isinstance<tensor::Tensor>(obj)) {
247 auto arg = py::cast<tensor::TensorPtr>(obj);
248 TypeId arg_type_id = arg->data_type();
249 auto type_priority = prim::type_map.find(arg_type_id);
250 if (type_priority == prim::type_map.end()) {
251 continue;
252 }
253 if (arg_type_id == kNumberTypeInt8) {
254 has_tensor_int8 = true;
255 }
256 if (type_priority->second > priority) {
257 max_type = type_priority->first;
258 priority = type_priority->second;
259 }
260 }
261 }
262 max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
263 MS_EXCEPTION_IF_NULL(dst_type);
264 (void)dst_type->emplace(std::make_pair(type, max_type));
265 }
266 }
267
TypeIdToMsTypeStr(const TypeId & type_id)268 const std::string &TypeIdToMsTypeStr(const TypeId &type_id) {
269 const auto &type_name = type_name_map.find(type_id);
270 if (type_name == type_name_map.end()) {
271 MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id);
272 }
273 return type_name->second;
274 }
275
GetSignatureType(const PrimitivePyPtr & prim,std::vector<SignatureEnumDType> * dtypes)276 bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) {
277 MS_EXCEPTION_IF_NULL(prim);
278 MS_EXCEPTION_IF_NULL(dtypes);
279 const auto &signature = prim->signatures();
280 bool has_sig_dtype = false;
281 (void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes),
282 [&has_sig_dtype](const Signature &sig) {
283 auto dtype = sig.dtype;
284 if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
285 has_sig_dtype = true;
286 }
287 return dtype;
288 });
289 return has_sig_dtype;
290 }
291
PynativeInfer(const PrimitivePyPtr & prim,OpExecInfo * const op_exec_info,const abstract::AbstractBasePtrList & args_spec_list)292 void PynativeInfer(const PrimitivePyPtr &prim, OpExecInfo *const op_exec_info,
293 const abstract::AbstractBasePtrList &args_spec_list) {
294 MS_EXCEPTION_IF_NULL(prim);
295 MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list);
296 prim->BeginRecordAddAttr();
297 auto eval_ret = EvalOnePrim(prim, args_spec_list);
298 MS_EXCEPTION_IF_NULL(eval_ret);
299 AbstractBasePtr infer_res = eval_ret->abstract();
300 MS_EXCEPTION_IF_NULL(infer_res);
301 prim->EndRecordAddAttr();
302 MS_EXCEPTION_IF_NULL(op_exec_info);
303 op_exec_info->abstract = infer_res;
304 MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
305 MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
306 }
307
GetSingleOpGraphInfo(const OpExecInfoPtr & op_exec_info,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<int64_t> & tensors_mask,std::string * graph_info_key)308 void GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<tensor::TensorPtr> &input_tensors,
309 const std::vector<int64_t> &tensors_mask, std::string *graph_info_key) {
310 MS_EXCEPTION_IF_NULL(op_exec_info);
311 MS_EXCEPTION_IF_NULL(graph_info_key);
312 auto &graph_info = *graph_info_key;
313 if (input_tensors.size() != tensors_mask.size()) {
314 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
315 << tensors_mask.size();
316 }
317 std::ostringstream buf;
318 buf << op_exec_info->op_name;
319 bool has_const_input = false;
320 for (size_t index = 0; index < input_tensors.size(); ++index) {
321 MS_EXCEPTION_IF_NULL(input_tensors[index]);
322 buf << input_tensors[index]->shape();
323 buf << input_tensors[index]->data_type();
324 buf << input_tensors[index]->padding_type();
325 // In the case of the same shape, but dtype and format are inconsistent
326 auto tensor_addr = input_tensors[index]->device_address();
327 if (tensor_addr != nullptr) {
328 auto p_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr);
329 MS_EXCEPTION_IF_NULL(p_address);
330 buf << p_address->type_id();
331 buf << p_address->format();
332 }
333 // For constant input
334 if (tensors_mask[index] == kValueNodeTensorMask) {
335 has_const_input = true;
336 auto dtype = input_tensors[index]->Dtype();
337 MS_EXCEPTION_IF_NULL(dtype);
338 if (dtype->type_id() == kNumberTypeInt64) {
339 buf << *reinterpret_cast<int *>(input_tensors[index]->data_c());
340 } else if (dtype->type_id() == kNumberTypeFloat32 || dtype->type_id() == kNumberTypeFloat16) {
341 buf << *reinterpret_cast<float *>(input_tensors[index]->data_c());
342 } else {
343 MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!";
344 }
345 }
346 buf << "_";
347 }
348 // The value of the attribute affects the operator selection
349 const auto &op_prim = op_exec_info->py_primitive;
350 MS_EXCEPTION_IF_NULL(op_prim);
351 const auto &attr_map = op_prim->attrs();
352 (void)std::for_each(attr_map.begin(), attr_map.end(),
353 [&buf](const auto &element) { buf << element.second->ToString(); });
354
355 // Constant input affects output, operators like DropoutGenMask whose output is related to values of input when input
356 // shapes are the same but values are different
357 if (has_const_input) {
358 buf << "_";
359 auto abstr = op_exec_info->abstract;
360 MS_EXCEPTION_IF_NULL(abstr);
361 auto build_shape = abstr->BuildShape();
362 MS_EXCEPTION_IF_NULL(build_shape);
363 buf << build_shape->ToString();
364 auto build_type = abstr->BuildType();
365 MS_EXCEPTION_IF_NULL(build_type);
366 buf << build_type->type_id();
367 }
368 graph_info = buf.str();
369 }
370
FilterTensorArgs(const py::args & args,bool has_sens=false)371 py::list FilterTensorArgs(const py::args &args, bool has_sens = false) {
372 size_t size = args.size();
373 if (size == 0 && has_sens) {
374 MS_LOG(EXCEPTION) << "The size of args is 0, when the flag of sens is set to True";
375 }
376 py::list only_tensors;
377 size_t forward_args_size = has_sens ? size - 1 : size;
378 for (size_t i = 0; i < forward_args_size; ++i) {
379 if (py::isinstance<tensor::Tensor>(args[i])) {
380 only_tensors.append(args[i]);
381 }
382 }
383 if (has_sens) {
384 only_tensors.append(args[forward_args_size]);
385 }
386 return only_tensors;
387 }
388
RunOpConvertConstInputToAttr(const py::object & input_object,size_t input_index,const PrimitivePtr & op_prim,const std::unordered_set<size_t> & input_attrs)389 bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
390 const std::unordered_set<size_t> &input_attrs) {
391 MS_EXCEPTION_IF_NULL(op_prim);
392 const auto &input_names_value = op_prim->GetAttr(kAttrInputNames);
393 if (input_names_value == nullptr) {
394 return false;
395 }
396 const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
397 if (input_index >= input_names_vec.size()) {
398 MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
399 }
400
401 if (input_attrs.find(input_index) != input_attrs.end()) {
402 const auto &value = PyObjToValue(input_object);
403 auto input_name = input_names_vec[input_index];
404 op_prim->AddAttr(input_name, value);
405 return true;
406 }
407 return false;
408 }
409
PlantTensorTupleToVector(const py::tuple & tuple_inputs,const PrimitivePtr & op_prim,std::vector<tensor::TensorPtr> * input_tensors)410 void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
411 std::vector<tensor::TensorPtr> *input_tensors) {
412 MS_EXCEPTION_IF_NULL(op_prim);
413 MS_EXCEPTION_IF_NULL(input_tensors);
414 for (const auto &input_object : tuple_inputs) {
415 if (!py::isinstance<tensor::Tensor>(input_object)) {
416 MS_LOG(EXCEPTION) << "The input object is not a tensor!";
417 }
418 auto tensor = py::cast<tensor::TensorPtr>(input_object);
419 MS_EXCEPTION_IF_NULL(tensor);
420 input_tensors->emplace_back(tensor);
421 }
422 op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(tuple_inputs.size())}));
423 }
424
ConvertValueTupleToTensor(const py::object & input_object,std::vector<tensor::TensorPtr> * input_tensors)425 void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensors) {
426 MS_EXCEPTION_IF_NULL(input_tensors);
427 const auto &input_value = PyObjToValue(input_object);
428 MS_EXCEPTION_IF_NULL(input_value);
429 if (!input_value->isa<ValueTuple>()) {
430 MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
431 }
432 auto value_tuple = input_value->cast<ValueTuplePtr>();
433 MS_EXCEPTION_IF_NULL(value_tuple);
434 tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
435 MS_EXCEPTION_IF_NULL(tensor_ptr);
436 input_tensors->emplace_back(tensor_ptr);
437 }
438
ConvertMultiPyObjectToTensor(const py::object & input_object,const PrimitivePtr & op_prim,std::vector<tensor::TensorPtr> * input_tensors,int64_t * const tensor_mask)439 void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
440 std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
441 MS_EXCEPTION_IF_NULL(op_prim);
442 MS_EXCEPTION_IF_NULL(input_tensors);
443 MS_EXCEPTION_IF_NULL(tensor_mask);
444
445 if (!py::isinstance<py::tuple>(input_object)) {
446 MS_LOG(EXCEPTION) << "The input should be a tuple!";
447 }
448 auto tuple_inputs = py::cast<py::tuple>(input_object);
449 if (tuple_inputs.empty()) {
450 MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
451 }
452 if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
453 PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
454 } else {
455 ConvertValueTupleToTensor(input_object, input_tensors);
456 *tensor_mask = kValueNodeTensorMask;
457 }
458 }
459
ConvertPyObjectToTensor(const py::object & input_object,const PrimitivePtr & op_prim,std::vector<tensor::TensorPtr> * input_tensors,int64_t * const tensor_mask)460 void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
461 std::vector<tensor::TensorPtr> *input_tensors, int64_t *const tensor_mask) {
462 MS_EXCEPTION_IF_NULL(op_prim);
463 MS_EXCEPTION_IF_NULL(input_tensors);
464 MS_EXCEPTION_IF_NULL(tensor_mask);
465 tensor::TensorPtr tensor_ptr = nullptr;
466 if (py::isinstance<tensor::Tensor>(input_object)) {
467 tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
468 } else if (py::isinstance<py::float_>(input_object)) {
469 double input_value = py::cast<py::float_>(input_object);
470 tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
471 *tensor_mask = kValueNodeTensorMask;
472 } else if (py::isinstance<py::int_>(input_object)) {
473 tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<int64_t>(input_object), kInt64);
474 *tensor_mask = kValueNodeTensorMask;
475 } else if (py::isinstance<py::array>(input_object)) {
476 tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr);
477 } else if (py::isinstance<py::list>(input_object)) {
478 auto list_inputs = py::cast<py::list>(input_object);
479 py::tuple tuple_inputs(list_inputs.size());
480 for (size_t i = 0; i < tuple_inputs.size(); ++i) {
481 tuple_inputs[i] = list_inputs[i];
482 }
483 ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
484 return;
485 } else if (py::isinstance<py::tuple>(input_object)) {
486 ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
487 return;
488 } else if (py::isinstance<py::none>(input_object)) {
489 return;
490 } else {
491 MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
492 }
493 MS_EXCEPTION_IF_NULL(tensor_ptr);
494 input_tensors->emplace_back(tensor_ptr);
495 }
496
ConstructInputTensor(const OpExecInfoPtr & op_run_info,std::vector<int64_t> * tensors_mask,std::vector<tensor::TensorPtr> * input_tensors)497 void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> *tensors_mask,
498 std::vector<tensor::TensorPtr> *input_tensors) {
499 MS_EXCEPTION_IF_NULL(op_run_info);
500 MS_EXCEPTION_IF_NULL(tensors_mask);
501 MS_EXCEPTION_IF_NULL(input_tensors);
502 PrimitivePtr op_prim = op_run_info->py_primitive;
503 // Checking whether attr conversion is needed.
504 opt::ConstInputToAttrInfoRegister reg;
505 bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
506 if (op_run_info->is_dynamic_shape &&
507 dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
508 MS_LOG(DEBUG) << "current node is dynamic shape: " << op_run_info->op_name;
509 reg_exist = false;
510 }
511 auto ms_context = MsContext::GetInstance();
512 MS_EXCEPTION_IF_NULL(ms_context);
513 const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
514 if (device_target != kCPUDevice && op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
515 reg_exist = false;
516 }
517 // Gather op needs converting const input to attr on GPU device
518 if (device_target != kGPUDevice && op_run_info->op_name == prim::kPrimGatherD->name()) {
519 reg_exist = false;
520 }
521 // Get input tensors
522 MS_EXCEPTION_IF_NULL(op_prim);
523 op_prim->BeginRecordAddAttr();
524 size_t input_num = op_run_info->op_inputs.size();
525 if (input_num != op_run_info->inputs_mask.size()) {
526 MS_LOG(EXCEPTION) << "The op input size " << input_num << ", but the size of input mask "
527 << op_run_info->inputs_mask.size();
528 }
529 for (size_t index = 0; index < input_num; ++index) {
530 // convert const input to attr
531 if (reg_exist &&
532 RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
533 continue;
534 }
535 // convert const and tuple input to tensor
536 int64_t tensor_mask = op_run_info->inputs_mask[index];
537 ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
538 // Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
539 op_run_info->inputs_mask[index] = tensor_mask;
540 std::vector<int64_t> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
541 tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
542 }
543 op_prim->EndRecordAddAttr();
544 }
545
ConvertAttrToUnifyMindIR(const OpExecInfoPtr & op_run_info)546 void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
547 MS_EXCEPTION_IF_NULL(op_run_info);
548 const auto &op_prim = op_run_info->py_primitive;
549 MS_EXCEPTION_IF_NULL(op_prim);
550
551 const auto &op_name = op_run_info->op_name;
552 auto attrs = op_prim->attrs();
553 for (auto attr : attrs) {
554 bool converted = CheckAndConvertUtils::ConvertAttrValueToString(op_name, attr.first, &attr.second);
555 if (converted) {
556 op_prim->set_attr(attr.first, attr.second);
557 }
558 bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(op_name, attr.first, &attr.second);
559 if (converted_ir_attr) {
560 op_prim->set_attr(attr.first, attr.second);
561 }
562 }
563 }
564
GetTupleSize(const py::tuple & args)565 size_t GetTupleSize(const py::tuple &args) {
566 size_t count = 0;
567 for (size_t i = 0; i < args.size(); i++) {
568 if (py::isinstance<py::tuple>(args[i])) {
569 count += GetTupleSize(args[i]);
570 } else {
571 count += 1;
572 }
573 }
574 return count;
575 }
576
ConvertTupleArg(py::tuple * res,size_t * const index,const py::tuple & arg)577 void ConvertTupleArg(py::tuple *res, size_t *const index, const py::tuple &arg) {
578 MS_EXCEPTION_IF_NULL(res);
579 MS_EXCEPTION_IF_NULL(index);
580 auto res_size = res->size();
581 for (size_t i = 0; i < arg.size(); i++) {
582 if (py::isinstance<py::tuple>(arg[i])) {
583 ConvertTupleArg(res, index, arg[i]);
584 } else {
585 if (*index >= res_size) {
586 MS_LOG(EXCEPTION) << "Convert tuple error, index is greater than tuple size, index " << (*index)
587 << ", tuple size " << res_size;
588 }
589 (*res)[(*index)++] = arg[i];
590 }
591 }
592 }
593
ConvertArgs(const py::tuple & args)594 py::tuple ConvertArgs(const py::tuple &args) {
595 size_t tuple_size = GetTupleSize(args);
596 py::tuple res(tuple_size);
597 size_t index = 0;
598 for (size_t i = 0; i < args.size(); i++) {
599 if (py::isinstance<py::tuple>(args[i])) {
600 ConvertTupleArg(&res, &index, args[i]);
601 } else {
602 if (index >= tuple_size) {
603 MS_LOG(EXCEPTION) << "Convert error, index is greater than tuple size, index " << index << ", tuple size "
604 << tuple_size;
605 }
606 res[index++] = args[i];
607 }
608 }
609 return res;
610 }
611
ResetTopCellInfo(const TopCellInfoPtr & top_cell,const py::args & args)612 void ResetTopCellInfo(const TopCellInfoPtr &top_cell, const py::args &args) {
613 MS_EXCEPTION_IF_NULL(top_cell);
614 top_cell->set_op_num(0);
615 top_cell->all_op_info().clear();
616 top_cell->set_forward_already_run(true);
617 std::string input_args_id;
618 for (size_t i = 0; i < args.size(); ++i) {
619 input_args_id += GetId(args[i]) + "_";
620 }
621 top_cell->set_input_args_id(input_args_id);
622 }
623
RunReplace(const CNodePtr & added_make_tuple,const std::vector<tensor::TensorPtr> & total_output_tensors,const FuncGraphPtr & grad_graph)624 void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
625 const FuncGraphPtr &grad_graph) {
626 MS_EXCEPTION_IF_NULL(grad_graph);
627 MS_EXCEPTION_IF_NULL(added_make_tuple);
628 size_t index = 0;
629 for (size_t i = 1; i < added_make_tuple->size(); ++i) {
630 const auto &input_i = added_make_tuple->input(i);
631 MS_EXCEPTION_IF_NULL(input_i);
632 auto cnode = input_i->cast<CNodePtr>();
633 MS_EXCEPTION_IF_NULL(cnode);
634 MS_LOG(DEBUG) << "Replace new output tensors for cnode: " << cnode->DebugString();
635 auto output_vnode = cnode->forward().first;
636 MS_EXCEPTION_IF_NULL(output_vnode);
637 grad_graph->AddValueNode(output_vnode);
638 MS_LOG(DEBUG) << "Original output value node: " << output_vnode << " info: " << output_vnode->ToString();
639 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
640 if (index + output_num > total_output_tensors.size()) {
641 MS_LOG(EXCEPTION) << "The size of total_output_tensors: " << total_output_tensors.size()
642 << ", but the current index: " << index << ", output num: " << output_num;
643 }
644 // Get new tensors.
645 std::vector<ValuePtr> new_values;
646 for (size_t j = index; j < index + output_num; ++j) {
647 new_values.push_back(total_output_tensors[j]);
648 }
649 index = index + output_num;
650 // Replace new tensors.
651 if (output_num == 1) {
652 output_vnode->set_value(new_values[0]);
653 } else if (output_num > 1) {
654 output_vnode->set_value(std::make_shared<ValueTuple>(new_values));
655 } else {
656 MS_LOG(EXCEPTION) << "The output value of forward cnode is empty, forward cnode info: " << cnode->ToString();
657 }
658 MS_LOG(DEBUG) << "New output value node: " << output_vnode << " info: " << output_vnode->ToString();
659 }
660 // Save op info with new tensors for current running ms_function func graph.
661 if (index != total_output_tensors.size()) {
662 MS_LOG(EXCEPTION) << "The index: " << index
663 << " should be equal to the size of total_output_tensors: " << total_output_tensors.size();
664 }
665 }
666
ReplaceNewTensorsInGradGraph(const TopCellInfoPtr & top_cell,const OpExecInfoPtr & op_exec_info,const ValuePtr & added_out,const FuncGraphPtr & ms_func_graph,const FuncGraphPtr & grad_graph)667 void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecInfoPtr &op_exec_info,
668 const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
669 const FuncGraphPtr &grad_graph) {
670 MS_EXCEPTION_IF_NULL(top_cell);
671 MS_EXCEPTION_IF_NULL(grad_graph);
672 MS_EXCEPTION_IF_NULL(op_exec_info);
673 MS_EXCEPTION_IF_NULL(ms_func_graph);
674 // Get added forward nodes.
675 auto merge_node = ms_func_graph->output();
676 MS_EXCEPTION_IF_NULL(merge_node);
677 auto merge_make_tuple = merge_node->cast<CNodePtr>();
678 MS_EXCEPTION_IF_NULL(merge_make_tuple);
679 constexpr size_t merge_output_size = 3;
680 if (merge_make_tuple->size() != merge_output_size) {
681 MS_LOG(EXCEPTION) << "The input size of merge make tuple node should be 3, but it is: " << merge_make_tuple->size();
682 }
683 constexpr size_t added_output_index = 2;
684 const auto &added_forward_node = merge_make_tuple->input(added_output_index);
685 MS_EXCEPTION_IF_NULL(added_forward_node);
686 if (added_forward_node->isa<ValueNode>()) {
687 MS_LOG(DEBUG) << "The added forward output node is value node: " << added_forward_node->DebugString();
688 std::vector<tensor::TensorPtr> total_output_tensors;
689 TensorValueToTensor(added_out, &total_output_tensors);
690 top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
691 return;
692 }
693 // Replace new output tensors for forward nodes, it will also work in grad graph with same value node.
694 auto added_make_tuple = added_forward_node->cast<CNodePtr>();
695 MS_EXCEPTION_IF_NULL(added_make_tuple);
696 MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_make_tuple->DebugString();
697 std::vector<tensor::TensorPtr> total_output_tensors;
698 TensorValueToTensor(added_out, &total_output_tensors);
699 RunReplace(added_make_tuple, total_output_tensors, grad_graph);
700 top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
701 }
702
SaveOpInfo(const TopCellInfoPtr & top_cell,const std::string & op_info,const std::vector<tensor::TensorPtr> & op_out_tensors)703 void SaveOpInfo(const TopCellInfoPtr &top_cell, const std::string &op_info,
704 const std::vector<tensor::TensorPtr> &op_out_tensors) {
705 MS_EXCEPTION_IF_NULL(top_cell);
706 auto &op_info_with_tensor_id = top_cell->op_info_with_tensor_id();
707 if (op_info_with_tensor_id.find(op_info) != op_info_with_tensor_id.end()) {
708 MS_LOG(EXCEPTION) << "Top cell: " << top_cell.get() << " records op info with tensor id, but get op info "
709 << op_info << " in op_info_with_tensor_id map";
710 }
711 // Record the relationship between the forward op and its output tensor id
712 std::for_each(op_out_tensors.begin(), op_out_tensors.end(),
713 [&op_info_with_tensor_id, &op_info](const tensor::TensorPtr &tensor) {
714 op_info_with_tensor_id[op_info].emplace_back(tensor->id());
715 });
716 }
717
UpdateTensorInfo(const tensor::TensorPtr & new_tensor,const std::vector<tensor::TensorPtr> & pre_tensors)718 void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<tensor::TensorPtr> &pre_tensors) {
719 MS_EXCEPTION_IF_NULL(new_tensor);
720 if (pre_tensors.empty()) {
721 MS_LOG(EXCEPTION) << "The size of pre tensors is empty.";
722 }
723
724 const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
725 for (auto &pre_tensor : pre_tensors) {
726 MS_EXCEPTION_IF_NULL(pre_tensor);
727 MS_LOG(DEBUG) << "Replace Old tensor " << pre_tensor.get() << " id " << pre_tensor->id()
728 << " device_address: " << pre_tensor->device_address() << " shape and type "
729 << pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor " << new_tensor.get() << " id "
730 << new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
731 << new_tensor->GetShapeAndDataTypeInfo();
732 pre_tensor->set_shape(new_tensor->shape());
733 pre_tensor->set_data_type(new_tensor->data_type());
734 if (device_target != kCPUDevice) {
735 pre_tensor->set_device_address(new_tensor->device_address());
736 continue;
737 }
738 // Replace data in device address when run in CPU device.
739 if (pre_tensor->device_address() != nullptr) {
740 auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address());
741 auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
742 MS_EXCEPTION_IF_NULL(old_device_address);
743 auto old_ptr = old_device_address->GetMutablePtr();
744 MS_EXCEPTION_IF_NULL(old_ptr);
745 MS_EXCEPTION_IF_NULL(new_device_address);
746 auto new_ptr = new_device_address->GetPtr();
747 MS_EXCEPTION_IF_NULL(new_ptr);
748 auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
749 if (ret != EOK) {
750 MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
751 }
752 }
753 }
754 }
755
CheckPyNativeContext()756 void CheckPyNativeContext() {
757 const auto ¶llel_context = parallel::ParallelContext::GetInstance();
758 MS_EXCEPTION_IF_NULL(parallel_context);
759 const auto &ms_context = MsContext::GetInstance();
760 MS_EXCEPTION_IF_NULL(ms_context);
761 const auto ¶llel_mode = parallel_context->parallel_mode();
762 if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL &&
763 ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
764 MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
765 }
766 }
767
GetDstType(const TypeId & type_id)768 py::object GetDstType(const TypeId &type_id) {
769 ValuePtr value = nullptr;
770 if (type_id == kNumberTypeFloat16) {
771 value = std::make_shared<Float>(16);
772 } else if (type_id == kNumberTypeFloat32) {
773 value = std::make_shared<Float>(32);
774 } else if (type_id == kNumberTypeFloat64) {
775 value = std::make_shared<Float>(64);
776 } else if (type_id == kNumberTypeBool) {
777 value = std::make_shared<Bool>();
778 } else if (type_id == kNumberTypeInt8) {
779 value = std::make_shared<Int>(8);
780 } else if (type_id == kNumberTypeUInt8) {
781 value = std::make_shared<UInt>(8);
782 } else if (type_id == kNumberTypeInt16) {
783 value = std::make_shared<Int>(16);
784 } else if (type_id == kNumberTypeInt32) {
785 value = std::make_shared<Int>(32);
786 } else if (type_id == kNumberTypeInt64) {
787 value = std::make_shared<Int>(64);
788 } else {
789 MS_LOG(EXCEPTION) << "Not support dst type";
790 }
791 MS_EXCEPTION_IF_NULL(value);
792 return py::cast(value);
793 }
794 } // namespace
795
RealRunOp(const py::args & args)796 py::object RealRunOp(const py::args &args) {
797 CheckPyNativeContext();
798 const auto &executor = PynativeExecutor::GetInstance();
799 MS_EXCEPTION_IF_NULL(executor);
800 OpExecInfoPtr op_exec_info = executor->forward_executor()->GenerateOpExecInfo(args);
801 MS_EXCEPTION_IF_NULL(op_exec_info);
802 py::object ret = py::none();
803 PynativeExecutorTry(executor->forward_executor()->RunOpS, &ret, op_exec_info);
804 return ret;
805 }
806
grad() const807 GradExecutorPtr ForwardExecutor::grad() const {
808 auto grad_executor = grad_executor_.lock();
809 MS_EXCEPTION_IF_NULL(grad_executor);
810 return grad_executor;
811 }
812
IsSubCell(const std::string & cell_id) const813 bool TopCellInfo::IsSubCell(const std::string &cell_id) const {
814 if (sub_cell_list_.empty()) {
815 MS_LOG(DEBUG) << "The sub cell list is empty, there is no sub cell";
816 return false;
817 }
818 if (sub_cell_list_.find(cell_id) != sub_cell_list_.end()) {
819 return true;
820 }
821 return false;
822 }
823
ClearDeviceMemory()824 void TopCellInfo::ClearDeviceMemory() {
825 MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
826 auto ms_context = MsContext::GetInstance();
827 MS_EXCEPTION_IF_NULL(ms_context);
828 const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
829 if (device_target == kCPUDevice) {
830 MS_LOG(DEBUG) << "No need to clear device address when run in CPU device.";
831 return;
832 }
833 k_pynative_cell_ptr_ = nullptr;
834 // Get all tensors obj in value node of running graph
835 std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
836 MS_EXCEPTION_IF_NULL(resource_);
837 const auto &bprop_graph = resource_->func_graph();
838 MS_EXCEPTION_IF_NULL(bprop_graph);
839 const auto &value_node_list = bprop_graph->value_nodes();
840 for (const auto &elem : value_node_list) {
841 auto &node = elem.first;
842 MS_EXCEPTION_IF_NULL(node);
843 auto value_node = node->cast<ValueNodePtr>();
844 MS_EXCEPTION_IF_NULL(value_node);
845 TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
846 }
847 for (const auto &tensor : tensors_in_bprop_graph) {
848 MS_EXCEPTION_IF_NULL(tensor);
849 MS_LOG(DEBUG) << "Clear device address for tensor: " << tensor->ToString();
850 tensor->set_device_address(nullptr);
851 }
852 }
853
Clear()854 void TopCellInfo::Clear() {
855 MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
856 op_num_ = 0;
857 is_dynamic_ = false;
858 vm_compiled_ = false;
859 ms_function_flag_ = false;
860 is_init_kpynative_ = false;
861 need_compile_graph_ = false;
862 forward_already_run_ = false;
863 input_args_id_.clear();
864 all_op_info_.clear();
865 resource_ = nullptr;
866 df_builder_ = nullptr;
867 k_pynative_cell_ptr_ = nullptr;
868 graph_info_map_.clear();
869 sub_cell_list_.clear();
870 op_info_with_tensor_id_.clear();
871 tensor_id_with_tensor_object_.clear();
872 op_info_with_ms_func_forward_tensors_.clear();
873 }
874
RunOpInner(py::object * ret,const OpExecInfoPtr & op_exec_info)875 void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info) {
876 MS_EXCEPTION_IF_NULL(ret);
877 MS_EXCEPTION_IF_NULL(op_exec_info);
878 MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name;
879 if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
880 RunMixedPrecisionCastOp(op_exec_info, ret);
881 return;
882 }
883
884 // 1.Set cast for inputs
885 SetCastForInputs(op_exec_info);
886 // 2.Construct graph, first step abs will update by node
887 auto cnode = ConstructForwardGraph(op_exec_info);
888 // 3.Get inputs abstract
889 abstract::AbstractBasePtrList args_spec_list;
890 GetInputsArgsSpec(op_exec_info, &args_spec_list);
891 // 4.Get output abstract
892 bool prim_cache_hit = false;
893 GetOpOutputAbstract(op_exec_info, args_spec_list, &prim_cache_hit);
894 // 5.Get output
895 GetOpOutput(op_exec_info, args_spec_list, cnode, prim_cache_hit, ret);
896 }
897
GenerateOpExecInfo(const py::args & args)898 OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
899 if (args.size() != PY_ARGS_NUM) {
900 MS_LOG(EXCEPTION) << "Three args are needed by RunOp";
901 }
902 const auto &op_exec_info = std::make_shared<OpExecInfo>();
903 const auto &op_name = py::cast<std::string>(args[PY_NAME]);
904 op_exec_info->op_name = op_name;
905
906 const auto &adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
907 MS_EXCEPTION_IF_NULL(adapter);
908 auto prim = adapter->attached_primitive();
909 if (prim == nullptr) {
910 prim = std::make_shared<PrimitivePy>(args[PY_PRIM], adapter);
911 adapter->set_attached_primitive(prim);
912 }
913
914 if (!prim->HasPyObj()) {
915 MS_LOG(EXCEPTION) << "Pyobj is empty";
916 }
917 op_exec_info->py_primitive = prim;
918 op_exec_info->op_inputs = args[PY_INPUTS];
919 op_exec_info->lazy_build = lazy_build_;
920 return op_exec_info;
921 }
922
SetCastForInputs(const OpExecInfoPtr & op_exec_info)923 void ForwardExecutor::SetCastForInputs(const OpExecInfoPtr &op_exec_info) {
924 MS_EXCEPTION_IF_NULL(op_exec_info);
925 // No need cast self
926 if (op_exec_info->op_name == prim::kPrimCast->name()) {
927 return;
928 }
929
930 // Mixed precision conversion tensors which has cast dtype
931 SetTensorMixPrecisionCast(op_exec_info);
932 // Implicit transform
933 SetImplicitCast(op_exec_info);
934 }
935
RunMixedPrecisionCastOp(const OpExecInfoPtr & op_exec_info,py::object * ret)936 void ForwardExecutor::RunMixedPrecisionCastOp(const OpExecInfoPtr &op_exec_info, py::object *ret) {
937 py::tuple res = RunOpWithInitBackendPolicy(op_exec_info);
938 MS_EXCEPTION_IF_NULL(ret);
939 if (res.size() == 1) {
940 *ret = res[0];
941 return;
942 }
943 *ret = std::move(res);
944 }
945
SetNonCostantValueAbs(const AbstractBasePtr & abs,size_t i,const std::string & id)946 void ForwardExecutor::SetNonCostantValueAbs(const AbstractBasePtr &abs, size_t i, const std::string &id) {
947 MS_EXCEPTION_IF_NULL(abs);
948 if (abs->isa<abstract::AbstractTensor>()) {
949 abs->set_value(kAnyValue);
950 } else if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>()) {
951 const auto &abs_seq = abs->cast<abstract::AbstractSequeuePtr>();
952 MS_EXCEPTION_IF_NULL(abs_seq);
953 for (auto &item : abs_seq->elements()) {
954 MS_EXCEPTION_IF_NULL(item);
955 if (item->isa<abstract::AbstractTensor>()) {
956 item->set_value(kAnyValue);
957 }
958 }
959 }
960 MS_LOG(DEBUG) << "Set " << i << "th abs " << abs->ToString();
961 node_abs_map_[id] = abs;
962 }
963
GetInputsArgsSpec(const OpExecInfoPtr & op_exec_info,abstract::AbstractBasePtrList * args_spec_list)964 void ForwardExecutor::GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info,
965 abstract::AbstractBasePtrList *args_spec_list) {
966 MS_EXCEPTION_IF_NULL(op_exec_info);
967 MS_EXCEPTION_IF_NULL(args_spec_list);
968 auto prim = op_exec_info->py_primitive;
969 MS_EXCEPTION_IF_NULL(prim);
970 for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
971 abstract::AbstractBasePtr abs = nullptr;
972 const auto &obj = op_exec_info->op_inputs[i];
973 const auto &id = GetId(obj);
974 MS_LOG(DEBUG) << "Set input abs " << id;
975 auto it = node_abs_map_.find(id);
976 if (it != node_abs_map_.end()) {
977 abs = it->second;
978 }
979 const auto const_input_index = prim->get_const_input_indexes();
980 bool have_const_input = !const_input_index.empty();
981 bool is_const_prim = prim->is_const_prim();
982 MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
983 << prim->is_const_prim();
984 bool is_const_input =
985 have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
986 if (abs == nullptr || is_const_prim || is_const_input) {
987 abs = PyObjToValue(obj)->ToAbstract();
988 if (!is_const_prim && !is_const_input) {
989 SetNonCostantValueAbs(abs, i, id);
990 }
991 }
992 args_spec_list->emplace_back(abs);
993 }
994 }
995
ConstructForwardGraph(const OpExecInfoPtr & op_exec_info)996 CNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_info) {
997 MS_EXCEPTION_IF_NULL(op_exec_info);
998 auto prim = op_exec_info->py_primitive;
999 std::vector<AnfNodePtr> inputs;
1000 std::vector<int64_t> op_masks;
1001 inputs.emplace_back(NewValueNode(prim));
1002 for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
1003 const auto &obj = op_exec_info->op_inputs[i];
1004 bool op_mask = false;
1005 tensor::MetaTensorPtr meta_tensor = nullptr;
1006 if (py::isinstance<tensor::MetaTensor>(obj)) {
1007 meta_tensor = obj.cast<tensor::MetaTensorPtr>();
1008 if (meta_tensor) {
1009 op_mask = meta_tensor->is_parameter();
1010 }
1011 }
1012 MS_LOG(DEBUG) << "Args i " << i << ", op mask " << op_mask;
1013 op_masks.emplace_back(static_cast<int64_t>(op_mask));
1014
1015 // Construct grad graph
1016 if (grad()->need_construct_graph()) {
1017 const auto &id = GetId(obj);
1018 AnfNodePtr input_node = nullptr;
1019 input_node = grad()->GetInput(obj, op_mask);
1020 // update abstract
1021 if (input_node != nullptr) {
1022 if (input_node->abstract() != nullptr) {
1023 abstract::AbstractBasePtr abs = input_node->abstract();
1024 node_abs_map_[id] = abs;
1025 }
1026 inputs.emplace_back(input_node);
1027 }
1028 }
1029 }
1030 op_exec_info->inputs_mask = std::move(op_masks);
1031 CNodePtr cnode = nullptr;
1032 if (grad()->need_construct_graph()) {
1033 cnode = grad()->curr_g()->NewCNodeInOrder(inputs);
1034 MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << ", new cnode is " << cnode->DebugString();
1035 }
1036 return cnode;
1037 }
1038
GetOpOutputAbstract(const OpExecInfoPtr & op_exec_info,const abstract::AbstractBasePtrList & args_spec_list,bool * prim_cache_hit)1039 void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
1040 const abstract::AbstractBasePtrList &args_spec_list, bool *prim_cache_hit) {
1041 MS_EXCEPTION_IF_NULL(op_exec_info);
1042 MS_EXCEPTION_IF_NULL(prim_cache_hit);
1043 auto op_name = op_exec_info->op_name;
1044 auto prim = op_exec_info->py_primitive;
1045 MS_EXCEPTION_IF_NULL(prim);
1046
1047 AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
1048 auto temp = prim_abs_list_.find(key);
1049 if (temp != prim_abs_list_.end()) {
1050 MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
1051 auto iter = temp->second.find(args_spec_list);
1052 if (iter != temp->second.end()) {
1053 MS_LOG(DEBUG) << "Match prim ok " << op_name;
1054 op_exec_info->abstract = iter->second.abs;
1055 prim->set_evaluate_added_attrs(iter->second.attrs);
1056 *prim_cache_hit = true;
1057 }
1058 }
1059
1060 if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
1061 // Use python infer method
1062 if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
1063 PynativeInfer(prim, op_exec_info.get(), args_spec_list);
1064 }
1065 }
1066 // Get output dynamic shape info
1067 auto abstract = op_exec_info->abstract;
1068 MS_EXCEPTION_IF_NULL(abstract);
1069 auto shape = abstract->BuildShape();
1070 MS_EXCEPTION_IF_NULL(shape);
1071
1072 if (shape->IsDynamic()) {
1073 op_exec_info->is_dynamic_shape = true;
1074 // Dynamic shape operator in the current top cell, disable backend cache
1075 grad()->EnableOpGraphCache(false);
1076 }
1077 }
1078
GetOpOutput(const OpExecInfoPtr & op_exec_info,const abstract::AbstractBasePtrList & args_spec_list,const CNodePtr & cnode,bool prim_cache_hit,py::object * ret)1079 void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
1080 const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
1081 bool prim_cache_hit, py::object *ret) {
1082 MS_EXCEPTION_IF_NULL(op_exec_info);
1083 auto prim = op_exec_info->py_primitive;
1084 MS_EXCEPTION_IF_NULL(prim);
1085 // Infer output value by constant folding
1086 MS_EXCEPTION_IF_NULL(ret);
1087 py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
1088 if (!output["value"].is_none()) {
1089 *ret = output["value"];
1090 grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
1091 return;
1092 }
1093 if (prim->is_const_prim()) {
1094 *ret = py::cast("");
1095 grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
1096 return;
1097 }
1098
1099 // Add output abstract info into cache, the const value needs to infer evert step
1100 if (grad()->enable_op_cache() && !prim_cache_hit && !op_exec_info->is_dynamic_shape) {
1101 AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
1102 auto &out = prim_abs_list_[key];
1103 out[args_spec_list].abs = op_exec_info->abstract;
1104 out[args_spec_list].attrs = prim->evaluate_added_attrs();
1105 }
1106 // run op with selected backend
1107 auto result = RunOpWithInitBackendPolicy(op_exec_info);
1108 py::object out_real = result;
1109 if (result.size() == 1 && op_exec_info->abstract != nullptr &&
1110 !op_exec_info->abstract->isa<abstract::AbstractSequeue>()) {
1111 out_real = result[0];
1112 }
1113 // get output value
1114 ValuePtr out_real_value = nullptr;
1115 if (grad()->grad_flag()) {
1116 out_real_value = PyObjToValue(out_real);
1117 }
1118 // Save cnode info and build grad graph
1119 if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
1120 MS_EXCEPTION_IF_NULL(cnode);
1121 const auto &obj_id = GetId(out_real);
1122 cnode->set_abstract(op_exec_info->abstract);
1123 node_abs_map_[obj_id] = op_exec_info->abstract;
1124 grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
1125 grad()->DoOpGrad(op_exec_info, cnode, out_real_value);
1126 } else {
1127 node_abs_map_.clear();
1128 }
1129 // Record op info for judge whether the construct of cell has been changed
1130 grad()->RecordGradOpInfo(op_exec_info, out_real_value);
1131 grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
1132 *ret = out_real;
1133 }
1134
DoAutoCast(const py::object & arg,const TypeId & type_id,const std::string & op_name,size_t index)1135 py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
1136 size_t index) {
1137 static py::object cast_prim = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
1138 const auto &op_exec_info = std::make_shared<OpExecInfo>();
1139 op_exec_info->op_name = prim::kPrimCast->name();
1140 const auto &adapter = py::cast<PrimitivePyAdapterPtr>(cast_prim);
1141 MS_EXCEPTION_IF_NULL(adapter);
1142 auto prim = adapter->attached_primitive();
1143 if (prim == nullptr) {
1144 prim = std::make_shared<PrimitivePy>(cast_prim, adapter);
1145 adapter->set_attached_primitive(prim);
1146 }
1147 op_exec_info->py_primitive = prim;
1148 op_exec_info->is_mixed_precision_cast = true;
1149 op_exec_info->next_op_name = op_name;
1150 op_exec_info->next_input_index = index;
1151 py::object dst_type = GetDstType(type_id);
1152 py::tuple inputs(ARG_SIZE);
1153 inputs[0] = arg;
1154 inputs[1] = dst_type;
1155 op_exec_info->op_inputs = inputs;
1156 op_exec_info->lazy_build = lazy_build_;
1157 py::object ret = py::none();
1158 RunOpInner(&ret, op_exec_info);
1159 return ret;
1160 }
1161
DoAutoCastTuple(const py::tuple & tuple,const TypeId & type_id,const std::string & op_name,size_t index)1162 py::object ForwardExecutor::DoAutoCastTuple(const py::tuple &tuple, const TypeId &type_id, const std::string &op_name,
1163 size_t index) {
1164 auto tuple_size = tuple.size();
1165 py::tuple result(tuple_size);
1166 for (size_t i = 0; i < tuple_size; i++) {
1167 if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
1168 result[i] = DoAutoCastTuple(tuple[i], type_id, op_name, index);
1169 } else {
1170 result[i] = DoAutoCast(tuple[i], type_id, op_name, index);
1171 }
1172 }
1173 return std::move(result);
1174 }
1175
DoParamMixPrecisionCast(bool * is_cast,const py::object & obj,const std::string & op_name,size_t index)1176 py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name,
1177 size_t index) {
1178 MS_EXCEPTION_IF_NULL(is_cast);
1179 const auto &tensor = py::cast<tensor::TensorPtr>(obj);
1180 MS_EXCEPTION_IF_NULL(tensor);
1181 const auto &cast_type = tensor->cast_dtype();
1182 if (cast_type != nullptr) {
1183 auto source_element = tensor->Dtype();
1184 if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
1185 MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
1186 *is_cast = true;
1187 return DoAutoCast(obj, cast_type->type_id(), op_name, index);
1188 }
1189 }
1190 return obj;
1191 }
1192
DoParamMixPrecisionCastTuple(bool * is_cast,const py::tuple & tuple,const std::string & op_name,size_t index)1193 py::object ForwardExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple,
1194 const std::string &op_name, size_t index) {
1195 MS_EXCEPTION_IF_NULL(is_cast);
1196 auto tuple_size = tuple.size();
1197 py::tuple result(tuple_size);
1198 for (size_t i = 0; i < tuple_size; i++) {
1199 if (py::isinstance<tensor::MetaTensor>(tuple[i])) {
1200 MS_LOG(DEBUG) << "Call cast for item " << i;
1201 result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index);
1202 } else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) {
1203 result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index);
1204 } else {
1205 result[i] = tuple[i];
1206 }
1207 }
1208 return std::move(result);
1209 }
1210
DoSignatrueCast(const PrimitivePyPtr & prim,const std::unordered_map<SignatureEnumDType,TypeId> & dst_type,const std::vector<SignatureEnumDType> & dtypes,const OpExecInfoPtr & op_exec_info)1211 void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim,
1212 const std::unordered_map<SignatureEnumDType, TypeId> &dst_type,
1213 const std::vector<SignatureEnumDType> &dtypes,
1214 const OpExecInfoPtr &op_exec_info) {
1215 MS_EXCEPTION_IF_NULL(prim);
1216 MS_EXCEPTION_IF_NULL(op_exec_info);
1217 const auto &signature = prim->signatures();
1218 auto &input_args = op_exec_info->op_inputs;
1219 size_t input_args_size = input_args.size();
1220 if (!dtypes.empty() && input_args_size > dtypes.size()) {
1221 MS_LOG(EXCEPTION) << "The input args size " << input_args_size << " exceeds the size of dtypes " << dtypes.size();
1222 }
1223 for (size_t i = 0; i < input_args_size; ++i) {
1224 // No need to implicit cast if no dtype.
1225 if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
1226 continue;
1227 }
1228 auto it = dst_type.find(dtypes[i]);
1229 if (it == dst_type.end() || it->second == kTypeUnknown) {
1230 continue;
1231 }
1232 MS_LOG(DEBUG) << "Check inputs " << i;
1233 const auto &obj = input_args[i];
1234 auto sig = SignatureEnumRW::kRWDefault;
1235 if (!signature.empty()) {
1236 if (i >= signature.size()) {
1237 MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
1238 << ", index " << i;
1239 }
1240 sig = signature[i].rw;
1241 }
1242 TypeId arg_type_id = kTypeUnknown;
1243 if (py::isinstance<tensor::MetaTensor>(obj)) {
1244 const auto &arg = py::cast<tensor::MetaTensorPtr>(obj);
1245 arg_type_id = arg->data_type();
1246 }
1247 // Implicit cast
1248 bool is_same_type = false;
1249 if (arg_type_id != kTypeUnknown) {
1250 is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second);
1251 }
1252 if (sig == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) {
1253 prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id),
1254 TypeIdToMsTypeStr(it->second));
1255 }
1256 if (is_same_type) {
1257 continue;
1258 }
1259
1260 if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
1261 MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i
1262 << "th input is a not support implicit conversion type: "
1263 << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
1264 << py::cast<py::str>(obj) << ".";
1265 }
1266 py::object cast_output = DoAutoCast(input_args[i], it->second, op_exec_info->op_name, i);
1267 input_args[i] = cast_output;
1268 }
1269 }
1270
SetTensorMixPrecisionCast(const OpExecInfoPtr & op_exec_info)1271 void ForwardExecutor::SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info) {
1272 MS_EXCEPTION_IF_NULL(op_exec_info);
1273 const auto &prim = op_exec_info->py_primitive;
1274 MS_EXCEPTION_IF_NULL(prim);
1275 const auto &signature = prim->signatures();
1276 for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
1277 const auto &obj = op_exec_info->op_inputs[i];
1278 auto sig = SignatureEnumRW::kRWDefault;
1279 if (!signature.empty()) {
1280 if (i >= signature.size()) {
1281 MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
1282 << ", index " << i;
1283 }
1284 sig = signature[i].rw;
1285 }
1286 MS_LOG(DEBUG) << "Check mix precision " << op_exec_info->op_name << " input " << i;
1287 // mix precision for non param
1288 bool is_cast = false;
1289 py::object cast_output;
1290 if (py::isinstance<tensor::MetaTensor>(obj)) {
1291 auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
1292 if (meta_tensor && meta_tensor->is_parameter()) {
1293 // If parameter write(not kRWRead), no need cast
1294 if (sig != SignatureEnumRW::kRWRead) {
1295 continue;
1296 }
1297 }
1298 cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i);
1299 } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
1300 // mix precision for tuple inputs
1301 cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i);
1302 }
1303 if (is_cast) {
1304 op_exec_info->op_inputs[i] = cast_output;
1305 }
1306 }
1307 }
1308
SetImplicitCast(const OpExecInfoPtr & op_exec_info)1309 void ForwardExecutor::SetImplicitCast(const OpExecInfoPtr &op_exec_info) {
1310 MS_EXCEPTION_IF_NULL(op_exec_info);
1311 const auto &prim = op_exec_info->py_primitive;
1312 MS_EXCEPTION_IF_NULL(prim);
1313 const auto &it = implicit_cast_map_.find(prim->name());
1314 if (it == implicit_cast_map_.end()) {
1315 MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name << " first";
1316 const auto &signature = prim->signatures();
1317 auto sig_size = signature.size();
1318 // Ignore monad signature
1319 for (const auto &sig : signature) {
1320 if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
1321 --sig_size;
1322 }
1323 }
1324 auto size = op_exec_info->op_inputs.size();
1325 if (sig_size > 0 && sig_size != size) {
1326 MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
1327 << "signature size " << sig_size;
1328 }
1329 std::vector<SignatureEnumDType> dtypes;
1330 std::unordered_map<SignatureEnumDType, std::vector<size_t>> type_indexes;
1331 bool has_dtype_sig = GetSignatureType(op_exec_info->py_primitive, &dtypes);
1332 if (has_dtype_sig) {
1333 std::unordered_map<SignatureEnumDType, TypeId> dst_type;
1334 GetTypeIndex(dtypes, &type_indexes);
1335 GetDstType(op_exec_info->op_inputs, type_indexes, &dst_type);
1336 DoSignatrueCast(op_exec_info->py_primitive, dst_type, dtypes, op_exec_info);
1337 }
1338 PrimSignature sig_value{has_dtype_sig, dtypes, type_indexes};
1339 implicit_cast_map_[prim->name()] = sig_value;
1340 } else {
1341 if (!it->second.has_dtype_sig) {
1342 MS_LOG(DEBUG) << op_exec_info->op_name << " have no dtype sig";
1343 return;
1344 }
1345 MS_LOG(DEBUG) << "Do signature for " << op_exec_info->op_name << " with cache";
1346 std::unordered_map<SignatureEnumDType, TypeId> dst_type;
1347 GetDstType(op_exec_info->op_inputs, it->second.type_indexes, &dst_type);
1348 DoSignatrueCast(op_exec_info->py_primitive, dst_type, it->second.dtypes, op_exec_info);
1349 }
1350 }
1351
GetInput(const py::object & obj,bool op_mask)1352 AnfNodePtr GradExecutor::GetInput(const py::object &obj, bool op_mask) {
1353 AnfNodePtr node = nullptr;
1354 const auto &obj_id = GetId(obj);
1355
1356 if (op_mask) {
1357 MS_LOG(DEBUG) << "Cell parameters(weights)";
1358 // get the parameter name from parameter object
1359 auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
1360 if (py::isinstance<py::none>(name_attr)) {
1361 MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
1362 }
1363 const auto ¶m_name = py::cast<std::string>(name_attr);
1364 auto df_builder = top_cell()->df_builder();
1365 MS_EXCEPTION_IF_NULL(df_builder);
1366 auto graph_info = top_cell()->graph_info_map().at(df_builder);
1367 MS_EXCEPTION_IF_NULL(graph_info);
1368 if (graph_info->params.find(obj_id) == graph_info->params.end()) {
1369 auto free_param = df_builder->add_parameter();
1370 free_param->set_name(param_name);
1371 free_param->debug_info()->set_name(param_name);
1372 auto value = py::cast<tensor::TensorPtr>(obj);
1373 free_param->set_default_param(value);
1374 MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
1375 SetParamNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
1376 SetParamNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
1377 SetNodeMapInGraphInfoMap(df_builder, obj_id, free_param);
1378 SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
1379 return free_param;
1380 }
1381 node = graph_info->params.at(obj_id);
1382 MS_EXCEPTION_IF_NULL(node);
1383 MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
1384 return node;
1385 }
1386
1387 auto curr_graph_info = top_cell()->graph_info_map().at(curr_g_);
1388 MS_EXCEPTION_IF_NULL(curr_graph_info);
1389 if (curr_graph_info->node_map.find(obj_id) != curr_graph_info->node_map.end()) {
1390 // op(x, y)
1391 // out = op(op1(x, y))
1392 // out = op(cell1(x, y))
1393 // out = op(cell1(x, y)[0])
1394 node = GetObjNode(obj, obj_id);
1395 } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
1396 // out = op((x, y))
1397 // out = cell((x, y))
1398 auto tuple = obj.cast<py::tuple>();
1399 // cell((1,2)): support not mix (scalar, tensor)
1400 if (!tuple.empty() && !py::isinstance<tensor::Tensor>(tuple[0])) {
1401 return MakeValueNode(obj, obj_id);
1402 }
1403 std::vector<AnfNodePtr> args;
1404 args.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1405 auto tuple_size = tuple.size();
1406 for (size_t i = 0; i < tuple_size; i++) {
1407 args.emplace_back(GetInput(tuple[i], false));
1408 }
1409 auto cnode = curr_g_->NewCNode(args);
1410 SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode);
1411 node = cnode;
1412 } else {
1413 node = MakeValueNode(obj, obj_id);
1414 }
1415 node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
1416 : MS_LOG(DEBUG) << "Get input node " << node->ToString() << ", id " << obj_id;
1417 return node;
1418 }
1419
GetObjNode(const py::object & obj,const std::string & obj_id)1420 AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
1421 auto graph_info = top_cell()->graph_info_map().at(curr_g_);
1422 MS_EXCEPTION_IF_NULL(graph_info);
1423 const auto &out = graph_info->node_map.at(obj_id);
1424 if (out.second.size() == 1 && out.second[0] == -1) {
1425 return out.first;
1426 }
1427 MS_LOG(DEBUG) << "Output size " << out.second.size();
1428
1429 // Params node
1430 if (graph_info->params.find(obj_id) != graph_info->params.end()) {
1431 auto para_node = out.first;
1432 for (auto &v : out.second) {
1433 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(v)};
1434 MS_EXCEPTION_IF_NULL(curr_g_);
1435 para_node = curr_g_->NewCNode(tuple_get_item_inputs);
1436 }
1437 return para_node;
1438 }
1439
1440 // Normal node
1441 auto node = out.first->cast<CNodePtr>();
1442 MS_EXCEPTION_IF_NULL(node);
1443 auto abs = node->abstract();
1444 ValuePtr out_obj = nullptr;
1445 if (node->forward().first != nullptr) {
1446 out_obj = node->forward().first->value();
1447 } else {
1448 out_obj = PyObjToValue(obj);
1449 }
1450 for (const auto idx : out.second) {
1451 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
1452 node = curr_g_->NewCNode(tuple_get_item_inputs);
1453 if (out_obj->isa<ValueTuple>()) {
1454 node->add_input_value(out_obj, "");
1455 node->add_input_value(MakeValue(idx), "");
1456 auto out_tuple = out_obj->cast<ValueTuplePtr>();
1457 MS_EXCEPTION_IF_NULL(out_tuple);
1458 if (static_cast<size_t>(idx) >= out_tuple->size()) {
1459 MS_LOG(EXCEPTION) << "Index exceeds the size of tuple. Index " << idx << ", tuple size " << out_tuple->size();
1460 }
1461 out_obj = (*out_tuple)[static_cast<size_t>(idx)];
1462 node->set_forward(NewValueNode(out_obj), "");
1463 }
1464 if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
1465 auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
1466 MS_EXCEPTION_IF_NULL(abs_tuple);
1467 const auto &elements = abs_tuple->elements();
1468 if (static_cast<size_t>(idx) >= elements.size()) {
1469 MS_LOG(EXCEPTION) << "Index exceeds the size of elements. Index " << idx << ", elements size "
1470 << elements.size();
1471 }
1472 auto prim_abs = elements[static_cast<size_t>(idx)];
1473 MS_EXCEPTION_IF_NULL(prim_abs);
1474 MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
1475 node->set_abstract(prim_abs);
1476 }
1477 }
1478 if (node->abstract() != nullptr) {
1479 forward()->node_abs_map()[obj_id] = node->abstract();
1480 }
1481 MS_LOG(DEBUG) << "GetObjNode output " << node->DebugString();
1482 return node;
1483 }
1484
MakeValueNode(const py::object & obj,const std::string & obj_id)1485 AnfNodePtr GradExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) {
1486 ValuePtr converted_ret = nullptr;
1487 parse::ConvertData(obj, &converted_ret);
1488 auto node = NewValueNode(converted_ret);
1489 SetNodeMapInGraphInfoMap(curr_g_, obj_id, node);
1490 return node;
1491 }
1492
GetTopCell(const std::string & already_run_cell_id)1493 TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id) {
1494 TopCellInfoPtr find_top_cell = nullptr;
1495 for (const auto &top_cell : top_cell_list_) {
1496 MS_EXCEPTION_IF_NULL(top_cell);
1497 // Complete match, means run grad operation first
1498 if (top_cell->already_run_cell_id() == already_run_cell_id) {
1499 return top_cell;
1500 }
1501 // Partial match, means run forward first
1502 if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
1503 top_cell->already_run_cell_id().back() == '_') {
1504 find_top_cell = top_cell;
1505 break;
1506 }
1507 }
1508 // Same topcell info, but grad operation is not the same, construct backward graph again
1509 if (find_top_cell != nullptr) {
1510 if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
1511 MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
1512 << grad_operation_;
1513 EraseTopCellFromTopCellList(find_top_cell);
1514 (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
1515 return nullptr;
1516 } else {
1517 return find_top_cell;
1518 }
1519 }
1520 return nullptr;
1521 }
1522
EnableOpGraphCache(bool is_enable)1523 void GradExecutor::EnableOpGraphCache(bool is_enable) {
1524 MS_LOG(DEBUG) << "Op cache is enable: " << is_enable;
1525 enable_op_cache_ = is_enable;
1526 const auto inst = MsContext::GetInstance();
1527 MS_EXCEPTION_IF_NULL(inst);
1528 inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
1529 }
1530
RecordGradOpInfo(const OpExecInfoPtr & op_exec_info,const ValuePtr & op_out)1531 void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
1532 if (!grad_flag_) {
1533 MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
1534 return;
1535 }
1536 MS_EXCEPTION_IF_NULL(op_exec_info);
1537 MS_EXCEPTION_IF_NULL(op_out);
1538 std::string input_args_info;
1539 // Record input args info (weight or data)
1540 for (const auto mask : op_exec_info->inputs_mask) {
1541 if (mask) {
1542 input_args_info += "w";
1543 continue;
1544 }
1545 input_args_info += "d";
1546 }
1547 // Record op name and index
1548 op_exec_info->op_info.clear();
1549 const auto &curr_op_num = top_cell()->op_num();
1550 op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
1551 // The out shape is added to determine those ops that change the shape
1552 auto out_abs = op_out->ToAbstract();
1553 if (out_abs != nullptr) {
1554 auto out_shape = out_abs->BuildShape()->ToString();
1555 if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
1556 op_exec_info->op_info += "-" + out_shape;
1557 }
1558 }
1559 top_cell()->all_op_info() += "-" + op_exec_info->op_info;
1560 top_cell()->set_op_num(curr_op_num + 1);
1561 }
1562
SaveOutputNodeMap(const std::string & obj_id,const py::object & out_real,const CNodePtr & cnode)1563 void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const CNodePtr &cnode) {
1564 if (cell_stack_.empty()) {
1565 MS_LOG(DEBUG) << "No need save output";
1566 return;
1567 }
1568 MS_EXCEPTION_IF_NULL(cnode);
1569 MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << " id " << obj_id;
1570 if (py::isinstance<py::tuple>(out_real)) {
1571 auto value = py::cast<py::tuple>(out_real);
1572 auto size = static_cast<int64_t>(value.size());
1573 if (size > 1) {
1574 for (int64_t i = 0; i < size; ++i) {
1575 auto value_id = GetId(value[static_cast<size_t>(i)]);
1576 SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i);
1577 }
1578 }
1579 }
1580 SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode);
1581 }
1582
1583 // Run ad grad for curr op and connect grad graph with previous op
DoOpGrad(const OpExecInfoPtr & op_exec_info,const CNodePtr & cnode,const ValuePtr & op_out)1584 void GradExecutor::DoOpGrad(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const ValuePtr &op_out) {
1585 MS_EXCEPTION_IF_NULL(op_out);
1586 if (grad_is_running_ && !bprop_grad_stack_.top().second) {
1587 MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
1588 return;
1589 }
1590 ValuePtrList input_args;
1591 for (size_t i = 0; i < op_exec_info->op_inputs.size(); ++i) {
1592 const auto &arg = PyObjToValue(op_exec_info->op_inputs[i]);
1593 input_args.emplace_back(arg);
1594 }
1595
1596 if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, op_out)) {
1597 MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_exec_info->op_name;
1598 }
1599 }
1600
UpdateMsFunctionForwardTensors(const OpExecInfoPtr & op_exec_info,const ValuePtr & new_forward_value)1601 void GradExecutor::UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info,
1602 const ValuePtr &new_forward_value) {
1603 MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase();
1604 MS_EXCEPTION_IF_NULL(new_forward_value);
1605 MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
1606 std::vector<tensor::TensorPtr> new_tensors;
1607 TensorValueToTensor(new_forward_value, &new_tensors);
1608 if (new_tensors.empty()) {
1609 MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
1610 return;
1611 }
1612
1613 MS_EXCEPTION_IF_NULL(op_exec_info);
1614 const auto &old_tensors = top_cell()->op_info_with_ms_func_forward_tensors().at(op_exec_info->op_info);
1615 if (old_tensors.size() != new_tensors.size()) {
1616 MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
1617 << ", but the size of new tensors is: " << new_tensors.size()
1618 << ", the current op info is: " << op_exec_info->op_info;
1619 }
1620 for (size_t i = 0; i < new_tensors.size(); ++i) {
1621 UpdateTensorInfo(new_tensors[i], {old_tensors[i]});
1622 old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
1623 }
1624 }
1625
MakeCNodeForMsFunction(const FuncGraphPtr & ms_func_graph,const py::args & args,ValuePtrList * input_values,CNodePtr * ms_function_cnode)1626 void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
1627 ValuePtrList *input_values, CNodePtr *ms_function_cnode) {
1628 // Get input node info of ms_function
1629 MS_EXCEPTION_IF_NULL(ms_func_graph);
1630 std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
1631 MS_EXCEPTION_IF_NULL(input_values);
1632 for (size_t i = 0; i < args.size(); ++i) {
1633 auto input_i_node = GetInput(args[i], false);
1634 MS_EXCEPTION_IF_NULL(input_i_node);
1635 MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
1636 input_nodes.emplace_back(input_i_node);
1637 const auto &inp_i_value = PyObjToValue(args[i]);
1638 MS_LOG(DEBUG) << "The input " << i << " value of ms_function graph is: " << inp_i_value->ToString();
1639 (*input_values).emplace_back(inp_i_value);
1640 }
1641
1642 // Get dfbuilder and graph info map
1643 auto df_builder = top_cell()->df_builder();
1644 MS_EXCEPTION_IF_NULL(df_builder);
1645 const auto &graph_info = top_cell()->graph_info_map().at(df_builder);
1646 MS_EXCEPTION_IF_NULL(graph_info);
1647 // Get weights info of ms_function
1648 std::vector<AnfNodePtr> new_params;
1649 auto manage = Manage(ms_func_graph, false);
1650 for (const auto &anf_node : ms_func_graph->parameters()) {
1651 MS_EXCEPTION_IF_NULL(anf_node);
1652 auto param = anf_node->cast<ParameterPtr>();
1653 MS_EXCEPTION_IF_NULL(param);
1654 if (!param->has_default()) {
1655 new_params.push_back(param);
1656 continue;
1657 }
1658 auto param_info = param->param_info();
1659 MS_EXCEPTION_IF_NULL(param_info);
1660 auto param_name = param_info->name();
1661 if (graph_info->params.count(param_name)) {
1662 // Share same weight parameter in different ms_function call.
1663 auto same_param = graph_info->params.at(param_name);
1664 manage->Replace(anf_node, same_param);
1665 param = same_param;
1666 } else {
1667 df_builder->add_parameter(param);
1668 param->debug_info()->set_name(param_name);
1669 }
1670 new_params.push_back(param);
1671 input_nodes.emplace_back(param);
1672 (*input_values).emplace_back(param->default_param());
1673 SetParamNodeMapInGraphInfoMap(df_builder, param_name, param);
1674 MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
1675 << param->default_param()->ToString() << ". Its name is: " << param_name;
1676 }
1677 ms_func_graph->set_parameters(new_params);
1678 manage->Clear();
1679
1680 // Make a CNode which includes ms_function fprop graph and inputs node
1681 MS_EXCEPTION_IF_NULL(ms_function_cnode);
1682 *ms_function_cnode = curr_g_->NewCNode(input_nodes);
1683 MS_LOG(DEBUG) << "Make ms function forward cnode: " << (*ms_function_cnode)->DebugString();
1684 }
1685
1686 // Make adjoint for ms_function fprop graph and connect it with previous op
MakeAdjointForMsFunction(const FuncGraphPtr & ms_func_graph,const FuncGraphPtr & grad_graph,const py::object & actual_out,const py::args & args,const ValuePtr & actual_out_v)1687 void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
1688 const py::object &actual_out, const py::args &args,
1689 const ValuePtr &actual_out_v) {
1690 ValuePtrList input_values;
1691 CNodePtr ms_function_cnode = nullptr;
1692 MakeCNodeForMsFunction(ms_func_graph, args, &input_values, &ms_function_cnode);
1693 MS_EXCEPTION_IF_NULL(ms_function_cnode);
1694 SetTupleArgsToGraphInfoMap(curr_g_, actual_out, ms_function_cnode);
1695 SetNodeMapInGraphInfoMap(curr_g_, GetId(actual_out), ms_function_cnode);
1696
1697 // Connect grad graph of ms_function to context.
1698 auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
1699 MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
1700 MS_EXCEPTION_IF_NULL(grad_graph);
1701 if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, actual_out_v, grad_graph)) {
1702 MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
1703 << ms_function_cnode->DebugString();
1704 }
1705 top_cell()->set_ms_function_flag(true);
1706 }
1707
UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr & op_exec_info,const ValuePtr & op_out)1708 void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
1709 if (!grad_flag_) {
1710 MS_LOG(DEBUG) << "The grad flag is false, no need to update forward op info in bprop graph";
1711 return;
1712 }
1713 MS_EXCEPTION_IF_NULL(op_exec_info);
1714 MS_EXCEPTION_IF_NULL(op_out);
1715 const auto &op_info = op_exec_info->op_info;
1716 MS_LOG(DEBUG) << "Current op info: " << op_info;
1717
1718 std::vector<tensor::TensorPtr> all_op_tensors;
1719 // Get output tensors
1720 TensorValueToTensor(op_out, &all_op_tensors);
1721 // Save all tensors info of current op
1722 if (need_construct_graph()) {
1723 SaveOpInfo(top_cell_, op_info, all_op_tensors);
1724 }
1725
1726 // First run top cell
1727 if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
1728 MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
1729 if (!need_construct_graph()) {
1730 MS_LOG(EXCEPTION) << "The cell stack is empty when running a new top cell " << top_cell_->cell_id();
1731 }
1732 return;
1733 }
1734 // Non-first run
1735 const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
1736 MS_EXCEPTION_IF_NULL(pre_top_cell);
1737 if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
1738 MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
1739 << top_cell_->cell_id();
1740 return;
1741 }
1742
1743 // Update new output tensor info in bprop graph
1744 const auto &pre_op_tensor_id = pre_top_cell->op_info_with_tensor_id().at(op_info);
1745 if (pre_op_tensor_id.size() != all_op_tensors.size()) {
1746 MS_LOG(EXCEPTION) << "The size of pre op tensor id: " << pre_op_tensor_id.size()
1747 << " is not equal to the size of all tensors of current op " << all_op_tensors.size();
1748 }
1749 const auto &pre_tensor_id_with_tensor_object = pre_top_cell->tensor_id_with_tensor_object();
1750 for (size_t i = 0; i < pre_op_tensor_id.size(); ++i) {
1751 auto pre_id = pre_op_tensor_id[i];
1752 if (pre_tensor_id_with_tensor_object.find(pre_id) == pre_tensor_id_with_tensor_object.end()) {
1753 continue;
1754 }
1755 const auto &new_tensor = all_op_tensors[i];
1756 const auto &pre_tensor_object = pre_tensor_id_with_tensor_object.at(pre_id);
1757 UpdateTensorInfo(new_tensor, pre_tensor_object);
1758 }
1759 }
1760
SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr & resource) const1761 void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
1762 MS_EXCEPTION_IF_NULL(resource);
1763 // Get all tensors id of forward op
1764 std::unordered_set<std::string> forward_op_tensor_id;
1765 const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
1766 for (const auto &record : op_info_with_tensor_id) {
1767 std::for_each(record.second.begin(), record.second.end(),
1768 [&forward_op_tensor_id](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); });
1769 }
1770 // Get all tensors obj in value node of bprop graph
1771 const auto &bprop_graph = resource->func_graph();
1772 MS_EXCEPTION_IF_NULL(bprop_graph);
1773 const auto &value_node_list = bprop_graph->value_nodes();
1774 std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
1775 for (const auto &elem : value_node_list) {
1776 auto value_node = elem.first->cast<ValueNodePtr>();
1777 MS_EXCEPTION_IF_NULL(value_node);
1778 TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
1779 }
1780
1781 auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object();
1782 if (!tensor_id_with_tensor_object.empty()) {
1783 MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: "
1784 << top_cell()->cell_id();
1785 }
1786 // Save tensor in value node of bprop graph
1787 for (const auto &tensor : tensors_in_bprop_graph) {
1788 MS_EXCEPTION_IF_NULL(tensor);
1789 if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) {
1790 continue;
1791 }
1792 tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
1793 MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
1794 << " device address: " << tensor->device_address() << " shape and dtype "
1795 << tensor->GetShapeAndDataTypeInfo();
1796 }
1797 }
1798
RunOpWithInitBackendPolicy(const OpExecInfoPtr & op_exec_info)1799 py::tuple ForwardExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info) {
1800 MS_EXCEPTION_IF_NULL(op_exec_info);
1801 auto backend_policy = InitEnv(op_exec_info);
1802 PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
1803 // returns a null py::tuple on error
1804 py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
1805 if (status != PYNATIVE_SUCCESS) {
1806 MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name;
1807 }
1808 MS_LOG(DEBUG) << "RunOp end";
1809 return result;
1810 }
1811
InitEnv(const OpExecInfoPtr & op_exec_info)1812 MsBackendPolicy ForwardExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
1813 MS_EXCEPTION_IF_NULL(op_exec_info);
1814 MS_LOG(DEBUG) << "RunOp start, op name is: " << op_exec_info->op_name;
1815 parse::python_adapter::set_python_env_flag(true);
1816 MsBackendPolicy backend_policy;
1817 #if (!defined ENABLE_GE)
1818 auto ms_context = MsContext::GetInstance();
1819 MS_EXCEPTION_IF_NULL(ms_context);
1820 if (!context::IsTsdOpened(ms_context)) {
1821 if (!context::OpenTsd(ms_context)) {
1822 MS_LOG(EXCEPTION) << "Open tsd failed";
1823 }
1824 }
1825 if (ms_context->backend_policy() == "ms") {
1826 backend_policy = kMsBackendMsPrior;
1827 } else {
1828 backend_policy = kMsBackendVmOnly;
1829 }
1830 #else
1831 auto ms_context = MsContext::GetInstance();
1832 MS_EXCEPTION_IF_NULL(ms_context);
1833 context::PynativeInitGe(ms_context);
1834 backend_policy = kMsBackendGeOnly;
1835 #endif
1836 if (kVmOperators.find(op_exec_info->op_name) != kVmOperators.end()) {
1837 backend_policy = kMsBackendVmOnly;
1838 }
1839 return backend_policy;
1840 }
1841
RunOpWithBackendPolicy(MsBackendPolicy backend_policy,const OpExecInfoPtr & op_exec_info,PynativeStatusCode * status)1842 py::object ForwardExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
1843 PynativeStatusCode *status) {
1844 MS_EXCEPTION_IF_NULL(status);
1845 py::object result;
1846 switch (backend_policy) {
1847 case kMsBackendVmOnly: {
1848 // use vm only
1849 MS_LOG(DEBUG) << "RunOp use VM only backend";
1850 result = RunOpInVM(op_exec_info, status);
1851 break;
1852 }
1853 case kMsBackendGePrior: {
1854 #ifdef ENABLE_GE
1855 // use GE first, use vm when GE fails
1856 MS_LOG(DEBUG) << "RunOp use GE first backend";
1857 result = RunOpInGE(op_exec_info, status);
1858 if (*status != PYNATIVE_SUCCESS) {
1859 result = RunOpInVM(op_exec_info, status);
1860 }
1861 #endif
1862 break;
1863 }
1864 case kMsBackendMsPrior: {
1865 // use Ms first,use others when ms failed
1866 MS_LOG(DEBUG) << "RunOp use Ms first backend";
1867 result = RunOpInMs(op_exec_info, status);
1868 if (*status != PYNATIVE_SUCCESS) {
1869 MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
1870 }
1871 break;
1872 }
1873 default:
1874 MS_LOG(ERROR) << "No backend configured for run op";
1875 }
1876 return result;
1877 }
1878
RunOpInVM(const OpExecInfoPtr & op_exec_info,PynativeStatusCode * status)1879 py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
1880 MS_LOG(DEBUG) << "RunOpInVM start";
1881 MS_EXCEPTION_IF_NULL(status);
1882 MS_EXCEPTION_IF_NULL(op_exec_info);
1883 MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
1884
1885 auto &op_inputs = op_exec_info->op_inputs;
1886 if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" ||
1887 op_exec_info->op_name == "stop_gradient") {
1888 py::tuple result(op_inputs.size());
1889 for (size_t i = 0; i < op_inputs.size(); i++) {
1890 py::object input = op_inputs[i];
1891 auto tensor = py::cast<tensor::TensorPtr>(input);
1892 MS_EXCEPTION_IF_NULL(tensor);
1893 if (op_exec_info->op_name == "HookBackward") {
1894 // the input object is not a output of forward cnode, eg: parameter
1895 result[i] = tensor;
1896 } else {
1897 // the input object is a output of forward cnode
1898 auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
1899 new_tensor->set_device_address(tensor->device_address());
1900 new_tensor->set_sync_status(tensor->sync_status());
1901 result[i] = new_tensor;
1902 }
1903 }
1904 *status = PYNATIVE_SUCCESS;
1905 MS_LOG(DEBUG) << "RunOpInVM end";
1906 return std::move(result);
1907 }
1908
1909 auto primitive = op_exec_info->py_primitive;
1910 MS_EXCEPTION_IF_NULL(primitive);
1911 auto result = primitive->RunPyComputeFunction(op_inputs);
1912 MS_LOG(DEBUG) << "RunOpInVM end";
1913 if (py::isinstance<py::none>(result)) {
1914 MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
1915 *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
1916 py::tuple err_ret(0);
1917 return std::move(err_ret);
1918 }
1919 *status = PYNATIVE_SUCCESS;
1920 if (py::isinstance<py::tuple>(result)) {
1921 return result;
1922 }
1923 py::tuple tuple_result = py::make_tuple(result);
1924 return std::move(tuple_result);
1925 }
1926
RunOpInMs(const OpExecInfoPtr & op_exec_info,PynativeStatusCode * status)1927 py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
1928 MS_EXCEPTION_IF_NULL(op_exec_info);
1929 MS_EXCEPTION_IF_NULL(status);
1930 MS_LOG(DEBUG) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
1931 auto ms_context = MsContext::GetInstance();
1932 MS_EXCEPTION_IF_NULL(ms_context);
1933 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
1934 compile::SetMindRTEnable();
1935
1936 if (kSession == nullptr && !ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1937 const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1938 kSession = session::SessionFactory::Get().Create(device_target);
1939 MS_EXCEPTION_IF_NULL(kSession);
1940 kSession->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
1941 }
1942
1943 std::vector<tensor::TensorPtr> input_tensors;
1944 std::vector<int64_t> tensors_mask;
1945 std::string graph_info;
1946 ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
1947 ConvertAttrToUnifyMindIR(op_exec_info);
1948 // get graph info for checking it whether existing in the cache
1949 GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask, &graph_info);
1950 #if defined(__APPLE__)
1951 session::OpRunInfo op_run_info = {op_exec_info->op_name,
1952 op_exec_info->py_primitive,
1953 op_exec_info->abstract,
1954 op_exec_info->is_dynamic_shape,
1955 op_exec_info->is_mixed_precision_cast,
1956 op_exec_info->lazy_build,
1957 op_exec_info->next_op_name,
1958 static_cast<int>(op_exec_info->next_input_index)};
1959 #else
1960 session::OpRunInfo op_run_info = {op_exec_info->op_name,
1961 op_exec_info->py_primitive,
1962 op_exec_info->abstract,
1963 op_exec_info->is_dynamic_shape,
1964 op_exec_info->is_mixed_precision_cast,
1965 op_exec_info->lazy_build,
1966 op_exec_info->next_op_name,
1967 op_exec_info->next_input_index};
1968 #endif
1969 VectorRef outputs;
1970 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1971 kSession->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
1972 } else {
1973 if (mind_rt_backend == nullptr) {
1974 const auto &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1975 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1976 mind_rt_backend = std::make_shared<compile::MindRTBackend>("ms", device_target, device_id);
1977 }
1978
1979 mindspore::ScopedLongRunning long_running;
1980 const compile::ActorInfo &actor_info =
1981 mind_rt_backend->CompileGraph(op_run_info, graph_info, &tensors_mask, &input_tensors);
1982 mind_rt_backend->RunGraph(actor_info, &op_run_info, &tensors_mask, &input_tensors, &outputs);
1983 }
1984
1985 if (op_exec_info->is_dynamic_shape) {
1986 op_exec_info->abstract = op_run_info.abstract;
1987 }
1988 auto result = BaseRefToPyData(outputs);
1989 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
1990 *status = PYNATIVE_SUCCESS;
1991 MS_LOG(DEBUG) << "End run op [" << op_exec_info->op_name << "] with backend policy ms";
1992 return result;
1993 }
1994
ClearRes()1995 void ForwardExecutor::ClearRes() {
1996 MS_LOG(DEBUG) << "Clear forward res";
1997 lazy_build_ = false;
1998 implicit_cast_map_.clear();
1999 prim_abs_list_.clear();
2000 node_abs_map_.clear();
2001 }
2002
forward() const2003 ForwardExecutorPtr GradExecutor::forward() const {
2004 auto forward_executor = forward_executor_.lock();
2005 MS_EXCEPTION_IF_NULL(forward_executor);
2006 return forward_executor;
2007 }
2008
top_cell() const2009 TopCellInfoPtr GradExecutor::top_cell() const {
2010 MS_EXCEPTION_IF_NULL(top_cell_);
2011 return top_cell_;
2012 }
2013
curr_g() const2014 FuncGraphPtr GradExecutor::curr_g() const {
2015 MS_EXCEPTION_IF_NULL(curr_g_);
2016 return curr_g_;
2017 }
2018
PushCellStack(const std::string & cell_id)2019 void GradExecutor::PushCellStack(const std::string &cell_id) { cell_stack_.push(cell_id); }
2020
PopCellStack()2021 void GradExecutor::PopCellStack() {
2022 if (cell_stack_.empty()) {
2023 MS_LOG(EXCEPTION) << "Stack cell_statck_ is empty";
2024 }
2025 cell_stack_.pop();
2026 }
2027
PushHighOrderGraphStack(const TopCellInfoPtr & top_cell)2028 void GradExecutor::PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) {
2029 high_order_stack_.push(std::make_pair(curr_g_, top_cell));
2030 }
2031
PopHighOrderGraphStack()2032 TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
2033 if (high_order_stack_.empty()) {
2034 MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
2035 }
2036 high_order_stack_.pop();
2037 TopCellInfoPtr top_cell = nullptr;
2038 if (!high_order_stack_.empty()) {
2039 auto t = high_order_stack_.top();
2040 curr_g_ = t.first;
2041 top_cell = t.second;
2042 }
2043 return top_cell;
2044 }
2045
GetCellId(const py::object & cell,const py::args & args)2046 std::string GradExecutor::GetCellId(const py::object &cell, const py::args &args) {
2047 auto cell_id = GetId(cell);
2048 for (size_t i = 0; i < args.size(); i++) {
2049 const auto &arg_id = GetId(args[i]);
2050 auto it = forward()->node_abs_map().find(arg_id);
2051 if (it != forward()->node_abs_map().end()) {
2052 auto &abs = it->second;
2053 MS_EXCEPTION_IF_NULL(abs);
2054 auto shape = abs->BuildShape();
2055 MS_EXCEPTION_IF_NULL(shape);
2056 auto type = abs->BuildType();
2057 MS_EXCEPTION_IF_NULL(type);
2058 cell_id += "_" + shape->ToString();
2059 cell_id += type->ToString();
2060 } else {
2061 auto value = PyObjToValue(args[i]);
2062 MS_EXCEPTION_IF_NULL(value);
2063 auto abs = value->ToAbstract();
2064 MS_EXCEPTION_IF_NULL(abs);
2065 if (abs->isa<abstract::AbstractTensor>()) {
2066 abs->set_value(kAnyValue);
2067 }
2068 forward()->node_abs_map()[arg_id] = abs;
2069 auto shape = abs->BuildShape();
2070 MS_EXCEPTION_IF_NULL(shape);
2071 auto type = abs->BuildType();
2072 MS_EXCEPTION_IF_NULL(type);
2073 cell_id += "_" + shape->ToString();
2074 cell_id += type->ToString();
2075 }
2076 }
2077 return cell_id;
2078 }
2079
DumpGraphIR(const std::string & filename,const FuncGraphPtr & graph)2080 void GradExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
2081 #ifdef ENABLE_DUMP_IR
2082 auto ms_context = MsContext::GetInstance();
2083 MS_EXCEPTION_IF_NULL(ms_context);
2084 if (ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
2085 DumpIR(filename, graph);
2086 }
2087 #endif
2088 }
2089
IsNestedGrad() const2090 inline bool GradExecutor::IsNestedGrad() const {
2091 MS_LOG(DEBUG) << "Grad nested order is " << grad_order_;
2092 return grad_order_ > 1;
2093 }
2094
IsCellObjIdEq(const std::string & l_cell_id,const std::string & r_cell_id) const2095 bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const {
2096 // just compare obj_id, ignore args id
2097 return l_cell_id.compare(0, PTR_LEN, r_cell_id, 0, PTR_LEN) == 0;
2098 }
2099
IsBpropGraph(const std::string & cell_id)2100 bool GradExecutor::IsBpropGraph(const std::string &cell_id) {
2101 if (top_cell_ == nullptr) {
2102 return false;
2103 }
2104 return std::any_of(bprop_cell_list_.begin(), bprop_cell_list_.end(),
2105 [&cell_id](const std::string &value) { return cell_id.find(value) != std::string::npos; });
2106 }
2107
UpdateTopCellInfo(bool forward_already_run,bool need_compile_graph,bool vm_compiled)2108 void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled) {
2109 top_cell()->set_vm_compiled(vm_compiled);
2110 top_cell()->set_need_compile_graph(need_compile_graph);
2111 top_cell()->set_forward_already_run(forward_already_run);
2112 }
2113
ClearCellRes(const std::string & cell_id)2114 void GradExecutor::ClearCellRes(const std::string &cell_id) {
2115 static bool clear_all_cell_res = false;
2116 // Grad clean
2117 if (cell_id.empty()) {
2118 MS_LOG(DEBUG) << "Clear all cell resources";
2119 clear_all_cell_res = true;
2120 for (const auto &iter : top_cell_list_) {
2121 MS_EXCEPTION_IF_NULL(iter);
2122 iter->Clear();
2123 }
2124 top_cell_list_.clear();
2125 already_run_top_cell_.clear();
2126 clear_all_cell_res = false;
2127 return;
2128 }
2129 if (clear_all_cell_res) {
2130 MS_LOG(DEBUG) << "In process of clearing all cell resources, so no need to clear single cell resource again";
2131 return;
2132 }
2133 // clear when cell destruction
2134 for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
2135 MS_EXCEPTION_IF_NULL(*it);
2136 const auto &top_cell_id = (*it)->cell_id();
2137 const auto &already_run_cell_id = (*it)->already_run_cell_id();
2138 if (IsCellObjIdEq(cell_id, top_cell_id)) {
2139 MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id;
2140 (*it)->Clear();
2141 it = top_cell_list_.erase(it);
2142 (void)already_run_top_cell_.erase(already_run_cell_id);
2143 continue;
2144 }
2145 ++it;
2146 }
2147 }
2148
HandleInputArgsForTopCell(const py::args & args,bool is_bprop_top)2149 void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top) {
2150 if (is_bprop_top) {
2151 // Convert input args to parameters for top cell graph in bprop.
2152 for (size_t i = 0; i < args.size(); ++i) {
2153 auto param = args[i];
2154 auto new_param = curr_g_->add_parameter();
2155 const auto ¶m_id = GetId(param);
2156 SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
2157 SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
2158 SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
2159 }
2160 return;
2161 }
2162 // Convert input args to parameters for top cell graph in construct.
2163 std::vector<ValuePtr> input_param_values;
2164 const auto &only_tensors = FilterTensorArgs(args);
2165 for (size_t i = 0; i < only_tensors.size(); ++i) {
2166 auto new_param = curr_g_->add_parameter();
2167 auto param_i = only_tensors[i];
2168 const auto ¶m_i_value = PyObjToValue(param_i);
2169 input_param_values.emplace_back(param_i_value);
2170 auto param_i_abs = param_i_value->ToAbstract();
2171 MS_EXCEPTION_IF_NULL(param_i_abs);
2172 new_param->set_abstract(param_i_abs->Broaden());
2173 const auto ¶m_i_id = GetId(param_i);
2174 SetTupleArgsToGraphInfoMap(curr_g_, param_i, new_param, true);
2175 SetNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
2176 SetParamNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
2177 SetParamNodeMapInGraphInfoMap(top_cell_->df_builder(), param_i_id, new_param);
2178 }
2179 top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g_->parameters(), input_param_values));
2180 }
2181
InitResourceAndDfBuilder(const std::string & cell_id,const py::args & args)2182 void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args) {
2183 if (cell_stack_.empty() || IsNestedGrad()) {
2184 if (cell_stack_.empty() && !grad_is_running_) {
2185 MS_LOG(DEBUG) << "Make new topest graph";
2186 MakeNewTopGraph(cell_id, args, true);
2187 } else if (grad_is_running_ && IsBpropGraph(cell_id)) {
2188 MS_LOG(DEBUG) << "Run bprop cell";
2189 curr_g_ = std::make_shared<FuncGraph>();
2190 auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
2191 top_cell()->graph_info_map()[curr_g_] = graph_info_cg;
2192 HandleInputArgsForTopCell(args, true);
2193 bprop_grad_stack_.push(std::make_pair(cell_id, false));
2194 } else if (grad_is_running_ && top_cell()->grad_order() != grad_order_) {
2195 MS_LOG(DEBUG) << "Nested grad graph existed in bprop";
2196 MakeNewTopGraph(cell_id, args, false);
2197 bprop_grad_stack_.push(std::make_pair(cell_id, true));
2198 } else if (!cell_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_) {
2199 MS_LOG(DEBUG) << "Nested grad graph existed in construct";
2200 auto cur_top_is_dynamic = top_cell()->is_dynamic();
2201 MakeNewTopGraph(cell_id, args, false);
2202 top_cell()->set_is_dynamic(cur_top_is_dynamic);
2203 }
2204 }
2205
2206 PushCellStack(cell_id);
2207 // Init kPynativeCellPtr with input parameters of top cell
2208 if (!top_cell()->is_init_kpynative()) {
2209 auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
2210 top_cell()->graph_info_map()[curr_g_] = graph_info_cg;
2211 auto graph_info_df = std::make_shared<GraphInfo>(cell_id);
2212 top_cell()->graph_info_map()[top_cell_->df_builder()] = graph_info_df;
2213 HandleInputArgsForTopCell(args, false);
2214 top_cell()->set_need_compile_graph(true);
2215 top_cell()->set_init_kpynative(true);
2216 } else {
2217 // Non-top cell
2218 top_cell()->sub_cell_list().emplace(cell_id);
2219 }
2220 }
2221
NewGraphInner(py::object * ret,const py::object & cell,const py::args & args)2222 void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
2223 MS_EXCEPTION_IF_NULL(ret);
2224 const auto &cell_id = GetCellId(cell, args);
2225 MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
2226 if (top_cell_ != nullptr && cell_stack_.empty()) {
2227 // Already run top cell need distinguish high order; high order add "0" otherwise "1"
2228 const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
2229 auto top_it = already_run_top_cell_.find(already_run_cell_id);
2230 if (top_it != already_run_top_cell_.end()) {
2231 // Top cell forward run.
2232 const auto &pre_top_cell = top_it->second;
2233 MS_EXCEPTION_IF_NULL(pre_top_cell);
2234 if (!pre_top_cell->is_dynamic()) {
2235 MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic, no need to run NewGraphInner again";
2236 ResetTopCellInfo(pre_top_cell, args);
2237 PushHighOrderGraphStack(pre_top_cell);
2238 set_top_cell(pre_top_cell);
2239 grad_order_ = pre_top_cell->grad_order();
2240 return;
2241 }
2242 } else if ((top_cell()->IsSubCell(cell_id) || GetHighOrderStackSize() >= 1) &&
2243 !IsCellObjIdEq(cell_id, check_graph_cell_id_)) {
2244 // Sub cell ( or may be a temporary cell, but must be non top) forward run in cache process.
2245 MS_LOG(DEBUG) << "Sub cell no need to run NewGraphInner again";
2246 return;
2247 }
2248 }
2249 // When the cell has custom bprop, in_custom_bprop_cell is lager than 0
2250 if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
2251 custom_bprop_cell_count_ += 1;
2252 }
2253 // Make top graph and init resource for resource and df_builder
2254 InitResourceAndDfBuilder(cell_id, args);
2255 // Check whether cell has dynamic construct
2256 if (!top_cell()->is_dynamic()) {
2257 bool is_dynamic = parse::DynamicParser::IsDynamicCell(cell);
2258 MS_LOG(DEBUG) << "Current cell dynamic " << is_dynamic;
2259 if (is_dynamic) {
2260 top_cell()->set_is_dynamic(is_dynamic);
2261 }
2262 }
2263 }
2264
MakeNewTopGraph(const string & cell_id,const py::args & args,bool is_topest)2265 void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest) {
2266 pipeline::CheckArgsValid(args);
2267 // Record input args info
2268 std::string input_args_id;
2269 for (size_t i = 0; i < args.size(); ++i) {
2270 input_args_id += GetId(args[i]) + "_";
2271 }
2272 // Run forward first need plus 1
2273 if (grad_order_ == 0) {
2274 ++grad_order_;
2275 }
2276 // The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
2277 // disable backend cache
2278 if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
2279 EnableOpGraphCache(false);
2280 const auto last_top_cell = top_cell_list_.back();
2281 top_cell_list_.pop_back();
2282 MS_EXCEPTION_IF_NULL(last_top_cell);
2283 last_top_cell->Clear();
2284 (void)already_run_top_cell_.erase(last_top_cell->already_run_cell_id());
2285 }
2286 // Create top cell
2287 curr_g_ = std::make_shared<FuncGraph>();
2288 auto df_builder = std::make_shared<FuncGraph>();
2289 auto resource = std::make_shared<pipeline::Resource>();
2290 const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
2291 auto top_cell =
2292 std::make_shared<TopCellInfo>(is_topest, grad_order_, resource, df_builder, cell_id, already_run_cell_id);
2293 top_cell->set_forward_already_run(true);
2294 top_cell->set_input_args_id(input_args_id);
2295 top_cell_list_.emplace_back(top_cell);
2296 PushHighOrderGraphStack(top_cell);
2297 set_top_cell(top_cell);
2298 MS_LOG(DEBUG) << "New top graph, curr_g ptr " << curr_g_.get() << " resource ptr " << resource.get();
2299 }
2300
SetTupleArgsToGraphInfoMap(const FuncGraphPtr & g,const py::object & args,const AnfNodePtr & node,bool is_param)2301 void GradExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
2302 bool is_param) {
2303 if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
2304 return;
2305 }
2306 auto tuple = args.cast<py::tuple>();
2307 auto tuple_size = static_cast<int64_t>(tuple.size());
2308 for (int64_t i = 0; i < tuple_size; ++i) {
2309 // tuple slice used size_t
2310 auto id = GetId(tuple[static_cast<size_t>(i)]);
2311 if (is_param && node->isa<Parameter>()) {
2312 auto param = node->cast<ParameterPtr>();
2313 MS_EXCEPTION_IF_NULL(param);
2314 SetParamNodeMapInGraphInfoMap(g, id, param);
2315 }
2316 SetNodeMapInGraphInfoMap(g, id, node, i);
2317 SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
2318 }
2319 }
2320
SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr & g,const py::object & args,const AnfNodePtr & node,const std::vector<int64_t> & index_sequence,bool is_param)2321 void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
2322 const std::vector<int64_t> &index_sequence, bool is_param) {
2323 if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) {
2324 return;
2325 }
2326 MS_EXCEPTION_IF_NULL(node);
2327 auto tuple = args.cast<py::tuple>();
2328 auto tuple_size = static_cast<int64_t>(tuple.size());
2329 for (int64_t i = 0; i < tuple_size; ++i) {
2330 std::vector<int64_t> tmp = index_sequence;
2331 tmp.emplace_back(i);
2332 // tuple slice used size_t
2333 auto id = GetId(tuple[static_cast<size_t>(i)]);
2334 if (is_param && node->isa<Parameter>()) {
2335 auto param = node->cast<ParameterPtr>();
2336 MS_EXCEPTION_IF_NULL(param);
2337 SetParamNodeMapInGraphInfoMap(g, id, param);
2338 }
2339 SetNodeMapInGraphInfoMap(g, id, node, tmp);
2340 SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
2341 }
2342 }
2343
CreateMakeTupleNodeForMultiOut(const FuncGraphPtr & curr_g,const py::object & out,const std::string & out_id)2344 void GradExecutor::CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out,
2345 const std::string &out_id) {
2346 MS_EXCEPTION_IF_NULL(curr_g);
2347 const auto &out_tuple = out.cast<py::tuple>();
2348 // get input node and value
2349 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
2350 ValuePtrList input_args;
2351 std::vector<size_t> value_index;
2352 for (size_t i = 0; i < out_tuple.size(); i++) {
2353 const auto &v = PyObjToValue(out_tuple[i]);
2354 // Graph have no define for grad
2355 if (v->isa<FuncGraph>()) {
2356 continue;
2357 }
2358 value_index.emplace_back(i);
2359 input_args.emplace_back(v);
2360 inputs.emplace_back(GetInput(out_tuple[i], false));
2361 }
2362 py::tuple value_outs(value_index.size());
2363 for (size_t i = 0; i < value_index.size(); ++i) {
2364 value_outs[i] = out_tuple[value_index[i]];
2365 }
2366 auto cnode = curr_g_->NewCNode(inputs);
2367 MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
2368 // record node info in graph map
2369 SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
2370 SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
2371 if (grad_is_running_ && !bprop_grad_stack_.top().second) {
2372 MS_LOG(DEBUG) << "Custom bprop, no need GradPynativeOp";
2373 return;
2374 }
2375 // run ad for maketuple node
2376 const auto &out_value = PyObjToValue(value_outs);
2377 ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args, out_value);
2378 }
2379
EndGraphInner(py::object * ret,const py::object & cell,const py::object & out,const py::args & args)2380 void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) {
2381 MS_EXCEPTION_IF_NULL(ret);
2382 const auto &cell_id = GetCellId(cell, args);
2383 MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
2384 if (cell_stack_.empty()) {
2385 MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again";
2386 if (cell_id == top_cell()->cell_id()) {
2387 if (top_cell()->is_topest()) {
2388 set_grad_flag(false);
2389 }
2390 if (GetHighOrderStackSize() < ARG_SIZE) {
2391 auto outer_top_cell = PopHighOrderGraphStack();
2392 if (outer_top_cell != nullptr) {
2393 set_top_cell(outer_top_cell);
2394 }
2395 }
2396 }
2397 return;
2398 }
2399 // Make output node in this case: x = op1, y = op2, return (x, y)
2400 const auto &out_id = GetId(out);
2401 const auto &graph_info = top_cell()->graph_info_map().at(curr_g_);
2402 MS_EXCEPTION_IF_NULL(graph_info);
2403 if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
2404 if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
2405 CreateMakeTupleNodeForMultiOut(curr_g_, out, out_id);
2406 } else {
2407 MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
2408 MakeValueNode(out, out_id);
2409 }
2410 }
2411 DoGradForCustomBprop(cell, out, args);
2412 // Set output node for forward graph when need.
2413 PopCellStack();
2414 if (grad_is_running_ && !bprop_grad_stack_.empty()) {
2415 if (!bprop_grad_stack_.top().second) {
2416 bprop_grad_stack_.pop();
2417 MS_EXCEPTION_IF_NULL(curr_g_);
2418 curr_g_->set_output(GetObjNode(out, out_id));
2419 return;
2420 } else if (bprop_grad_stack_.top().first == cell_id) {
2421 bprop_grad_stack_.pop();
2422 }
2423 }
2424
2425 bool is_top_cell_end = cell_id == top_cell()->cell_id();
2426 // Just only dump the last forward graph
2427 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && is_top_cell_end) {
2428 curr_g_->set_output(GetObjNode(out, out_id));
2429 #ifdef ENABLE_DUMP_IR
2430 DumpIR("fg.ir", curr_g_);
2431 #endif
2432 }
2433
2434 // Reset grad flag and update output node of top cell
2435 if (cell_stack_.empty() && is_top_cell_end) {
2436 MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
2437 set_grad_flag(false);
2438 PopHighOrderGraphStack();
2439 // Update real output node of top cell for generating bprop graph
2440 AnfNodePtr output_node = GetObjNode(out, out_id);
2441 MS_EXCEPTION_IF_NULL(output_node);
2442 auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
2443 MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
2444 k_pynative_cell_ptr->UpdateOutputNodeOfTopCell(output_node);
2445 }
2446
2447 // Checkout whether need to compile graph when top cell has ran finished
2448 if (is_top_cell_end) {
2449 CheckNeedCompileGraph();
2450 }
2451 }
2452
DoGradForCustomBprop(const py::object & cell,const py::object & out,const py::args & args)2453 void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) {
2454 if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
2455 return;
2456 }
2457 custom_bprop_cell_count_ -= 1;
2458 if (custom_bprop_cell_count_ != 0) {
2459 return;
2460 }
2461 MS_LOG(DEBUG) << "Do grad for custom bprop";
2462 size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
2463 if (par_number > 0) {
2464 MS_LOG(EXCEPTION) << "When user defines the net bprop, the 'Parameter' data type is not supported in the net.";
2465 }
2466 py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
2467 auto bprop_func_cellid = GetId(bprop_func);
2468 bprop_cell_list_.emplace_back(bprop_func_cellid);
2469 auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
2470 fake_prim->set_hook(bprop_func);
2471 const auto &cell_id = GetCellId(cell, args);
2472 (void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
2473 (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
2474
2475 py::object code_obj = py::getattr(bprop_func, "__code__");
2476 py::object co_name = py::getattr(code_obj, "co_name");
2477 if (std::string(py::str(co_name)) == "staging_specialize") {
2478 MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported.";
2479 }
2480 // Three parameters self, out and dout need to be excluded
2481 const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
2482 if (inputs_num > args.size()) {
2483 MS_EXCEPTION(TypeError) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
2484 << args.size() << "]";
2485 }
2486
2487 py::list cell_inputs;
2488 for (size_t i = 0; i < inputs_num; i += 1) {
2489 cell_inputs.append(args[i]);
2490 }
2491 OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
2492 op_exec_info->op_name = fake_prim->name();
2493 op_exec_info->py_primitive = fake_prim;
2494 op_exec_info->op_inputs = cell_inputs;
2495 auto cnode = forward()->ConstructForwardGraph(op_exec_info);
2496 const auto &v_out = PyObjToValue(out);
2497 DoOpGrad(op_exec_info, cnode, v_out);
2498 const auto &out_obj_id = GetId(out);
2499 SaveOutputNodeMap(out_obj_id, out, cnode);
2500 }
2501
GetAlreadyRunCellId(const std::string & cell_id)2502 std::string GradExecutor::GetAlreadyRunCellId(const std::string &cell_id) {
2503 std::string already_run_cell_id;
2504 if (IsNestedGrad()) {
2505 already_run_cell_id = cell_id + "0";
2506 } else {
2507 already_run_cell_id = cell_id + "1";
2508 }
2509 already_run_cell_id += "_" + grad_operation_;
2510 MS_LOG(DEBUG) << "Get already run top cell id " << already_run_cell_id;
2511 return already_run_cell_id;
2512 }
2513
GetGradCellId(bool has_sens,const py::object & cell,const py::args & args)2514 std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, const py::args &args) {
2515 size_t forward_args_size = args.size();
2516 py::args tmp = args;
2517 if (has_sens) {
2518 forward_args_size--;
2519 py::tuple f_args(forward_args_size);
2520 for (size_t i = 0; i < forward_args_size; ++i) {
2521 f_args[i] = args[i];
2522 }
2523 tmp = f_args;
2524 }
2525 const auto &cell_id = GetCellId(cell, tmp);
2526 return cell_id;
2527 }
2528
GradNetInner(py::object * ret,const prim::GradOperationPtr & grad,const py::object & cell,const py::object & weights,const py::args & args)2529 void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
2530 const py::object &weights, const py::args &args) {
2531 MS_EXCEPTION_IF_NULL(ret);
2532 MS_EXCEPTION_IF_NULL(grad);
2533 auto size = args.size();
2534 const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args);
2535 MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id;
2536 if (!top_cell()->need_compile_graph()) {
2537 MS_LOG(DEBUG) << "No need compile graph";
2538 UpdateTopCellInfo(false, false, false);
2539 return;
2540 }
2541 top_cell()->set_grad_operation(grad_operation_);
2542 auto resource = top_cell()->resource();
2543 MS_EXCEPTION_IF_NULL(resource);
2544 auto df_builder = top_cell()->df_builder();
2545 MS_EXCEPTION_IF_NULL(df_builder);
2546 MS_LOG(DEBUG) << "curr_g ptr " << curr_g_.get() << " resource ptr " << resource.get();
2547
2548 // Get params(weights) require derivative
2549 auto w_args = GetWeightsArgs(weights, df_builder);
2550 if (w_args.empty() && !df_builder->parameters().empty()) {
2551 MS_LOG(DEBUG) << "Add weights params to w_args";
2552 w_args.insert(w_args.end(), df_builder->parameters().begin(), df_builder->parameters().end());
2553 }
2554 // Get bprop graph of top cell
2555 auto bprop_graph = GetBpropGraph(grad, cell, w_args, size, args);
2556 resource->set_func_graph(bprop_graph);
2557 auto manager = resource->manager();
2558 MS_EXCEPTION_IF_NULL(manager);
2559 manager->AddFuncGraph(bprop_graph, true);
2560 DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
2561 // Launch bprop graph to backend
2562 SaveForwardTensorInfoInBpropGraph(resource);
2563 compile::SetMindRTEnable();
2564 resource->results()[pipeline::kBackend] = compile::CreateBackend();
2565 MS_LOG(DEBUG) << "Start task emit action";
2566 TaskEmitAction(resource);
2567 MS_LOG(DEBUG) << "Start execute action";
2568 ExecuteAction(resource);
2569 MS_LOG(DEBUG) << "Start update top cell info when run finish";
2570 UpdateTopCellInfo(false, false, true);
2571 resource->Clean();
2572 abstract::AnalysisContext::ClearContext();
2573 }
2574
GetWeightsArgs(const py::object & weights,const FuncGraphPtr & df_builder)2575 std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder) {
2576 MS_EXCEPTION_IF_NULL(df_builder);
2577 if (!py::hasattr(weights, "__parameter_tuple__")) {
2578 MS_LOG(DEBUG) << "No parameter tuple get";
2579 return {};
2580 }
2581
2582 const auto &tuple = weights.cast<py::tuple>();
2583 MS_LOG(DEBUG) << "Get weights tuple size " << tuple.size();
2584 std::vector<AnfNodePtr> w_args;
2585 for (size_t it = 0; it < tuple.size(); ++it) {
2586 auto param = tuple[it];
2587 auto param_id = GetId(param);
2588 auto &graph_info_map = top_cell()->graph_info_map();
2589 if (graph_info_map.find(df_builder) == graph_info_map.end()) {
2590 MS_LOG(EXCEPTION) << "Can not find df_builder " << df_builder.get() << " Top cell " << top_cell().get()
2591 << " cell id " << top_cell()->cell_id();
2592 }
2593 auto graph_info = graph_info_map.at(df_builder);
2594 MS_EXCEPTION_IF_NULL(graph_info);
2595 AnfNodePtr para_node = nullptr;
2596 if (graph_info->params.find(param_id) != graph_info->params.end()) {
2597 para_node = graph_info->params.at(param_id);
2598 w_args.emplace_back(para_node);
2599 continue;
2600 }
2601 const auto &name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
2602 if (py::isinstance<py::none>(name_attr)) {
2603 MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
2604 }
2605 const auto ¶m_name = py::cast<std::string>(name_attr);
2606 MS_LOG(DEBUG) << "The input " << it << " parameter weight name " << param_name;
2607 if (graph_info->params.find(param_name) != graph_info->params.end()) {
2608 para_node = graph_info->params.at(param_name);
2609 } else {
2610 MS_LOG(DEBUG) << "Can not find input param in graph info map, make a new parameter";
2611 auto free_param = df_builder->add_parameter();
2612 free_param->set_name(param_name);
2613 auto value = py::cast<tensor::TensorPtr>(param);
2614 free_param->set_default_param(value);
2615 free_param->debug_info()->set_name(param_name);
2616 para_node = free_param;
2617 }
2618 w_args.emplace_back(para_node);
2619 }
2620 return w_args;
2621 }
2622
GetArgsSpec(const py::list & args,const FuncGraphPtr & bprop_graph)2623 abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph) {
2624 MS_EXCEPTION_IF_NULL(bprop_graph);
2625 std::size_t size = args.size();
2626 abstract::AbstractBasePtrList args_spec;
2627 const auto &bprop_params = bprop_graph->parameters();
2628 // bprop_params include inputs, parameters, more than size(inputs)
2629 if (bprop_params.size() < size) {
2630 MS_LOG(EXCEPTION) << "Df parameters size " << bprop_params.size() << " less than " << size;
2631 }
2632 // Update abstract info for parameters in bprop graph
2633 size_t index = 0;
2634 for (const auto ¶m : bprop_params) {
2635 auto param_node = param->cast<ParameterPtr>();
2636 MS_EXCEPTION_IF_NULL(param_node);
2637 if (param_node->has_default()) {
2638 // update abstract info for weights
2639 ValuePtr value = param_node->default_param();
2640 auto ptr = value->ToAbstract();
2641 MS_EXCEPTION_IF_NULL(ptr);
2642 args_spec.emplace_back(ptr);
2643 param_node->set_abstract(ptr->Broaden());
2644 } else {
2645 // update abstract info for input params
2646 const auto &input_value = PyObjToValue(args[index]);
2647 auto input_abs = abstract::FromValue(input_value, true);
2648 if (param_node->abstract() != nullptr) {
2649 auto input_shape = input_abs->BuildShape()->ToString();
2650 auto param_tensor_abs = param_node->abstract();
2651 if (param_tensor_abs->isa<abstract::AbstractRef>()) {
2652 param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
2653 }
2654 auto ir_shape = param_tensor_abs->BuildShape()->ToString();
2655 // Exclude const input
2656 if (input_shape != "()" && ir_shape != "()") {
2657 if (input_shape != ir_shape) {
2658 MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
2659 << param->DebugString();
2660 }
2661 auto ir_dtype = param_tensor_abs->BuildType()->ToString();
2662 auto input_dtype = input_abs->BuildType()->ToString();
2663 if (input_dtype != ir_dtype) {
2664 MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
2665 << param->DebugString();
2666 }
2667 }
2668 }
2669 args_spec.emplace_back(input_abs);
2670 param_node->set_abstract(input_abs->Broaden());
2671 index++;
2672 }
2673 }
2674 MS_LOG(DEBUG) << "Args_spec size " << args_spec.size();
2675 return args_spec;
2676 }
2677
GetBpropGraph(const prim::GradOperationPtr & grad,const py::object & cell,const std::vector<AnfNodePtr> & weights,size_t arg_size,const py::args & args)2678 FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
2679 const std::vector<AnfNodePtr> &weights, size_t arg_size,
2680 const py::args &args) {
2681 bool build_formal_param = false;
2682 if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !cell_stack_.empty() && IsNestedGrad()) {
2683 build_formal_param = true;
2684 need_renormalize_ = true;
2685 }
2686 if (top_cell()->ms_function_flag()) {
2687 need_renormalize_ = true;
2688 }
2689
2690 auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
2691 MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
2692 MS_EXCEPTION_IF_NULL(grad);
2693 FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(k_pynative_cell_ptr, weights, grad->get_all_, grad->get_by_list_,
2694 grad->sens_param_, build_formal_param);
2695 MS_EXCEPTION_IF_NULL(bprop_graph);
2696
2697 MS_LOG(DEBUG) << "Top graph input params size " << arg_size;
2698 std::ostringstream ss;
2699 ss << "grad{" << arg_size << "}";
2700 bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2701 bprop_graph->debug_info()->set_name(ss.str());
2702 // Get the parameters items and add the value to args_spec
2703 (void)GetArgsSpec(FilterTensorArgs(args, grad->sens_param_), bprop_graph);
2704
2705 // Do opt for final bprop graph
2706 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
2707 resource->set_func_graph(bprop_graph);
2708 auto manager = resource->manager();
2709 MS_EXCEPTION_IF_NULL(manager);
2710 manager->AddFuncGraph(bprop_graph);
2711 auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
2712
2713 if (cell_stack_.empty()) {
2714 need_renormalize_ = false;
2715 }
2716 DumpGraphIR("after_final_opt.ir", optimized_bg);
2717 return optimized_bg;
2718 }
2719
CheckGraph(const py::object & cell,const py::args & args)2720 py::object GradExecutor::CheckGraph(const py::object &cell, const py::args &args) {
2721 BaseRef ret = false;
2722 check_graph_cell_id_ = GetCellId(cell, args);
2723 if (!(top_cell_ != nullptr && check_graph_cell_id_.find(top_cell_->cell_id()) != std::string::npos &&
2724 grad_order_ >= 1)) {
2725 ++grad_order_;
2726 }
2727 if (!grad_is_running_) {
2728 MS_LOG(DEBUG) << "Grad not running yet";
2729 return BaseRefToPyData(ret);
2730 }
2731 MS_LOG(DEBUG) << "Key is " << check_graph_cell_id_;
2732 if (top_cell_ != nullptr) {
2733 for (auto it = top_cell_->sub_cell_list().begin(); it != top_cell_->sub_cell_list().end(); ++it) {
2734 MS_LOG(DEBUG) << "Cur cell id " << *it;
2735 if (!IsCellObjIdEq(*it, check_graph_cell_id_)) {
2736 continue;
2737 }
2738 MS_LOG(DEBUG) << "Delete cellid from cell graph list, top cell is " << top_cell_;
2739 top_cell_->sub_cell_list().erase(it);
2740 ret = true;
2741 break;
2742 }
2743 }
2744 return BaseRefToPyData(ret);
2745 }
2746
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & cell,const py::args & args)2747 py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
2748 const py::args &args) {
2749 bool forward_run = false;
2750 // Get cell id and input args info
2751 const auto &cell_id = GetCellId(cell, args);
2752 grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_);
2753
2754 std::string input_args_id;
2755 for (size_t i = 0; i < args.size(); ++i) {
2756 input_args_id += GetId(args[i]) + "_";
2757 }
2758 // Check whether need to run forward process
2759 const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
2760 auto find_top_cell = GetTopCell(check_already_run_cell_id);
2761 if (find_top_cell != nullptr) {
2762 forward_run = find_top_cell->forward_already_run();
2763 auto curr_top_cell = top_cell();
2764 set_top_cell(find_top_cell);
2765 bool input_args_changed =
2766 !find_top_cell->input_args_id().empty() && find_top_cell->input_args_id() != input_args_id;
2767 if (forward_run && input_args_changed && find_top_cell->is_dynamic()) {
2768 MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, "
2769 "forward process will run again";
2770 forward_run = false;
2771 }
2772 if (forward_run && GetHighOrderStackSize() >= 1) {
2773 PushHighOrderGraphStack(curr_top_cell);
2774 }
2775 }
2776 MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << cell_id;
2777 return BaseRefToPyData(forward_run);
2778 }
2779
CheckNeedCompileGraph()2780 void GradExecutor::CheckNeedCompileGraph() {
2781 auto new_top_cell = top_cell();
2782 const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
2783 // Update top cell by current cell op info
2784 if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
2785 MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has never been ran, need compile graph";
2786 already_run_top_cell_[already_top_cell_id] = new_top_cell;
2787 return;
2788 }
2789
2790 MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
2791 auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
2792 MS_EXCEPTION_IF_NULL(pre_top_cell);
2793 const auto &pre_all_op_info = pre_top_cell->all_op_info();
2794 const auto &new_all_op_info = new_top_cell->all_op_info();
2795 MS_LOG(DEBUG) << "Pre all op info : " << pre_all_op_info;
2796 MS_LOG(DEBUG) << "New all op info : " << new_all_op_info;
2797 if (pre_all_op_info != new_all_op_info) {
2798 MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
2799 // The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
2800 if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
2801 EnableOpGraphCache(false);
2802 } else {
2803 // Increase top cell switches counts
2804 ++top_cell_switch_counts_;
2805 }
2806 EraseTopCellFromTopCellList(pre_top_cell);
2807 pre_top_cell->Clear();
2808 already_run_top_cell_[already_top_cell_id] = new_top_cell;
2809 g_pyobj_id_cache.clear();
2810 } else {
2811 MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
2812 pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
2813 EraseTopCellFromTopCellList(new_top_cell);
2814 new_top_cell->Clear();
2815 pre_top_cell->set_forward_already_run(true);
2816 set_top_cell(pre_top_cell);
2817 }
2818 }
2819
RunGradGraph(py::object * ret,const py::object & cell,const py::tuple & args)2820 void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args) {
2821 MS_EXCEPTION_IF_NULL(ret);
2822 const auto &cell_id = GetCellId(cell, args);
2823 MS_LOG(DEBUG) << "Run start cell id " << cell_id;
2824 auto has_sens = std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) {
2825 return cell_id.find(value->cell_id()) != std::string::npos && cell_id != value->cell_id();
2826 });
2827 MS_LOG(DEBUG) << "Run has sens " << has_sens << " cell id " << cell_id;
2828 auto resource = top_cell()->resource();
2829 MS_EXCEPTION_IF_NULL(resource);
2830 MS_LOG(DEBUG) << "Run resource ptr " << resource.get();
2831
2832 VectorRef arg_list;
2833 py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens));
2834 pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
2835 if (resource->results().find(pipeline::kOutput) == resource->results().end()) {
2836 MS_LOG(EXCEPTION) << "Can't find run graph output";
2837 }
2838 if (!resource->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
2839 MS_LOG(EXCEPTION) << "Run graph is not VmEvalFuncPtr";
2840 }
2841 compile::VmEvalFuncPtr run = resource->results()[pipeline::kOutput].cast<compile::VmEvalFuncPtr>();
2842 MS_EXCEPTION_IF_NULL(run);
2843
2844 const auto &backend = MsContext::GetInstance()->backend_policy();
2845 MS_LOG(DEBUG) << "Eval run " << backend;
2846 grad_is_running_ = true;
2847 BaseRef value = (*run)(arg_list);
2848 grad_is_running_ = false;
2849 MS_LOG(DEBUG) << "Eval run end " << value.ToString();
2850 *ret = BaseRefToPyData(value);
2851 // Clear device memory resource of top cell when it has been ran.
2852 auto has_higher_order = std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
2853 [](const TopCellInfoPtr &value) { return !value->is_topest(); });
2854 if (top_cell()->is_topest() && !has_higher_order) {
2855 top_cell()->ClearDeviceMemory();
2856 }
2857 // High order
2858 if (top_cell()->vm_compiled()) {
2859 MakeNestedCnode(cell, converted_args, resource, *ret);
2860 } else if (GetHighOrderStackSize() >= ARG_SIZE) {
2861 SwitchTopcell();
2862 }
2863 }
2864
SwitchTopcell()2865 void GradExecutor::SwitchTopcell() {
2866 const auto &inner_top_cell_all_op_info = top_cell()->all_op_info();
2867 bool inner_top_cell_is_dynamic = top_cell()->is_dynamic();
2868 top_cell()->set_grad_order(1);
2869
2870 // Get outer top cell
2871 auto outer_top_cell = PopHighOrderGraphStack();
2872 MS_EXCEPTION_IF_NULL(outer_top_cell);
2873 outer_top_cell->all_op_info() += inner_top_cell_all_op_info;
2874 // If inner is dynamic, outer set dynamic too
2875 if (inner_top_cell_is_dynamic) {
2876 outer_top_cell->set_is_dynamic(inner_top_cell_is_dynamic);
2877 }
2878 set_top_cell(outer_top_cell);
2879 }
2880
DoParameterReplace(const FuncGraphPtr & first_grad_fg,const py::tuple & forward_args,std::vector<AnfNodePtr> * inputs,ValuePtrList * weights_args)2881 void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const py::tuple &forward_args,
2882 std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args) {
2883 MS_EXCEPTION_IF_NULL(inputs);
2884 MS_EXCEPTION_IF_NULL(weights_args);
2885 auto first_df_builder = top_cell()->df_builder();
2886 MS_EXCEPTION_IF_NULL(first_df_builder);
2887 auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
2888 MS_EXCEPTION_IF_NULL(first_graph_info);
2889 SwitchTopcell();
2890 auto second_df_builder = top_cell()->df_builder();
2891 MS_EXCEPTION_IF_NULL(second_df_builder);
2892 auto second_graph_info = top_cell()->graph_info_map().at(second_df_builder);
2893 MS_EXCEPTION_IF_NULL(second_graph_info);
2894
2895 std::unordered_set<std::string> params_weights_set;
2896 std::unordered_set<std::string> params_inputs_set;
2897 for (const auto &sec : second_graph_info->params) {
2898 if (sec.second->has_default()) {
2899 params_weights_set.emplace(sec.first);
2900 } else {
2901 params_inputs_set.insert(sec.first);
2902 }
2903 }
2904 auto manager = Manage({first_grad_fg}, false);
2905 // Replace inputs param
2906 for (size_t i = 0; i < forward_args.size(); ++i) {
2907 const auto &id = GetId(forward_args[i]);
2908 if (params_inputs_set.count(id)) {
2909 // Can find in second graph
2910 const auto &input_param_second = second_graph_info->params.at(id);
2911 manager->Replace(first_graph_info->params.at(id), input_param_second);
2912 inputs->emplace_back(input_param_second);
2913 } else {
2914 inputs->emplace_back(GetInput(forward_args[i], false));
2915 }
2916 }
2917
2918 // Replace weights param
2919 for (const auto &fir : first_graph_info->params) {
2920 if (!fir.second->has_default()) {
2921 continue;
2922 }
2923 // Second graph no this weight param, need add to second graph
2924 if (!params_weights_set.count(fir.first)) {
2925 SetParamNodeMapInGraphInfoMap(second_df_builder, fir.first, fir.second);
2926 inputs->emplace_back(fir.second);
2927 weights_args->emplace_back(fir.second->default_param());
2928 } else {
2929 // Need replace
2930 for (const auto &sec : second_graph_info->params) {
2931 MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
2932 if (sec.second->has_default() && fir.second->name() == sec.second->name()) {
2933 manager->Replace(fir.second, sec.second);
2934 inputs->emplace_back(sec.second);
2935 weights_args->emplace_back(sec.second->default_param());
2936 break;
2937 }
2938 }
2939 }
2940 }
2941 }
2942
MakeNestedCnode(const py::object & cell,const py::tuple & forward_args,const pipeline::ResourcePtr & resource,const py::object & out)2943 void GradExecutor::MakeNestedCnode(const py::object &cell, const py::tuple &forward_args,
2944 const pipeline::ResourcePtr &resource, const py::object &out) {
2945 if (cell_stack_.empty()) {
2946 MS_LOG(DEBUG) << "No nested grad find";
2947 return;
2948 }
2949 FuncGraphPtr first_grad_fg = nullptr;
2950 if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
2951 first_grad_fg = curr_g_;
2952 MS_LOG(DEBUG) << "Bprop nested";
2953 } else {
2954 first_grad_fg = resource->func_graph();
2955 }
2956 MS_EXCEPTION_IF_NULL(first_grad_fg);
2957 DumpGraphIR("first_grad_fg.ir", first_grad_fg);
2958
2959 std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
2960 ValuePtrList weights_args;
2961 DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
2962
2963 pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
2964 r->manager()->AddFuncGraph(first_grad_fg);
2965 set_eliminate_forward(false);
2966 FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
2967 set_eliminate_forward(true);
2968 DumpGraphIR("second_grad_fg.ir", second_grad_fg);
2969 r->Clean();
2970
2971 MS_LOG(DEBUG) << "Get pre graph ptr " << curr_g().get();
2972 auto cnode = curr_g()->NewCNode(inputs);
2973 auto out_id = GetId(out);
2974 SetTupleArgsToGraphInfoMap(curr_g(), out, cnode);
2975 SetNodeMapInGraphInfoMap(curr_g(), out_id, cnode);
2976 MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
2977
2978 // Get input values
2979 ValuePtrList input_args;
2980 for (size_t i = 0; i < forward_args.size(); ++i) {
2981 const auto &arg = PyObjToValue(forward_args[i]);
2982 input_args.emplace_back(arg);
2983 }
2984 input_args.insert(input_args.end(), weights_args.begin(), weights_args.end());
2985 // Get output values
2986 py::object new_out;
2987 if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !py::isinstance<py::tuple>(out)) {
2988 new_out = py::make_tuple(out);
2989 } else {
2990 new_out = out;
2991 }
2992 const auto &out_value = PyObjToValue(new_out);
2993 if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
2994 MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
2995 }
2996 need_renormalize_ = true;
2997 }
2998
EraseTopCellFromTopCellList(const TopCellInfoPtr & top_cell)2999 void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
3000 MS_EXCEPTION_IF_NULL(top_cell);
3001 auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
3002 [&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
3003 if (iter == top_cell_list_.end()) {
3004 MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
3005 << " from top cell list";
3006 } else {
3007 (void)top_cell_list_.erase(iter);
3008 }
3009 }
3010
GradMsFunctionInner(const std::string & phase,const py::object & out,const py::args & args,const FuncGraphPtr & ms_func_graph,const FuncGraphPtr & grad_graph)3011 void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
3012 const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) {
3013 // Get actual output value and added output value.
3014 if (!py::isinstance<py::tuple>(out)) {
3015 MS_LOG(EXCEPTION) << "The output value of ms_function func graph should be a tuple.";
3016 }
3017 auto tuple_out = py::cast<py::tuple>(out);
3018 constexpr size_t tuple_out_size = 2;
3019 if (tuple_out.size() != tuple_out_size) {
3020 MS_LOG(EXCEPTION) << "The tuple size of output value of ms_function func graph should be 2.";
3021 }
3022 py::object actual_out = tuple_out[0];
3023 auto actual_out_v = PyObjToValue(actual_out);
3024 auto added_out = PyObjToValue(tuple_out[1]);
3025 MS_LOG(DEBUG) << "Added output value is: " << added_out->ToString();
3026
3027 // Identity op info for current running ms_func graph.
3028 OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
3029 op_exec_info->op_name = phase;
3030 RecordGradOpInfo(op_exec_info, actual_out_v);
3031 MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
3032
3033 // Step 1: Update actual output tensors used in grad graph.
3034 MS_LOG(DEBUG) << "ms_function actual output value: " << actual_out_v->ToString();
3035 UpdateForwardTensorInfoInBpropGraph(op_exec_info, actual_out_v);
3036
3037 // Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
3038 if (top_cell()->op_info_with_ms_func_forward_tensors().count(op_exec_info->op_info)) {
3039 UpdateMsFunctionForwardTensors(op_exec_info, added_out);
3040 return;
3041 }
3042 MS_LOG(DEBUG) << "Ms func graph run firstly. The graph phase is: " << graph_phase();
3043 if (!need_construct_graph()) {
3044 MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
3045 }
3046 ReplaceNewTensorsInGradGraph(top_cell(), op_exec_info, added_out, ms_func_graph, grad_graph);
3047
3048 // Clone new ms_function func graph and grad graph.
3049 auto new_ms_func_graph = BasicClone(ms_func_graph);
3050 auto new_grad_graph = BasicClone(grad_graph, true);
3051 auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
3052 MS_EXCEPTION_IF_NULL(new_make_tuple);
3053 new_ms_func_graph->set_output(new_make_tuple->input(1));
3054
3055 // Make Adjoint for grad graph
3056 MakeAdjointForMsFunction(new_ms_func_graph, new_grad_graph, actual_out, args, actual_out_v);
3057 }
3058
GradMsFunction(const py::object & out,const py::args & args)3059 py::object GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
3060 // Get actual forward output object.
3061 if (graph_phase().empty()) {
3062 MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain ms_function func graph.";
3063 }
3064 const auto &phase = graph_phase();
3065 MS_LOG(DEBUG) << "ms_function func graph phase: " << phase;
3066 auto executor = pipeline::GraphExecutorPy::GetInstance();
3067 MS_EXCEPTION_IF_NULL(executor);
3068 FuncGraphPtr ms_func_graph = executor->GetFuncGraph(phase);
3069 MS_EXCEPTION_IF_NULL(ms_func_graph);
3070 py::object ret = out;
3071 if (ms_func_graph->modify_output()) {
3072 auto tuple_out = py::cast<py::tuple>(out);
3073 ret = tuple_out[0];
3074 }
3075
3076 // Make Adjoint for grad graph of ms_function.
3077 if (!grad_flag_) {
3078 MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
3079 set_graph_phase("");
3080 return ret;
3081 }
3082 FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
3083 MS_EXCEPTION_IF_NULL(grad_graph);
3084 GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
3085 set_graph_phase("");
3086 return ret;
3087 }
3088
ClearGrad(const py::object & cell,const py::args & args)3089 void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
3090 MS_LOG(DEBUG) << "Clear top cell grad resource " << GetCellId(cell, args);
3091 if (grad_order_ > 0) {
3092 --grad_order_;
3093 }
3094 check_graph_cell_id_.clear();
3095 grad_operation_.clear();
3096 forward()->node_abs_map().clear();
3097 ad::CleanRes();
3098 pipeline::ReclaimOptimizer();
3099 }
3100
ClearRes()3101 void GradExecutor::ClearRes() {
3102 MS_LOG(DEBUG) << "Clear grad res";
3103 grad_flag_ = false;
3104 enable_op_cache_ = true;
3105 grad_is_running_ = false;
3106 need_renormalize_ = false;
3107 eliminate_forward_ = true;
3108 custom_bprop_cell_count_ = 0;
3109 grad_order_ = 0;
3110 top_cell_switch_counts_ = 0;
3111
3112 check_graph_cell_id_.clear();
3113 grad_operation_.clear();
3114 top_cell_ = nullptr;
3115 curr_g_ = nullptr;
3116 bprop_cell_list_.clear();
3117 already_run_top_cell_.clear();
3118 ClearCellRes();
3119 std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
3120 std::stack<std::string>().swap(cell_stack_);
3121 std::stack<std::pair<FuncGraphPtr, TopCellInfoPtr>>().swap(high_order_stack_);
3122 }
3123
grad_executor() const3124 GradExecutorPtr PynativeExecutor::grad_executor() const {
3125 MS_EXCEPTION_IF_NULL(grad_executor_);
3126 return grad_executor_;
3127 }
forward_executor() const3128 ForwardExecutorPtr PynativeExecutor::forward_executor() const {
3129 MS_EXCEPTION_IF_NULL(forward_executor_);
3130 return forward_executor_;
3131 }
3132
grad_flag() const3133 bool PynativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); }
3134
set_grad_flag(bool flag)3135 void PynativeExecutor::set_grad_flag(bool flag) { grad_executor()->set_grad_flag(flag); }
3136
set_graph_phase(const std::string & graph_phase)3137 void PynativeExecutor::set_graph_phase(const std::string &graph_phase) {
3138 grad_executor()->set_graph_phase(graph_phase);
3139 }
3140
set_py_exe_path(const py::object & py_exe_path)3141 void PynativeExecutor::set_py_exe_path(const py::object &py_exe_path) {
3142 if (!py::isinstance<py::str>(py_exe_path)) {
3143 MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";
3144 }
3145 auto py_exe_path_s = py::cast<std::string>(py_exe_path);
3146 auto ms_context = MsContext::GetInstance();
3147 ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
3148 }
3149
set_kernel_build_server_dir(const py::object & kernel_build_server_dir)3150 void PynativeExecutor::set_kernel_build_server_dir(const py::object &kernel_build_server_dir) {
3151 if (!py::isinstance<py::str>(kernel_build_server_dir)) {
3152 MS_LOG(EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
3153 }
3154 auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
3155 auto ms_context = MsContext::GetInstance();
3156 ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
3157 }
3158
CheckGraph(const py::object & cell,const py::args & args)3159 py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
3160 return grad_executor()->CheckGraph(cell, args);
3161 }
3162
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & cell,const py::args & args)3163 py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell,
3164 const py::args &args) {
3165 return grad_executor()->CheckAlreadyRun(grad, cell, args);
3166 }
3167
Run(const py::object & cell,const py::tuple & args)3168 py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args) {
3169 py::object ret;
3170 PynativeExecutorTry(grad_executor()->RunGraph, &ret, cell, args);
3171 return ret;
3172 }
3173
ClearCell(const std::string & cell_id)3174 void PynativeExecutor::ClearCell(const std::string &cell_id) {
3175 MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
3176 grad_executor()->ClearCellRes(cell_id);
3177 }
3178
ClearGrad(const py::object & cell,const py::args & args)3179 void PynativeExecutor::ClearGrad(const py::object &cell, const py::args &args) {
3180 MS_LOG(DEBUG) << "Clear grad";
3181 return grad_executor()->ClearGrad(cell, args);
3182 }
3183
ClearRes()3184 void PynativeExecutor::ClearRes() {
3185 MS_LOG(DEBUG) << "Clear all res";
3186 session::PynativeTaskManager::GetInstance().Reset();
3187 SetLazyBuild(false);
3188 cell_depth_ = 0;
3189
3190 // Maybe exit in runop step
3191 auto ms_context = MsContext::GetInstance();
3192 if (ms_context != nullptr) {
3193 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
3194 }
3195 ConfigManager::GetInstance().ResetIterNum();
3196 if (forward_executor_ != nullptr) {
3197 forward_executor_->ClearRes();
3198 }
3199 if (grad_executor_ != nullptr) {
3200 grad_executor_->ClearRes();
3201 }
3202 ad::CleanRes();
3203 pipeline::ReclaimOptimizer();
3204 kSession = nullptr;
3205 mind_rt_backend = nullptr;
3206 g_pyobj_id_cache.clear();
3207 }
3208
NewGraph(const py::object & cell,const py::args & args)3209 void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
3210 // Make a flag for new cell
3211 if (!grad_executor()->grad_flag()) {
3212 MS_LOG(DEBUG) << "Grad flag is false";
3213 return;
3214 }
3215 py::object ret;
3216 PynativeExecutorTry(grad_executor()->InitGraph, &ret, cell, args);
3217 }
3218
EndGraph(const py::object & cell,const py::object & out,const py::args & args)3219 void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
3220 if (!grad_executor()->grad_flag()) {
3221 MS_LOG(DEBUG) << "Grad flag is false";
3222 return;
3223 }
3224 MS_LOG(DEBUG) << "Enter end graph process.";
3225 py::object ret;
3226 PynativeExecutorTry(grad_executor()->LinkGraph, &ret, cell, out, args);
3227 MS_LOG(DEBUG) << "Leave end graph process.";
3228 }
3229
GradMsFunction(const py::object & out,const py::args & args)3230 py::object PynativeExecutor::GradMsFunction(const py::object &out, const py::args &args) {
3231 return grad_executor()->GradMsFunction(out, args);
3232 }
3233
GradNet(const prim::GradOperationPtr & grad,const py::object & cell,const py::object & weights,const py::args & args)3234 void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
3235 const py::args &args) {
3236 py::object ret;
3237 PynativeExecutorTry(grad_executor()->GradGraph, &ret, grad, cell, weights, args);
3238 }
3239
Sync()3240 void PynativeExecutor::Sync() {
3241 auto ms_context = MsContext::GetInstance();
3242 MS_EXCEPTION_IF_NULL(ms_context);
3243
3244 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
3245 if (kSession == nullptr) {
3246 MS_EXCEPTION(NotExistsError) << "No session has been created!";
3247 }
3248 kSession->SyncStream();
3249 } else {
3250 std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
3251 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
3252 const auto &device_context =
3253 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
3254 MS_EXCEPTION_IF_NULL(device_context);
3255 (void)device_context->SyncStream();
3256 }
3257 }
3258
SetLazyBuild(bool enable)3259 void PynativeExecutor::SetLazyBuild(bool enable) { forward_executor()->set_lazy_build(enable); }
3260
EnterCell()3261 void PynativeExecutor::EnterCell() {
3262 if (cell_depth_ < UINT32_MAX) {
3263 ++cell_depth_;
3264 } else {
3265 MS_LOG(ERROR) << "Cell call stack too deep";
3266 }
3267 }
3268
ExitCell()3269 void PynativeExecutor::ExitCell() {
3270 if (cell_depth_ > 0) {
3271 --cell_depth_;
3272 }
3273 }
3274
IsTopCell() const3275 bool PynativeExecutor::IsTopCell() const { return cell_depth_ == 0; }
3276
ExecuteAllTask()3277 void PynativeExecutor::ExecuteAllTask() { session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks(); }
3278
__anonc481bc1a0a02(const py::module *m) 3279 REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
3280 (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
3281 .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
3282 .def("enter_cell", &PynativeExecutor::EnterCell, "enter cell.")
3283 .def("exit_cell", &PynativeExecutor::ExitCell, "exit cell.")
3284 .def("is_top_cell", &PynativeExecutor::IsTopCell, "check top cell.")
3285 .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.")
3286 .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.")
3287 .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.")
3288 .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.")
3289 .def("grad_ms_function", &PynativeExecutor::GradMsFunction, "pynative grad for ms_function.")
3290 .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.")
3291 .def("clear_cell", &PynativeExecutor::ClearCell, "pynative clear status.")
3292 .def("clear_res", &PynativeExecutor::ClearRes, "pynative clear exception res.")
3293 .def("clear_grad", &PynativeExecutor::ClearGrad, "pynative clear grad status.")
3294 .def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
3295 .def("set_lazy_build", &PynativeExecutor::SetLazyBuild, "pynative build kernel async")
3296 .def("execute_all_task", &PynativeExecutor::ExecuteAllTask, "clear all task")
3297 .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
3298 .def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase")
3299 .def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag")
3300 .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
3301 "Executor set grad flag.")
3302 .def("set_py_exe_path", &PynativeExecutor::set_py_exe_path,
3303 py::arg("py_exe_path") = py::str(""), "set python executable path.")
3304 .def("set_kernel_build_server_dir", &PynativeExecutor::set_kernel_build_server_dir,
3305 py::arg("kernel_build_server_dir") = py::str(""),
3306 "set kernel build server directory path.");
3307 }));
3308 } // namespace mindspore::pynative
3309