• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/pynative/pynative_utils.h"
17 #include <algorithm>
18 #include <vector>
19 #include "ops/sparse_ops.h"
20 #include "ops/sequence_ops.h"
21 #include "ops/framework_ops.h"
22 #include "include/backend/optimizer/helper.h"
23 #include "include/backend/optimizer/op_adaptation_info_factory.h"
24 #include "pybind_api/ir/primitive_py.h"
25 #include "pybind_api/gil_scoped_long_running.h"
26 #include "pybind_api/ir/hook_py.h"
27 #include "utils/ms_context.h"
28 #include "ir/cell.h"
29 #include "include/common/utils/utils.h"
30 #include "include/common/utils/convert_utils_py.h"
31 #include "include/common/utils/primfunc_utils.h"
32 #include "include/common/debug/anf_ir_dump.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "include/common/utils/stub_tensor.h"
35 #include "frontend/expander/bprop/bprop.h"
36 #include "frontend/optimizer/environ_conversion.h"
37 #include "frontend/optimizer/fallback_rewriter.h"
38 #include "pipeline/pynative/grad/jit/jit_grad.h"
39 #include "ops/sequence_op_name.h"
40 #include "ops/structure_ops.h"
41 #include "ops/other_ops.h"
42 #include "pipeline/pynative/predict_out_type_map.h"
43 #include "kernel/pyboost/auto_generate/contiguous.h"
44 #include "runtime/pipeline/pipeline.h"
45 #include "ops/auto_generate/gen_ops_primitive.h"
46 #include "include/common/pynative/abstract_converter.h"
47 #include "kernel/pyboost/pyboost_utils.h"
48 
49 namespace mindspore {
50 namespace pynative {
51 namespace PyNativeAlgo {
52 namespace {
GetObjIdFromPython(const py::handle & obj)53 std::string GetObjIdFromPython(const py::handle &obj) {
54   py::object out = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
55   if (py::isinstance<py::none>(out)) {
56     MS_LOG(EXCEPTION) << "Get pyobj failed";
57   }
58   return out.cast<std::string>();
59 }
60 // for simply infer (simple infer will push abs in bprop queue)
61 static AbstractConverter kGradAbstractConverter;
62 
GetIdForPyTupleOrList(const py::handle & obj)63 std::string GetIdForPyTupleOrList(const py::handle &obj) {
64   auto p_list = py::cast<py::tuple>(obj);
65   string prefix = py::isinstance<py::tuple>(obj) ? "Tuple<" : "List<";
66   if (p_list.empty()) {
67     prefix = "Empty:";
68   } else {
69     for (size_t i = 0; i < p_list.size(); ++i) {
70       prefix += PyParser::GetIdByPyObj(p_list[i]) + ":";
71     }
72   }
73   prefix.pop_back();
74   prefix += ">";
75   return prefix;
76 }
77 
GetFnInfoByPyObj(const py::object & obj)78 std::string GetFnInfoByPyObj(const py::object &obj) {
79   std::string fn_info = obj.attr("__module__").cast<std::string>();
80   fn_info += "_" + obj.attr("__name__").cast<std::string>();
81   fn_info += "_" + obj.attr("__code__").attr("co_filename").cast<std::string>();
82   fn_info += "_" + py::str(obj.attr("__code__").attr("co_firstlineno")).cast<std::string>();
83   if (py::hasattr(obj, "__warpped__")) {
84     auto warpped_obj = obj.attr("__warpped__");
85     fn_info += "_" + warpped_obj.attr("__name__").cast<std::string>();
86     fn_info += "_" + warpped_obj.attr("__code__").attr("co_filename").cast<std::string>();
87     fn_info += "_" + py::str(warpped_obj.attr("__code__").attr("co_firstlineno")).cast<std::string>();
88   }
89   return fn_info;
90 }
91 
AddDynInputsSizesAttr(const FrontendOpRunInfoPtr & op_run_info)92 void AddDynInputsSizesAttr(const FrontendOpRunInfoPtr &op_run_info) {
93   if (op_run_info->base_op_run_info.dyn_input_sizes.empty()) {
94     return;
95   }
96   op_run_info->op_grad_info->op_prim->set_attr(kAttrDynInputSizes,
97                                                MakeValue(op_run_info->base_op_run_info.dyn_input_sizes));
98 }
99 
CreateNonTensorByAbstract(const abstract::AbstractBasePtr & abs)100 ValuePtr CreateNonTensorByAbstract(const abstract::AbstractBasePtr &abs) {
101   MS_EXCEPTION_IF_NULL(abs);
102   auto type_id = Common::GetTypeFromAbstract(abs);
103   if (abs->isa<abstract::AbstractMonad>()) {
104     return std::make_shared<tensor::Tensor>(0);
105   }
106   if (type_id == kMetaTypeNone) {
107     return kNone;
108   }
109   if (type_id == kMetaTypeNull) {
110     return kNull;
111   }
112   if (abs->isa<abstract::AbstractSequence>()) {
113     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>()->elements();
114     ValuePtrList value_ptr_list;
115     (void)std::transform(abs_seq.begin(), abs_seq.end(), std::back_inserter(value_ptr_list),
116                          [](const abstract::AbstractBasePtr &elem) { return CreateNonTensorByAbstract(elem); });
117     return std::make_shared<ValueTuple>(value_ptr_list);
118   }
119   if (type_id == kNumberTypeBool) {
120     return MakeValue(true);
121   }
122   if (type_id == kObjectTypeString) {
123     return MakeValue("");
124   }
125   if (type_id >= kNumberTypeInt && type_id <= kNumberTypeUInt64) {
126     return MakeValue(static_cast<int64_t>(0));
127   }
128   if (type_id >= kNumberTypeFloat && type_id <= kNumberTypeFloat64) {
129     return MakeValue(static_cast<float>(0));
130   }
131   if (type_id == kNumberTypeDouble) {
132     return MakeValue(static_cast<double>(0));
133   }
134   MS_LOG(EXCEPTION) << "Get unsupported type " << type_id;
135 }
136 
PlantTupleParam(const FuncGraphPtr & bprop_graph,const abstract::AbstractSequencePtr & abs_seq,AnfNodePtrList * make_tuple,AnfNodePtrList * new_param)137 void PlantTupleParam(const FuncGraphPtr &bprop_graph, const abstract::AbstractSequencePtr &abs_seq,
138                      AnfNodePtrList *make_tuple, AnfNodePtrList *new_param) {
139   MS_EXCEPTION_IF_NULL(bprop_graph);
140   MS_EXCEPTION_IF_NULL(make_tuple);
141   MS_EXCEPTION_IF_NULL(new_param);
142   MS_EXCEPTION_IF_NULL(abs_seq);
143   for (size_t i = 0; i < abs_seq->size(); ++i) {
144     if (abs_seq->elements()[i]->isa<abstract::AbstractSequence>()) {
145       PlantTupleParam(bprop_graph, abs_seq->elements()[i]->cast<abstract::AbstractSequencePtr>(), make_tuple,
146                       new_param);
147     } else if (abs_seq->elements()[i]->isa<abstract::AbstractTensor>()) {
148       auto plant_param = bprop_graph->add_parameter();
149       plant_param->set_abstract(abs_seq->elements()[i]);
150       (void)make_tuple->emplace_back(plant_param);
151       (void)new_param->emplace_back(plant_param);
152     }
153   }
154 }
155 
GetContiguousGradTensor(const ValuePtr & v)156 ValuePtr GetContiguousGradTensor(const ValuePtr &v) {
157   const auto &tensor = v->cast<tensor::BaseTensorPtr>();
158   MS_EXCEPTION_IF_NULL(tensor);
159   if (tensor->storage_info() == nullptr) {
160     return nullptr;
161   }
162 
163   auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
164   MS_EXCEPTION_IF_NULL(old_device_address);
165   const auto &device_target = old_device_address->device_name();
166   if (device_target != kAscendDevice) {
167     // GPU/CPU contiguous tensor when convert stub node, contiguous before grad.
168     return nullptr;
169   }
170 
171   MS_LOG(DEBUG) << "tensor id:" << tensor->id();
172   auto stream_id = old_device_address->stream_id();
173   const auto &old_storage_info = old_device_address->GetTensorStorageInfo();
174   MS_EXCEPTION_IF_NULL(old_storage_info);
175 
176   const auto &device_context = runtime::OpRunner::GetDeviceContext(old_device_address->device_name());
177   MS_EXCEPTION_IF_NULL(device_context);
178   auto address_size = GetTypeByte(TypeIdToType(old_device_address->type_id())) * SizeOf(old_storage_info->shape);
179   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
180     nullptr, address_size, Format::DEFAULT_FORMAT, old_device_address->type_id(), old_storage_info->shape,
181     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
182   kernel_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(old_device_address->type_id())));
183   kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(old_storage_info->shape));
184   kernel_tensor->set_stream_id(stream_id);
185 
186   auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
187   new_device_address->set_device_shape(old_storage_info->shape);
188   new_device_address->set_original_ref_count(SIZE_MAX);
189   new_device_address->ResetRefCount();
190 
191   device::DeviceAddressPtrList input_addr_list{old_device_address};
192   device::DeviceAddressPtrList output_addr_list{new_device_address};
193   GilReleaseWithCheck release_gil;
194   if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(runtime::KernelTaskType::kCONTIGUOUS_TASK,
195                                                                    input_addr_list, output_addr_list, stream_id)) {
196     MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
197   }
198 
199   MS_LOG(DEBUG) << "Update contiguous address, old_device_address:" << old_device_address
200                 << ", new_device_address:" << new_device_address;
201 
202   auto new_tensor = std::make_shared<tensor::Tensor>(*tensor);
203   new_tensor->set_device_address(new_device_address);
204   return new_tensor;
205 }
206 
RefreshGradContiguousTensor(const FrontendOpRunInfoPtr & op_run_info,size_t index)207 void RefreshGradContiguousTensor(const FrontendOpRunInfoPtr &op_run_info, size_t index) {
208   const auto &unused_inputs = BpropExpander::GetUnusedInputs(op_run_info->op_grad_info->op_prim->name());
209   // Input is not used in bprop, no need to contiguous.
210   if (unused_inputs.find(index) != unused_inputs.end()) {
211     return;
212   }
213 
214   const auto &v = op_run_info->op_grad_info->input_value[index];
215   if (v->isa<tensor::BaseTensor>()) {
216     const auto &new_tensor = GetContiguousGradTensor(v);
217     if (new_tensor != nullptr) {
218       op_run_info->op_grad_info->input_value[index] = new_tensor;
219     }
220   } else if (v->isa<ValueSequence>()) {
221     const auto &vec = v->cast<ValueSequencePtr>()->value();
222     if (vec.empty() || !vec[0]->isa<tensor::BaseTensor>()) {
223       return;
224     }
225     // Tensor tuple need contiguous tensor.
226     bool need_refresh_tuple = false;
227     std::vector<ValuePtr> new_vec(vec.size());
228     for (size_t i = 0; i < vec.size(); i++) {
229       const auto &new_tensor = GetContiguousGradTensor(vec[i]);
230       if (new_tensor == nullptr) {
231         new_vec[i] = vec[i];
232       } else {
233         // Not-contiguous tensor in input_value, need refresh tuple after contiguous tensor.
234         need_refresh_tuple = true;
235         new_vec[i] = new_tensor;
236       }
237     }
238     if (need_refresh_tuple) {
239       op_run_info->op_grad_info->input_value[index] = MakeValue(new_vec);
240     }
241   }
242 }
243 
244 const mindspore::HashSet<std::string> kNotRealOP{
245   kMakeTupleOpName,
246   kMakeListNewOpName,
247   kTupleGetItemOpName,
248   kStopGradientOpName,
249   kUpdateStateOpName,
250   kLoadOpName,
251   kDependOpName,
252   kReturnOpName,
253   kNPUAllocFloatStatusOpName,
254   kNPUGetFloatStatusOpName,
255   kNPUClearFloatStatusOpName,
256   kMirrorOperatorOpName,
257   kSequenceSliceOpName,
258   kSequenceMulOpName,
259   kPyExecuteOpName,
260 };
261 
GetContiguousTensor(const tensor::BaseTensorPtr & input_tensor,const std::string & device_target,bool requires_grad)262 tensor::BaseTensorPtr GetContiguousTensor(const tensor::BaseTensorPtr &input_tensor, const std::string &device_target,
263                                           bool requires_grad) {
264   auto contiguous_op = CREATE_PYBOOST_OP(Contiguous, device_target);
265   auto contiguous_tensor = contiguous_op->Call(input_tensor);
266   if (requires_grad) {
267     const auto &contiguous_run_info = std::make_shared<FrontendOpRunInfo>();
268     contiguous_run_info->requires_grad = true;
269     PyBoost::UpdateOpRunInfo(contiguous_op, contiguous_run_info);
270     contiguous_run_info->base_op_run_info.device_target = device_target;
271     contiguous_run_info->input_size = 1;
272     contiguous_run_info->base_op_run_info.op_name = ops::kNameContiguous;
273     contiguous_run_info->op_grad_info->op_prim = prim::kPrimContiguous;
274     PyBoost::DoGrad(contiguous_op, contiguous_run_info, {input_tensor});
275   }
276   return contiguous_tensor;
277 }
278 
UnsetValueAbstractCache(const ValuePtr & value)279 void UnsetValueAbstractCache(const ValuePtr &value) {
280   if (value->isa<tensor::BaseTensor>()) {
281     auto tensor = value->cast<tensor::BaseTensorPtr>();
282     tensor->set_abstract(std::weak_ptr<abstract::AbstractBase>());
283   } else if (value->isa<tensor::BaseTensor>()) {
284     auto tensor = value->cast<tensor::BaseTensorPtr>();
285     tensor->set_abstract(std::weak_ptr<abstract::AbstractBase>());
286   } else if (value->isa<ValueSequence>()) {
287     const auto &seq = value->cast<ValueSequencePtr>();
288     auto elements = seq->value();
289     for (const auto &element : elements) {
290       UnsetValueAbstractCache(element);
291     }
292   }
293 }
294 }  // namespace
295 
SetAbstractValueToAnyValue(const AbstractBasePtr & abs)296 AbstractBasePtr Common::SetAbstractValueToAnyValue(const AbstractBasePtr &abs) {
297   MS_EXCEPTION_IF_NULL(abs);
298   if (abs->isa<abstract::AbstractTensor>()) {
299     abs->set_value(kValueAny);
300   } else if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>()) {
301     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
302     for (const auto &elem : abs_seq->elements()) {
303       (void)SetAbstractValueToAnyValue(elem);
304     }
305   } else if (abs->isa<abstract::AbstractDictionary>()) {
306     const auto &abs_dic = abs->cast<abstract::AbstractDictionaryPtr>();
307     for (const auto &elem : abs_dic->elements()) {
308       (void)SetAbstractValueToAnyValue(elem.first);
309       (void)SetAbstractValueToAnyValue(elem.second);
310     }
311   }
312   return abs;
313 }
314 
ConvertValueSequenceToMakeTuple(const ValueNodePtr & node,const FuncGraphPtr & func_graph)315 AnfNodePtr Common::ConvertValueSequenceToMakeTuple(const ValueNodePtr &node, const FuncGraphPtr &func_graph) {
316   MS_EXCEPTION_IF_NULL(node);
317   const auto &v = node->value();
318   if (!v->isa<ValueSequence>()) {
319     return node;
320   }
321   auto value_sequence = v->cast<ValueSequencePtr>();
322   if (!node->abstract()->isa<abstract::AbstractSequence>() &&
323       (node->abstract()->cast<abstract::AbstractSequencePtr>()->size() != value_sequence->size())) {
324     MS_LOG(EXCEPTION) << "Get wrong matched abs " << node->abstract()->ToString() << " and value "
325                       << value_sequence->ToString();
326   }
327 
328   AnfNodePtrList inputs{NewValueNode(prim::kPrimMakeTuple)};
329   for (const auto &value : value_sequence->value()) {
330     MS_EXCEPTION_IF_NULL(value);
331     auto value_node = NewValueNode(value);
332     auto abs = Common::SetAbstractValueToAnyValue(value->ToAbstract());
333     value_node->set_abstract(abs);
334     auto tuple_node = ConvertValueSequenceToMakeTuple(value_node, func_graph);
335     (void)inputs.emplace_back(tuple_node);
336   }
337   MS_EXCEPTION_IF_NULL(func_graph);
338   auto make_tuple_node = func_graph->NewCNode(inputs);
339   make_tuple_node->set_abstract(node->abstract());
340   return make_tuple_node;
341 }
342 
GetIdByValue(const ValuePtr & v)343 std::string Common::GetIdByValue(const ValuePtr &v) {
344   MS_EXCEPTION_IF_NULL(v);
345   if (v->isa<tensor::BaseTensor>()) {
346     return v->cast<tensor::BaseTensorPtr>()->id();
347   } else if (v->isa<stub::StubNode>()) {
348     return GetIdByValue(v->cast<stub::StubNodePtr>()->WaitValue());
349   } else if (v->isa<Cell>()) {
350     return v->cast<CellPtr>()->id();
351   } else if (v->isa<mindspore::Type>()) {
352     auto type_ptr = v->cast<mindspore::TypePtr>();
353     return "Type:" + type_ptr->ToString();
354   } else if (v->isa<StringImm>()) {
355     return "S" + v->cast<StringImmPtr>()->value();
356   } else if (v->isa<BoolImm>()) {
357     return "B" + std::to_string(v->cast<BoolImmPtr>()->value());
358   } else if (v->isa<IntegerImm>()) {
359     return "I" + std::to_string(v->cast<Int64ImmPtr>()->value());
360   } else if (v->isa<FloatImm>()) {
361     return "F" + std::to_string(v->cast<FP32ImmPtr>()->value());
362   } else if (v->isa<None>()) {
363     return "None";
364   } else if (v->isa<Ellipsis>()) {
365     return "Ellipsis";
366   } else if (v->isa<ValueSequence>()) {
367     auto p_list = v->cast<ValueSequencePtr>();
368     string prefix = v->isa<ValueTuple>() ? "Tuple<" : "List<";
369     if (p_list->size() == 0) {
370       prefix = "Empty:";
371     } else {
372       for (size_t i = 0; i < p_list->size(); ++i) {
373         prefix += GetIdByValue(p_list->value()[i]) + ":";
374       }
375     }
376     prefix.pop_back();
377     prefix += ">";
378     return prefix;
379   }
380   MS_LOG(DEBUG) << "Get type " << v->ToString();
381   return v->ToString();
382 }
383 
GetCellId(const std::string & obj_id,const std::vector<std::string> & input_arg_id_vec,const std::vector<ValuePtr> & input_arg_value_vec)384 std::string Common::GetCellId(const std::string &obj_id, const std::vector<std::string> &input_arg_id_vec,
385                               const std::vector<ValuePtr> &input_arg_value_vec) {
386   auto cell_id = obj_id;
387   auto fn = [&cell_id](const abstract::AbstractBasePtr &abs) {
388     MS_EXCEPTION_IF_NULL(abs);
389     auto shape = abs->BuildShape();
390     auto type = abs->BuildType();
391     cell_id += "_" + shape->ToString();
392     cell_id += type->ToString();
393   };
394 
395   const auto &forward = GetPyNativeExecutor()->forward_executor();
396   for (size_t i = 0; i < input_arg_id_vec.size(); ++i) {
397     const auto &arg_id = input_arg_id_vec[i];
398     // Find in step process
399     auto cache_abs = forward->GetNodeAbsById(arg_id);
400     if (cache_abs != nullptr) {
401       fn(cache_abs);
402     } else {
403       MS_EXCEPTION_IF_NULL(input_arg_value_vec[i]);
404       fn(SetAbstractValueToAnyValue(input_arg_value_vec[i]->ToAbstract()));
405     }
406   }
407   return cell_id;
408 }
409 
SplitString(const std::string & str,std::vector<std::string> * id_vec)410 void Common::SplitString(const std::string &str, std::vector<std::string> *id_vec) {
411   constexpr char colon_delim = ':';
412   constexpr char angle_bracket_left_delim = '<';
413   constexpr char angle_bracket_right_delim = '>';
414   auto paren_pos = str.find_first_of(angle_bracket_left_delim);
415   if (paren_pos == std::string::npos) {
416     MS_LOG(EXCEPTION) << "Get wrong str " << str;
417   }
418   size_t str_size = str.size();
419   const auto &sub_str = str.substr(paren_pos + 1, str_size - paren_pos - 2);
420   MS_LOG(DEBUG) << "Ori str " << str << ", get sub str " << sub_str;
421   size_t begin = 0;
422   size_t angle_bracket_left = 0;
423   size_t angle_bracket_right = 0;
424   size_t sub_str_size = sub_str.size();
425   for (size_t i = 0; i < sub_str_size; ++i) {
426     switch (sub_str[i]) {
427       case colon_delim:
428         if (i != 0 && angle_bracket_left == angle_bracket_right) {
429           (void)id_vec->emplace_back(sub_str.substr(begin, i - begin));
430           begin = i + 1;
431           angle_bracket_left = 0;
432           angle_bracket_right = 0;
433         }
434         break;
435       case angle_bracket_left_delim:
436         ++angle_bracket_left;
437         break;
438       case angle_bracket_right_delim:
439         ++angle_bracket_right;
440         break;
441       default: {
442       }
443     }
444   }
445   if (angle_bracket_left == angle_bracket_right) {
446     (void)id_vec->emplace_back(sub_str.substr(begin, sub_str_size - begin));
447   }
448 }
449 
ValueHasDynamicShape(const ValuePtr & value)450 bool Common::ValueHasDynamicShape(const ValuePtr &value) {
451   MS_EXCEPTION_IF_NULL(value);
452   if (value->isa<tensor::BaseTensor>()) {
453     return value->cast<tensor::BaseTensorPtr>()->base_shape_ptr() != nullptr;
454   } else if (value->isa<ValueSequence>()) {
455     auto value_seq = value->cast<ValueSequencePtr>();
456     return std::any_of(value_seq->value().begin(), value_seq->value().end(),
457                        [](const ValuePtr &elem) { return ValueHasDynamicShape(elem); });
458   }
459   return false;
460 }
461 
IsTensor(const ValuePtr & v,bool include_sequence)462 bool Common::IsTensor(const ValuePtr &v, bool include_sequence) {
463   MS_EXCEPTION_IF_NULL(v);
464   if (include_sequence) {
465     if (v->isa<tensor::MetaSparseTensor>() || v->isa<tensor::BaseTensor>()) {
466       return true;
467     } else if (v->isa<ValueSequence>()) {
468       auto v_seq = v->cast<ValueSequencePtr>();
469       if (v_seq->size() == 0) {
470         MS_LOG(DEBUG) << "Get empty value sequence";
471         return false;
472       }
473       // SpareTensor have scalar index, so just check have csr tensor
474       if (v_seq->value().front()->isa<tensor::MetaSparseTensor>()) {
475         return true;
476       }
477       // All value are tensor
478       return std::all_of(v_seq->value().begin(), v_seq->value().end(),
479                          [](const ValuePtr &e) { return IsTensor(e, true); });
480     } else {
481       MS_LOG(DEBUG) << "Get value " << v->ToString();
482       return false;
483     }
484   }
485   MS_LOG(DEBUG) << "Get value " << v->ToString();
486   return v->isa<tensor::BaseTensor>() || v->isa<tensor::MetaSparseTensor>();
487 }
488 
IsControlFlowGraph(const FuncGraphPtr & func_graph)489 bool Common::IsControlFlowGraph(const FuncGraphPtr &func_graph) {
490   MS_EXCEPTION_IF_NULL(func_graph);
491   return !func_graph->func_graphs_used_total().empty();
492 }
493 
FilterSensValues(const ValuePtr & value,bool dict_convert_to_tuple)494 ValuePtr Common::FilterSensValues(const ValuePtr &value, bool dict_convert_to_tuple) {
495   MS_EXCEPTION_IF_NULL(value);
496   if (value->isa<tensor::BaseTensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
497     return value;
498   }
499   if (value->isa<ValueSequence>()) {
500     std::vector<ValuePtr> value_list;
501     auto value_seq = value->cast<ValueSequencePtr>();
502     MS_EXCEPTION_IF_NULL(value_seq);
503     for (auto &filter_value : value_seq->value()) {
504       if (auto t = FilterSensValues(filter_value, dict_convert_to_tuple); t != nullptr) {
505         (void)value_list.emplace_back(t);
506       }
507     }
508     return std::make_shared<ValueTuple>(value_list);
509   }
510   if (value->isa<ValueDictionary>()) {
511     if (dict_convert_to_tuple) {
512       return FilterSensValues(DataConvert::ConvertValueDictToValueTuple(value), dict_convert_to_tuple);
513     }
514     return value;
515   }
516   MS_LOG(DEBUG) << "Value type: " << value->ToString();
517   return nullptr;
518 }
519 
GetTensorFromParam(const AnfNodePtr & param_node)520 tensor::BaseTensorPtr Common::GetTensorFromParam(const AnfNodePtr &param_node) {
521   MS_EXCEPTION_IF_NULL(param_node);
522   auto param = param_node->cast<ParameterPtr>();
523   MS_EXCEPTION_IF_NULL(param);
524   if (!param->has_default()) {
525     return nullptr;
526   }
527   auto default_value = param->default_param();
528   MS_EXCEPTION_IF_NULL(default_value);
529   auto tensor_value = default_value->cast<tensor::BaseTensorPtr>();
530   MS_EXCEPTION_IF_NULL(tensor_value);
531   return tensor_value;
532 }
533 
GetPyNativeExecutor()534 const std::shared_ptr<PyNativeExecutor> &Common::GetPyNativeExecutor() {
535   const auto &executor = PyNativeExecutor::GetInstance();
536   MS_EXCEPTION_IF_NULL(executor);
537   return executor;
538 }
539 
DumpGraphIR(const std::string & filename,const FuncGraphPtr & graph)540 void Common::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
541 #ifdef ENABLE_DUMP_IR
542   auto context = MsContext::GetInstance();
543   MS_EXCEPTION_IF_NULL(context);
544   if (context->CanDump(kIntroductory)) {
545     DumpIR(filename, graph);
546   }
547 #endif
548 }
549 
GetTypeFromAbstract(const abstract::AbstractBasePtr & abs)550 TypeId Common::GetTypeFromAbstract(const abstract::AbstractBasePtr &abs) {
551   MS_EXCEPTION_IF_NULL(abs);
552   if (abs->isa<abstract::AbstractSequence>()) {
553     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
554     return GetTypeFromAbstract(abs_seq->elements().front());
555   }
556   const auto &type = abs->BuildType();
557   MS_EXCEPTION_IF_NULL(type);
558   return common::AnfAlgo::GetOutputInferDataType(type, 0);
559 }
560 
GetShapeFromAbstract(const abstract::AbstractBasePtr & abs)561 ShapeVector Common::GetShapeFromAbstract(const abstract::AbstractBasePtr &abs) {
562   MS_EXCEPTION_IF_NULL(abs);
563   if (abs->isa<abstract::AbstractSequence>()) {
564     MS_LOG(EXCEPTION) << "Get abstract sequence";
565   }
566   auto shape = abs->BuildShape();
567   MS_EXCEPTION_IF_NULL(shape);
568   auto shape_ptr = shape->cast<abstract::ShapePtr>();
569   MS_EXCEPTION_IF_NULL(shape_ptr);
570   return shape_ptr->shape();
571 }
572 
GetTypeFromValue(const ValuePtr & v)573 std::pair<TypePtr, TypeId> Common::GetTypeFromValue(const ValuePtr &v) {
574   MS_EXCEPTION_IF_NULL(v);
575   if (v->isa<tensor::BaseTensor>()) {
576     return std::make_pair(v->cast<tensor::BaseTensorPtr>()->Dtype(), kObjectTypeTensorType);
577   } else if (v->isa<ValueTuple>()) {
578     return std::make_pair(v->type(), kObjectTypeTuple);
579   } else if (v->isa<ValueList>()) {
580     return std::make_pair(v->type(), kObjectTypeList);
581   } else if (v->isa<None>()) {
582     return std::make_pair(kTypeNone, kMetaTypeNone);
583   } else {
584     return std::make_pair(v->type(), v->type()->object_type());
585   }
586 }
587 
GetShapeFromValue(const ValuePtr & v)588 ShapeVector Common::GetShapeFromValue(const ValuePtr &v) {
589   MS_EXCEPTION_IF_NULL(v);
590   if (v->isa<tensor::BaseTensor>()) {
591     return v->cast<tensor::BaseTensorPtr>()->shape_c();
592   } else if (v->isa<ValueSequence>()) {
593     const auto &v_seq = v->cast<ValueSequencePtr>()->value();
594     ShapeVector plant_shape_vector;
595     for (const auto &item : v_seq) {
596       const auto &shape = GetShapeFromValue(item);
597       (void)std::transform(shape.begin(), shape.end(), std::back_inserter(plant_shape_vector),
598                            [](int64_t s) { return s; });
599     }
600     return plant_shape_vector;
601   } else {
602     return ShapeVector{};
603   }
604 }
605 
CreatOutputTensorValueByAbstract(const abstract::AbstractBasePtr & abs)606 ValuePtr Common::CreatOutputTensorValueByAbstract(const abstract::AbstractBasePtr &abs) {
607   MS_EXCEPTION_IF_NULL(abs);
608   auto type_id = GetTypeFromAbstract(abs);
609   if (abs->isa<abstract::AbstractMonad>()) {
610     return std::make_shared<tensor::Tensor>(0);
611   }
612   if (abs->isa<abstract::AbstractSequence>()) {
613     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
614     std::vector<ValuePtr> out;
615     if (!abs_seq->elements().front()->isa<abstract::AbstractTensor>()) {
616       MS_LOG(DEBUG) << "Get non tensor output";
617       return CreateNonTensorByAbstract(abs);
618     }
619     for (size_t i = 0; i < abs_seq->size(); ++i) {
620       (void)out.emplace_back(std::make_shared<tensor::Tensor>(type_id, GetShapeFromAbstract(abs_seq->elements()[i])));
621     }
622     return std::make_shared<ValueTuple>(out);
623   }
624   if (!abs->isa<abstract::AbstractTensor>()) {
625     MS_LOG(DEBUG) << "Get non tensor output";
626     return CreateNonTensorByAbstract(abs);
627   }
628   return std::make_shared<tensor::Tensor>(type_id, GetShapeFromAbstract(abs));
629 }
630 
ReplaceCNodeWithValueNode(const FuncGraphPtr & bprop_graph)631 void Common::ReplaceCNodeWithValueNode(const FuncGraphPtr &bprop_graph) {
632   MS_EXCEPTION_IF_NULL(bprop_graph);
633   if (bprop_graph->used_forward_nodes().empty()) {
634     return;
635   }
636   auto mng = MakeManager({bprop_graph}, false);
637   auto tr = mng->Transact();
638   for (const auto &forward_node : bprop_graph->used_forward_nodes()) {
639     auto cnode = forward_node->cast<CNodePtr>();
640     auto v_node = cnode->forward().first;
641     MS_EXCEPTION_IF_NULL(v_node);
642     bprop_graph->AddValueNode(v_node);
643     MS_LOG(DEBUG) << "Replace " << forward_node->DebugString() << " by value node " << v_node->DebugString();
644     auto converted_node = ConvertValueSequenceToMakeTuple(v_node, bprop_graph);
645     (void)tr.Replace(forward_node, converted_node);
646   }
647   tr.Commit();
648   bprop_graph->ClearUsedForwardNodes();
649   DumpGraphIR("replace_cnode_with_valuenode.ir", bprop_graph);
650 }
651 
StubNodeToValueInner(const ValuePtr & v)652 ValuePtr StubNodeToValueInner(const ValuePtr &v) {
653   MS_EXCEPTION_IF_NULL(v);
654   if (utils::isa<stub::StubNode>(v)) {
655     auto stub = utils::cast<stub::StubNodePtr>(v);
656     return stub->WaitValue();
657   }
658   if (utils::isa<ValueSequence>(v)) {
659     const auto &value_seq = utils::cast<ValueSequencePtr>(v);
660     const auto &values = value_seq->value();
661     if (!values.empty() && utils::isa<Scalar>(values[0])) {
662       return v;
663     }
664     ValuePtrList value_list;
665     (void)std::transform(values.begin(), values.end(), std::back_inserter(value_list),
666                          [](const ValuePtr &value) { return StubNodeToValueInner(value); });
667     if (utils::isa<ValueTuple>(v)) {
668       return std::make_shared<ValueTuple>(value_list);
669     }
670     if (utils::isa<ValueList>(v)) {
671       return std::make_shared<ValueList>(value_list);
672     }
673     MS_LOG(EXCEPTION) << "Value not support ValueSequence " << v->ToString();
674   } else {
675     return v;
676   }
677 }
678 
StubNodeToValue(const FrontendOpRunInfoPtr & op_run_info)679 void Common::StubNodeToValue(const FrontendOpRunInfoPtr &op_run_info) {
680   MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info);
681   auto old_stream_id = kernel::pyboost::PyBoostUtils::cur_stream_id();
682   kernel::pyboost::PyBoostUtils::set_cur_stream_id(op_run_info->base_op_run_info.stream_id);
683   for (size_t i = 0; i < op_run_info->input_size; i++) {
684     op_run_info->op_grad_info->input_value[i] = StubNodeToValueInner(op_run_info->op_grad_info->input_value[i]);
685     if (!op_run_info->is_view_op) {
686       op_run_info->op_grad_info->input_value[i] =
687         ConvertToContiguousValue(op_run_info->op_grad_info->input_value[i], op_run_info->requires_grad);
688     }
689     kernel::pyboost::PyBoostUtils::set_cur_stream_id(old_stream_id);
690     runtime::DeviceAddressUtils::CreateKernelTensor(op_run_info->op_grad_info->input_value[i]);
691   }
692 }
693 
StubNodeToTensor(const ValuePtr & v)694 tensor::BaseTensorPtr Common::StubNodeToTensor(const ValuePtr &v) {
695   MS_EXCEPTION_IF_NULL(v);
696   if (utils::isa<stub::StubNode>(v)) {
697     auto stub = utils::cast<stub::StubNodePtr>(v);
698     return stub->WaitValue()->cast<tensor::BaseTensorPtr>();
699   }
700   if (v->isa<tensor::BaseTensor>()) {
701     return v->cast<tensor::BaseTensorPtr>();
702   }
703   MS_LOG(EXCEPTION) << "It should be stub tensor, but got " << v->ToString();
704 }
705 
ConvertToContiguousValue(const ValuePtr & v,bool requires_grad)706 ValuePtr Common::ConvertToContiguousValue(const ValuePtr &v, bool requires_grad) {
707   MS_EXCEPTION_IF_NULL(v);
708   if (v->isa<tensor::BaseTensor>()) {
709     auto tensor = v->cast<tensor::BaseTensorPtr>();
710     MS_EXCEPTION_IF_NULL(tensor);
711     if (tensor->storage_info() == nullptr) {
712       return tensor;
713     }
714 
715     auto contiguous_tensor = ConvertToContiguousTensor(tensor, requires_grad);
716     MS_LOG(DEBUG) << "ConvertToContiguousValue, old tensor id:" << tensor->id()
717                   << ", new tensor id:" << contiguous_tensor->id();
718     return contiguous_tensor;
719   }
720   if (utils::isa<ValueSequence>(v)) {
721     const auto &value_seq = utils::cast<ValueSequencePtr>(v);
722     const auto &values = value_seq->value();
723     if (values.empty() || utils::isa<Scalar>(values[0])) {
724       return v;
725     }
726     ValuePtrList value_list;
727     (void)std::transform(
728       values.begin(), values.end(), std::back_inserter(value_list),
729       [requires_grad](const ValuePtr &value) { return ConvertToContiguousValue(value, requires_grad); });
730     if (utils::isa<ValueTuple>(v)) {
731       return std::make_shared<ValueTuple>(value_list);
732     }
733     if (utils::isa<ValueList>(v)) {
734       return std::make_shared<ValueList>(value_list);
735     }
736     MS_LOG(EXCEPTION) << "Not support ValueSequence " << v->ToString();
737   } else {
738     return v;
739   }
740 }
741 
ConvertToContiguousTensor(const tensor::BaseTensorPtr & tensor,bool requires_grad)742 tensor::BaseTensorPtr Common::ConvertToContiguousTensor(const tensor::BaseTensorPtr &tensor, bool requires_grad) {
743   MS_EXCEPTION_IF_NULL(tensor);
744 
745   // Tensor with storage info, need covert to contiguous in no-view op.
746   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
747   MS_EXCEPTION_IF_NULL(device_address);
748   const auto &device_target = device_address->device_name();
749 
750   return GetContiguousTensor(tensor, device_target, requires_grad);
751 }
752 
ConvertStubNodeToTensor(const ValuePtr & v,bool need_contiguous,bool requires_grad)753 tensor::BaseTensorPtr Common::ConvertStubNodeToTensor(const ValuePtr &v, bool need_contiguous, bool requires_grad) {
754   const auto &tensor = StubNodeToTensor(v);
755   MS_EXCEPTION_IF_NULL(tensor);
756   if (!need_contiguous || tensor->storage_info() == nullptr) {
757     return tensor;
758   }
759 
760   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
761   MS_EXCEPTION_IF_NULL(device_address);
762   const auto &device_target = device_address->device_name();
763   if (device_target == kAscendDevice) {
764     return tensor;
765   }
766 
767   return GetContiguousTensor(tensor, device_target, requires_grad);
768 }
769 
ConvertStubNodeToTensor(const std::optional<ValuePtr> & v,bool need_contiguous,bool requires_grad)770 std::optional<tensor::BaseTensorPtr> Common::ConvertStubNodeToTensor(const std::optional<ValuePtr> &v,
771                                                                      bool need_contiguous, bool requires_grad) {
772   if (!v.has_value()) {
773     return std::nullopt;
774   }
775   return std::make_optional(ConvertStubNodeToTensor(v.value(), need_contiguous, requires_grad));
776 }
777 
ConvertStubNodeToValueTuple(const ValueListPtr & v,bool need_contiguous,bool requires_grad)778 ValueTuplePtr Common::ConvertStubNodeToValueTuple(const ValueListPtr &v, bool need_contiguous, bool requires_grad) {
779   if (utils::isa<ValueSequence>(v)) {
780     const auto &value_seq = utils::cast<ValueSequencePtr>(v);
781     const auto &values = value_seq->value();
782     std::vector<ValuePtr> tensor_list;
783     (void)std::transform(values.begin(), values.end(), std::back_inserter(tensor_list),
784                          [need_contiguous, requires_grad](const ValuePtr &value) {
785                            return ConvertStubNodeToTensor(value, need_contiguous, requires_grad);
786                          });
787     return std::make_shared<ValueTuple>(tensor_list);
788   }
789   MS_LOG(EXCEPTION) << "It should be stub tensor sequence, but got " << v->ToString();
790 }
791 
ConvertStubNodeToValueTuple(const ValueTuplePtr & v,bool need_contiguous,bool requires_grad)792 ValueTuplePtr Common::ConvertStubNodeToValueTuple(const ValueTuplePtr &v, bool need_contiguous, bool requires_grad) {
793   if (utils::isa<ValueSequence>(v)) {
794     const auto &value_seq = utils::cast<ValueSequencePtr>(v);
795     const auto &values = value_seq->value();
796     std::vector<ValuePtr> tensor_list;
797     (void)std::transform(values.begin(), values.end(), std::back_inserter(tensor_list),
798                          [need_contiguous, requires_grad](const ValuePtr &value) {
799                            return ConvertStubNodeToTensor(value, need_contiguous, requires_grad);
800                          });
801     return std::make_shared<ValueTuple>(tensor_list);
802   }
803   MS_LOG(EXCEPTION) << "It should be stub tensor sequence, but got " << v->ToString();
804 }
805 
ConvertStubNodeToValueTuple(const std::optional<ValueTuplePtr> & v,bool need_contiguous,bool requires_grad)806 std::optional<ValueTuplePtr> Common::ConvertStubNodeToValueTuple(const std::optional<ValueTuplePtr> &v,
807                                                                  bool need_contiguous, bool requires_grad) {
808   if (!v.has_value()) {
809     return std::nullopt;
810   }
811   return std::make_optional(ConvertStubNodeToValueTuple(v.value(), need_contiguous, requires_grad));
812 }
813 
GetConstInputToAttr(const PrimitivePtr & op_prim,const std::string & op_name,const std::string & device_target,bool is_dynamic_shape,mindspore::HashSet<size_t> * input_to_attr_index)814 void Common::GetConstInputToAttr(const PrimitivePtr &op_prim, const std::string &op_name,
815                                  const std::string &device_target, bool is_dynamic_shape,
816                                  mindspore::HashSet<size_t> *input_to_attr_index) {
817   if (op_name == prim::kPrimCustom->name()) {
818     // Custom op needs to set reg dynamically
819     mindspore::HashSet<size_t> attr_indexes;
820     PrimitiveReadLock read_lock(op_prim->shared_mutex());
821     opt::GetCustomOpAttrIndex(op_prim, input_to_attr_index);
822     return;
823   }
824 
825   // Ascend const input to attr move to AscendVmOpAdapter
826   if (device_target == kAscendDevice) {
827     return;
828   }
829 
830   auto reg_info =
831     opt::OpAdaptationInfoRegister::GetInstance().GetOpAdaptationInfo(op_name, device_target, is_dynamic_shape);
832   if (reg_info == nullptr) {
833     return;
834   } else {
835     MS_EXCEPTION_IF_NULL(input_to_attr_index);
836     for (auto &iter : reg_info->input_attr_map()) {
837       (void)input_to_attr_index->insert(iter.first);
838     }
839   }
840 }
841 
CreateValueNodeByValue(const ValuePtr & v,const abstract::AbstractBasePtr & abs)842 ValueNodePtr Common::CreateValueNodeByValue(const ValuePtr &v, const abstract::AbstractBasePtr &abs) {
843   MS_EXCEPTION_IF_NULL(v);
844   auto v_node = NewValueNode(v);
845   if (abs == nullptr) {
846     v_node->set_abstract(SetAbstractValueToAnyValue(v->ToAbstract()));
847   } else {
848     v_node->set_abstract(abs);
849   }
850   return v_node;
851 }
852 
CreateFakeTensorWithoutDeviceAddress(const tensor::TensorPtr & tensor)853 tensor::TensorPtr Common::CreateFakeTensorWithoutDeviceAddress(const tensor::TensorPtr &tensor) {
854   MS_EXCEPTION_IF_NULL(tensor);
855   auto t = std::make_shared<tensor::Tensor>(*tensor);
856   if (tensor->is_parameter()) {
857     t->set_param_info(tensor->param_info());
858   }
859   t->set_device_address(nullptr);
860   return t;
861 }
862 
ClearDeviceAddress(const ValuePtr & value)863 void Common::ClearDeviceAddress(const ValuePtr &value) {
864   std::vector<tensor::BaseTensorPtr> tensors;
865   TensorValueToTensor(value, &tensors);
866   for (const auto &tensor : tensors) {
867     tensor->set_device_address(nullptr);
868   }
869 }
870 
CreateFakeValueWithoutDeviceAddress(const ValuePtr & value)871 ValuePtr Common::CreateFakeValueWithoutDeviceAddress(const ValuePtr &value) {
872   MS_EXCEPTION_IF_NULL(value);
873   if (value->isa<tensor::BaseTensor>()) {
874     const auto &v_t = value->cast<tensor::BaseTensorPtr>();
875     auto t = std::make_shared<tensor::Tensor>(*v_t);
876     if (v_t->is_parameter()) {
877       t->set_param_info(v_t->param_info());
878     }
879     t->set_device_address(nullptr);
880     return t;
881   } else if (value->isa<ValueSequence>()) {
882     const auto &value_seq = value->cast<ValueSequencePtr>();
883     ValuePtrList value_list;
884     (void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(value_list),
885                          [](const ValuePtr &elem) { return CreateFakeValueWithoutDeviceAddress(elem); });
886     return std::make_shared<ValueTuple>(value_list);
887   } else if (value->isa<stub::StubNode>()) {
888     const auto &stub_node = value->cast<stub::StubNodePtr>();
889     return CreateFakeValueWithoutDeviceAddress(stub_node->WaitValue());
890   } else if (value->isa<ValueDictionary>()) {
891     auto dic_v = value->cast<ValueDictionaryPtr>();
892     std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
893     for (const auto &v : dic_v->value()) {
894       (void)key_values.emplace_back(v.first, CreateFakeValueWithoutDeviceAddress(v.second));
895     }
896     return std::make_shared<ValueDictionary>(key_values);
897   } else {
898     return value;
899   }
900 }
901 
SetValueGradInfo(const ValuePtr & value,const TopCellInfoPtr & top_cell,InputType grad_type)902 InputType Common::SetValueGradInfo(const ValuePtr &value, const TopCellInfoPtr &top_cell, InputType grad_type) {
903   MS_EXCEPTION_IF_NULL(value);
904   if (value->isa<tensor::BaseTensor>()) {
905     const auto &tensor_value = value->cast<tensor::BaseTensorPtr>();
906     auto auto_grad_meta_data = tensor_value->auto_grad_meta_data();
907     if (auto_grad_meta_data != nullptr) {
908       if (auto_grad_meta_data->input_type() != InputType::kUnkown) {
909         return auto_grad_meta_data->input_type();
910       }
911       MS_LOG(DEBUG) << "Set input type for tensor " << tensor_value->id();
912     } else {
913       MS_LOG(DEBUG) << "Create new auto grad meta for tensor " << tensor_value->id();
914       auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
915       tensor::RegisterHook::UpdateTensorBackwardHook(auto_grad_meta_data, tensor_value->id());
916       tensor_value->set_auto_grad_meta_data(auto_grad_meta_data);
917     }
918 
919     if (tensor_value->is_parameter() && grad_type != InputType::kInput) {
920       grad_type = InputType::kParameter;
921     }
922     auto_grad_meta_data->set_input_type(grad_type);
923     if (top_cell != nullptr && IsParam(grad_type)) {
924       top_cell->AddMetaGradInfo(tensor_value, auto_grad_meta_data);
925     }
926     return grad_type;
927   } else if (value->isa<ValueSequence>()) {
928     const auto &value_seq = value->cast<ValueSequencePtr>()->value();
929     InputType ret_type = grad_type;
930     for (const auto &v : value_seq) {
931       auto ret = SetValueGradInfo(v, top_cell, grad_type);
932       if (IsParam(ret)) {
933         ret_type = ret;
934       }
935     }
936     return ret_type;
937   } else if (value->isa<tensor::COOTensor>()) {
938     const auto &coo_tensor = value->cast<tensor::COOTensorPtr>();
939     const auto &indices_tensor = coo_tensor->GetIndices();
940     return SetValueGradInfo(indices_tensor, top_cell, grad_type);
941   } else if (value->isa<tensor::CSRTensor>()) {
942     const auto &csr_tensor = value->cast<tensor::CSRTensorPtr>();
943     const auto &indices_tensor = csr_tensor->GetIndices();
944     return SetValueGradInfo(indices_tensor, top_cell, grad_type);
945   } else if (value->isa<ValueDictionary>()) {
946     const auto &dic_v = value->cast<ValueDictionaryPtr>()->value();
947     for (const auto &v : dic_v) {
948       (void)SetValueGradInfo(v.second, top_cell, grad_type);
949     }
950   }
951   return grad_type;
952 }
953 
SetTensorGradInfo(const tensor::BaseTensorPtr & tensor,const TopCellInfoPtr & top_cell)954 InputType Common::SetTensorGradInfo(const tensor::BaseTensorPtr &tensor, const TopCellInfoPtr &top_cell) {
955   MS_EXCEPTION_IF_NULL(tensor);
956   auto auto_grad_meta_data = tensor->auto_grad_meta_data();
957   if (auto_grad_meta_data != nullptr) {
958     if (auto_grad_meta_data->input_type() != InputType::kUnkown) {
959       return auto_grad_meta_data->input_type();
960     }
961     MS_LOG(DEBUG) << "Set input type for tensor " << tensor->id();
962   } else {
963     MS_LOG(DEBUG) << "Create new auto grad meta for tensor " << tensor->id();
964     auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
965     tensor::RegisterHook::UpdateTensorBackwardHook(auto_grad_meta_data, tensor->id());
966     tensor->set_auto_grad_meta_data(auto_grad_meta_data);
967   }
968   // Set weight tensor grad type
969   if (tensor->is_parameter()) {
970     auto_grad_meta_data->set_input_type(InputType::kParameter);
971     if (top_cell != nullptr) {
972       top_cell->AddMetaGradInfo(tensor, auto_grad_meta_data);
973     }
974     return InputType::kParameter;
975   }
976   // Is a constant input tensor, but not constant scalar value
977   auto_grad_meta_data->set_input_type(InputType::kConstant);
978   return InputType::kConstant;
979 }
980 
SetGraphInputAndWeightsInfo(const FrontendOpRunInfoPtr & op_run_info,const FuncGraphPtr & func_graph,const TopCellInfoPtr & top_cell)981 void Common::SetGraphInputAndWeightsInfo(const FrontendOpRunInfoPtr &op_run_info, const FuncGraphPtr &func_graph,
982                                          const TopCellInfoPtr &top_cell) {
983   MS_EXCEPTION_IF_NULL(func_graph);
984   const auto &original_params = func_graph->parameters();
985   size_t params_size = original_params.size();
986   MS_EXCEPTION_IF_NULL(op_run_info);
987   op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
988   bool need_add_input_abs = op_run_info->op_grad_info->input_abs.empty();
989   for (size_t i = 0; i < params_size; ++i) {
990     if (i < op_run_info->input_size) {  // non-weights node.
991       op_run_info->op_grad_info->input_value_grad_type[i] =
992         SetValueGradInfo(op_run_info->op_grad_info->input_value[i], top_cell, InputType::kConstant);
993       if (need_add_input_abs) {
994         (void)op_run_info->op_grad_info->input_abs.emplace_back(original_params[i]->abstract());
995       }
996       continue;
997     }
998     // Must weight param
999     const auto &param = original_params[i]->cast<ParameterPtr>();
1000     const auto tensor_value = GetTensorFromParam(original_params[i]);
1001     MS_EXCEPTION_IF_NULL(tensor_value);
1002     (void)op_run_info->op_grad_info->input_value.emplace_back(tensor_value);
1003     (void)op_run_info->op_grad_info->input_value_grad_type.emplace_back(SetTensorGradInfo(tensor_value, top_cell));
1004     (void)op_run_info->op_grad_info->input_abs.emplace_back(param->abstract());
1005     MS_LOG(DEBUG) << "Set graph weight parameter " << param->DebugString() << ". Its default value is "
1006                   << tensor_value->ToString() << ". Its name is: " << param->name();
1007   }
1008 }
1009 
ProcessTupleParam(const FuncGraphPtr & bprop_graph,size_t position)1010 void Common::ProcessTupleParam(const FuncGraphPtr &bprop_graph, size_t position) {
1011   auto bprop_params = bprop_graph->parameters();
1012   auto target_param = bprop_params[position];
1013   MS_EXCEPTION_IF_NULL(target_param);
1014   const auto &target_abstract = target_param->abstract();
1015   MS_EXCEPTION_IF_NULL(target_abstract);
1016   if (!target_abstract->isa<abstract::AbstractSequence>()) {
1017     MS_LOG(EXCEPTION) << "Get wrong param " << target_abstract->ToString();
1018   }
1019   const auto &abs_seq = target_abstract->cast<abstract::AbstractSequencePtr>();
1020   if (abs_seq->dynamic_len() && abs_seq->dynamic_len_element_abs() != nullptr) {
1021     return;
1022   }
1023   MS_LOG(DEBUG) << "Process tuple param " << target_abstract->ToString();
1024   auto it = std::find(bprop_params.begin(), bprop_params.end(), target_param);
1025   it = bprop_params.erase(it);
1026   AnfNodePtrList make_tuple{NewValueNode(prim::kPrimMakeTuple)};
1027   AnfNodePtrList new_param;
1028   PlantTupleParam(bprop_graph, abs_seq, &make_tuple, &new_param);
1029   (void)bprop_params.insert(it, new_param.begin(), new_param.end());
1030   bprop_graph->set_parameters(bprop_params);
1031   auto make_tuple_param = bprop_graph->NewCNode(make_tuple);
1032   make_tuple_param->set_abstract(target_abstract);
1033   auto manager = bprop_graph->manager();
1034   if (manager == nullptr) {
1035     manager = MakeManager({bprop_graph}, false);
1036   }
1037   MS_EXCEPTION_IF_NULL(manager);
1038   auto tr = manager->Transact();
1039   (void)tr.Replace(target_param, make_tuple_param);
1040   tr.Commit();
1041 }
1042 
ProcessDictParam(const FuncGraphPtr & bprop_graph,size_t position)1043 void Common::ProcessDictParam(const FuncGraphPtr &bprop_graph, size_t position) {
1044   auto bprop_params = bprop_graph->parameters();
1045   auto target_param = bprop_params[position];
1046   MS_EXCEPTION_IF_NULL(target_param);
1047   const auto &target_abstract = target_param->abstract();
1048   MS_EXCEPTION_IF_NULL(target_abstract);
1049   if (!target_abstract->isa<abstract::AbstractDictionary>()) {
1050     MS_LOG(EXCEPTION) << "Get wrong param " << target_abstract->ToString();
1051   }
1052   MS_LOG(DEBUG) << "Process Dict param " << target_abstract->ToString();
1053   auto it = std::find(bprop_params.begin(), bprop_params.end(), target_param);
1054   it = bprop_params.erase(it);
1055   const auto &abs_dict = target_abstract->cast<abstract::AbstractDictionaryPtr>();
1056   abstract::AbstractBasePtrList local_key_abs_inputs;
1057   abstract::AbstractBasePtrList local_value_abs_inputs;
1058   for (size_t i = 0; i < abs_dict->size(); ++i) {
1059     (void)local_key_abs_inputs.emplace_back(abs_dict->elements()[i].first);
1060     (void)local_value_abs_inputs.emplace_back(abs_dict->elements()[i].second);
1061   }
1062   auto key_param = bprop_graph->add_parameter();
1063   key_param->set_abstract(std::make_shared<abstract::AbstractTuple>(local_key_abs_inputs));
1064   auto value_param = bprop_graph->add_parameter();
1065   value_param->set_abstract(std::make_shared<abstract::AbstractTuple>(local_value_abs_inputs));
1066   auto key_it = bprop_params.insert(it, value_param);
1067   (void)bprop_params.insert(key_it, key_param);
1068   bprop_graph->set_parameters(bprop_params);
1069   auto dict_node = bprop_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), key_param, value_param});
1070   dict_node->set_abstract(abs_dict);
1071   auto manager = bprop_graph->manager();
1072   if (manager == nullptr) {
1073     manager = MakeManager({bprop_graph}, false);
1074   }
1075   auto tr = manager->Transact();
1076   (void)tr.Replace(target_param, dict_node);
1077   tr.Commit();
1078 }
1079 
FreeFuncGraphForwardNodes(const FuncGraphPtr & func_graph)1080 void Common::FreeFuncGraphForwardNodes(const FuncGraphPtr &func_graph) {
1081   MS_EXCEPTION_IF_NULL(func_graph);
1082   if (func_graph->used_forward_nodes().empty()) {
1083     return;
1084   }
1085   for (const auto &node : func_graph->used_forward_nodes()) {
1086     MS_EXCEPTION_IF_NULL(node);
1087     auto cnode = node->cast<CNodePtr>();
1088     MS_EXCEPTION_IF_NULL(cnode);
1089     cnode->set_forward(nullptr, "");
1090   }
1091   func_graph->ClearUsedForwardNodes();
1092 }
1093 
GetValueSize(const ValuePtr & v)1094 size_t Common::GetValueSize(const ValuePtr &v) {
1095   MS_EXCEPTION_IF_NULL(v);
1096   if (v->isa<tensor::BaseTensor>() || v->isa<Scalar>()) {
1097     return 1;
1098   } else if (v->isa<ValueSequence>()) {
1099     auto seq = v->cast<ValueSequencePtr>();
1100     size_t output_size = 0;
1101     for (const auto &val : seq->value()) {
1102       output_size += GetValueSize(val);
1103     }
1104     return output_size;
1105   } else if (v->isa<ValueDictionary>()) {
1106     const auto &v_dict = v->cast<ValueDictionaryPtr>();
1107     size_t output_size = 0;
1108     for (const auto &val : v_dict->value()) {
1109       output_size += GetValueSize(val.second);
1110     }
1111     return output_size;
1112   }
1113   return 0;
1114 }
1115 
CreateTensorByConstantValue(const ValuePtr & value)1116 ValuePtr Common::CreateTensorByConstantValue(const ValuePtr &value) {
1117   MS_EXCEPTION_IF_NULL(value);
1118   MS_EXCEPTION_IF_NULL(value);
1119   auto type = value->type();
1120   if (Common::IsTensor(value, true) || value->isa<Number>() || value->isa<None>() ||
1121       (type != nullptr && type->isa<String>())) {
1122     return value;
1123   }
1124   tensor::TensorPtr tensor_ptr = nullptr;
1125   if (value->isa<Scalar>()) {
1126     tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
1127   } else if (value->isa<ValueTuple>()) {
1128     tensor_ptr = opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
1129   } else if (value->isa<ValueList>()) {
1130     tensor_ptr = opt::CreateTupleTensor(std::make_shared<ValueTuple>(value->cast<ValueListPtr>()->value()));
1131   } else {
1132     MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple, but get type " << value->type_name()
1133                       << ", value " << value->ToString();
1134   }
1135   MS_EXCEPTION_IF_NULL(tensor_ptr);
1136   return tensor_ptr;
1137 }
1138 
CacheOutputAbstract(const ValuePtr & v,const abstract::AbstractBasePtr & abs)1139 void AutoGrad::CacheOutputAbstract(const ValuePtr &v, const abstract::AbstractBasePtr &abs) {
1140   MS_EXCEPTION_IF_NULL(v);
1141   MS_EXCEPTION_IF_NULL(abs);
1142 
1143   if (v->isa<tensor::BaseTensor>()) {
1144     auto tensor = v->cast<tensor::BaseTensorPtr>();
1145     tensor->set_abstract(abs);
1146     kGradAbstractConverter.CacheAbstract(abs);
1147   } else if (v->isa<ValueSequence>()) {
1148     const auto &value_seq = v->cast<ValueSequencePtr>();
1149     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
1150     if (abs_seq == nullptr) {
1151       MS_LOG(EXCEPTION) << "Abstract is not abstract sequence, get " << abs->ToString();
1152     }
1153     size_t value_size = value_seq->size();
1154     if (value_size != abs_seq->size()) {
1155       MS_LOG(EXCEPTION) << "Abstract size " << abs_seq->size() << " is not equal to value size " << value_size;
1156     }
1157     for (size_t i = 0; i < value_size; ++i) {
1158       CacheOutputAbstract(value_seq->value()[i], abs_seq->elements()[i]);
1159     }
1160   }
1161 }
1162 
1163 namespace {
ConvertSimpleInferInfoToAbstract(const OpGradInfoPtr & op_grad_info)1164 void ConvertSimpleInferInfoToAbstract(const OpGradInfoPtr &op_grad_info) {
1165   MS_EXCEPTION_IF_NULL(op_grad_info);
1166   // Get inputs abstract
1167   for (const auto &v : op_grad_info->input_value) {
1168     op_grad_info->input_abs.emplace_back(kGradAbstractConverter.ConvertAbstract(v));
1169   }
1170 
1171   // Get output abstract
1172   MS_EXCEPTION_IF_NULL(op_grad_info->output_value_simple_info);
1173   op_grad_info->out_abs = TransformValueSimpleInfoToAbstract(*op_grad_info->output_value_simple_info);
1174 
1175   // Set abstract to tensor
1176   AutoGrad::CacheOutputAbstract(op_grad_info->out_value, op_grad_info->out_abs);
1177   MS_LOG(DEBUG) << "Get output abstract " << op_grad_info->out_abs->ToString();
1178 }
1179 }  // namespace
1180 
CheckAndSetAbstract(const OpGradInfoPtr & op_grad_info)1181 void AutoGrad::CheckAndSetAbstract(const OpGradInfoPtr &op_grad_info) {
1182   MS_EXCEPTION_IF_NULL(op_grad_info);
1183   if (op_grad_info->output_value_simple_info != nullptr) {
1184     MS_LOG(DEBUG) << "Convert op " << op_grad_info->op_prim->name() << " simple infer info to abstract";
1185     ConvertSimpleInferInfoToAbstract(op_grad_info);
1186     return;
1187   }
1188 
1189   // View op input abs and output abs maybe nullptr
1190   if (MS_UNLIKELY(op_grad_info->input_abs.empty())) {
1191     // Get inputs abstract
1192     MS_LOG(DEBUG) << "Op " << op_grad_info->op_prim->name() << " inputs abstract not set, set it now";
1193     for (const auto &v : op_grad_info->input_value) {
1194       // For use abstract cache on tensor
1195       op_grad_info->input_abs.emplace_back(kGradAbstractConverter.ConvertAbstract(v));
1196     }
1197   }
1198   if (op_grad_info->out_abs == nullptr) {
1199     MS_LOG(EXCEPTION) << "Get output abs is nullptr";
1200   }
1201 }
1202 
GetIdByPyObj(const py::object & obj)1203 std::string PyParser::GetIdByPyObj(const py::object &obj) {
1204   if (py::isinstance<tensor::BaseTensor>(obj)) {
1205     return obj.cast<tensor::BaseTensorPtr>()->id();
1206   } else if (IsStubTensor(obj)) {
1207     return ConvertStubTensor(obj)->id();
1208   } else if (py::isinstance<Cell>(obj)) {
1209     return obj.cast<CellPtr>()->id();
1210   } else if (py::isinstance<mindspore::Type>(obj)) {
1211     auto type_ptr = obj.cast<mindspore::TypePtr>();
1212     return "Type:" + type_ptr->ToString();
1213   } else if (py::isinstance<py::str>(obj)) {
1214     return "S" + obj.cast<std::string>();
1215   } else if (py::isinstance<py::bool_>(obj)) {
1216     return "B" + py::str(obj).cast<std::string>();
1217   } else if (py::isinstance<py::int_>(obj)) {
1218     return "I" + py::str(obj).cast<std::string>();
1219   } else if (py::isinstance<py::float_>(obj)) {
1220     return "F" + py::str(obj).cast<std::string>();
1221   } else if (py::isinstance<py::none>(obj)) {
1222     return "None";
1223   } else if (py::isinstance<py::ellipsis>(obj)) {
1224     return "Ellipsis";
1225   } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
1226     return GetIdForPyTupleOrList(obj);
1227   } else if (py::isinstance<py::function>(obj)) {
1228     return GetFnInfoByPyObj(obj);
1229   }
1230   // For id with value and obj can be the same
1231   if (py::isinstance<tensor::CSRTensor>(obj) || py::isinstance<tensor::COOTensor>(obj) ||
1232       py::isinstance<tensor::RowTensor>(obj)) {
1233     return DataConvert::PyObjToValue(obj)->ToString();
1234   }
1235   return GetObjIdFromPython(obj);
1236 }
1237 
GetArgsIdAndValue(const py::args & args)1238 std::pair<std::vector<std::string>, std::vector<ValuePtr>> PyParser::GetArgsIdAndValue(const py::args &args) {
1239   size_t arg_size = args.size();
1240   std::vector<std::string> input_arg_id_vec;
1241   std::vector<ValuePtr> input_arg_value_vec;
1242   input_arg_id_vec.reserve(arg_size);
1243   input_arg_value_vec.reserve(arg_size);
1244   for (size_t i = 0; i < arg_size; ++i) {
1245     if (py::isinstance<py::list>(args[i])) {
1246       (void)input_arg_value_vec.emplace_back(DataConvert::PyObjToValue(py::cast<py::tuple>(args[i])));
1247     } else {
1248       (void)input_arg_value_vec.emplace_back(DataConvert::PyObjToValue(args[i]));
1249     }
1250     (void)input_arg_id_vec.emplace_back(Common::GetIdByValue(input_arg_value_vec.back()));
1251   }
1252   return {input_arg_id_vec, input_arg_value_vec};
1253 }
1254 
SetPrim(const FrontendOpRunInfoPtr & op_run_info,const py::object & prim_arg)1255 void PyParser::SetPrim(const FrontendOpRunInfoPtr &op_run_info, const py::object &prim_arg) {
1256   MS_EXCEPTION_IF_NULL(op_run_info);
1257   const auto &adapter = prim_arg.cast<PrimitivePyAdapterPtr>();
1258   MS_EXCEPTION_IF_NULL(adapter);
1259   auto prim = adapter->attached_primitive();
1260   if (prim == nullptr) {
1261     prim = std::make_shared<PrimitivePy>(prim_arg);
1262     adapter->set_attached_primitive(prim);
1263   }
1264   if (!prim->HasPyObj()) {
1265     MS_LOG(EXCEPTION) << "Pyobj is empty";
1266   }
1267   prim->EnableSharedMutex();
1268   op_run_info->op_grad_info->op_prim = prim;
1269   op_run_info->base_op_run_info.op_name = prim->name();
1270   op_run_info->signatures = prim->signatures();
1271   op_run_info->base_op_run_info.py_prim_id_ = adapter->id();
1272 }
1273 
BuilidPyInputTypeString(const py::object & obj)1274 std::string PyParser::BuilidPyInputTypeString(const py::object &obj) {
1275   if (py::isinstance<py::bool_>(obj)) {
1276     return "bool";
1277   }
1278 
1279   if (py::isinstance<py::int_>(obj)) {
1280     return "int";
1281   }
1282 
1283   if (py::isinstance<py::float_>(obj)) {
1284     return "float";
1285   }
1286 
1287   if (py::isinstance<py::str>(obj)) {
1288     return "string";
1289   }
1290 
1291   if (py::isinstance<py::none>(obj)) {
1292     return "None";
1293   }
1294 
1295   if (py::isinstance<mindspore::tensor::BaseTensor>(obj)) {
1296     return "Tensor";
1297   }
1298 
1299   if (IsStubTensor(obj)) {
1300     return "Tensor";
1301   }
1302 
1303   if (py::isinstance<py::tuple>(obj)) {
1304     std::stringstream ss;
1305     ss << "tuple<";
1306     auto tuple = obj.cast<py::tuple>();
1307     for (size_t i = 0; i < tuple.size(); i++) {
1308       if (i == 0) {
1309         ss << BuilidPyInputTypeString(tuple[i]);
1310       } else {
1311         ss << ", " << BuilidPyInputTypeString(tuple[i]);
1312       }
1313     }
1314     ss << ">";
1315     return ss.str();
1316   }
1317 
1318   if (py::isinstance<py::list>(obj)) {
1319     std::stringstream ss;
1320     ss << "list<";
1321     auto list = obj.cast<py::list>();
1322     for (size_t i = 0; i < list.size(); i++) {
1323       if (i == 0) {
1324         ss << BuilidPyInputTypeString(list[i]);
1325       } else {
1326         ss << ", " << BuilidPyInputTypeString(list[i]);
1327       }
1328     }
1329     ss << ">";
1330     return ss.str();
1331   }
1332 
1333   std::stringstream ss;
1334   ss << obj.get_type();
1335   return ss.str();
1336 }
1337 
PrintTypeCastError(const ops::OpDefPtr & op_def,const py::list & op_inputs,size_t idx)1338 void PyParser::PrintTypeCastError(const ops::OpDefPtr &op_def, const py::list &op_inputs, size_t idx) {
1339   auto const &op_arg = op_def->args_[idx];
1340   bool is_suppport_tensor_cast = std::any_of(op_arg.cast_dtype_.begin(), op_arg.cast_dtype_.end(),
1341                                              [](const auto &type) { return type == ops::DT_TENSOR; });
1342   if (is_suppport_tensor_cast) {
1343     auto tensor = parse::ConvertTensorValue(op_inputs[idx]);
1344     auto PrintVectorFunc = [](const ShapeVector &shape) -> std::string {
1345       std::stringstream ss;
1346       ss << "[";
1347       for (size_t i = 0; i < shape.size(); i++) {
1348         if (i != 0) {
1349           ss << ", " << shape[i];
1350         } else {
1351           ss << shape[i];
1352         }
1353       }
1354       ss << "]";
1355       return ss.str();
1356     };
1357     if (tensor != nullptr) {
1358       MS_EXCEPTION(TypeError) << "For " << op_def->name_ << ", the " << idx << "'th input is a Tensor whose shape is "
1359                               << PrintVectorFunc(tensor->shape()) << " and dtype is ["
1360                               << TypeIdToString(tensor->data_type()) << "], which can not be converted to "
1361                               << ops::EnumToString(op_arg.arg_dtype_) << ".";
1362     }
1363   }
1364   std::vector<std::string> op_type_list;
1365   for (size_t index = 0; index < op_inputs.size(); ++index) {
1366     (void)op_type_list.emplace_back(PyParser::BuilidPyInputTypeString(op_inputs[index]));
1367   }
1368   MS_EXCEPTION(TypeError) << ops::BuildOpErrorMsg(op_def, op_type_list);
1369 }
1370 
ConvertScalarToTensor(const ValuePtr & value)1371 inline ValuePtr ConvertScalarToTensor(const ValuePtr &value) {
1372   auto fp32_imm = value->cast<FP32ImmPtr>();
1373   if (fp32_imm != nullptr) {
1374     return std::make_shared<tensor::Tensor>(fp32_imm->value());
1375   }
1376 
1377   auto bool_imm = value->cast<BoolImmPtr>();
1378   if (bool_imm != nullptr) {
1379     return std::make_shared<tensor::Tensor>(bool_imm->value());
1380   }
1381 
1382   auto int64_imm = value->cast<Int64ImmPtr>();
1383   if (int64_imm != nullptr) {
1384     return std::make_shared<tensor::Tensor>(int64_imm->value());
1385   }
1386 
1387   MS_LOG(EXCEPTION) << "Unsupported type: " << value->ToString();
1388 }
1389 
ConvertBySignature(const py::object & obj,const FrontendOpRunInfoPtr & op_run_info,size_t index)1390 inline ValuePtr ConvertBySignature(const py::object &obj, const FrontendOpRunInfoPtr &op_run_info, size_t index) {
1391   if (op_run_info->signatures.size() <= index) {
1392     return nullptr;
1393   }
1394 
1395   if (op_run_info->signatures[index].dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
1396     auto convert_func = parse::GetConverterByType(static_cast<int32_t>(ops::DT_NUMBER));
1397     MS_EXCEPTION_IF_NULL(convert_func);
1398     return convert_func(obj);
1399   }
1400   return nullptr;
1401 }
1402 
ParseOpInputByOpDef(const ops::OpDefPtr & op_def,const py::list & op_inputs,bool stub,const FrontendOpRunInfoPtr & op_run_info)1403 void ParseOpInputByOpDef(const ops::OpDefPtr &op_def, const py::list &op_inputs, bool stub,
1404                          const FrontendOpRunInfoPtr &op_run_info) {
1405   size_t input_size = op_inputs.size();
1406   if (input_size != op_def->args_.size()) {
1407     MS_LOG(EXCEPTION) << "For Operator[" << op_def->name_ << "], the inputs number should be " << op_def->args_.size()
1408                       << " but got " << op_inputs.size() << ".";
1409   }
1410   (void)op_run_info->op_grad_info->input_value.resize(input_size);
1411   for (size_t i = 0; i < op_def->args_.size(); i++) {
1412     auto const &op_arg = op_def->args_[i];
1413     op_run_info->none_init_inputs_num += static_cast<size_t>(!op_arg.as_init_arg_);
1414 
1415     // Optional argument is valid for None as input.
1416     if (op_arg.is_optional_ && py::isinstance<py::none>(op_inputs[i])) {
1417       op_run_info->op_grad_info->input_value[i] = kNone;
1418       continue;
1419     }
1420 
1421     ValuePtr value = nullptr;
1422     parse::OpDefConvertFunc convert_func = parse::GetConverterByType(static_cast<int32_t>(op_arg.arg_dtype_));
1423     MS_EXCEPTION_IF_NULL(convert_func);
1424     value = convert_func(op_inputs[i]);
1425     if (value != nullptr) {
1426       op_run_info->op_grad_info->input_value[i] = value;
1427       continue;
1428     }
1429 
1430     // type cast has lower priority then signature cast
1431     if (!op_arg.cast_dtype_.empty()) {
1432       for (auto cast_dtype : op_arg.cast_dtype_) {
1433         convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
1434         MS_EXCEPTION_IF_NULL(convert_func);
1435         value = convert_func(op_inputs[i]);
1436         if (value != nullptr) {
1437           op_run_info->op_grad_info->input_value[i] = value;
1438           op_run_info->source_type[i] = cast_dtype;
1439           break;
1440         }
1441       }
1442     }
1443 
1444     if (value == nullptr) {
1445       PyParser::PrintTypeCastError(op_def, op_inputs, i);
1446     }
1447   }
1448 }
1449 
ParseOpInputByPythonObj(const FrontendOpRunInfoPtr & op_run_info,const py::list & op_inputs,bool stub)1450 void PyParser::ParseOpInputByPythonObj(const FrontendOpRunInfoPtr &op_run_info, const py::list &op_inputs, bool stub) {
1451   MS_EXCEPTION_IF_NULL(op_run_info);
1452   op_run_info->input_size = op_inputs.size();
1453   op_run_info->op_grad_info->input_abs.resize(op_run_info->input_size);
1454   op_run_info->source_type.resize(op_run_info->input_size);
1455   op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
1456 
1457   auto op_def = mindspore::ops::GetOpDef(op_run_info->base_op_run_info.op_name);
1458   if (op_def == nullptr) {
1459     op_run_info->op_grad_info->input_value.resize(op_run_info->input_size);
1460     op_run_info->none_init_inputs_num = op_run_info->input_size;
1461     for (size_t i = 0; i < op_run_info->input_size; ++i) {
1462       op_run_info->op_grad_info->input_value[i] = DataConvert::PyObjToValue(op_inputs[i], stub);
1463     }
1464   } else {
1465     op_run_info->none_init_inputs_num = 0;
1466     ParseOpInputByOpDef(op_def, op_inputs, stub, op_run_info);
1467   }
1468 }
1469 
ValueToPyObj(const ValuePtr & v)1470 py::object DataConvert::ValueToPyObj(const ValuePtr &v) { return ValueToPyData(v); }
1471 
PyObjToValue(const py::object & obj,bool stub)1472 ValuePtr DataConvert::PyObjToValue(const py::object &obj, bool stub) {
1473   ValuePtr converted_ret;
1474   if (stub) {
1475     converted_ret = parse::data_converter::PyDataToStubNode(obj);
1476   } else {
1477     converted_ret = parse::data_converter::PyDataToValue(obj);
1478   }
1479   if (converted_ret == nullptr) {
1480     MS_LOG(EXCEPTION) << "Attribute convert error with type: " << ConvertPyObjToString(obj);
1481   }
1482   return converted_ret;
1483 }
1484 
BaseRefToValue(const BaseRef & value,bool requires_grad,bool is_out_sequence)1485 ValuePtr DataConvert::BaseRefToValue(const BaseRef &value, bool requires_grad, bool is_out_sequence) {
1486   MS_EXCEPTION_IF_NULL(value);
1487   ValuePtr ret;
1488   if (utils::isa<tensor::BaseTensorPtr>(value)) {
1489     auto t = utils::cast<tensor::BaseTensorPtr>(value);
1490     if (requires_grad) {
1491       t->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1492       t->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1493     }
1494     ret = t;
1495   } else if (utils::isa<ValuePtr>(value)) {
1496     ret = utils::cast<ValuePtr>(value);
1497   } else if (utils::isa<VectorRef>(value)) {
1498     auto vec_ref = utils::cast<VectorRef>(value);
1499     ret = VectorRefToValue(vec_ref, requires_grad, is_out_sequence);
1500   } else if (utils::isa<int>(value)) {
1501     ret = MakeValue(utils::cast<int>(value));
1502   } else if (utils::isa<float>(value)) {
1503     ret = MakeValue(utils::cast<float>(value));
1504   } else if (utils::isa<double>(value)) {
1505     ret = MakeValue(utils::cast<double>(value));
1506   } else if (utils::isa<bool>(value)) {
1507     ret = MakeValue(utils::cast<bool>(value));
1508   } else {
1509     MS_LOG(EXCEPTION) << "value is not support type " << value.ToString();
1510   }
1511   return ret;
1512 }
1513 
VectorRefToValue(const VectorRef & vec_ref,bool requires_grad,bool is_out_sequence)1514 ValuePtr DataConvert::VectorRefToValue(const VectorRef &vec_ref, bool requires_grad, bool is_out_sequence) {
1515   MS_EXCEPTION_IF_NULL(vec_ref);
1516 
1517   size_t value_size = vec_ref.size();
1518   if (value_size == 1 && !is_out_sequence) {
1519     return BaseRefToValue(vec_ref[0], requires_grad, is_out_sequence);
1520   }
1521   std::vector<ValuePtr> v_list(value_size);
1522   for (size_t i = 0; i < value_size; ++i) {
1523     v_list[i] = BaseRefToValue(vec_ref[i], requires_grad, is_out_sequence);
1524   }
1525   return std::make_shared<ValueTuple>(v_list);
1526 }
1527 
FlattenValueSeqArg(const ValuePtr & v,bool is_only_flatten_tensor_seq,bool is_filter_tensor,std::vector<ValuePtr> * flatten_v)1528 void DataConvert::FlattenValueSeqArg(const ValuePtr &v, bool is_only_flatten_tensor_seq, bool is_filter_tensor,
1529                                      std::vector<ValuePtr> *flatten_v) {
1530   MS_EXCEPTION_IF_NULL(v);
1531   MS_EXCEPTION_IF_NULL(flatten_v);
1532   MS_LOG(DEBUG) << "Get is only flatten tensor seq " << is_only_flatten_tensor_seq;
1533   if (v->isa<tensor::BaseTensor>()) {
1534     (void)flatten_v->emplace_back(v);
1535   } else if (v->isa<ValueSequence>()) {
1536     const auto &v_vec = v->cast<ValueSequencePtr>()->value();
1537     if (v_vec.empty() && !is_filter_tensor) {
1538       MS_LOG(DEBUG) << "Get empty tuple value";
1539       (void)flatten_v->emplace_back(v);
1540       MS_LOG(DEBUG) << "Get empty value sequence";
1541       return;
1542     }
1543     if (is_only_flatten_tensor_seq && !v_vec.front()->isa<tensor::BaseTensor>()) {
1544       (void)flatten_v->emplace_back(v);
1545     } else {
1546       for (const auto &elem : v_vec) {
1547         FlattenValueSeqArg(elem, is_only_flatten_tensor_seq, is_filter_tensor, flatten_v);
1548       }
1549     }
1550   } else if (is_only_flatten_tensor_seq) {
1551     if (v->isa<ValueDictionary>()) {
1552       auto dic_v = v->cast<ValueDictionaryPtr>();
1553       for (const auto &elem : dic_v->value()) {
1554         FlattenValueSeqArg(elem.second, is_only_flatten_tensor_seq, is_filter_tensor, flatten_v);
1555       }
1556     } else {
1557       (void)flatten_v->emplace_back(v);
1558     }
1559   } else if (!is_filter_tensor) {
1560     MS_LOG(DEBUG) << "Get not tensor value: " << v->ToString();
1561     (void)flatten_v->emplace_back(v);
1562   }
1563 }
1564 
FlattenTensorSeqInValue(const ValuePtr & v)1565 ValuePtrList DataConvert::FlattenTensorSeqInValue(const ValuePtr &v) {
1566   MS_EXCEPTION_IF_NULL(v);
1567   ValuePtrList outputs;
1568   FlattenValueSeqArg(v, true, false, &outputs);
1569   return outputs;
1570 }
1571 
FlattenTensorSeqInValueSeq(const ValuePtrList & v,bool only_flatten_tensor)1572 ValuePtrList DataConvert::FlattenTensorSeqInValueSeq(const ValuePtrList &v, bool only_flatten_tensor) {
1573   ValuePtrList outputs;
1574   for (const auto &item : v) {
1575     FlattenValueSeqArg(item, only_flatten_tensor, false, &outputs);
1576   }
1577   return outputs;
1578 }
1579 
FlattenArgs(const std::vector<ValuePtr> & v_vec,std::vector<ValuePtr> * flatten_v,bool has_sens)1580 void DataConvert::FlattenArgs(const std::vector<ValuePtr> &v_vec, std::vector<ValuePtr> *flatten_v, bool has_sens) {
1581   MS_EXCEPTION_IF_NULL(flatten_v);
1582   if (v_vec.empty()) {
1583     MS_LOG(EXCEPTION) << "For bprop graph input value size should be greatet than 0, but get empty.";
1584   }
1585   size_t input_size = has_sens ? v_vec.size() - 1 : v_vec.size();
1586   for (size_t i = 0; i < input_size; ++i) {
1587     const auto &v = v_vec[i];
1588     MS_EXCEPTION_IF_NULL(v);
1589     MS_LOG(DEBUG) << "Get v is " << v->ToString();
1590     (void)flatten_v->emplace_back(v);
1591   }
1592   if (has_sens) {
1593     if (Common::IsTensor(v_vec[input_size])) {
1594       (void)flatten_v->emplace_back(v_vec[input_size]);
1595     } else if (v_vec[input_size]->isa<ValueSequence>()) {
1596       MS_LOG(DEBUG) << "Get value tuple size " << v_vec[input_size]->cast<ValueSequencePtr>()->size();
1597       FlattenValueSeqArg(v_vec[input_size], false, false, flatten_v);
1598     }
1599   }
1600 }
1601 
RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,size_t input_index)1602 bool DataConvert::RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
1603                                                size_t input_index) {
1604   MS_EXCEPTION_IF_NULL(op_run_info);
1605   if (op_run_info->input_to_attr.empty()) {
1606     return false;
1607   }
1608   MS_EXCEPTION_IF_NULL(v);
1609   if (op_run_info->input_to_attr.find(input_index) == op_run_info->input_to_attr.end()) {
1610     return false;
1611   }
1612   const auto &input_names_value = op_run_info->op_grad_info->op_prim->GetAttr(kAttrInputNames);
1613   if (input_names_value == nullptr) {
1614     return false;
1615   }
1616   const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
1617   if (input_index >= input_names_vec.size()) {
1618     MS_LOG(EXCEPTION) << "The input index: " << input_index << " is larger than the input names vector size!";
1619   }
1620   const auto &input_name = input_names_vec[input_index];
1621   if (v->isa<tensor::BaseTensor>()) {
1622     auto tensor = v->cast<tensor::BaseTensorPtr>();
1623     if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
1624       return false;
1625     }
1626   }
1627   (void)op_run_info->op_grad_info->op_prim->AddAttr(input_name, v);
1628   return true;
1629 }
1630 
Init(const PrimitivePtr & prim,const py::list & args)1631 FrontendOpRunInfoPtr PyBoost::Init(const PrimitivePtr &prim, const py::list &args) {
1632   const auto &pynative_executor = Common::GetPyNativeExecutor();
1633   const auto &forward_executor = pynative_executor->forward_executor();
1634   const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
1635   prim->EnableSharedMutex();
1636   op_run_info->op_grad_info->op_prim = prim;
1637   op_run_info->base_op_run_info.op_name = prim->name();
1638   pynative_executor->StoreAsyncStatus(op_run_info);
1639   forward_executor->InitOpRunInfo(op_run_info);
1640   return op_run_info;
1641 }
1642 
MakeOutputValue(const FrontendOpRunInfoPtr & op_run_info,const kernel::pyboost::OpPtr & op)1643 void PyBoost::MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const kernel::pyboost::OpPtr &op) {
1644   size_t size = op->outputs().size();
1645   // If op are Contiguous, Cast(precision, implicit cast), which are internal ops and not have stub output
1646   bool is_tuple_output = op_run_info->stub_output != nullptr ? op_run_info->stub_output->isa<stub::SequenceNode>()
1647                                                              : PredictOutTypeByName(op->primitive()->name()) == kTuple;
1648   if (op->output_value_simple_info() != nullptr) {
1649     op_run_info->op_grad_info->output_value_simple_info = op->output_value_simple_info();
1650     op_run_info->op_grad_info->output_value_simple_info->is_tuple_output_ = is_tuple_output;
1651   }
1652   if (!is_tuple_output) {
1653     MS_EXCEPTION_IF_CHECK_FAIL(size == kSizeOne, "The size is more than one!");
1654     if (op->output_abs() != nullptr || op->output_value_simple_info() != nullptr) {
1655       // Set auto grad meta data for op output
1656       if (op_run_info->requires_grad) {
1657         op->outputs()[0]->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1658         op->outputs()[0]->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1659       }
1660       op_run_info->real_out = op->outputs()[0];
1661       return;
1662     }
1663   }
1664   std::vector<ValuePtr> output_values(size);
1665   for (size_t i = 0; i < size; ++i) {
1666     const auto &output_tensor = op->outputs()[i];
1667     MS_EXCEPTION_IF_NULL(output_tensor);
1668     // Set auto grad meta data for op outputs
1669     if (op_run_info->requires_grad) {
1670       output_tensor->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1671       output_tensor->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1672     }
1673     output_values[i] = output_tensor;
1674   }
1675   op_run_info->real_out = std::make_shared<ValueTuple>(output_values);
1676 }
1677 
UpdateStubOutput(const FrontendOpRunInfoPtr & op_run_info,const AbstractBasePtr & abstract,const kernel::pyboost::OpPtr & op)1678 void PyBoost::UpdateStubOutput(const FrontendOpRunInfoPtr &op_run_info, const AbstractBasePtr &abstract,
1679                                const kernel::pyboost::OpPtr &op) {
1680   MS_EXCEPTION_IF_NULL(op);
1681   if (op_run_info->stub_output == nullptr) {
1682     return;
1683   }
1684   if (MS_UNLIKELY(op->output_value_simple_info() != nullptr)) {
1685     op_run_info->stub_output->SetValueSimpleInfo(op->output_value_simple_info());
1686   } else {
1687     MS_EXCEPTION_IF_NULL(abstract);
1688     auto success = op_run_info->stub_output->SetAbstract(abstract);
1689     if (!success) {
1690       const auto &op_name = op_run_info->base_op_run_info.op_name;
1691       MS_EXCEPTION(TypeError) << "The predict type and infer type is not match, predict type is "
1692                               << PredictOutType(op_run_info) << ", infer type is " << abstract->BuildType()
1693                               << ", the name of operator is [" << op_name
1694                               << "]. Please modify or add predict type of operator in predict_out_type_map.h.";
1695     }
1696     MS_LOG(DEBUG) << "Update StubNode abstract " << abstract->ToString();
1697   }
1698   op_run_info->stub_output->SetValue(op_run_info->real_out);
1699 }
1700 
UpdateOpRunInfo(const kernel::pyboost::OpPtr & op,const FrontendOpRunInfoPtr & op_run_info)1701 void PyBoost::UpdateOpRunInfo(const kernel::pyboost::OpPtr &op, const FrontendOpRunInfoPtr &op_run_info) {
1702   MS_EXCEPTION_IF_NULL(op);
1703   MS_EXCEPTION_IF_NULL(op_run_info);
1704   // Create output value
1705   MakeOutputValue(op_run_info, op);
1706 
1707   // Set output value to python
1708   UpdateStubOutput(op_run_info, op->output_abs(), op);
1709 }
1710 
DataSyncForGraph(const kernel::pyboost::OpPtr & op,ValuePtrList && op_inputs)1711 void PyBoost::DataSyncForGraph(const kernel::pyboost::OpPtr &op, ValuePtrList &&op_inputs) {
1712   auto ms_context = MsContext::GetInstance();
1713   MS_EXCEPTION_IF_NULL(ms_context);
1714   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
1715       !runtime::OpExecutor::GetInstance().async_for_graph()) {
1716     // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
1717     // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
1718     for (const auto &output : op->outputs()) {
1719       auto device_address = std::static_pointer_cast<device::DeviceAddress>(output->device_address());
1720       runtime::DeviceAddressUtils::CreateKernelTensor(device_address, output);
1721       output->data_sync(true);
1722       output->set_abstract(std::weak_ptr<abstract::AbstractBase>());
1723     }
1724     for (const auto &input : op_inputs) {
1725       if (input->isa<tensor::BaseTensor>()) {
1726         auto tensor = input->cast<tensor::BaseTensorPtr>();
1727         auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
1728         runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
1729       }
1730       UnsetValueAbstractCache(input);
1731     }
1732   }
1733 }
1734 
ConvertPrimitive(const py::object & obj)1735 PrimitivePtr PyBoost::ConvertPrimitive(const py::object &obj) {
1736   const auto &adapter = obj.cast<PrimitivePyAdapterPtr>();
1737   MS_EXCEPTION_IF_NULL(adapter);
1738 
1739   auto prim = adapter->attached_primitive();
1740   if (prim == nullptr) {
1741 #ifndef ENABLE_TEST
1742     return std::make_shared<Primitive>(adapter->name(), adapter->attrs());
1743 #else
1744     prim = std::make_shared<PrimitivePy>(obj);
1745     adapter->set_attached_primitive(prim);
1746 #endif
1747   }
1748   if (!prim->HasPyObj()) {
1749     MS_LOG(EXCEPTION) << "Pyobj is empty";
1750   }
1751   prim->EnableSharedMutex();
1752   return prim;
1753 }
1754 
RunPyFunction(const PrimitivePtr & prim,const py::list & args)1755 py::object PyBoost::RunPyFunction(const PrimitivePtr &prim, const py::list &args) {
1756   py::tuple wrap_args(kIndex3);
1757   if (prim->isa<PrimitivePy>()) {
1758     auto prim_py = prim->cast<PrimitivePyPtr>();
1759     if (!prim_py->HasPyObj()) {
1760       MS_LOG(EXCEPTION) << "Prim has not python obj!";
1761     }
1762     wrap_args[kIndex0] = prim_py->GetPyObj();
1763   } else {
1764     wrap_args[kIndex0] = std::make_shared<PrimitivePyAdapter>(prim->name());
1765   }
1766   wrap_args[kIndex1] = prim->name();
1767   wrap_args[kIndex2] = args;
1768   const auto &pynative_executor = Common::GetPyNativeExecutor();
1769   return pynative_executor->RunOpStub(wrap_args);
1770 }
1771 
SetAnyValueForAbstract(const kernel::pyboost::OpPtr & op)1772 void PyBoost::SetAnyValueForAbstract(const kernel::pyboost::OpPtr &op) {
1773   const auto &input_abs = op->input_abs();
1774   for (const auto &abs : input_abs) {
1775     Common::SetAbstractValueToAnyValue(abs);
1776   }
1777   Common::SetAbstractValueToAnyValue(op->output_abs());
1778 }
1779 
DoGrad(const kernel::pyboost::OpPtr & op,const FrontendOpRunInfoPtr & op_run_info,ValuePtrList && op_inputs)1780 void PyBoost::DoGrad(const kernel::pyboost::OpPtr &op, const FrontendOpRunInfoPtr &op_run_info,
1781                      ValuePtrList &&op_inputs) {
1782   static const std::string kDoGradName = "DoGrad";
1783   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeFrontendTask,
1784                                      kDoGradName, false);
1785   MS_EXCEPTION_IF_NULL(op);
1786   // Update op grad info
1787   op_run_info->op_grad_info->input_value = std::move(op_inputs);
1788   op_run_info->op_grad_info->out_value = op_run_info->real_out;
1789 
1790   const auto &pynative_executor = Common::GetPyNativeExecutor();
1791   const auto &forward = pynative_executor->forward_executor();
1792   op_run_info->op_grad_info->output_size = op->outputs().size();
1793   if (op->output_value_simple_info() == nullptr) {
1794     if (op->input_abs().size() != op_run_info->input_size) {
1795       MS_LOG(EXCEPTION) << "Op " << op_run_info->base_op_run_info.op_name << " input size is "
1796                         << op_run_info->input_size << " but got input abstract size " << op->input_abs().size();
1797     }
1798     SetAnyValueForAbstract(op);
1799     op_run_info->op_grad_info->input_abs = op->input_abs();
1800     op_run_info->base_op_run_info.abstract = op->output_abs();
1801   }
1802 
1803   if (MS_LIKELY(!forward->grad()->top_cell()->is_bprop_need_get_forward_graph())) {
1804     // Check and set input auto grad meta info and InputType
1805     op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
1806     for (size_t index = 0; index < op_run_info->input_size; ++index) {
1807       // Inplace input_value with contiguous tensor.
1808       RefreshGradContiguousTensor(op_run_info, index);
1809       const ValuePtr &input_object = op_run_info->op_grad_info->input_value[index];
1810       DataConvert::MarkInputs(op_run_info, input_object, index, forward->grad()->top_cell());
1811     }
1812   }
1813   forward->ForwardOpGradImpl(op_run_info);
1814 }
1815 
PlantTensorTupleToVector(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & value_seq,size_t index,const TopCellInfoPtr & top_cell)1816 void DataConvert::PlantTensorTupleToVector(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &value_seq,
1817                                            size_t index, const TopCellInfoPtr &top_cell) {
1818   MS_EXCEPTION_IF_NULL(op_run_info);
1819   MS_EXCEPTION_IF_NULL(value_seq);
1820   if (op_run_info->requires_grad) {
1821     op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kOpOutput;
1822   }
1823   for (const auto &v : value_seq->value()) {
1824     if (!v->isa<tensor::BaseTensor>()) {
1825       MS_LOG(EXCEPTION) << "The input object is not a tensor!";
1826     }
1827     InputType input_type = InputType::kInput;
1828     auto tensor = v->cast<tensor::BaseTensorPtr>();
1829     MS_EXCEPTION_IF_NULL(tensor);
1830     if (tensor->is_parameter()) {
1831       input_type = InputType::kParameter;
1832     }
1833     if (op_run_info->requires_grad) {
1834       auto grad_type = Common::SetTensorGradInfo(tensor, top_cell);
1835       if (Common::IsParam(grad_type)) {
1836         op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kParameter;
1837       }
1838     }
1839     (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(tensor);
1840     (void)op_run_info->base_op_run_info.input_types.emplace_back(input_type);
1841   }
1842 
1843   if (!op_run_info->base_op_run_info.dyn_input_sizes.empty()) {
1844     int64_t elem_size = SizeToLong(value_seq->size());
1845     if (op_run_info->base_op_run_info.dyn_input_sizes.size() != op_run_info->input_size) {
1846       for (size_t i = op_run_info->base_op_run_info.dyn_input_sizes.size(); i < index; ++i) {
1847         (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(-1);
1848       }
1849       (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(elem_size);
1850     } else {
1851       op_run_info->base_op_run_info.dyn_input_sizes[index] = elem_size;
1852     }
1853   } else {
1854     for (size_t i = 0; i < index; ++i) {
1855       (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(-1);
1856     }
1857     (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(SizeToLong(value_seq->size()));
1858   }
1859 }
1860 
ConvertValueDictToValueTuple(const ValuePtr & v)1861 ValuePtr DataConvert::ConvertValueDictToValueTuple(const ValuePtr &v) {
1862   MS_EXCEPTION_IF_NULL(v);
1863   const auto &dic_v = v->cast<ValueDictionaryPtr>();
1864   MS_EXCEPTION_IF_NULL(dic_v);
1865   std::vector<ValuePtr> v_list;
1866   (void)std::transform(dic_v->value().begin(), dic_v->value().end(), std::back_inserter(v_list),
1867                        [](const std::pair<ValuePtr, ValuePtr> &elem) { return elem.second; });
1868   return std::make_shared<ValueTuple>(v_list);
1869 }
1870 
ConvertMapTensor(const FrontendOpRunInfoPtr & op_run_info,const tensor::MapTensorPtr & map_tensor,const TopCellInfoPtr & top_cell,size_t index)1871 void DataConvert::ConvertMapTensor(const FrontendOpRunInfoPtr &op_run_info, const tensor::MapTensorPtr &map_tensor,
1872                                    const TopCellInfoPtr &top_cell, size_t index) {
1873   MS_EXCEPTION_IF_NULL(op_run_info);
1874   MS_EXCEPTION_IF_NULL(map_tensor);
1875   constexpr int input_num = 1;
1876   const auto input_names = op_run_info->op_grad_info->op_prim->GetAttr(kAttrInputNames);
1877   if (input_names == nullptr) {
1878     MS_LOG(DEBUG) << "input_names are nullptr";
1879     return;
1880   }
1881   (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(map_tensor);
1882   const auto it = op_run_info->base_op_run_info.input_types.end();
1883   (void)op_run_info->base_op_run_info.input_types.insert(it, input_num, InputType::kParameter);
1884   if (op_run_info->requires_grad) {
1885     op_run_info->op_grad_info->input_value_grad_type[index] = Common::SetTensorGradInfo(map_tensor, top_cell);
1886   }
1887 }
1888 
ConvertCSRTensorToTensorList(const FrontendOpRunInfoPtr & op_run_info,const tensor::CSRTensorPtr & csr_tensor,const TopCellInfoPtr & top_cell,size_t index)1889 void DataConvert::ConvertCSRTensorToTensorList(const FrontendOpRunInfoPtr &op_run_info,
1890                                                const tensor::CSRTensorPtr &csr_tensor, const TopCellInfoPtr &top_cell,
1891                                                size_t index) {
1892   MS_EXCEPTION_IF_NULL(op_run_info);
1893   MS_EXCEPTION_IF_NULL(csr_tensor);
1894   constexpr int input_num = 3;
1895   const auto input_names = op_run_info->op_grad_info->op_prim->GetAttr(kAttrInputNames);
1896   if (input_names == nullptr) {
1897     MS_LOG(DEBUG) << "input_names are nullptr";
1898     return;
1899   }
1900 
1901   (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(csr_tensor->GetIndptr());
1902   (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(csr_tensor->GetIndices());
1903   (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(csr_tensor->GetValues());
1904   const auto it = op_run_info->base_op_run_info.input_types.end();
1905   (void)op_run_info->base_op_run_info.input_types.insert(it, input_num, InputType::kInput);
1906   op_run_info->op_grad_info->op_prim->set_attr("is_csr", MakeValue(true));
1907   op_run_info->op_grad_info->op_prim->set_attr("dense_shape", MakeValue(csr_tensor->shape()));
1908   if (op_run_info->requires_grad) {
1909     op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kOpOutput;
1910     for (int i = 0; i < input_num; ++i) {
1911       auto iter = op_run_info->base_op_run_info.expanded_input_values.rbegin() + i;
1912       auto grad_type = Common::SetTensorGradInfo((*iter)->cast<tensor::BaseTensorPtr>(), top_cell);
1913       if (Common::IsParam(grad_type)) {
1914         op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kParameter;
1915       }
1916     }
1917   }
1918 }
1919 
ConvertValueTensorId(const ValuePtr & value,std::vector<std::string> * converted_tensor_id)1920 void DataConvert::ConvertValueTensorId(const ValuePtr &value, std::vector<std::string> *converted_tensor_id) {
1921   if (value->isa<tensor::BaseTensor>()) {
1922     (void)converted_tensor_id->emplace_back(value->cast<tensor::BaseTensorPtr>()->id());
1923     MS_LOG(DEBUG) << "Get top cell output tensor id " << converted_tensor_id->back();
1924   } else if (value->isa<ValueSequence>()) {
1925     const auto &seq = value->cast<ValueSequencePtr>();
1926     for (const auto &val : seq->value()) {
1927       ConvertValueTensorId(val, converted_tensor_id);
1928     }
1929   } else if (value->isa<ValueDictionary>()) {
1930     ConvertValueTensorId(ConvertValueDictToValueTuple(value), converted_tensor_id);
1931   }
1932 }
1933 
ConvertTupleValueToTensor(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & value_seq,size_t index,const TopCellInfoPtr & top_cell)1934 void DataConvert::ConvertTupleValueToTensor(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &value_seq,
1935                                             size_t index, const TopCellInfoPtr &top_cell) {
1936   MS_EXCEPTION_IF_NULL(op_run_info);
1937   MS_EXCEPTION_IF_NULL(value_seq);
1938 
1939   const auto &tuple_inputs = value_seq->value();
1940   if (tuple_inputs.empty()) {
1941     (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(value_seq);
1942     (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1943     return;
1944   }
1945   if (tuple_inputs[0]->isa<tensor::BaseTensor>()) {
1946     PlantTensorTupleToVector(op_run_info, value_seq, index, top_cell);
1947   } else {
1948     (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(value_seq);
1949     (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1950   }
1951 }
1952 
MarkInputs(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,size_t index,const TopCellInfoPtr & top_cell)1953 void DataConvert::MarkInputs(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v, size_t index,
1954                              const TopCellInfoPtr &top_cell) {
1955   MS_EXCEPTION_IF_NULL(op_run_info);
1956   MS_EXCEPTION_IF_NULL(v);
1957   tensor::BaseTensorPtr tensor_ptr = nullptr;
1958   InputType input_type = InputType::kInput;
1959   if (v->isa<tensor::BaseTensor>()) {
1960     tensor_ptr = v->cast<tensor::BaseTensorPtr>();
1961     if (tensor_ptr->is_parameter()) {
1962       input_type = InputType::kParameter;
1963     }
1964     if (op_run_info->requires_grad) {
1965       op_run_info->op_grad_info->input_value_grad_type[index] = Common::SetTensorGradInfo(tensor_ptr, top_cell);
1966     }
1967   } else if (v->isa<BoolImm>() || v->isa<FloatImm>() || v->isa<Type>() || v->isa<StringImm>() || v->isa<None>()) {
1968     (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(v);
1969     (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1970     return;
1971   } else if (v->isa<IntegerImm>()) {
1972     if (op_run_info->base_op_run_info.op_name == prim::kPrimCSRReduceSum->name()) {
1973       int64_t input = v->cast<Int64ImmPtr>()->value();
1974       op_run_info->op_grad_info->op_prim->set_attr("axis", MakeValue(input));
1975       return;
1976     }
1977     (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(v);
1978     (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1979     return;
1980   } else if (v->isa<ValueSequence>()) {
1981     ConvertTupleValueToTensor(op_run_info, v->cast<ValueSequencePtr>(), index, top_cell);
1982     return;
1983   } else if (v->isa<tensor::MapTensor>()) {
1984     ConvertMapTensor(op_run_info, v->cast<tensor::MapTensorPtr>(), top_cell, index);
1985     return;
1986   } else if (v->isa<tensor::CSRTensor>()) {
1987     ConvertCSRTensorToTensorList(op_run_info, v->cast<tensor::CSRTensorPtr>(), top_cell, index);
1988     return;
1989   } else if (v->isa<Monad>()) {
1990     return;
1991   } else if (v->isa<parse::InterpretedObject>()) {
1992     MS_EXCEPTION(TypeError) << "Not support for " << v->ToString();
1993   } else {
1994     MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
1995   }
1996   MS_EXCEPTION_IF_NULL(tensor_ptr);
1997   (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(tensor_ptr);
1998   (void)op_run_info->base_op_run_info.input_types.emplace_back(input_type);
1999 }
2000 
ReplaceReduceAxis(const FrontendOpRunInfoPtr & op_run_info)2001 void ReplaceReduceAxis(const FrontendOpRunInfoPtr &op_run_info) {
2002   MS_EXCEPTION_IF_NULL(op_run_info);
2003   if (!common::AnfAlgo::IsReduceOp(op_run_info->base_op_run_info.op_name)) {
2004     return;
2005   }
2006   const auto &inputs = op_run_info->base_op_run_info.expanded_input_values;
2007   constexpr size_t kReduceOpInputNum = 2;
2008   if (inputs.size() < kReduceOpInputNum) {
2009     MS_LOG(EXCEPTION) << "Invalid input tensor size " << inputs.size() << " of Op "
2010                       << op_run_info->base_op_run_info.op_name;
2011   }
2012 
2013   MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info);
2014   const auto &op_prim = op_run_info->op_grad_info->op_prim;
2015   MS_EXCEPTION_IF_NULL(op_prim);
2016   if (op_prim->HasAttr(kAttrSkipMode) && GetValue<bool>(op_prim->GetAttr(kAttrSkipMode))) {
2017     return;
2018   }
2019 
2020   // 2nd input tensor is {} or nulltpr, means reduce all axis.
2021   bool reduce_all_axis = false;
2022   if (inputs[kIndex1]->isa<ValueSequence>()) {
2023     auto seq_size = inputs[1]->cast<ValueSequencePtr>()->size();
2024     reduce_all_axis = seq_size == 0;
2025   } else if (inputs[kIndex1]->isa<None>()) {
2026     reduce_all_axis = true;
2027   }
2028   if (reduce_all_axis) {
2029     auto size = inputs[0]->cast<tensor::BaseTensorPtr>()->shape().size();
2030     // For example, input 0 is Tensor(shape=[], value=1), the axis to reduce is 0.
2031     std::vector<ValuePtr> axis = {std::make_shared<Int64Imm>(0)};
2032     for (size_t i = 1; i < size; ++i) {
2033       axis.push_back(std::make_shared<Int64Imm>(static_cast<int64_t>(i)));
2034     }
2035     op_run_info->base_op_run_info.expanded_input_values[1] = std::make_shared<ValueTuple>(axis);
2036   }
2037 }
2038 
GetInputTensor(const FrontendOpRunInfoPtr & op_run_info,const TopCellInfoPtr & top_cell)2039 void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell) {
2040   MS_EXCEPTION_IF_NULL(op_run_info);
2041 
2042   (void)op_run_info->base_op_run_info.expanded_input_values.reserve(op_run_info->input_size);
2043   (void)op_run_info->base_op_run_info.input_types.reserve(op_run_info->input_size);
2044   // Get input tensors.
2045   op_run_info->op_grad_info->op_prim->BeginRecordAddAttr();
2046   for (size_t index = 0; index < op_run_info->input_size; ++index) {
2047     const ValuePtr &input_object = op_run_info->op_grad_info->input_value[index];
2048     // convert const input to attr
2049     if (RunOpConvertConstInputToAttr(op_run_info, input_object, index)) {
2050       continue;
2051     }
2052     // Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
2053     MarkInputs(op_run_info, input_object, index, top_cell);
2054     // -1 indicates input_object is not a dynInput
2055     if (!op_run_info->base_op_run_info.dyn_input_sizes.empty() && !input_object->isa<ValueSequence>()) {
2056       (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(-1);
2057     }
2058   }
2059   op_run_info->op_grad_info->op_prim->EndRecordAddAttr();
2060   ReplaceReduceAxis(op_run_info);
2061   AddDynInputsSizesAttr(op_run_info);
2062 }
2063 
2064 namespace {
2065 const mindspore::HashSet<std::string> kGradBlackList{kMakeTupleOpName,         kMakeListOpName,
2066                                                      kTupleGetItemOpName,      kStopGradientOpName,
2067                                                      kUpdateStateOpName,       kNPUAllocFloatStatusOpName,
2068                                                      kNPUGetFloatStatusOpName, kNPUClearFloatStatusOpName};
2069 
2070 mindspore::HashMap<std::string, pipeline::ResourcePtr> jit_call_graph_compile_cache_;
2071 
CreateMakeTupleNode(const KernelGraphPtr & tape,const ValueSequencePtr & tuple,const abstract::AbstractSequencePtr & abs_seq,const SpecialType & type)2072 AnfNodePtr CreateMakeTupleNode(const KernelGraphPtr &tape, const ValueSequencePtr &tuple,
2073                                const abstract::AbstractSequencePtr &abs_seq, const SpecialType &type) {
2074   AnfNodePtrList args{NewValueNode(prim::kPrimMakeTuple)};
2075   for (size_t i = 0; i < tuple->size(); ++i) {
2076     AnfNodePtr special_like_value = AutoGrad::BuildSpecialNode(tape, tuple->value()[i], abs_seq->elements()[i], type);
2077     (void)args.emplace_back(special_like_value);
2078   }
2079   auto special_like_value = tape->FuncGraph::NewCNode(args);
2080   special_like_value->set_abstract(abs_seq);
2081   return special_like_value;
2082 }
2083 
CreateMakeDictNode(const KernelGraphPtr & tape,const ValueDictionaryPtr & v_dict,const abstract::AbstractDictionaryPtr & abs_dict,const SpecialType & type)2084 AnfNodePtr CreateMakeDictNode(const KernelGraphPtr &tape, const ValueDictionaryPtr &v_dict,
2085                               const abstract::AbstractDictionaryPtr &abs_dict, const SpecialType &type) {
2086   MS_EXCEPTION_IF_NULL(tape);
2087   MS_EXCEPTION_IF_NULL(v_dict);
2088   MS_EXCEPTION_IF_NULL(abs_dict);
2089   AnfNodePtrList key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2090   AnfNodePtrList value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2091   abstract::AbstractBasePtrList local_key_abs_inputs;
2092   abstract::AbstractBasePtrList local_value_abs_inputs;
2093   for (size_t i = 0; i < v_dict->size(); ++i) {
2094     (void)key_inputs.emplace_back(
2095       Common::CreateValueNodeByValue(v_dict->value()[i].first, abs_dict->elements()[i].first));
2096     (void)local_key_abs_inputs.emplace_back(abs_dict->elements()[i].first);
2097     AnfNodePtr special_like_value =
2098       AutoGrad::BuildSpecialNode(tape, v_dict->value()[i].second, abs_dict->elements()[i].second, type);
2099     (void)value_inputs.emplace_back(special_like_value);
2100     (void)local_value_abs_inputs.emplace_back(abs_dict->elements()[i].second);
2101   }
2102   auto local_key_node = tape->NewCNode(key_inputs);
2103   local_key_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_key_abs_inputs));
2104   auto local_value_node = tape->NewCNode(value_inputs);
2105   local_value_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_value_abs_inputs));
2106   auto dict_node = tape->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2107   dict_node->set_abstract(abs_dict);
2108   return dict_node;
2109 }
2110 
GetSparseTensorShapeNode(const ShapeVector & shape)2111 ValueNodePtr GetSparseTensorShapeNode(const ShapeVector &shape) {
2112   auto value_shape = NewValueNode(shape);
2113   std::vector<abstract::AbstractBasePtr> abstract_shape;
2114   (void)std::transform(
2115     shape.begin(), shape.end(), std::back_inserter(abstract_shape),
2116     [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
2117   auto abs_shape = std::make_shared<abstract::AbstractTuple>(abstract_shape);
2118   value_shape->set_abstract(abs_shape);
2119   return value_shape;
2120 }
2121 
WrapCOOTensor(const ValuePtr & coo_out,const ValuePtr & value)2122 ValuePtr WrapCOOTensor(const ValuePtr &coo_out, const ValuePtr &value) {
2123   MS_EXCEPTION_IF_NULL(coo_out);
2124   auto coo_tensor = coo_out->cast<tensor::COOTensorPtr>();
2125   MS_EXCEPTION_IF_NULL(coo_tensor);
2126   auto value_tensor = value->cast<tensor::TensorPtr>();
2127   MS_EXCEPTION_IF_NULL(value_tensor);
2128   auto indices_tensor = coo_tensor->GetIndices();
2129   auto shape_vector = coo_tensor->shape();
2130   return std::make_shared<tensor::COOTensor>(indices_tensor, value_tensor, shape_vector);
2131 }
2132 
WrapCSRTensor(const ValuePtr & csr_out,const ValuePtr & value)2133 ValuePtr WrapCSRTensor(const ValuePtr &csr_out, const ValuePtr &value) {
2134   MS_EXCEPTION_IF_NULL(csr_out);
2135   auto csr_tensor = csr_out->cast<tensor::CSRTensorPtr>();
2136   MS_EXCEPTION_IF_NULL(csr_tensor);
2137   auto value_tensor = value->cast<tensor::TensorPtr>();
2138   MS_EXCEPTION_IF_NULL(value_tensor);
2139   auto indptr_tensor = csr_tensor->GetIndptr();
2140   auto indices_tensor = csr_tensor->GetIndices();
2141   auto shape_vector = csr_tensor->shape();
2142   return std::make_shared<tensor::CSRTensor>(indptr_tensor, indices_tensor, value_tensor, shape_vector);
2143 }
2144 }  // namespace
2145 
IsPrimNeedGrad(const PrimitivePtr & prim)2146 bool AutoGrad::IsPrimNeedGrad(const PrimitivePtr &prim) {
2147   MS_EXCEPTION_IF_NULL(prim);
2148   return kGradBlackList.find(prim->name()) == kGradBlackList.end();
2149 }
2150 
NeedGrad(const std::vector<ValuePtr> & input_values)2151 bool AutoGrad::NeedGrad(const std::vector<ValuePtr> &input_values) {
2152   for (const ValuePtr &input_arg : input_values) {
2153     MS_EXCEPTION_IF_NULL(input_arg);
2154     if (input_arg->isa<tensor::BaseTensor>()) {
2155       tensor::BaseTensorPtr input_tensor = nullptr;
2156       input_tensor = input_arg->cast<tensor::BaseTensorPtr>();
2157       auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
2158       MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
2159       if (auto_grad_meta_data->input_type() == InputType::kParameter && Common::IsParamRequiresGrad(input_tensor)) {
2160         return true;
2161       }
2162       auto variable = auto_grad_meta_data->variable();
2163       if (variable != nullptr) {
2164         return true;
2165       }
2166     } else if (input_arg->isa<ValueSequence>()) {
2167       auto value_seq = input_arg->cast<ValueSequencePtr>()->value();
2168       if (NeedGrad(value_seq)) {
2169         return true;
2170       }
2171     } else if (input_arg->isa<tensor::COOTensor>() || input_arg->isa<tensor::CSRTensor>()) {
2172       return true;
2173     }
2174     MS_LOG(DEBUG) << "Get value " << input_arg->ToString();
2175   }
2176   return false;
2177 }
2178 
IsZerosLikeNode(const AnfNodePtr & node)2179 bool AutoGrad::IsZerosLikeNode(const AnfNodePtr &node) {
2180   MS_EXCEPTION_IF_NULL(node);
2181   if (!node->isa<CNode>()) {
2182     return false;
2183   }
2184   auto cnode = node->cast<CNodePtr>();
2185   if (IsPrimitiveCNode(cnode, prim::kPrimZerosLike)) {
2186     return true;
2187   }
2188   if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
2189     return std::all_of(cnode->inputs().begin() + 1, cnode->inputs().end(),
2190                        [](const auto &node) { return IsZerosLikeNode(node) == true; });
2191   }
2192   if (IsPrimitiveCNode(cnode, prim::kPrimMakeDict)) {
2193     return IsZerosLikeNode(cnode->input(kIndex2));
2194   }
2195   return false;
2196 }
2197 
GetFakeZeroTensor()2198 ValuePtr AutoGrad::GetFakeZeroTensor() {
2199   static ValuePtr fake_v = std::make_shared<tensor::Tensor>(0);
2200   return fake_v;
2201 }
2202 
BuildSpecialValueGrad(const ValuePtr & value,const tensor::BaseTensorPtr & grad,autograd::FuncBuilder * func_builder,const SpecialType & type)2203 ValuePtr AutoGrad::BuildSpecialValueGrad(const ValuePtr &value, const tensor::BaseTensorPtr &grad,
2204                                          autograd::FuncBuilder *func_builder, const SpecialType &type) {
2205   MS_EXCEPTION_IF_NULL(value);
2206   if (grad != nullptr) {
2207     return grad;
2208   }
2209   if (value->isa<tensor::BaseTensor>()) {
2210     return (type == SpecialType::kZerosLikeType ? func_builder->Zeros(value) : func_builder->Ones(value));
2211   }
2212   if (value->isa<ValueSequence>()) {
2213     ValuePtr zero_value = nullptr;
2214     auto v_seq = value->cast<ValueSequencePtr>();
2215     ValuePtrList v_list;
2216     for (const auto &item : v_seq->value()) {
2217       (void)v_list.emplace_back(BuildSpecialValueGrad(item, grad, func_builder, type));
2218     }
2219     return std::make_shared<ValueTuple>(v_list);
2220   }
2221   if (value->isa<Scalar>()) {
2222     auto fake_tensor = std::make_shared<tensor::Tensor>(0, value->type());
2223     return BuildSpecialValueGrad(fake_tensor, grad, func_builder, type);
2224   }
2225   if (value->isa<tensor::CSRTensor>()) {
2226     auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
2227     return WrapCSRTensor(csr_tensor, BuildSpecialValueGrad(csr_tensor->GetValues(), grad, func_builder, type));
2228   }
2229   if (value->isa<tensor::COOTensor>()) {
2230     auto coo_tensor = value->cast<tensor::COOTensorPtr>();
2231     return WrapCOOTensor(coo_tensor, BuildSpecialValueGrad(coo_tensor->GetValues(), grad, func_builder, type));
2232   }
2233   MS_LOG(INFO) << "For value " << value->ToString() << ", the type is not tensor or scalar";
2234   auto fake_tensor = std::make_shared<tensor::Tensor>(0, value->type());
2235   return BuildSpecialValueGrad(fake_tensor, grad, func_builder, type);
2236 }
2237 
BuildSpecialNode(const KernelGraphPtr & tape,const ValuePtr & value,const abstract::AbstractBasePtr & abs,const SpecialType & type)2238 AnfNodePtr AutoGrad::BuildSpecialNode(const KernelGraphPtr &tape, const ValuePtr &value,
2239                                       const abstract::AbstractBasePtr &abs, const SpecialType &type) {
2240   MS_EXCEPTION_IF_NULL(value);
2241   if (value->isa<tensor::BaseTensor>()) {
2242     auto prim_node =
2243       (type == SpecialType::kZerosLikeType ? NewValueNode(std::make_shared<Primitive>(*prim::kPrimZerosLike))
2244                                            : NewValueNode(std::make_shared<Primitive>(*prim::kPrimOnesLike)));
2245     auto value_node = Common::CreateValueNodeByValue(value, abs);
2246     auto special_like_value = tape->FuncGraph::NewCNode({prim_node, value_node});
2247     special_like_value->set_abstract(value_node->abstract());
2248     return special_like_value;
2249   }
2250   if (value->isa<ValueSequence>()) {
2251     auto tuple = value->cast<ValueSequencePtr>();
2252     abstract::AbstractSequencePtr abs_seq;
2253     if (abs == nullptr) {
2254       abs_seq = Common::SetAbstractValueToAnyValue(value->ToAbstract())->cast<abstract::AbstractSequencePtr>();
2255     } else {
2256       abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2257     }
2258     return CreateMakeTupleNode(tape, tuple, abs_seq, type);
2259   }
2260   if (value->isa<Scalar>()) {
2261     auto fake_tensor = GetFakeZeroTensor();
2262     return BuildSpecialNode(tape, fake_tensor, nullptr, type);
2263   }
2264   if (value->isa<tensor::CSRTensor>()) {
2265     auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
2266     MS_EXCEPTION_IF_NULL(csr_tensor);
2267     auto data = csr_tensor->GetValues();
2268     return BuildSpecialNode(tape, data, nullptr, type);
2269   }
2270   if (value->isa<tensor::COOTensor>()) {
2271     auto coo_tensor = value->cast<tensor::COOTensorPtr>();
2272     MS_EXCEPTION_IF_NULL(coo_tensor);
2273     auto data = coo_tensor->GetValues();
2274     return BuildSpecialNode(tape, data, nullptr, type);
2275   }
2276   if (value->isa<ValueDictionary>()) {
2277     auto v_dict = value->cast<ValueDictionaryPtr>();
2278     abstract::AbstractDictionaryPtr abs_dict;
2279     if (abs == nullptr) {
2280       abs_dict = Common::SetAbstractValueToAnyValue(value->ToAbstract())->cast<abstract::AbstractDictionaryPtr>();
2281     } else {
2282       abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
2283     }
2284     return CreateMakeDictNode(tape, v_dict, abs_dict, type);
2285   }
2286   MS_LOG(INFO) << "For value " << value->ToString() << ", the type is not tensor or scalar";
2287   return BuildSpecialNode(tape, GetFakeZeroTensor(), nullptr, type);
2288 }
2289 
BuildSparseTensorNode(const KernelGraphPtr & tape,const ValuePtr & sparse_value,const AnfNodePtr & dout_value_node)2290 AnfNodePtr AutoGrad::BuildSparseTensorNode(const KernelGraphPtr &tape, const ValuePtr &sparse_value,
2291                                            const AnfNodePtr &dout_value_node) {
2292   MS_EXCEPTION_IF_NULL(tape);
2293   MS_EXCEPTION_IF_NULL(sparse_value);
2294   if (sparse_value->isa<tensor::CSRTensor>()) {
2295     auto csr_tensor = sparse_value->cast<tensor::CSRTensorPtr>();
2296     MS_EXCEPTION_IF_NULL(csr_tensor);
2297     auto indptr_node = Common::CreateValueNodeByValue(csr_tensor->GetIndptr());
2298     auto indices_node = Common::CreateValueNodeByValue(csr_tensor->GetIndices());
2299     auto value_shape = GetSparseTensorShapeNode(csr_tensor->shape());
2300     auto special_like_csr_node = tape->FuncGraph::NewCNode(
2301       {NewValueNode(prim::kPrimMakeTuple), indptr_node, indices_node, dout_value_node, value_shape});
2302     special_like_csr_node->set_abstract(sparse_value->ToAbstract()->Broaden());
2303     return special_like_csr_node;
2304   }
2305   if (sparse_value->isa<tensor::COOTensor>()) {
2306     auto coo_tensor = sparse_value->cast<tensor::COOTensorPtr>();
2307     MS_EXCEPTION_IF_NULL(coo_tensor);
2308     auto indices_node = Common::CreateValueNodeByValue(coo_tensor->GetIndices());
2309     auto value_shape = GetSparseTensorShapeNode(coo_tensor->shape());
2310     auto special_like_coo_node =
2311       tape->FuncGraph::NewCNode({NewValueNode(prim::kPrimMakeTuple), indices_node, dout_value_node, value_shape});
2312     special_like_coo_node->set_abstract(sparse_value->ToAbstract()->Broaden());
2313     return special_like_coo_node;
2314   }
2315   MS_LOG(EXCEPTION) << "Get invalid sparse tensor";
2316 }
2317 
SetGradMetaData(const ValuePtr & value,const VariablePtr & variable,const ParameterPtr & param)2318 void AutoGrad::SetGradMetaData(const ValuePtr &value, const VariablePtr &variable, const ParameterPtr &param) {
2319   if (value->isa<tensor::BaseTensor>()) {
2320     tensor::BaseTensorPtr tensor = nullptr;
2321     tensor = value->cast<tensor::BaseTensorPtr>();
2322     auto auto_grad_meta_data = tensor->auto_grad_meta_data();
2323     if (auto_grad_meta_data == nullptr) {
2324       MS_LOG(DEBUG) << "tensor has no auto_grad_meta_data";
2325       auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
2326       tensor->set_auto_grad_meta_data(auto_grad_meta_data);
2327     }
2328     auto_grad_meta_data->set_variable(variable);
2329     if (param != nullptr) {
2330       auto_grad_meta_data->set_parameter(param);
2331       auto_grad_meta_data->set_input_type(InputType::kParameter);
2332     }
2333   } else if (value->isa<ValueSequence>()) {
2334     auto value_sequence = value->cast<ValueSequencePtr>();
2335     for (const auto &val : value_sequence->value()) {
2336       SetGradMetaData(val, variable);
2337     }
2338   } else if (value->isa<ValueDictionary>()) {
2339     auto value_dict = value->cast<ValueDictionaryPtr>();
2340     for (const auto &val : value_dict->value()) {
2341       SetGradMetaData(val.second, variable);
2342     }
2343   }
2344 }
2345 
SetGradInfoForInputs(const ValuePtr & value,const VariablePtr & variable,const ParameterPtr & param)2346 void AutoGrad::SetGradInfoForInputs(const ValuePtr &value, const VariablePtr &variable, const ParameterPtr &param) {
2347   if (value->isa<tensor::BaseTensor>()) {
2348     const auto &input_tensor = value->cast<tensor::BaseTensorPtr>();
2349     const auto &auto_grad_meta_data = input_tensor->auto_grad_meta_data();
2350     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
2351     auto_grad_meta_data->set_variable(variable);
2352     auto_grad_meta_data->set_parameter(param);
2353   } else if (value->isa<tensor::COOTensor>()) {
2354     const auto &coo_tensor = value->cast<tensor::COOTensorPtr>();
2355     const auto &indices_tensor = coo_tensor->GetIndices();
2356     SetGradInfoForInputs(indices_tensor, variable, param);
2357   } else if (value->isa<tensor::CSRTensor>()) {
2358     const auto &csr_tensor = value->cast<tensor::CSRTensorPtr>();
2359     const auto &indices_tensor = csr_tensor->GetIndices();
2360     SetGradInfoForInputs(indices_tensor, variable, param);
2361   }
2362 }
2363 
2364 // Create fake bprop
BuildFakeBpropCNode(const CNodePtr & cnode,std::vector<CNodePtr> * outputs)2365 void AutoGrad::BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
2366   auto prim = GetCNodePrimitive(cnode);
2367   if (prim == nullptr) {
2368     MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
2369   }
2370   size_t dout_index = cnode->size() - 1;
2371   const auto &dout = cnode->input(dout_index);
2372   const auto &dout_cnode = dout->cast<CNodePtr>();
2373   MS_EXCEPTION_IF_NULL(dout_cnode);
2374   // Size is same as op_arg size
2375   size_t input_size = cnode->size() - 2;
2376   for (size_t i = 1; i < input_size; ++i) {
2377     (void)outputs->emplace_back(dout_cnode);
2378   }
2379 }
2380 
CreateGraphCallBack(const FuncGraphPtr & call_graph,const std::string & cache_key,const GraphCallCondition & graph_call_condition)2381 CallBackFn AutoGrad::CreateGraphCallBack(const FuncGraphPtr &call_graph, const std::string &cache_key,
2382                                          const GraphCallCondition &graph_call_condition) {
2383   // kFlagJitCallGraph is set true to avoid compilig call_graph whe compiling the main graph
2384   call_graph->set_flag(kFlagJitCallGraph, true);
2385   // call graph not inline to grad top
2386   call_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
2387   // Pynative bprop graph flag
2388   call_graph->set_flag(kFlagIsPynativeBpropGraph, true);
2389   // Run graph by single op will use this kFlagPyNativeBpropGraphWithBpropCut flag
2390   if (graph_call_condition.is_dynamic_shape_process_) {
2391     call_graph->set_flag(kFlagPyNativeBpropGraphWithBpropCut, false);
2392     if (!graph_call_condition.is_jit_graph_) {
2393       call_graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
2394     }
2395   }
2396   pipeline::ResourcePtr resource;
2397   constexpr auto kNeedCompile = "NeedCompile";
2398   const auto it = jit_call_graph_compile_cache_.find(cache_key);
2399   bool need_compile = (it == jit_call_graph_compile_cache_.end());
2400   if (need_compile) {
2401     resource = std::make_shared<pipeline::Resource>();
2402     resource->set_func_graph(call_graph);
2403     if (graph_call_condition.is_func_grad_) {
2404       auto manager = resource->manager();
2405       manager->AddFuncGraph(call_graph, false);
2406       (void)opt::EnvironConversion(resource);
2407       if (graph_call_condition.jit_out_has_dict_) {
2408         MS_LOG(DEBUG) << "Jit out is dict, need convert make dict to pyexecute";
2409         (void)mindspore::opt::RewriterAfterOptA(resource->func_graph(), resource);
2410       }
2411     }
2412     if (graph_call_condition.is_jit_graph_ || !graph_call_condition.is_dynamic_shape_process_) {
2413       (void)jit_call_graph_compile_cache_.emplace(cache_key, resource);
2414     }
2415     resource->SetResult(kNeedCompile, true);
2416   } else {
2417     resource = it->second;
2418     // If resource func graph not compile(not call run grad graph), but hit cache
2419     need_compile = resource->GetResult(kNeedCompile).cast<bool>();
2420   }
2421   MS_EXCEPTION_IF_NULL(resource);
2422   bool is_control_flow = graph_call_condition.is_control_flow_;
2423   auto fn = [resource, need_compile, is_control_flow, kNeedCompile](const VectorRef &arg_list) -> VectorRef {
2424     if (need_compile) {
2425       MS_LOG(DEBUG) << "Start emit action for graph " << resource->func_graph()->ToString();
2426       auto manager = resource->manager();
2427       manager->AddFuncGraph(resource->func_graph(), true);
2428       resource->SetBackendAsync([]() { return compile::CreateBackend(); });
2429       // kFlagJitCallGraph is set false to compile sub graph in control flow
2430       if (is_control_flow) {
2431         for (const auto &g : manager->func_graphs()) {
2432           g->set_flag(kFlagJitCallGraph, false);
2433         }
2434       }
2435       (void)TaskEmitAction(resource);
2436       (void)ExecuteAction(resource);
2437       resource->SetResult(kNeedCompile, false);
2438     }
2439     MS_LOG(DEBUG) << "Start execute action for graph " << resource->func_graph()->ToString();
2440     compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
2441     return utils::cast<VectorRef>((*run)(arg_list));
2442   };
2443   return fn;
2444 }
2445 
BuildBpropCutPrim(const PrimitivePtr & prim,bool is_need_recompute)2446 PrimitivePyPtr AutoGrad::BuildBpropCutPrim(const PrimitivePtr &prim, bool is_need_recompute) {
2447   MS_EXCEPTION_IF_NULL(prim);
2448   auto prim_py = prim->cast<PrimitivePyPtr>();
2449   MS_EXCEPTION_IF_NULL(prim_py);
2450   auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
2451   bprop_cut->CopyHookFunction(prim_py);
2452   prim_py->AddBpropCutPrim(bprop_cut);
2453   if (prim->HasAttr("cell_id")) {
2454     auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
2455     if (!cell_id.empty()) {
2456       (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
2457       (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
2458     }
2459   }
2460   // Only custom op need add this attr, hook function not need.
2461   if (prim->HasAttr("custom_op_bprop")) {
2462     (void)bprop_cut->AddAttr("custom_op_bprop", MakeValue(true));
2463   }
2464   (void)bprop_cut->AddAttr("custom_op_name", MakeValue(prim->name()));
2465   if (is_need_recompute) {
2466     (void)bprop_cut->AddAttr("is_recompute", MakeValue(true));
2467   }
2468   return bprop_cut;
2469 }
2470 
CheckRecomputeInputs(const GradParamPtr & grad_param)2471 void AutoGrad::CheckRecomputeInputs(const GradParamPtr &grad_param) {
2472   if (grad_param->op_grad_info->is_need_recompute) {
2473     for (const auto &input : grad_param->op_grad_info->input_value) {
2474       if (input->isa<ValueSequence>()) {
2475         const auto &seq = input->cast<ValueSequencePtr>();
2476         const auto val = seq->value();
2477         if (AutoGrad::NeedGrad(val)) {
2478           MS_LOG(EXCEPTION) << "For recompute cell, now we do not support calculate tensor's gradient from tuple. "
2479                                "You need check your inputs of construct function from recompute cell, and not put "
2480                                "tensors in tuple which need grad!";
2481         }
2482       }
2483     }
2484   }
2485 }
2486 
ClearAutoGradStaticCache()2487 void AutoGrad::ClearAutoGradStaticCache() { jit_call_graph_compile_cache_.clear(); }
2488 
IsRealOp(const AnfNodePtr & cnode)2489 bool GradCommon::IsRealOp(const AnfNodePtr &cnode) {
2490   MS_EXCEPTION_IF_NULL(cnode);
2491   const auto &prim = GetCNodePrimitive(cnode);
2492   if (prim == nullptr) {
2493     return false;
2494   }
2495   return kNotRealOP.find(prim->name()) == kNotRealOP.end();
2496 }
2497 
SetForward(const AnfNodePtrList & node_list)2498 void GradCommon::SetForward(const AnfNodePtrList &node_list) {
2499   for (const auto &cn : node_list) {
2500     auto out = Common::CreatOutputTensorValueByAbstract(cn->abstract());
2501     const auto &c_node = cn->cast<CNodePtr>();
2502     MS_EXCEPTION_IF_NULL(c_node);
2503     c_node->set_forward(Common::CreateValueNodeByValue(out, cn->abstract()), "");
2504   }
2505 }
2506 
GetUsedCNodeInBpropGraph(const CNodePtr & cnode,const mindspore::HashSet<size_t> & unused_inputs,AnfNodePtrList * node_list)2507 void GradCommon::GetUsedCNodeInBpropGraph(const CNodePtr &cnode, const mindspore::HashSet<size_t> &unused_inputs,
2508                                           AnfNodePtrList *node_list) {
2509   MS_EXCEPTION_IF_NULL(cnode);
2510   MS_EXCEPTION_IF_NULL(node_list);
2511   // Check input used in single op bprop graph. For example,
2512   // A = a * b;
2513   // B = A * c;
2514   // So, A can also replace by its output
2515   size_t input_num = cnode->size() - 1;
2516   for (size_t i = 0; i < input_num; ++i) {
2517     if (unused_inputs.find(i) == unused_inputs.end() && cnode->input(i + 1)->isa<CNode>()) {
2518       // Input used by bprop graph, and it is a cnode have produce real output
2519       const auto &input_c = cnode->input(i + 1)->cast<CNodePtr>();
2520       MS_EXCEPTION_IF_NULL(input_c);
2521       if (IsPrimitive(input_c, prim::kPrimMakeTuple)) {
2522         size_t tuple_input_num = input_c->size() - 1;
2523         for (size_t j = 0; j < tuple_input_num; ++j) {
2524           if (auto f_node = common::AnfAlgo::VisitKernel(input_c, j).first; f_node->isa<CNode>() && IsRealOp(f_node)) {
2525             MS_LOG(DEBUG) << "Get used input node " << f_node->DebugString();
2526             (void)node_list->emplace_back(f_node);
2527           }
2528         }
2529       } else {
2530         if (auto f_node = common::AnfAlgo::VisitKernel(input_c, 0).first; f_node->isa<CNode>() && IsRealOp(f_node)) {
2531           MS_LOG(DEBUG) << "Get used input node " << f_node->DebugString();
2532           (void)node_list->emplace_back(f_node);
2533         }
2534       }
2535     }
2536   }
2537   // Check output used in single op bprop graph
2538   if (unused_inputs.find(cnode->size() - 1) == unused_inputs.end()) {
2539     MS_LOG(DEBUG) << "Get used output node " << cnode->DebugString();
2540     (void)node_list->emplace_back(cnode);
2541   }
2542 }
2543 }  // namespace PyNativeAlgo
2544 
DispatchOp(const std::shared_ptr<runtime::AsyncTask> & task)2545 void DispatchOp(const std::shared_ptr<runtime::AsyncTask> &task) {
2546   static bool need_sync = runtime::OpExecutor::NeedSync();
2547   if (need_sync && !runtime::OpExecutor::GetInstance().async_for_graph()) {
2548     MS_LOG(INFO) << "PyBoost sync run frontend task";
2549     runtime::OpExecutor::GetInstance().WaitAll();
2550     task->Run();
2551   } else {
2552     runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
2553     runtime::Pipeline::Get().frontend_stage()->Push(task);
2554   }
2555 }
2556 }  // namespace pynative
2557 }  // namespace mindspore
2558