1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/pynative/pynative_utils.h"
17 #include <algorithm>
18 #include <vector>
19 #include "ops/sparse_ops.h"
20 #include "ops/sequence_ops.h"
21 #include "ops/framework_ops.h"
22 #include "include/backend/optimizer/helper.h"
23 #include "include/backend/optimizer/op_adaptation_info_factory.h"
24 #include "pybind_api/ir/primitive_py.h"
25 #include "pybind_api/gil_scoped_long_running.h"
26 #include "pybind_api/ir/hook_py.h"
27 #include "utils/ms_context.h"
28 #include "ir/cell.h"
29 #include "include/common/utils/utils.h"
30 #include "include/common/utils/convert_utils_py.h"
31 #include "include/common/utils/primfunc_utils.h"
32 #include "include/common/debug/anf_ir_dump.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "include/common/utils/stub_tensor.h"
35 #include "frontend/expander/bprop/bprop.h"
36 #include "frontend/optimizer/environ_conversion.h"
37 #include "frontend/optimizer/fallback_rewriter.h"
38 #include "pipeline/pynative/grad/jit/jit_grad.h"
39 #include "ops/sequence_op_name.h"
40 #include "ops/structure_ops.h"
41 #include "ops/other_ops.h"
42 #include "pipeline/pynative/predict_out_type_map.h"
43 #include "kernel/pyboost/auto_generate/contiguous.h"
44 #include "runtime/pipeline/pipeline.h"
45 #include "ops/auto_generate/gen_ops_primitive.h"
46 #include "include/common/pynative/abstract_converter.h"
47 #include "kernel/pyboost/pyboost_utils.h"
48
49 namespace mindspore {
50 namespace pynative {
51 namespace PyNativeAlgo {
52 namespace {
GetObjIdFromPython(const py::handle & obj)53 std::string GetObjIdFromPython(const py::handle &obj) {
54 py::object out = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
55 if (py::isinstance<py::none>(out)) {
56 MS_LOG(EXCEPTION) << "Get pyobj failed";
57 }
58 return out.cast<std::string>();
59 }
60 // for simply infer (simple infer will push abs in bprop queue)
61 static AbstractConverter kGradAbstractConverter;
62
GetIdForPyTupleOrList(const py::handle & obj)63 std::string GetIdForPyTupleOrList(const py::handle &obj) {
64 auto p_list = py::cast<py::tuple>(obj);
65 string prefix = py::isinstance<py::tuple>(obj) ? "Tuple<" : "List<";
66 if (p_list.empty()) {
67 prefix = "Empty:";
68 } else {
69 for (size_t i = 0; i < p_list.size(); ++i) {
70 prefix += PyParser::GetIdByPyObj(p_list[i]) + ":";
71 }
72 }
73 prefix.pop_back();
74 prefix += ">";
75 return prefix;
76 }
77
GetFnInfoByPyObj(const py::object & obj)78 std::string GetFnInfoByPyObj(const py::object &obj) {
79 std::string fn_info = obj.attr("__module__").cast<std::string>();
80 fn_info += "_" + obj.attr("__name__").cast<std::string>();
81 fn_info += "_" + obj.attr("__code__").attr("co_filename").cast<std::string>();
82 fn_info += "_" + py::str(obj.attr("__code__").attr("co_firstlineno")).cast<std::string>();
83 if (py::hasattr(obj, "__warpped__")) {
84 auto warpped_obj = obj.attr("__warpped__");
85 fn_info += "_" + warpped_obj.attr("__name__").cast<std::string>();
86 fn_info += "_" + warpped_obj.attr("__code__").attr("co_filename").cast<std::string>();
87 fn_info += "_" + py::str(warpped_obj.attr("__code__").attr("co_firstlineno")).cast<std::string>();
88 }
89 return fn_info;
90 }
91
AddDynInputsSizesAttr(const FrontendOpRunInfoPtr & op_run_info)92 void AddDynInputsSizesAttr(const FrontendOpRunInfoPtr &op_run_info) {
93 if (op_run_info->base_op_run_info.dyn_input_sizes.empty()) {
94 return;
95 }
96 op_run_info->op_grad_info->op_prim->set_attr(kAttrDynInputSizes,
97 MakeValue(op_run_info->base_op_run_info.dyn_input_sizes));
98 }
99
CreateNonTensorByAbstract(const abstract::AbstractBasePtr & abs)100 ValuePtr CreateNonTensorByAbstract(const abstract::AbstractBasePtr &abs) {
101 MS_EXCEPTION_IF_NULL(abs);
102 auto type_id = Common::GetTypeFromAbstract(abs);
103 if (abs->isa<abstract::AbstractMonad>()) {
104 return std::make_shared<tensor::Tensor>(0);
105 }
106 if (type_id == kMetaTypeNone) {
107 return kNone;
108 }
109 if (type_id == kMetaTypeNull) {
110 return kNull;
111 }
112 if (abs->isa<abstract::AbstractSequence>()) {
113 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>()->elements();
114 ValuePtrList value_ptr_list;
115 (void)std::transform(abs_seq.begin(), abs_seq.end(), std::back_inserter(value_ptr_list),
116 [](const abstract::AbstractBasePtr &elem) { return CreateNonTensorByAbstract(elem); });
117 return std::make_shared<ValueTuple>(value_ptr_list);
118 }
119 if (type_id == kNumberTypeBool) {
120 return MakeValue(true);
121 }
122 if (type_id == kObjectTypeString) {
123 return MakeValue("");
124 }
125 if (type_id >= kNumberTypeInt && type_id <= kNumberTypeUInt64) {
126 return MakeValue(static_cast<int64_t>(0));
127 }
128 if (type_id >= kNumberTypeFloat && type_id <= kNumberTypeFloat64) {
129 return MakeValue(static_cast<float>(0));
130 }
131 if (type_id == kNumberTypeDouble) {
132 return MakeValue(static_cast<double>(0));
133 }
134 MS_LOG(EXCEPTION) << "Get unsupported type " << type_id;
135 }
136
PlantTupleParam(const FuncGraphPtr & bprop_graph,const abstract::AbstractSequencePtr & abs_seq,AnfNodePtrList * make_tuple,AnfNodePtrList * new_param)137 void PlantTupleParam(const FuncGraphPtr &bprop_graph, const abstract::AbstractSequencePtr &abs_seq,
138 AnfNodePtrList *make_tuple, AnfNodePtrList *new_param) {
139 MS_EXCEPTION_IF_NULL(bprop_graph);
140 MS_EXCEPTION_IF_NULL(make_tuple);
141 MS_EXCEPTION_IF_NULL(new_param);
142 MS_EXCEPTION_IF_NULL(abs_seq);
143 for (size_t i = 0; i < abs_seq->size(); ++i) {
144 if (abs_seq->elements()[i]->isa<abstract::AbstractSequence>()) {
145 PlantTupleParam(bprop_graph, abs_seq->elements()[i]->cast<abstract::AbstractSequencePtr>(), make_tuple,
146 new_param);
147 } else if (abs_seq->elements()[i]->isa<abstract::AbstractTensor>()) {
148 auto plant_param = bprop_graph->add_parameter();
149 plant_param->set_abstract(abs_seq->elements()[i]);
150 (void)make_tuple->emplace_back(plant_param);
151 (void)new_param->emplace_back(plant_param);
152 }
153 }
154 }
155
GetContiguousGradTensor(const ValuePtr & v)156 ValuePtr GetContiguousGradTensor(const ValuePtr &v) {
157 const auto &tensor = v->cast<tensor::BaseTensorPtr>();
158 MS_EXCEPTION_IF_NULL(tensor);
159 if (tensor->storage_info() == nullptr) {
160 return nullptr;
161 }
162
163 auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
164 MS_EXCEPTION_IF_NULL(old_device_address);
165 const auto &device_target = old_device_address->device_name();
166 if (device_target != kAscendDevice) {
167 // GPU/CPU contiguous tensor when convert stub node, contiguous before grad.
168 return nullptr;
169 }
170
171 MS_LOG(DEBUG) << "tensor id:" << tensor->id();
172 auto stream_id = old_device_address->stream_id();
173 const auto &old_storage_info = old_device_address->GetTensorStorageInfo();
174 MS_EXCEPTION_IF_NULL(old_storage_info);
175
176 const auto &device_context = runtime::OpRunner::GetDeviceContext(old_device_address->device_name());
177 MS_EXCEPTION_IF_NULL(device_context);
178 auto address_size = GetTypeByte(TypeIdToType(old_device_address->type_id())) * SizeOf(old_storage_info->shape);
179 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
180 nullptr, address_size, Format::DEFAULT_FORMAT, old_device_address->type_id(), old_storage_info->shape,
181 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
182 kernel_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(old_device_address->type_id())));
183 kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(old_storage_info->shape));
184 kernel_tensor->set_stream_id(stream_id);
185
186 auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
187 new_device_address->set_device_shape(old_storage_info->shape);
188 new_device_address->set_original_ref_count(SIZE_MAX);
189 new_device_address->ResetRefCount();
190
191 device::DeviceAddressPtrList input_addr_list{old_device_address};
192 device::DeviceAddressPtrList output_addr_list{new_device_address};
193 GilReleaseWithCheck release_gil;
194 if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(runtime::KernelTaskType::kCONTIGUOUS_TASK,
195 input_addr_list, output_addr_list, stream_id)) {
196 MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
197 }
198
199 MS_LOG(DEBUG) << "Update contiguous address, old_device_address:" << old_device_address
200 << ", new_device_address:" << new_device_address;
201
202 auto new_tensor = std::make_shared<tensor::Tensor>(*tensor);
203 new_tensor->set_device_address(new_device_address);
204 return new_tensor;
205 }
206
RefreshGradContiguousTensor(const FrontendOpRunInfoPtr & op_run_info,size_t index)207 void RefreshGradContiguousTensor(const FrontendOpRunInfoPtr &op_run_info, size_t index) {
208 const auto &unused_inputs = BpropExpander::GetUnusedInputs(op_run_info->op_grad_info->op_prim->name());
209 // Input is not used in bprop, no need to contiguous.
210 if (unused_inputs.find(index) != unused_inputs.end()) {
211 return;
212 }
213
214 const auto &v = op_run_info->op_grad_info->input_value[index];
215 if (v->isa<tensor::BaseTensor>()) {
216 const auto &new_tensor = GetContiguousGradTensor(v);
217 if (new_tensor != nullptr) {
218 op_run_info->op_grad_info->input_value[index] = new_tensor;
219 }
220 } else if (v->isa<ValueSequence>()) {
221 const auto &vec = v->cast<ValueSequencePtr>()->value();
222 if (vec.empty() || !vec[0]->isa<tensor::BaseTensor>()) {
223 return;
224 }
225 // Tensor tuple need contiguous tensor.
226 bool need_refresh_tuple = false;
227 std::vector<ValuePtr> new_vec(vec.size());
228 for (size_t i = 0; i < vec.size(); i++) {
229 const auto &new_tensor = GetContiguousGradTensor(vec[i]);
230 if (new_tensor == nullptr) {
231 new_vec[i] = vec[i];
232 } else {
233 // Not-contiguous tensor in input_value, need refresh tuple after contiguous tensor.
234 need_refresh_tuple = true;
235 new_vec[i] = new_tensor;
236 }
237 }
238 if (need_refresh_tuple) {
239 op_run_info->op_grad_info->input_value[index] = MakeValue(new_vec);
240 }
241 }
242 }
243
244 const mindspore::HashSet<std::string> kNotRealOP{
245 kMakeTupleOpName,
246 kMakeListNewOpName,
247 kTupleGetItemOpName,
248 kStopGradientOpName,
249 kUpdateStateOpName,
250 kLoadOpName,
251 kDependOpName,
252 kReturnOpName,
253 kNPUAllocFloatStatusOpName,
254 kNPUGetFloatStatusOpName,
255 kNPUClearFloatStatusOpName,
256 kMirrorOperatorOpName,
257 kSequenceSliceOpName,
258 kSequenceMulOpName,
259 kPyExecuteOpName,
260 };
261
GetContiguousTensor(const tensor::BaseTensorPtr & input_tensor,const std::string & device_target,bool requires_grad)262 tensor::BaseTensorPtr GetContiguousTensor(const tensor::BaseTensorPtr &input_tensor, const std::string &device_target,
263 bool requires_grad) {
264 auto contiguous_op = CREATE_PYBOOST_OP(Contiguous, device_target);
265 auto contiguous_tensor = contiguous_op->Call(input_tensor);
266 if (requires_grad) {
267 const auto &contiguous_run_info = std::make_shared<FrontendOpRunInfo>();
268 contiguous_run_info->requires_grad = true;
269 PyBoost::UpdateOpRunInfo(contiguous_op, contiguous_run_info);
270 contiguous_run_info->base_op_run_info.device_target = device_target;
271 contiguous_run_info->input_size = 1;
272 contiguous_run_info->base_op_run_info.op_name = ops::kNameContiguous;
273 contiguous_run_info->op_grad_info->op_prim = prim::kPrimContiguous;
274 PyBoost::DoGrad(contiguous_op, contiguous_run_info, {input_tensor});
275 }
276 return contiguous_tensor;
277 }
278
UnsetValueAbstractCache(const ValuePtr & value)279 void UnsetValueAbstractCache(const ValuePtr &value) {
280 if (value->isa<tensor::BaseTensor>()) {
281 auto tensor = value->cast<tensor::BaseTensorPtr>();
282 tensor->set_abstract(std::weak_ptr<abstract::AbstractBase>());
283 } else if (value->isa<tensor::BaseTensor>()) {
284 auto tensor = value->cast<tensor::BaseTensorPtr>();
285 tensor->set_abstract(std::weak_ptr<abstract::AbstractBase>());
286 } else if (value->isa<ValueSequence>()) {
287 const auto &seq = value->cast<ValueSequencePtr>();
288 auto elements = seq->value();
289 for (const auto &element : elements) {
290 UnsetValueAbstractCache(element);
291 }
292 }
293 }
294 } // namespace
295
SetAbstractValueToAnyValue(const AbstractBasePtr & abs)296 AbstractBasePtr Common::SetAbstractValueToAnyValue(const AbstractBasePtr &abs) {
297 MS_EXCEPTION_IF_NULL(abs);
298 if (abs->isa<abstract::AbstractTensor>()) {
299 abs->set_value(kValueAny);
300 } else if (abs->isa<abstract::AbstractTuple>() || abs->isa<abstract::AbstractList>()) {
301 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
302 for (const auto &elem : abs_seq->elements()) {
303 (void)SetAbstractValueToAnyValue(elem);
304 }
305 } else if (abs->isa<abstract::AbstractDictionary>()) {
306 const auto &abs_dic = abs->cast<abstract::AbstractDictionaryPtr>();
307 for (const auto &elem : abs_dic->elements()) {
308 (void)SetAbstractValueToAnyValue(elem.first);
309 (void)SetAbstractValueToAnyValue(elem.second);
310 }
311 }
312 return abs;
313 }
314
ConvertValueSequenceToMakeTuple(const ValueNodePtr & node,const FuncGraphPtr & func_graph)315 AnfNodePtr Common::ConvertValueSequenceToMakeTuple(const ValueNodePtr &node, const FuncGraphPtr &func_graph) {
316 MS_EXCEPTION_IF_NULL(node);
317 const auto &v = node->value();
318 if (!v->isa<ValueSequence>()) {
319 return node;
320 }
321 auto value_sequence = v->cast<ValueSequencePtr>();
322 if (!node->abstract()->isa<abstract::AbstractSequence>() &&
323 (node->abstract()->cast<abstract::AbstractSequencePtr>()->size() != value_sequence->size())) {
324 MS_LOG(EXCEPTION) << "Get wrong matched abs " << node->abstract()->ToString() << " and value "
325 << value_sequence->ToString();
326 }
327
328 AnfNodePtrList inputs{NewValueNode(prim::kPrimMakeTuple)};
329 for (const auto &value : value_sequence->value()) {
330 MS_EXCEPTION_IF_NULL(value);
331 auto value_node = NewValueNode(value);
332 auto abs = Common::SetAbstractValueToAnyValue(value->ToAbstract());
333 value_node->set_abstract(abs);
334 auto tuple_node = ConvertValueSequenceToMakeTuple(value_node, func_graph);
335 (void)inputs.emplace_back(tuple_node);
336 }
337 MS_EXCEPTION_IF_NULL(func_graph);
338 auto make_tuple_node = func_graph->NewCNode(inputs);
339 make_tuple_node->set_abstract(node->abstract());
340 return make_tuple_node;
341 }
342
GetIdByValue(const ValuePtr & v)343 std::string Common::GetIdByValue(const ValuePtr &v) {
344 MS_EXCEPTION_IF_NULL(v);
345 if (v->isa<tensor::BaseTensor>()) {
346 return v->cast<tensor::BaseTensorPtr>()->id();
347 } else if (v->isa<stub::StubNode>()) {
348 return GetIdByValue(v->cast<stub::StubNodePtr>()->WaitValue());
349 } else if (v->isa<Cell>()) {
350 return v->cast<CellPtr>()->id();
351 } else if (v->isa<mindspore::Type>()) {
352 auto type_ptr = v->cast<mindspore::TypePtr>();
353 return "Type:" + type_ptr->ToString();
354 } else if (v->isa<StringImm>()) {
355 return "S" + v->cast<StringImmPtr>()->value();
356 } else if (v->isa<BoolImm>()) {
357 return "B" + std::to_string(v->cast<BoolImmPtr>()->value());
358 } else if (v->isa<IntegerImm>()) {
359 return "I" + std::to_string(v->cast<Int64ImmPtr>()->value());
360 } else if (v->isa<FloatImm>()) {
361 return "F" + std::to_string(v->cast<FP32ImmPtr>()->value());
362 } else if (v->isa<None>()) {
363 return "None";
364 } else if (v->isa<Ellipsis>()) {
365 return "Ellipsis";
366 } else if (v->isa<ValueSequence>()) {
367 auto p_list = v->cast<ValueSequencePtr>();
368 string prefix = v->isa<ValueTuple>() ? "Tuple<" : "List<";
369 if (p_list->size() == 0) {
370 prefix = "Empty:";
371 } else {
372 for (size_t i = 0; i < p_list->size(); ++i) {
373 prefix += GetIdByValue(p_list->value()[i]) + ":";
374 }
375 }
376 prefix.pop_back();
377 prefix += ">";
378 return prefix;
379 }
380 MS_LOG(DEBUG) << "Get type " << v->ToString();
381 return v->ToString();
382 }
383
GetCellId(const std::string & obj_id,const std::vector<std::string> & input_arg_id_vec,const std::vector<ValuePtr> & input_arg_value_vec)384 std::string Common::GetCellId(const std::string &obj_id, const std::vector<std::string> &input_arg_id_vec,
385 const std::vector<ValuePtr> &input_arg_value_vec) {
386 auto cell_id = obj_id;
387 auto fn = [&cell_id](const abstract::AbstractBasePtr &abs) {
388 MS_EXCEPTION_IF_NULL(abs);
389 auto shape = abs->BuildShape();
390 auto type = abs->BuildType();
391 cell_id += "_" + shape->ToString();
392 cell_id += type->ToString();
393 };
394
395 const auto &forward = GetPyNativeExecutor()->forward_executor();
396 for (size_t i = 0; i < input_arg_id_vec.size(); ++i) {
397 const auto &arg_id = input_arg_id_vec[i];
398 // Find in step process
399 auto cache_abs = forward->GetNodeAbsById(arg_id);
400 if (cache_abs != nullptr) {
401 fn(cache_abs);
402 } else {
403 MS_EXCEPTION_IF_NULL(input_arg_value_vec[i]);
404 fn(SetAbstractValueToAnyValue(input_arg_value_vec[i]->ToAbstract()));
405 }
406 }
407 return cell_id;
408 }
409
SplitString(const std::string & str,std::vector<std::string> * id_vec)410 void Common::SplitString(const std::string &str, std::vector<std::string> *id_vec) {
411 constexpr char colon_delim = ':';
412 constexpr char angle_bracket_left_delim = '<';
413 constexpr char angle_bracket_right_delim = '>';
414 auto paren_pos = str.find_first_of(angle_bracket_left_delim);
415 if (paren_pos == std::string::npos) {
416 MS_LOG(EXCEPTION) << "Get wrong str " << str;
417 }
418 size_t str_size = str.size();
419 const auto &sub_str = str.substr(paren_pos + 1, str_size - paren_pos - 2);
420 MS_LOG(DEBUG) << "Ori str " << str << ", get sub str " << sub_str;
421 size_t begin = 0;
422 size_t angle_bracket_left = 0;
423 size_t angle_bracket_right = 0;
424 size_t sub_str_size = sub_str.size();
425 for (size_t i = 0; i < sub_str_size; ++i) {
426 switch (sub_str[i]) {
427 case colon_delim:
428 if (i != 0 && angle_bracket_left == angle_bracket_right) {
429 (void)id_vec->emplace_back(sub_str.substr(begin, i - begin));
430 begin = i + 1;
431 angle_bracket_left = 0;
432 angle_bracket_right = 0;
433 }
434 break;
435 case angle_bracket_left_delim:
436 ++angle_bracket_left;
437 break;
438 case angle_bracket_right_delim:
439 ++angle_bracket_right;
440 break;
441 default: {
442 }
443 }
444 }
445 if (angle_bracket_left == angle_bracket_right) {
446 (void)id_vec->emplace_back(sub_str.substr(begin, sub_str_size - begin));
447 }
448 }
449
ValueHasDynamicShape(const ValuePtr & value)450 bool Common::ValueHasDynamicShape(const ValuePtr &value) {
451 MS_EXCEPTION_IF_NULL(value);
452 if (value->isa<tensor::BaseTensor>()) {
453 return value->cast<tensor::BaseTensorPtr>()->base_shape_ptr() != nullptr;
454 } else if (value->isa<ValueSequence>()) {
455 auto value_seq = value->cast<ValueSequencePtr>();
456 return std::any_of(value_seq->value().begin(), value_seq->value().end(),
457 [](const ValuePtr &elem) { return ValueHasDynamicShape(elem); });
458 }
459 return false;
460 }
461
IsTensor(const ValuePtr & v,bool include_sequence)462 bool Common::IsTensor(const ValuePtr &v, bool include_sequence) {
463 MS_EXCEPTION_IF_NULL(v);
464 if (include_sequence) {
465 if (v->isa<tensor::MetaSparseTensor>() || v->isa<tensor::BaseTensor>()) {
466 return true;
467 } else if (v->isa<ValueSequence>()) {
468 auto v_seq = v->cast<ValueSequencePtr>();
469 if (v_seq->size() == 0) {
470 MS_LOG(DEBUG) << "Get empty value sequence";
471 return false;
472 }
473 // SpareTensor have scalar index, so just check have csr tensor
474 if (v_seq->value().front()->isa<tensor::MetaSparseTensor>()) {
475 return true;
476 }
477 // All value are tensor
478 return std::all_of(v_seq->value().begin(), v_seq->value().end(),
479 [](const ValuePtr &e) { return IsTensor(e, true); });
480 } else {
481 MS_LOG(DEBUG) << "Get value " << v->ToString();
482 return false;
483 }
484 }
485 MS_LOG(DEBUG) << "Get value " << v->ToString();
486 return v->isa<tensor::BaseTensor>() || v->isa<tensor::MetaSparseTensor>();
487 }
488
IsControlFlowGraph(const FuncGraphPtr & func_graph)489 bool Common::IsControlFlowGraph(const FuncGraphPtr &func_graph) {
490 MS_EXCEPTION_IF_NULL(func_graph);
491 return !func_graph->func_graphs_used_total().empty();
492 }
493
FilterSensValues(const ValuePtr & value,bool dict_convert_to_tuple)494 ValuePtr Common::FilterSensValues(const ValuePtr &value, bool dict_convert_to_tuple) {
495 MS_EXCEPTION_IF_NULL(value);
496 if (value->isa<tensor::BaseTensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
497 return value;
498 }
499 if (value->isa<ValueSequence>()) {
500 std::vector<ValuePtr> value_list;
501 auto value_seq = value->cast<ValueSequencePtr>();
502 MS_EXCEPTION_IF_NULL(value_seq);
503 for (auto &filter_value : value_seq->value()) {
504 if (auto t = FilterSensValues(filter_value, dict_convert_to_tuple); t != nullptr) {
505 (void)value_list.emplace_back(t);
506 }
507 }
508 return std::make_shared<ValueTuple>(value_list);
509 }
510 if (value->isa<ValueDictionary>()) {
511 if (dict_convert_to_tuple) {
512 return FilterSensValues(DataConvert::ConvertValueDictToValueTuple(value), dict_convert_to_tuple);
513 }
514 return value;
515 }
516 MS_LOG(DEBUG) << "Value type: " << value->ToString();
517 return nullptr;
518 }
519
GetTensorFromParam(const AnfNodePtr & param_node)520 tensor::BaseTensorPtr Common::GetTensorFromParam(const AnfNodePtr ¶m_node) {
521 MS_EXCEPTION_IF_NULL(param_node);
522 auto param = param_node->cast<ParameterPtr>();
523 MS_EXCEPTION_IF_NULL(param);
524 if (!param->has_default()) {
525 return nullptr;
526 }
527 auto default_value = param->default_param();
528 MS_EXCEPTION_IF_NULL(default_value);
529 auto tensor_value = default_value->cast<tensor::BaseTensorPtr>();
530 MS_EXCEPTION_IF_NULL(tensor_value);
531 return tensor_value;
532 }
533
GetPyNativeExecutor()534 const std::shared_ptr<PyNativeExecutor> &Common::GetPyNativeExecutor() {
535 const auto &executor = PyNativeExecutor::GetInstance();
536 MS_EXCEPTION_IF_NULL(executor);
537 return executor;
538 }
539
DumpGraphIR(const std::string & filename,const FuncGraphPtr & graph)540 void Common::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
541 #ifdef ENABLE_DUMP_IR
542 auto context = MsContext::GetInstance();
543 MS_EXCEPTION_IF_NULL(context);
544 if (context->CanDump(kIntroductory)) {
545 DumpIR(filename, graph);
546 }
547 #endif
548 }
549
GetTypeFromAbstract(const abstract::AbstractBasePtr & abs)550 TypeId Common::GetTypeFromAbstract(const abstract::AbstractBasePtr &abs) {
551 MS_EXCEPTION_IF_NULL(abs);
552 if (abs->isa<abstract::AbstractSequence>()) {
553 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
554 return GetTypeFromAbstract(abs_seq->elements().front());
555 }
556 const auto &type = abs->BuildType();
557 MS_EXCEPTION_IF_NULL(type);
558 return common::AnfAlgo::GetOutputInferDataType(type, 0);
559 }
560
GetShapeFromAbstract(const abstract::AbstractBasePtr & abs)561 ShapeVector Common::GetShapeFromAbstract(const abstract::AbstractBasePtr &abs) {
562 MS_EXCEPTION_IF_NULL(abs);
563 if (abs->isa<abstract::AbstractSequence>()) {
564 MS_LOG(EXCEPTION) << "Get abstract sequence";
565 }
566 auto shape = abs->BuildShape();
567 MS_EXCEPTION_IF_NULL(shape);
568 auto shape_ptr = shape->cast<abstract::ShapePtr>();
569 MS_EXCEPTION_IF_NULL(shape_ptr);
570 return shape_ptr->shape();
571 }
572
GetTypeFromValue(const ValuePtr & v)573 std::pair<TypePtr, TypeId> Common::GetTypeFromValue(const ValuePtr &v) {
574 MS_EXCEPTION_IF_NULL(v);
575 if (v->isa<tensor::BaseTensor>()) {
576 return std::make_pair(v->cast<tensor::BaseTensorPtr>()->Dtype(), kObjectTypeTensorType);
577 } else if (v->isa<ValueTuple>()) {
578 return std::make_pair(v->type(), kObjectTypeTuple);
579 } else if (v->isa<ValueList>()) {
580 return std::make_pair(v->type(), kObjectTypeList);
581 } else if (v->isa<None>()) {
582 return std::make_pair(kTypeNone, kMetaTypeNone);
583 } else {
584 return std::make_pair(v->type(), v->type()->object_type());
585 }
586 }
587
GetShapeFromValue(const ValuePtr & v)588 ShapeVector Common::GetShapeFromValue(const ValuePtr &v) {
589 MS_EXCEPTION_IF_NULL(v);
590 if (v->isa<tensor::BaseTensor>()) {
591 return v->cast<tensor::BaseTensorPtr>()->shape_c();
592 } else if (v->isa<ValueSequence>()) {
593 const auto &v_seq = v->cast<ValueSequencePtr>()->value();
594 ShapeVector plant_shape_vector;
595 for (const auto &item : v_seq) {
596 const auto &shape = GetShapeFromValue(item);
597 (void)std::transform(shape.begin(), shape.end(), std::back_inserter(plant_shape_vector),
598 [](int64_t s) { return s; });
599 }
600 return plant_shape_vector;
601 } else {
602 return ShapeVector{};
603 }
604 }
605
CreatOutputTensorValueByAbstract(const abstract::AbstractBasePtr & abs)606 ValuePtr Common::CreatOutputTensorValueByAbstract(const abstract::AbstractBasePtr &abs) {
607 MS_EXCEPTION_IF_NULL(abs);
608 auto type_id = GetTypeFromAbstract(abs);
609 if (abs->isa<abstract::AbstractMonad>()) {
610 return std::make_shared<tensor::Tensor>(0);
611 }
612 if (abs->isa<abstract::AbstractSequence>()) {
613 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
614 std::vector<ValuePtr> out;
615 if (!abs_seq->elements().front()->isa<abstract::AbstractTensor>()) {
616 MS_LOG(DEBUG) << "Get non tensor output";
617 return CreateNonTensorByAbstract(abs);
618 }
619 for (size_t i = 0; i < abs_seq->size(); ++i) {
620 (void)out.emplace_back(std::make_shared<tensor::Tensor>(type_id, GetShapeFromAbstract(abs_seq->elements()[i])));
621 }
622 return std::make_shared<ValueTuple>(out);
623 }
624 if (!abs->isa<abstract::AbstractTensor>()) {
625 MS_LOG(DEBUG) << "Get non tensor output";
626 return CreateNonTensorByAbstract(abs);
627 }
628 return std::make_shared<tensor::Tensor>(type_id, GetShapeFromAbstract(abs));
629 }
630
ReplaceCNodeWithValueNode(const FuncGraphPtr & bprop_graph)631 void Common::ReplaceCNodeWithValueNode(const FuncGraphPtr &bprop_graph) {
632 MS_EXCEPTION_IF_NULL(bprop_graph);
633 if (bprop_graph->used_forward_nodes().empty()) {
634 return;
635 }
636 auto mng = MakeManager({bprop_graph}, false);
637 auto tr = mng->Transact();
638 for (const auto &forward_node : bprop_graph->used_forward_nodes()) {
639 auto cnode = forward_node->cast<CNodePtr>();
640 auto v_node = cnode->forward().first;
641 MS_EXCEPTION_IF_NULL(v_node);
642 bprop_graph->AddValueNode(v_node);
643 MS_LOG(DEBUG) << "Replace " << forward_node->DebugString() << " by value node " << v_node->DebugString();
644 auto converted_node = ConvertValueSequenceToMakeTuple(v_node, bprop_graph);
645 (void)tr.Replace(forward_node, converted_node);
646 }
647 tr.Commit();
648 bprop_graph->ClearUsedForwardNodes();
649 DumpGraphIR("replace_cnode_with_valuenode.ir", bprop_graph);
650 }
651
StubNodeToValueInner(const ValuePtr & v)652 ValuePtr StubNodeToValueInner(const ValuePtr &v) {
653 MS_EXCEPTION_IF_NULL(v);
654 if (utils::isa<stub::StubNode>(v)) {
655 auto stub = utils::cast<stub::StubNodePtr>(v);
656 return stub->WaitValue();
657 }
658 if (utils::isa<ValueSequence>(v)) {
659 const auto &value_seq = utils::cast<ValueSequencePtr>(v);
660 const auto &values = value_seq->value();
661 if (!values.empty() && utils::isa<Scalar>(values[0])) {
662 return v;
663 }
664 ValuePtrList value_list;
665 (void)std::transform(values.begin(), values.end(), std::back_inserter(value_list),
666 [](const ValuePtr &value) { return StubNodeToValueInner(value); });
667 if (utils::isa<ValueTuple>(v)) {
668 return std::make_shared<ValueTuple>(value_list);
669 }
670 if (utils::isa<ValueList>(v)) {
671 return std::make_shared<ValueList>(value_list);
672 }
673 MS_LOG(EXCEPTION) << "Value not support ValueSequence " << v->ToString();
674 } else {
675 return v;
676 }
677 }
678
StubNodeToValue(const FrontendOpRunInfoPtr & op_run_info)679 void Common::StubNodeToValue(const FrontendOpRunInfoPtr &op_run_info) {
680 MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info);
681 auto old_stream_id = kernel::pyboost::PyBoostUtils::cur_stream_id();
682 kernel::pyboost::PyBoostUtils::set_cur_stream_id(op_run_info->base_op_run_info.stream_id);
683 for (size_t i = 0; i < op_run_info->input_size; i++) {
684 op_run_info->op_grad_info->input_value[i] = StubNodeToValueInner(op_run_info->op_grad_info->input_value[i]);
685 if (!op_run_info->is_view_op) {
686 op_run_info->op_grad_info->input_value[i] =
687 ConvertToContiguousValue(op_run_info->op_grad_info->input_value[i], op_run_info->requires_grad);
688 }
689 kernel::pyboost::PyBoostUtils::set_cur_stream_id(old_stream_id);
690 runtime::DeviceAddressUtils::CreateKernelTensor(op_run_info->op_grad_info->input_value[i]);
691 }
692 }
693
StubNodeToTensor(const ValuePtr & v)694 tensor::BaseTensorPtr Common::StubNodeToTensor(const ValuePtr &v) {
695 MS_EXCEPTION_IF_NULL(v);
696 if (utils::isa<stub::StubNode>(v)) {
697 auto stub = utils::cast<stub::StubNodePtr>(v);
698 return stub->WaitValue()->cast<tensor::BaseTensorPtr>();
699 }
700 if (v->isa<tensor::BaseTensor>()) {
701 return v->cast<tensor::BaseTensorPtr>();
702 }
703 MS_LOG(EXCEPTION) << "It should be stub tensor, but got " << v->ToString();
704 }
705
ConvertToContiguousValue(const ValuePtr & v,bool requires_grad)706 ValuePtr Common::ConvertToContiguousValue(const ValuePtr &v, bool requires_grad) {
707 MS_EXCEPTION_IF_NULL(v);
708 if (v->isa<tensor::BaseTensor>()) {
709 auto tensor = v->cast<tensor::BaseTensorPtr>();
710 MS_EXCEPTION_IF_NULL(tensor);
711 if (tensor->storage_info() == nullptr) {
712 return tensor;
713 }
714
715 auto contiguous_tensor = ConvertToContiguousTensor(tensor, requires_grad);
716 MS_LOG(DEBUG) << "ConvertToContiguousValue, old tensor id:" << tensor->id()
717 << ", new tensor id:" << contiguous_tensor->id();
718 return contiguous_tensor;
719 }
720 if (utils::isa<ValueSequence>(v)) {
721 const auto &value_seq = utils::cast<ValueSequencePtr>(v);
722 const auto &values = value_seq->value();
723 if (values.empty() || utils::isa<Scalar>(values[0])) {
724 return v;
725 }
726 ValuePtrList value_list;
727 (void)std::transform(
728 values.begin(), values.end(), std::back_inserter(value_list),
729 [requires_grad](const ValuePtr &value) { return ConvertToContiguousValue(value, requires_grad); });
730 if (utils::isa<ValueTuple>(v)) {
731 return std::make_shared<ValueTuple>(value_list);
732 }
733 if (utils::isa<ValueList>(v)) {
734 return std::make_shared<ValueList>(value_list);
735 }
736 MS_LOG(EXCEPTION) << "Not support ValueSequence " << v->ToString();
737 } else {
738 return v;
739 }
740 }
741
ConvertToContiguousTensor(const tensor::BaseTensorPtr & tensor,bool requires_grad)742 tensor::BaseTensorPtr Common::ConvertToContiguousTensor(const tensor::BaseTensorPtr &tensor, bool requires_grad) {
743 MS_EXCEPTION_IF_NULL(tensor);
744
745 // Tensor with storage info, need covert to contiguous in no-view op.
746 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
747 MS_EXCEPTION_IF_NULL(device_address);
748 const auto &device_target = device_address->device_name();
749
750 return GetContiguousTensor(tensor, device_target, requires_grad);
751 }
752
ConvertStubNodeToTensor(const ValuePtr & v,bool need_contiguous,bool requires_grad)753 tensor::BaseTensorPtr Common::ConvertStubNodeToTensor(const ValuePtr &v, bool need_contiguous, bool requires_grad) {
754 const auto &tensor = StubNodeToTensor(v);
755 MS_EXCEPTION_IF_NULL(tensor);
756 if (!need_contiguous || tensor->storage_info() == nullptr) {
757 return tensor;
758 }
759
760 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
761 MS_EXCEPTION_IF_NULL(device_address);
762 const auto &device_target = device_address->device_name();
763 if (device_target == kAscendDevice) {
764 return tensor;
765 }
766
767 return GetContiguousTensor(tensor, device_target, requires_grad);
768 }
769
ConvertStubNodeToTensor(const std::optional<ValuePtr> & v,bool need_contiguous,bool requires_grad)770 std::optional<tensor::BaseTensorPtr> Common::ConvertStubNodeToTensor(const std::optional<ValuePtr> &v,
771 bool need_contiguous, bool requires_grad) {
772 if (!v.has_value()) {
773 return std::nullopt;
774 }
775 return std::make_optional(ConvertStubNodeToTensor(v.value(), need_contiguous, requires_grad));
776 }
777
ConvertStubNodeToValueTuple(const ValueListPtr & v,bool need_contiguous,bool requires_grad)778 ValueTuplePtr Common::ConvertStubNodeToValueTuple(const ValueListPtr &v, bool need_contiguous, bool requires_grad) {
779 if (utils::isa<ValueSequence>(v)) {
780 const auto &value_seq = utils::cast<ValueSequencePtr>(v);
781 const auto &values = value_seq->value();
782 std::vector<ValuePtr> tensor_list;
783 (void)std::transform(values.begin(), values.end(), std::back_inserter(tensor_list),
784 [need_contiguous, requires_grad](const ValuePtr &value) {
785 return ConvertStubNodeToTensor(value, need_contiguous, requires_grad);
786 });
787 return std::make_shared<ValueTuple>(tensor_list);
788 }
789 MS_LOG(EXCEPTION) << "It should be stub tensor sequence, but got " << v->ToString();
790 }
791
ConvertStubNodeToValueTuple(const ValueTuplePtr & v,bool need_contiguous,bool requires_grad)792 ValueTuplePtr Common::ConvertStubNodeToValueTuple(const ValueTuplePtr &v, bool need_contiguous, bool requires_grad) {
793 if (utils::isa<ValueSequence>(v)) {
794 const auto &value_seq = utils::cast<ValueSequencePtr>(v);
795 const auto &values = value_seq->value();
796 std::vector<ValuePtr> tensor_list;
797 (void)std::transform(values.begin(), values.end(), std::back_inserter(tensor_list),
798 [need_contiguous, requires_grad](const ValuePtr &value) {
799 return ConvertStubNodeToTensor(value, need_contiguous, requires_grad);
800 });
801 return std::make_shared<ValueTuple>(tensor_list);
802 }
803 MS_LOG(EXCEPTION) << "It should be stub tensor sequence, but got " << v->ToString();
804 }
805
ConvertStubNodeToValueTuple(const std::optional<ValueTuplePtr> & v,bool need_contiguous,bool requires_grad)806 std::optional<ValueTuplePtr> Common::ConvertStubNodeToValueTuple(const std::optional<ValueTuplePtr> &v,
807 bool need_contiguous, bool requires_grad) {
808 if (!v.has_value()) {
809 return std::nullopt;
810 }
811 return std::make_optional(ConvertStubNodeToValueTuple(v.value(), need_contiguous, requires_grad));
812 }
813
GetConstInputToAttr(const PrimitivePtr & op_prim,const std::string & op_name,const std::string & device_target,bool is_dynamic_shape,mindspore::HashSet<size_t> * input_to_attr_index)814 void Common::GetConstInputToAttr(const PrimitivePtr &op_prim, const std::string &op_name,
815 const std::string &device_target, bool is_dynamic_shape,
816 mindspore::HashSet<size_t> *input_to_attr_index) {
817 if (op_name == prim::kPrimCustom->name()) {
818 // Custom op needs to set reg dynamically
819 mindspore::HashSet<size_t> attr_indexes;
820 PrimitiveReadLock read_lock(op_prim->shared_mutex());
821 opt::GetCustomOpAttrIndex(op_prim, input_to_attr_index);
822 return;
823 }
824
825 // Ascend const input to attr move to AscendVmOpAdapter
826 if (device_target == kAscendDevice) {
827 return;
828 }
829
830 auto reg_info =
831 opt::OpAdaptationInfoRegister::GetInstance().GetOpAdaptationInfo(op_name, device_target, is_dynamic_shape);
832 if (reg_info == nullptr) {
833 return;
834 } else {
835 MS_EXCEPTION_IF_NULL(input_to_attr_index);
836 for (auto &iter : reg_info->input_attr_map()) {
837 (void)input_to_attr_index->insert(iter.first);
838 }
839 }
840 }
841
CreateValueNodeByValue(const ValuePtr & v,const abstract::AbstractBasePtr & abs)842 ValueNodePtr Common::CreateValueNodeByValue(const ValuePtr &v, const abstract::AbstractBasePtr &abs) {
843 MS_EXCEPTION_IF_NULL(v);
844 auto v_node = NewValueNode(v);
845 if (abs == nullptr) {
846 v_node->set_abstract(SetAbstractValueToAnyValue(v->ToAbstract()));
847 } else {
848 v_node->set_abstract(abs);
849 }
850 return v_node;
851 }
852
CreateFakeTensorWithoutDeviceAddress(const tensor::TensorPtr & tensor)853 tensor::TensorPtr Common::CreateFakeTensorWithoutDeviceAddress(const tensor::TensorPtr &tensor) {
854 MS_EXCEPTION_IF_NULL(tensor);
855 auto t = std::make_shared<tensor::Tensor>(*tensor);
856 if (tensor->is_parameter()) {
857 t->set_param_info(tensor->param_info());
858 }
859 t->set_device_address(nullptr);
860 return t;
861 }
862
ClearDeviceAddress(const ValuePtr & value)863 void Common::ClearDeviceAddress(const ValuePtr &value) {
864 std::vector<tensor::BaseTensorPtr> tensors;
865 TensorValueToTensor(value, &tensors);
866 for (const auto &tensor : tensors) {
867 tensor->set_device_address(nullptr);
868 }
869 }
870
CreateFakeValueWithoutDeviceAddress(const ValuePtr & value)871 ValuePtr Common::CreateFakeValueWithoutDeviceAddress(const ValuePtr &value) {
872 MS_EXCEPTION_IF_NULL(value);
873 if (value->isa<tensor::BaseTensor>()) {
874 const auto &v_t = value->cast<tensor::BaseTensorPtr>();
875 auto t = std::make_shared<tensor::Tensor>(*v_t);
876 if (v_t->is_parameter()) {
877 t->set_param_info(v_t->param_info());
878 }
879 t->set_device_address(nullptr);
880 return t;
881 } else if (value->isa<ValueSequence>()) {
882 const auto &value_seq = value->cast<ValueSequencePtr>();
883 ValuePtrList value_list;
884 (void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(value_list),
885 [](const ValuePtr &elem) { return CreateFakeValueWithoutDeviceAddress(elem); });
886 return std::make_shared<ValueTuple>(value_list);
887 } else if (value->isa<stub::StubNode>()) {
888 const auto &stub_node = value->cast<stub::StubNodePtr>();
889 return CreateFakeValueWithoutDeviceAddress(stub_node->WaitValue());
890 } else if (value->isa<ValueDictionary>()) {
891 auto dic_v = value->cast<ValueDictionaryPtr>();
892 std::vector<std::pair<ValuePtr, ValuePtr>> key_values;
893 for (const auto &v : dic_v->value()) {
894 (void)key_values.emplace_back(v.first, CreateFakeValueWithoutDeviceAddress(v.second));
895 }
896 return std::make_shared<ValueDictionary>(key_values);
897 } else {
898 return value;
899 }
900 }
901
SetValueGradInfo(const ValuePtr & value,const TopCellInfoPtr & top_cell,InputType grad_type)902 InputType Common::SetValueGradInfo(const ValuePtr &value, const TopCellInfoPtr &top_cell, InputType grad_type) {
903 MS_EXCEPTION_IF_NULL(value);
904 if (value->isa<tensor::BaseTensor>()) {
905 const auto &tensor_value = value->cast<tensor::BaseTensorPtr>();
906 auto auto_grad_meta_data = tensor_value->auto_grad_meta_data();
907 if (auto_grad_meta_data != nullptr) {
908 if (auto_grad_meta_data->input_type() != InputType::kUnkown) {
909 return auto_grad_meta_data->input_type();
910 }
911 MS_LOG(DEBUG) << "Set input type for tensor " << tensor_value->id();
912 } else {
913 MS_LOG(DEBUG) << "Create new auto grad meta for tensor " << tensor_value->id();
914 auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
915 tensor::RegisterHook::UpdateTensorBackwardHook(auto_grad_meta_data, tensor_value->id());
916 tensor_value->set_auto_grad_meta_data(auto_grad_meta_data);
917 }
918
919 if (tensor_value->is_parameter() && grad_type != InputType::kInput) {
920 grad_type = InputType::kParameter;
921 }
922 auto_grad_meta_data->set_input_type(grad_type);
923 if (top_cell != nullptr && IsParam(grad_type)) {
924 top_cell->AddMetaGradInfo(tensor_value, auto_grad_meta_data);
925 }
926 return grad_type;
927 } else if (value->isa<ValueSequence>()) {
928 const auto &value_seq = value->cast<ValueSequencePtr>()->value();
929 InputType ret_type = grad_type;
930 for (const auto &v : value_seq) {
931 auto ret = SetValueGradInfo(v, top_cell, grad_type);
932 if (IsParam(ret)) {
933 ret_type = ret;
934 }
935 }
936 return ret_type;
937 } else if (value->isa<tensor::COOTensor>()) {
938 const auto &coo_tensor = value->cast<tensor::COOTensorPtr>();
939 const auto &indices_tensor = coo_tensor->GetIndices();
940 return SetValueGradInfo(indices_tensor, top_cell, grad_type);
941 } else if (value->isa<tensor::CSRTensor>()) {
942 const auto &csr_tensor = value->cast<tensor::CSRTensorPtr>();
943 const auto &indices_tensor = csr_tensor->GetIndices();
944 return SetValueGradInfo(indices_tensor, top_cell, grad_type);
945 } else if (value->isa<ValueDictionary>()) {
946 const auto &dic_v = value->cast<ValueDictionaryPtr>()->value();
947 for (const auto &v : dic_v) {
948 (void)SetValueGradInfo(v.second, top_cell, grad_type);
949 }
950 }
951 return grad_type;
952 }
953
SetTensorGradInfo(const tensor::BaseTensorPtr & tensor,const TopCellInfoPtr & top_cell)954 InputType Common::SetTensorGradInfo(const tensor::BaseTensorPtr &tensor, const TopCellInfoPtr &top_cell) {
955 MS_EXCEPTION_IF_NULL(tensor);
956 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
957 if (auto_grad_meta_data != nullptr) {
958 if (auto_grad_meta_data->input_type() != InputType::kUnkown) {
959 return auto_grad_meta_data->input_type();
960 }
961 MS_LOG(DEBUG) << "Set input type for tensor " << tensor->id();
962 } else {
963 MS_LOG(DEBUG) << "Create new auto grad meta for tensor " << tensor->id();
964 auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
965 tensor::RegisterHook::UpdateTensorBackwardHook(auto_grad_meta_data, tensor->id());
966 tensor->set_auto_grad_meta_data(auto_grad_meta_data);
967 }
968 // Set weight tensor grad type
969 if (tensor->is_parameter()) {
970 auto_grad_meta_data->set_input_type(InputType::kParameter);
971 if (top_cell != nullptr) {
972 top_cell->AddMetaGradInfo(tensor, auto_grad_meta_data);
973 }
974 return InputType::kParameter;
975 }
976 // Is a constant input tensor, but not constant scalar value
977 auto_grad_meta_data->set_input_type(InputType::kConstant);
978 return InputType::kConstant;
979 }
980
SetGraphInputAndWeightsInfo(const FrontendOpRunInfoPtr & op_run_info,const FuncGraphPtr & func_graph,const TopCellInfoPtr & top_cell)981 void Common::SetGraphInputAndWeightsInfo(const FrontendOpRunInfoPtr &op_run_info, const FuncGraphPtr &func_graph,
982 const TopCellInfoPtr &top_cell) {
983 MS_EXCEPTION_IF_NULL(func_graph);
984 const auto &original_params = func_graph->parameters();
985 size_t params_size = original_params.size();
986 MS_EXCEPTION_IF_NULL(op_run_info);
987 op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
988 bool need_add_input_abs = op_run_info->op_grad_info->input_abs.empty();
989 for (size_t i = 0; i < params_size; ++i) {
990 if (i < op_run_info->input_size) { // non-weights node.
991 op_run_info->op_grad_info->input_value_grad_type[i] =
992 SetValueGradInfo(op_run_info->op_grad_info->input_value[i], top_cell, InputType::kConstant);
993 if (need_add_input_abs) {
994 (void)op_run_info->op_grad_info->input_abs.emplace_back(original_params[i]->abstract());
995 }
996 continue;
997 }
998 // Must weight param
999 const auto ¶m = original_params[i]->cast<ParameterPtr>();
1000 const auto tensor_value = GetTensorFromParam(original_params[i]);
1001 MS_EXCEPTION_IF_NULL(tensor_value);
1002 (void)op_run_info->op_grad_info->input_value.emplace_back(tensor_value);
1003 (void)op_run_info->op_grad_info->input_value_grad_type.emplace_back(SetTensorGradInfo(tensor_value, top_cell));
1004 (void)op_run_info->op_grad_info->input_abs.emplace_back(param->abstract());
1005 MS_LOG(DEBUG) << "Set graph weight parameter " << param->DebugString() << ". Its default value is "
1006 << tensor_value->ToString() << ". Its name is: " << param->name();
1007 }
1008 }
1009
ProcessTupleParam(const FuncGraphPtr & bprop_graph,size_t position)1010 void Common::ProcessTupleParam(const FuncGraphPtr &bprop_graph, size_t position) {
1011 auto bprop_params = bprop_graph->parameters();
1012 auto target_param = bprop_params[position];
1013 MS_EXCEPTION_IF_NULL(target_param);
1014 const auto &target_abstract = target_param->abstract();
1015 MS_EXCEPTION_IF_NULL(target_abstract);
1016 if (!target_abstract->isa<abstract::AbstractSequence>()) {
1017 MS_LOG(EXCEPTION) << "Get wrong param " << target_abstract->ToString();
1018 }
1019 const auto &abs_seq = target_abstract->cast<abstract::AbstractSequencePtr>();
1020 if (abs_seq->dynamic_len() && abs_seq->dynamic_len_element_abs() != nullptr) {
1021 return;
1022 }
1023 MS_LOG(DEBUG) << "Process tuple param " << target_abstract->ToString();
1024 auto it = std::find(bprop_params.begin(), bprop_params.end(), target_param);
1025 it = bprop_params.erase(it);
1026 AnfNodePtrList make_tuple{NewValueNode(prim::kPrimMakeTuple)};
1027 AnfNodePtrList new_param;
1028 PlantTupleParam(bprop_graph, abs_seq, &make_tuple, &new_param);
1029 (void)bprop_params.insert(it, new_param.begin(), new_param.end());
1030 bprop_graph->set_parameters(bprop_params);
1031 auto make_tuple_param = bprop_graph->NewCNode(make_tuple);
1032 make_tuple_param->set_abstract(target_abstract);
1033 auto manager = bprop_graph->manager();
1034 if (manager == nullptr) {
1035 manager = MakeManager({bprop_graph}, false);
1036 }
1037 MS_EXCEPTION_IF_NULL(manager);
1038 auto tr = manager->Transact();
1039 (void)tr.Replace(target_param, make_tuple_param);
1040 tr.Commit();
1041 }
1042
ProcessDictParam(const FuncGraphPtr & bprop_graph,size_t position)1043 void Common::ProcessDictParam(const FuncGraphPtr &bprop_graph, size_t position) {
1044 auto bprop_params = bprop_graph->parameters();
1045 auto target_param = bprop_params[position];
1046 MS_EXCEPTION_IF_NULL(target_param);
1047 const auto &target_abstract = target_param->abstract();
1048 MS_EXCEPTION_IF_NULL(target_abstract);
1049 if (!target_abstract->isa<abstract::AbstractDictionary>()) {
1050 MS_LOG(EXCEPTION) << "Get wrong param " << target_abstract->ToString();
1051 }
1052 MS_LOG(DEBUG) << "Process Dict param " << target_abstract->ToString();
1053 auto it = std::find(bprop_params.begin(), bprop_params.end(), target_param);
1054 it = bprop_params.erase(it);
1055 const auto &abs_dict = target_abstract->cast<abstract::AbstractDictionaryPtr>();
1056 abstract::AbstractBasePtrList local_key_abs_inputs;
1057 abstract::AbstractBasePtrList local_value_abs_inputs;
1058 for (size_t i = 0; i < abs_dict->size(); ++i) {
1059 (void)local_key_abs_inputs.emplace_back(abs_dict->elements()[i].first);
1060 (void)local_value_abs_inputs.emplace_back(abs_dict->elements()[i].second);
1061 }
1062 auto key_param = bprop_graph->add_parameter();
1063 key_param->set_abstract(std::make_shared<abstract::AbstractTuple>(local_key_abs_inputs));
1064 auto value_param = bprop_graph->add_parameter();
1065 value_param->set_abstract(std::make_shared<abstract::AbstractTuple>(local_value_abs_inputs));
1066 auto key_it = bprop_params.insert(it, value_param);
1067 (void)bprop_params.insert(key_it, key_param);
1068 bprop_graph->set_parameters(bprop_params);
1069 auto dict_node = bprop_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), key_param, value_param});
1070 dict_node->set_abstract(abs_dict);
1071 auto manager = bprop_graph->manager();
1072 if (manager == nullptr) {
1073 manager = MakeManager({bprop_graph}, false);
1074 }
1075 auto tr = manager->Transact();
1076 (void)tr.Replace(target_param, dict_node);
1077 tr.Commit();
1078 }
1079
FreeFuncGraphForwardNodes(const FuncGraphPtr & func_graph)1080 void Common::FreeFuncGraphForwardNodes(const FuncGraphPtr &func_graph) {
1081 MS_EXCEPTION_IF_NULL(func_graph);
1082 if (func_graph->used_forward_nodes().empty()) {
1083 return;
1084 }
1085 for (const auto &node : func_graph->used_forward_nodes()) {
1086 MS_EXCEPTION_IF_NULL(node);
1087 auto cnode = node->cast<CNodePtr>();
1088 MS_EXCEPTION_IF_NULL(cnode);
1089 cnode->set_forward(nullptr, "");
1090 }
1091 func_graph->ClearUsedForwardNodes();
1092 }
1093
GetValueSize(const ValuePtr & v)1094 size_t Common::GetValueSize(const ValuePtr &v) {
1095 MS_EXCEPTION_IF_NULL(v);
1096 if (v->isa<tensor::BaseTensor>() || v->isa<Scalar>()) {
1097 return 1;
1098 } else if (v->isa<ValueSequence>()) {
1099 auto seq = v->cast<ValueSequencePtr>();
1100 size_t output_size = 0;
1101 for (const auto &val : seq->value()) {
1102 output_size += GetValueSize(val);
1103 }
1104 return output_size;
1105 } else if (v->isa<ValueDictionary>()) {
1106 const auto &v_dict = v->cast<ValueDictionaryPtr>();
1107 size_t output_size = 0;
1108 for (const auto &val : v_dict->value()) {
1109 output_size += GetValueSize(val.second);
1110 }
1111 return output_size;
1112 }
1113 return 0;
1114 }
1115
CreateTensorByConstantValue(const ValuePtr & value)1116 ValuePtr Common::CreateTensorByConstantValue(const ValuePtr &value) {
1117 MS_EXCEPTION_IF_NULL(value);
1118 MS_EXCEPTION_IF_NULL(value);
1119 auto type = value->type();
1120 if (Common::IsTensor(value, true) || value->isa<Number>() || value->isa<None>() ||
1121 (type != nullptr && type->isa<String>())) {
1122 return value;
1123 }
1124 tensor::TensorPtr tensor_ptr = nullptr;
1125 if (value->isa<Scalar>()) {
1126 tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
1127 } else if (value->isa<ValueTuple>()) {
1128 tensor_ptr = opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
1129 } else if (value->isa<ValueList>()) {
1130 tensor_ptr = opt::CreateTupleTensor(std::make_shared<ValueTuple>(value->cast<ValueListPtr>()->value()));
1131 } else {
1132 MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple, but get type " << value->type_name()
1133 << ", value " << value->ToString();
1134 }
1135 MS_EXCEPTION_IF_NULL(tensor_ptr);
1136 return tensor_ptr;
1137 }
1138
CacheOutputAbstract(const ValuePtr & v,const abstract::AbstractBasePtr & abs)1139 void AutoGrad::CacheOutputAbstract(const ValuePtr &v, const abstract::AbstractBasePtr &abs) {
1140 MS_EXCEPTION_IF_NULL(v);
1141 MS_EXCEPTION_IF_NULL(abs);
1142
1143 if (v->isa<tensor::BaseTensor>()) {
1144 auto tensor = v->cast<tensor::BaseTensorPtr>();
1145 tensor->set_abstract(abs);
1146 kGradAbstractConverter.CacheAbstract(abs);
1147 } else if (v->isa<ValueSequence>()) {
1148 const auto &value_seq = v->cast<ValueSequencePtr>();
1149 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
1150 if (abs_seq == nullptr) {
1151 MS_LOG(EXCEPTION) << "Abstract is not abstract sequence, get " << abs->ToString();
1152 }
1153 size_t value_size = value_seq->size();
1154 if (value_size != abs_seq->size()) {
1155 MS_LOG(EXCEPTION) << "Abstract size " << abs_seq->size() << " is not equal to value size " << value_size;
1156 }
1157 for (size_t i = 0; i < value_size; ++i) {
1158 CacheOutputAbstract(value_seq->value()[i], abs_seq->elements()[i]);
1159 }
1160 }
1161 }
1162
1163 namespace {
ConvertSimpleInferInfoToAbstract(const OpGradInfoPtr & op_grad_info)1164 void ConvertSimpleInferInfoToAbstract(const OpGradInfoPtr &op_grad_info) {
1165 MS_EXCEPTION_IF_NULL(op_grad_info);
1166 // Get inputs abstract
1167 for (const auto &v : op_grad_info->input_value) {
1168 op_grad_info->input_abs.emplace_back(kGradAbstractConverter.ConvertAbstract(v));
1169 }
1170
1171 // Get output abstract
1172 MS_EXCEPTION_IF_NULL(op_grad_info->output_value_simple_info);
1173 op_grad_info->out_abs = TransformValueSimpleInfoToAbstract(*op_grad_info->output_value_simple_info);
1174
1175 // Set abstract to tensor
1176 AutoGrad::CacheOutputAbstract(op_grad_info->out_value, op_grad_info->out_abs);
1177 MS_LOG(DEBUG) << "Get output abstract " << op_grad_info->out_abs->ToString();
1178 }
1179 } // namespace
1180
CheckAndSetAbstract(const OpGradInfoPtr & op_grad_info)1181 void AutoGrad::CheckAndSetAbstract(const OpGradInfoPtr &op_grad_info) {
1182 MS_EXCEPTION_IF_NULL(op_grad_info);
1183 if (op_grad_info->output_value_simple_info != nullptr) {
1184 MS_LOG(DEBUG) << "Convert op " << op_grad_info->op_prim->name() << " simple infer info to abstract";
1185 ConvertSimpleInferInfoToAbstract(op_grad_info);
1186 return;
1187 }
1188
1189 // View op input abs and output abs maybe nullptr
1190 if (MS_UNLIKELY(op_grad_info->input_abs.empty())) {
1191 // Get inputs abstract
1192 MS_LOG(DEBUG) << "Op " << op_grad_info->op_prim->name() << " inputs abstract not set, set it now";
1193 for (const auto &v : op_grad_info->input_value) {
1194 // For use abstract cache on tensor
1195 op_grad_info->input_abs.emplace_back(kGradAbstractConverter.ConvertAbstract(v));
1196 }
1197 }
1198 if (op_grad_info->out_abs == nullptr) {
1199 MS_LOG(EXCEPTION) << "Get output abs is nullptr";
1200 }
1201 }
1202
GetIdByPyObj(const py::object & obj)1203 std::string PyParser::GetIdByPyObj(const py::object &obj) {
1204 if (py::isinstance<tensor::BaseTensor>(obj)) {
1205 return obj.cast<tensor::BaseTensorPtr>()->id();
1206 } else if (IsStubTensor(obj)) {
1207 return ConvertStubTensor(obj)->id();
1208 } else if (py::isinstance<Cell>(obj)) {
1209 return obj.cast<CellPtr>()->id();
1210 } else if (py::isinstance<mindspore::Type>(obj)) {
1211 auto type_ptr = obj.cast<mindspore::TypePtr>();
1212 return "Type:" + type_ptr->ToString();
1213 } else if (py::isinstance<py::str>(obj)) {
1214 return "S" + obj.cast<std::string>();
1215 } else if (py::isinstance<py::bool_>(obj)) {
1216 return "B" + py::str(obj).cast<std::string>();
1217 } else if (py::isinstance<py::int_>(obj)) {
1218 return "I" + py::str(obj).cast<std::string>();
1219 } else if (py::isinstance<py::float_>(obj)) {
1220 return "F" + py::str(obj).cast<std::string>();
1221 } else if (py::isinstance<py::none>(obj)) {
1222 return "None";
1223 } else if (py::isinstance<py::ellipsis>(obj)) {
1224 return "Ellipsis";
1225 } else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
1226 return GetIdForPyTupleOrList(obj);
1227 } else if (py::isinstance<py::function>(obj)) {
1228 return GetFnInfoByPyObj(obj);
1229 }
1230 // For id with value and obj can be the same
1231 if (py::isinstance<tensor::CSRTensor>(obj) || py::isinstance<tensor::COOTensor>(obj) ||
1232 py::isinstance<tensor::RowTensor>(obj)) {
1233 return DataConvert::PyObjToValue(obj)->ToString();
1234 }
1235 return GetObjIdFromPython(obj);
1236 }
1237
GetArgsIdAndValue(const py::args & args)1238 std::pair<std::vector<std::string>, std::vector<ValuePtr>> PyParser::GetArgsIdAndValue(const py::args &args) {
1239 size_t arg_size = args.size();
1240 std::vector<std::string> input_arg_id_vec;
1241 std::vector<ValuePtr> input_arg_value_vec;
1242 input_arg_id_vec.reserve(arg_size);
1243 input_arg_value_vec.reserve(arg_size);
1244 for (size_t i = 0; i < arg_size; ++i) {
1245 if (py::isinstance<py::list>(args[i])) {
1246 (void)input_arg_value_vec.emplace_back(DataConvert::PyObjToValue(py::cast<py::tuple>(args[i])));
1247 } else {
1248 (void)input_arg_value_vec.emplace_back(DataConvert::PyObjToValue(args[i]));
1249 }
1250 (void)input_arg_id_vec.emplace_back(Common::GetIdByValue(input_arg_value_vec.back()));
1251 }
1252 return {input_arg_id_vec, input_arg_value_vec};
1253 }
1254
SetPrim(const FrontendOpRunInfoPtr & op_run_info,const py::object & prim_arg)1255 void PyParser::SetPrim(const FrontendOpRunInfoPtr &op_run_info, const py::object &prim_arg) {
1256 MS_EXCEPTION_IF_NULL(op_run_info);
1257 const auto &adapter = prim_arg.cast<PrimitivePyAdapterPtr>();
1258 MS_EXCEPTION_IF_NULL(adapter);
1259 auto prim = adapter->attached_primitive();
1260 if (prim == nullptr) {
1261 prim = std::make_shared<PrimitivePy>(prim_arg);
1262 adapter->set_attached_primitive(prim);
1263 }
1264 if (!prim->HasPyObj()) {
1265 MS_LOG(EXCEPTION) << "Pyobj is empty";
1266 }
1267 prim->EnableSharedMutex();
1268 op_run_info->op_grad_info->op_prim = prim;
1269 op_run_info->base_op_run_info.op_name = prim->name();
1270 op_run_info->signatures = prim->signatures();
1271 op_run_info->base_op_run_info.py_prim_id_ = adapter->id();
1272 }
1273
BuilidPyInputTypeString(const py::object & obj)1274 std::string PyParser::BuilidPyInputTypeString(const py::object &obj) {
1275 if (py::isinstance<py::bool_>(obj)) {
1276 return "bool";
1277 }
1278
1279 if (py::isinstance<py::int_>(obj)) {
1280 return "int";
1281 }
1282
1283 if (py::isinstance<py::float_>(obj)) {
1284 return "float";
1285 }
1286
1287 if (py::isinstance<py::str>(obj)) {
1288 return "string";
1289 }
1290
1291 if (py::isinstance<py::none>(obj)) {
1292 return "None";
1293 }
1294
1295 if (py::isinstance<mindspore::tensor::BaseTensor>(obj)) {
1296 return "Tensor";
1297 }
1298
1299 if (IsStubTensor(obj)) {
1300 return "Tensor";
1301 }
1302
1303 if (py::isinstance<py::tuple>(obj)) {
1304 std::stringstream ss;
1305 ss << "tuple<";
1306 auto tuple = obj.cast<py::tuple>();
1307 for (size_t i = 0; i < tuple.size(); i++) {
1308 if (i == 0) {
1309 ss << BuilidPyInputTypeString(tuple[i]);
1310 } else {
1311 ss << ", " << BuilidPyInputTypeString(tuple[i]);
1312 }
1313 }
1314 ss << ">";
1315 return ss.str();
1316 }
1317
1318 if (py::isinstance<py::list>(obj)) {
1319 std::stringstream ss;
1320 ss << "list<";
1321 auto list = obj.cast<py::list>();
1322 for (size_t i = 0; i < list.size(); i++) {
1323 if (i == 0) {
1324 ss << BuilidPyInputTypeString(list[i]);
1325 } else {
1326 ss << ", " << BuilidPyInputTypeString(list[i]);
1327 }
1328 }
1329 ss << ">";
1330 return ss.str();
1331 }
1332
1333 std::stringstream ss;
1334 ss << obj.get_type();
1335 return ss.str();
1336 }
1337
PrintTypeCastError(const ops::OpDefPtr & op_def,const py::list & op_inputs,size_t idx)1338 void PyParser::PrintTypeCastError(const ops::OpDefPtr &op_def, const py::list &op_inputs, size_t idx) {
1339 auto const &op_arg = op_def->args_[idx];
1340 bool is_suppport_tensor_cast = std::any_of(op_arg.cast_dtype_.begin(), op_arg.cast_dtype_.end(),
1341 [](const auto &type) { return type == ops::DT_TENSOR; });
1342 if (is_suppport_tensor_cast) {
1343 auto tensor = parse::ConvertTensorValue(op_inputs[idx]);
1344 auto PrintVectorFunc = [](const ShapeVector &shape) -> std::string {
1345 std::stringstream ss;
1346 ss << "[";
1347 for (size_t i = 0; i < shape.size(); i++) {
1348 if (i != 0) {
1349 ss << ", " << shape[i];
1350 } else {
1351 ss << shape[i];
1352 }
1353 }
1354 ss << "]";
1355 return ss.str();
1356 };
1357 if (tensor != nullptr) {
1358 MS_EXCEPTION(TypeError) << "For " << op_def->name_ << ", the " << idx << "'th input is a Tensor whose shape is "
1359 << PrintVectorFunc(tensor->shape()) << " and dtype is ["
1360 << TypeIdToString(tensor->data_type()) << "], which can not be converted to "
1361 << ops::EnumToString(op_arg.arg_dtype_) << ".";
1362 }
1363 }
1364 std::vector<std::string> op_type_list;
1365 for (size_t index = 0; index < op_inputs.size(); ++index) {
1366 (void)op_type_list.emplace_back(PyParser::BuilidPyInputTypeString(op_inputs[index]));
1367 }
1368 MS_EXCEPTION(TypeError) << ops::BuildOpErrorMsg(op_def, op_type_list);
1369 }
1370
ConvertScalarToTensor(const ValuePtr & value)1371 inline ValuePtr ConvertScalarToTensor(const ValuePtr &value) {
1372 auto fp32_imm = value->cast<FP32ImmPtr>();
1373 if (fp32_imm != nullptr) {
1374 return std::make_shared<tensor::Tensor>(fp32_imm->value());
1375 }
1376
1377 auto bool_imm = value->cast<BoolImmPtr>();
1378 if (bool_imm != nullptr) {
1379 return std::make_shared<tensor::Tensor>(bool_imm->value());
1380 }
1381
1382 auto int64_imm = value->cast<Int64ImmPtr>();
1383 if (int64_imm != nullptr) {
1384 return std::make_shared<tensor::Tensor>(int64_imm->value());
1385 }
1386
1387 MS_LOG(EXCEPTION) << "Unsupported type: " << value->ToString();
1388 }
1389
ConvertBySignature(const py::object & obj,const FrontendOpRunInfoPtr & op_run_info,size_t index)1390 inline ValuePtr ConvertBySignature(const py::object &obj, const FrontendOpRunInfoPtr &op_run_info, size_t index) {
1391 if (op_run_info->signatures.size() <= index) {
1392 return nullptr;
1393 }
1394
1395 if (op_run_info->signatures[index].dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
1396 auto convert_func = parse::GetConverterByType(static_cast<int32_t>(ops::DT_NUMBER));
1397 MS_EXCEPTION_IF_NULL(convert_func);
1398 return convert_func(obj);
1399 }
1400 return nullptr;
1401 }
1402
ParseOpInputByOpDef(const ops::OpDefPtr & op_def,const py::list & op_inputs,bool stub,const FrontendOpRunInfoPtr & op_run_info)1403 void ParseOpInputByOpDef(const ops::OpDefPtr &op_def, const py::list &op_inputs, bool stub,
1404 const FrontendOpRunInfoPtr &op_run_info) {
1405 size_t input_size = op_inputs.size();
1406 if (input_size != op_def->args_.size()) {
1407 MS_LOG(EXCEPTION) << "For Operator[" << op_def->name_ << "], the inputs number should be " << op_def->args_.size()
1408 << " but got " << op_inputs.size() << ".";
1409 }
1410 (void)op_run_info->op_grad_info->input_value.resize(input_size);
1411 for (size_t i = 0; i < op_def->args_.size(); i++) {
1412 auto const &op_arg = op_def->args_[i];
1413 op_run_info->none_init_inputs_num += static_cast<size_t>(!op_arg.as_init_arg_);
1414
1415 // Optional argument is valid for None as input.
1416 if (op_arg.is_optional_ && py::isinstance<py::none>(op_inputs[i])) {
1417 op_run_info->op_grad_info->input_value[i] = kNone;
1418 continue;
1419 }
1420
1421 ValuePtr value = nullptr;
1422 parse::OpDefConvertFunc convert_func = parse::GetConverterByType(static_cast<int32_t>(op_arg.arg_dtype_));
1423 MS_EXCEPTION_IF_NULL(convert_func);
1424 value = convert_func(op_inputs[i]);
1425 if (value != nullptr) {
1426 op_run_info->op_grad_info->input_value[i] = value;
1427 continue;
1428 }
1429
1430 // type cast has lower priority then signature cast
1431 if (!op_arg.cast_dtype_.empty()) {
1432 for (auto cast_dtype : op_arg.cast_dtype_) {
1433 convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
1434 MS_EXCEPTION_IF_NULL(convert_func);
1435 value = convert_func(op_inputs[i]);
1436 if (value != nullptr) {
1437 op_run_info->op_grad_info->input_value[i] = value;
1438 op_run_info->source_type[i] = cast_dtype;
1439 break;
1440 }
1441 }
1442 }
1443
1444 if (value == nullptr) {
1445 PyParser::PrintTypeCastError(op_def, op_inputs, i);
1446 }
1447 }
1448 }
1449
ParseOpInputByPythonObj(const FrontendOpRunInfoPtr & op_run_info,const py::list & op_inputs,bool stub)1450 void PyParser::ParseOpInputByPythonObj(const FrontendOpRunInfoPtr &op_run_info, const py::list &op_inputs, bool stub) {
1451 MS_EXCEPTION_IF_NULL(op_run_info);
1452 op_run_info->input_size = op_inputs.size();
1453 op_run_info->op_grad_info->input_abs.resize(op_run_info->input_size);
1454 op_run_info->source_type.resize(op_run_info->input_size);
1455 op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
1456
1457 auto op_def = mindspore::ops::GetOpDef(op_run_info->base_op_run_info.op_name);
1458 if (op_def == nullptr) {
1459 op_run_info->op_grad_info->input_value.resize(op_run_info->input_size);
1460 op_run_info->none_init_inputs_num = op_run_info->input_size;
1461 for (size_t i = 0; i < op_run_info->input_size; ++i) {
1462 op_run_info->op_grad_info->input_value[i] = DataConvert::PyObjToValue(op_inputs[i], stub);
1463 }
1464 } else {
1465 op_run_info->none_init_inputs_num = 0;
1466 ParseOpInputByOpDef(op_def, op_inputs, stub, op_run_info);
1467 }
1468 }
1469
ValueToPyObj(const ValuePtr & v)1470 py::object DataConvert::ValueToPyObj(const ValuePtr &v) { return ValueToPyData(v); }
1471
PyObjToValue(const py::object & obj,bool stub)1472 ValuePtr DataConvert::PyObjToValue(const py::object &obj, bool stub) {
1473 ValuePtr converted_ret;
1474 if (stub) {
1475 converted_ret = parse::data_converter::PyDataToStubNode(obj);
1476 } else {
1477 converted_ret = parse::data_converter::PyDataToValue(obj);
1478 }
1479 if (converted_ret == nullptr) {
1480 MS_LOG(EXCEPTION) << "Attribute convert error with type: " << ConvertPyObjToString(obj);
1481 }
1482 return converted_ret;
1483 }
1484
BaseRefToValue(const BaseRef & value,bool requires_grad,bool is_out_sequence)1485 ValuePtr DataConvert::BaseRefToValue(const BaseRef &value, bool requires_grad, bool is_out_sequence) {
1486 MS_EXCEPTION_IF_NULL(value);
1487 ValuePtr ret;
1488 if (utils::isa<tensor::BaseTensorPtr>(value)) {
1489 auto t = utils::cast<tensor::BaseTensorPtr>(value);
1490 if (requires_grad) {
1491 t->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1492 t->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1493 }
1494 ret = t;
1495 } else if (utils::isa<ValuePtr>(value)) {
1496 ret = utils::cast<ValuePtr>(value);
1497 } else if (utils::isa<VectorRef>(value)) {
1498 auto vec_ref = utils::cast<VectorRef>(value);
1499 ret = VectorRefToValue(vec_ref, requires_grad, is_out_sequence);
1500 } else if (utils::isa<int>(value)) {
1501 ret = MakeValue(utils::cast<int>(value));
1502 } else if (utils::isa<float>(value)) {
1503 ret = MakeValue(utils::cast<float>(value));
1504 } else if (utils::isa<double>(value)) {
1505 ret = MakeValue(utils::cast<double>(value));
1506 } else if (utils::isa<bool>(value)) {
1507 ret = MakeValue(utils::cast<bool>(value));
1508 } else {
1509 MS_LOG(EXCEPTION) << "value is not support type " << value.ToString();
1510 }
1511 return ret;
1512 }
1513
VectorRefToValue(const VectorRef & vec_ref,bool requires_grad,bool is_out_sequence)1514 ValuePtr DataConvert::VectorRefToValue(const VectorRef &vec_ref, bool requires_grad, bool is_out_sequence) {
1515 MS_EXCEPTION_IF_NULL(vec_ref);
1516
1517 size_t value_size = vec_ref.size();
1518 if (value_size == 1 && !is_out_sequence) {
1519 return BaseRefToValue(vec_ref[0], requires_grad, is_out_sequence);
1520 }
1521 std::vector<ValuePtr> v_list(value_size);
1522 for (size_t i = 0; i < value_size; ++i) {
1523 v_list[i] = BaseRefToValue(vec_ref[i], requires_grad, is_out_sequence);
1524 }
1525 return std::make_shared<ValueTuple>(v_list);
1526 }
1527
FlattenValueSeqArg(const ValuePtr & v,bool is_only_flatten_tensor_seq,bool is_filter_tensor,std::vector<ValuePtr> * flatten_v)1528 void DataConvert::FlattenValueSeqArg(const ValuePtr &v, bool is_only_flatten_tensor_seq, bool is_filter_tensor,
1529 std::vector<ValuePtr> *flatten_v) {
1530 MS_EXCEPTION_IF_NULL(v);
1531 MS_EXCEPTION_IF_NULL(flatten_v);
1532 MS_LOG(DEBUG) << "Get is only flatten tensor seq " << is_only_flatten_tensor_seq;
1533 if (v->isa<tensor::BaseTensor>()) {
1534 (void)flatten_v->emplace_back(v);
1535 } else if (v->isa<ValueSequence>()) {
1536 const auto &v_vec = v->cast<ValueSequencePtr>()->value();
1537 if (v_vec.empty() && !is_filter_tensor) {
1538 MS_LOG(DEBUG) << "Get empty tuple value";
1539 (void)flatten_v->emplace_back(v);
1540 MS_LOG(DEBUG) << "Get empty value sequence";
1541 return;
1542 }
1543 if (is_only_flatten_tensor_seq && !v_vec.front()->isa<tensor::BaseTensor>()) {
1544 (void)flatten_v->emplace_back(v);
1545 } else {
1546 for (const auto &elem : v_vec) {
1547 FlattenValueSeqArg(elem, is_only_flatten_tensor_seq, is_filter_tensor, flatten_v);
1548 }
1549 }
1550 } else if (is_only_flatten_tensor_seq) {
1551 if (v->isa<ValueDictionary>()) {
1552 auto dic_v = v->cast<ValueDictionaryPtr>();
1553 for (const auto &elem : dic_v->value()) {
1554 FlattenValueSeqArg(elem.second, is_only_flatten_tensor_seq, is_filter_tensor, flatten_v);
1555 }
1556 } else {
1557 (void)flatten_v->emplace_back(v);
1558 }
1559 } else if (!is_filter_tensor) {
1560 MS_LOG(DEBUG) << "Get not tensor value: " << v->ToString();
1561 (void)flatten_v->emplace_back(v);
1562 }
1563 }
1564
FlattenTensorSeqInValue(const ValuePtr & v)1565 ValuePtrList DataConvert::FlattenTensorSeqInValue(const ValuePtr &v) {
1566 MS_EXCEPTION_IF_NULL(v);
1567 ValuePtrList outputs;
1568 FlattenValueSeqArg(v, true, false, &outputs);
1569 return outputs;
1570 }
1571
FlattenTensorSeqInValueSeq(const ValuePtrList & v,bool only_flatten_tensor)1572 ValuePtrList DataConvert::FlattenTensorSeqInValueSeq(const ValuePtrList &v, bool only_flatten_tensor) {
1573 ValuePtrList outputs;
1574 for (const auto &item : v) {
1575 FlattenValueSeqArg(item, only_flatten_tensor, false, &outputs);
1576 }
1577 return outputs;
1578 }
1579
FlattenArgs(const std::vector<ValuePtr> & v_vec,std::vector<ValuePtr> * flatten_v,bool has_sens)1580 void DataConvert::FlattenArgs(const std::vector<ValuePtr> &v_vec, std::vector<ValuePtr> *flatten_v, bool has_sens) {
1581 MS_EXCEPTION_IF_NULL(flatten_v);
1582 if (v_vec.empty()) {
1583 MS_LOG(EXCEPTION) << "For bprop graph input value size should be greatet than 0, but get empty.";
1584 }
1585 size_t input_size = has_sens ? v_vec.size() - 1 : v_vec.size();
1586 for (size_t i = 0; i < input_size; ++i) {
1587 const auto &v = v_vec[i];
1588 MS_EXCEPTION_IF_NULL(v);
1589 MS_LOG(DEBUG) << "Get v is " << v->ToString();
1590 (void)flatten_v->emplace_back(v);
1591 }
1592 if (has_sens) {
1593 if (Common::IsTensor(v_vec[input_size])) {
1594 (void)flatten_v->emplace_back(v_vec[input_size]);
1595 } else if (v_vec[input_size]->isa<ValueSequence>()) {
1596 MS_LOG(DEBUG) << "Get value tuple size " << v_vec[input_size]->cast<ValueSequencePtr>()->size();
1597 FlattenValueSeqArg(v_vec[input_size], false, false, flatten_v);
1598 }
1599 }
1600 }
1601
RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,size_t input_index)1602 bool DataConvert::RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
1603 size_t input_index) {
1604 MS_EXCEPTION_IF_NULL(op_run_info);
1605 if (op_run_info->input_to_attr.empty()) {
1606 return false;
1607 }
1608 MS_EXCEPTION_IF_NULL(v);
1609 if (op_run_info->input_to_attr.find(input_index) == op_run_info->input_to_attr.end()) {
1610 return false;
1611 }
1612 const auto &input_names_value = op_run_info->op_grad_info->op_prim->GetAttr(kAttrInputNames);
1613 if (input_names_value == nullptr) {
1614 return false;
1615 }
1616 const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
1617 if (input_index >= input_names_vec.size()) {
1618 MS_LOG(EXCEPTION) << "The input index: " << input_index << " is larger than the input names vector size!";
1619 }
1620 const auto &input_name = input_names_vec[input_index];
1621 if (v->isa<tensor::BaseTensor>()) {
1622 auto tensor = v->cast<tensor::BaseTensorPtr>();
1623 if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
1624 return false;
1625 }
1626 }
1627 (void)op_run_info->op_grad_info->op_prim->AddAttr(input_name, v);
1628 return true;
1629 }
1630
Init(const PrimitivePtr & prim,const py::list & args)1631 FrontendOpRunInfoPtr PyBoost::Init(const PrimitivePtr &prim, const py::list &args) {
1632 const auto &pynative_executor = Common::GetPyNativeExecutor();
1633 const auto &forward_executor = pynative_executor->forward_executor();
1634 const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
1635 prim->EnableSharedMutex();
1636 op_run_info->op_grad_info->op_prim = prim;
1637 op_run_info->base_op_run_info.op_name = prim->name();
1638 pynative_executor->StoreAsyncStatus(op_run_info);
1639 forward_executor->InitOpRunInfo(op_run_info);
1640 return op_run_info;
1641 }
1642
MakeOutputValue(const FrontendOpRunInfoPtr & op_run_info,const kernel::pyboost::OpPtr & op)1643 void PyBoost::MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const kernel::pyboost::OpPtr &op) {
1644 size_t size = op->outputs().size();
1645 // If op are Contiguous, Cast(precision, implicit cast), which are internal ops and not have stub output
1646 bool is_tuple_output = op_run_info->stub_output != nullptr ? op_run_info->stub_output->isa<stub::SequenceNode>()
1647 : PredictOutTypeByName(op->primitive()->name()) == kTuple;
1648 if (op->output_value_simple_info() != nullptr) {
1649 op_run_info->op_grad_info->output_value_simple_info = op->output_value_simple_info();
1650 op_run_info->op_grad_info->output_value_simple_info->is_tuple_output_ = is_tuple_output;
1651 }
1652 if (!is_tuple_output) {
1653 MS_EXCEPTION_IF_CHECK_FAIL(size == kSizeOne, "The size is more than one!");
1654 if (op->output_abs() != nullptr || op->output_value_simple_info() != nullptr) {
1655 // Set auto grad meta data for op output
1656 if (op_run_info->requires_grad) {
1657 op->outputs()[0]->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1658 op->outputs()[0]->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1659 }
1660 op_run_info->real_out = op->outputs()[0];
1661 return;
1662 }
1663 }
1664 std::vector<ValuePtr> output_values(size);
1665 for (size_t i = 0; i < size; ++i) {
1666 const auto &output_tensor = op->outputs()[i];
1667 MS_EXCEPTION_IF_NULL(output_tensor);
1668 // Set auto grad meta data for op outputs
1669 if (op_run_info->requires_grad) {
1670 output_tensor->set_auto_grad_meta_data(std::make_shared<AutoGradMetaData>());
1671 output_tensor->auto_grad_meta_data()->set_input_type(InputType::kOpOutput);
1672 }
1673 output_values[i] = output_tensor;
1674 }
1675 op_run_info->real_out = std::make_shared<ValueTuple>(output_values);
1676 }
1677
UpdateStubOutput(const FrontendOpRunInfoPtr & op_run_info,const AbstractBasePtr & abstract,const kernel::pyboost::OpPtr & op)1678 void PyBoost::UpdateStubOutput(const FrontendOpRunInfoPtr &op_run_info, const AbstractBasePtr &abstract,
1679 const kernel::pyboost::OpPtr &op) {
1680 MS_EXCEPTION_IF_NULL(op);
1681 if (op_run_info->stub_output == nullptr) {
1682 return;
1683 }
1684 if (MS_UNLIKELY(op->output_value_simple_info() != nullptr)) {
1685 op_run_info->stub_output->SetValueSimpleInfo(op->output_value_simple_info());
1686 } else {
1687 MS_EXCEPTION_IF_NULL(abstract);
1688 auto success = op_run_info->stub_output->SetAbstract(abstract);
1689 if (!success) {
1690 const auto &op_name = op_run_info->base_op_run_info.op_name;
1691 MS_EXCEPTION(TypeError) << "The predict type and infer type is not match, predict type is "
1692 << PredictOutType(op_run_info) << ", infer type is " << abstract->BuildType()
1693 << ", the name of operator is [" << op_name
1694 << "]. Please modify or add predict type of operator in predict_out_type_map.h.";
1695 }
1696 MS_LOG(DEBUG) << "Update StubNode abstract " << abstract->ToString();
1697 }
1698 op_run_info->stub_output->SetValue(op_run_info->real_out);
1699 }
1700
UpdateOpRunInfo(const kernel::pyboost::OpPtr & op,const FrontendOpRunInfoPtr & op_run_info)1701 void PyBoost::UpdateOpRunInfo(const kernel::pyboost::OpPtr &op, const FrontendOpRunInfoPtr &op_run_info) {
1702 MS_EXCEPTION_IF_NULL(op);
1703 MS_EXCEPTION_IF_NULL(op_run_info);
1704 // Create output value
1705 MakeOutputValue(op_run_info, op);
1706
1707 // Set output value to python
1708 UpdateStubOutput(op_run_info, op->output_abs(), op);
1709 }
1710
DataSyncForGraph(const kernel::pyboost::OpPtr & op,ValuePtrList && op_inputs)1711 void PyBoost::DataSyncForGraph(const kernel::pyboost::OpPtr &op, ValuePtrList &&op_inputs) {
1712 auto ms_context = MsContext::GetInstance();
1713 MS_EXCEPTION_IF_NULL(ms_context);
1714 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
1715 !runtime::OpExecutor::GetInstance().async_for_graph()) {
1716 // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
1717 // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
1718 for (const auto &output : op->outputs()) {
1719 auto device_address = std::static_pointer_cast<device::DeviceAddress>(output->device_address());
1720 runtime::DeviceAddressUtils::CreateKernelTensor(device_address, output);
1721 output->data_sync(true);
1722 output->set_abstract(std::weak_ptr<abstract::AbstractBase>());
1723 }
1724 for (const auto &input : op_inputs) {
1725 if (input->isa<tensor::BaseTensor>()) {
1726 auto tensor = input->cast<tensor::BaseTensorPtr>();
1727 auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
1728 runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
1729 }
1730 UnsetValueAbstractCache(input);
1731 }
1732 }
1733 }
1734
ConvertPrimitive(const py::object & obj)1735 PrimitivePtr PyBoost::ConvertPrimitive(const py::object &obj) {
1736 const auto &adapter = obj.cast<PrimitivePyAdapterPtr>();
1737 MS_EXCEPTION_IF_NULL(adapter);
1738
1739 auto prim = adapter->attached_primitive();
1740 if (prim == nullptr) {
1741 #ifndef ENABLE_TEST
1742 return std::make_shared<Primitive>(adapter->name(), adapter->attrs());
1743 #else
1744 prim = std::make_shared<PrimitivePy>(obj);
1745 adapter->set_attached_primitive(prim);
1746 #endif
1747 }
1748 if (!prim->HasPyObj()) {
1749 MS_LOG(EXCEPTION) << "Pyobj is empty";
1750 }
1751 prim->EnableSharedMutex();
1752 return prim;
1753 }
1754
RunPyFunction(const PrimitivePtr & prim,const py::list & args)1755 py::object PyBoost::RunPyFunction(const PrimitivePtr &prim, const py::list &args) {
1756 py::tuple wrap_args(kIndex3);
1757 if (prim->isa<PrimitivePy>()) {
1758 auto prim_py = prim->cast<PrimitivePyPtr>();
1759 if (!prim_py->HasPyObj()) {
1760 MS_LOG(EXCEPTION) << "Prim has not python obj!";
1761 }
1762 wrap_args[kIndex0] = prim_py->GetPyObj();
1763 } else {
1764 wrap_args[kIndex0] = std::make_shared<PrimitivePyAdapter>(prim->name());
1765 }
1766 wrap_args[kIndex1] = prim->name();
1767 wrap_args[kIndex2] = args;
1768 const auto &pynative_executor = Common::GetPyNativeExecutor();
1769 return pynative_executor->RunOpStub(wrap_args);
1770 }
1771
SetAnyValueForAbstract(const kernel::pyboost::OpPtr & op)1772 void PyBoost::SetAnyValueForAbstract(const kernel::pyboost::OpPtr &op) {
1773 const auto &input_abs = op->input_abs();
1774 for (const auto &abs : input_abs) {
1775 Common::SetAbstractValueToAnyValue(abs);
1776 }
1777 Common::SetAbstractValueToAnyValue(op->output_abs());
1778 }
1779
DoGrad(const kernel::pyboost::OpPtr & op,const FrontendOpRunInfoPtr & op_run_info,ValuePtrList && op_inputs)1780 void PyBoost::DoGrad(const kernel::pyboost::OpPtr &op, const FrontendOpRunInfoPtr &op_run_info,
1781 ValuePtrList &&op_inputs) {
1782 static const std::string kDoGradName = "DoGrad";
1783 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeFrontendTask,
1784 kDoGradName, false);
1785 MS_EXCEPTION_IF_NULL(op);
1786 // Update op grad info
1787 op_run_info->op_grad_info->input_value = std::move(op_inputs);
1788 op_run_info->op_grad_info->out_value = op_run_info->real_out;
1789
1790 const auto &pynative_executor = Common::GetPyNativeExecutor();
1791 const auto &forward = pynative_executor->forward_executor();
1792 op_run_info->op_grad_info->output_size = op->outputs().size();
1793 if (op->output_value_simple_info() == nullptr) {
1794 if (op->input_abs().size() != op_run_info->input_size) {
1795 MS_LOG(EXCEPTION) << "Op " << op_run_info->base_op_run_info.op_name << " input size is "
1796 << op_run_info->input_size << " but got input abstract size " << op->input_abs().size();
1797 }
1798 SetAnyValueForAbstract(op);
1799 op_run_info->op_grad_info->input_abs = op->input_abs();
1800 op_run_info->base_op_run_info.abstract = op->output_abs();
1801 }
1802
1803 if (MS_LIKELY(!forward->grad()->top_cell()->is_bprop_need_get_forward_graph())) {
1804 // Check and set input auto grad meta info and InputType
1805 op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
1806 for (size_t index = 0; index < op_run_info->input_size; ++index) {
1807 // Inplace input_value with contiguous tensor.
1808 RefreshGradContiguousTensor(op_run_info, index);
1809 const ValuePtr &input_object = op_run_info->op_grad_info->input_value[index];
1810 DataConvert::MarkInputs(op_run_info, input_object, index, forward->grad()->top_cell());
1811 }
1812 }
1813 forward->ForwardOpGradImpl(op_run_info);
1814 }
1815
PlantTensorTupleToVector(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & value_seq,size_t index,const TopCellInfoPtr & top_cell)1816 void DataConvert::PlantTensorTupleToVector(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &value_seq,
1817 size_t index, const TopCellInfoPtr &top_cell) {
1818 MS_EXCEPTION_IF_NULL(op_run_info);
1819 MS_EXCEPTION_IF_NULL(value_seq);
1820 if (op_run_info->requires_grad) {
1821 op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kOpOutput;
1822 }
1823 for (const auto &v : value_seq->value()) {
1824 if (!v->isa<tensor::BaseTensor>()) {
1825 MS_LOG(EXCEPTION) << "The input object is not a tensor!";
1826 }
1827 InputType input_type = InputType::kInput;
1828 auto tensor = v->cast<tensor::BaseTensorPtr>();
1829 MS_EXCEPTION_IF_NULL(tensor);
1830 if (tensor->is_parameter()) {
1831 input_type = InputType::kParameter;
1832 }
1833 if (op_run_info->requires_grad) {
1834 auto grad_type = Common::SetTensorGradInfo(tensor, top_cell);
1835 if (Common::IsParam(grad_type)) {
1836 op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kParameter;
1837 }
1838 }
1839 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(tensor);
1840 (void)op_run_info->base_op_run_info.input_types.emplace_back(input_type);
1841 }
1842
1843 if (!op_run_info->base_op_run_info.dyn_input_sizes.empty()) {
1844 int64_t elem_size = SizeToLong(value_seq->size());
1845 if (op_run_info->base_op_run_info.dyn_input_sizes.size() != op_run_info->input_size) {
1846 for (size_t i = op_run_info->base_op_run_info.dyn_input_sizes.size(); i < index; ++i) {
1847 (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(-1);
1848 }
1849 (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(elem_size);
1850 } else {
1851 op_run_info->base_op_run_info.dyn_input_sizes[index] = elem_size;
1852 }
1853 } else {
1854 for (size_t i = 0; i < index; ++i) {
1855 (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(-1);
1856 }
1857 (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(SizeToLong(value_seq->size()));
1858 }
1859 }
1860
ConvertValueDictToValueTuple(const ValuePtr & v)1861 ValuePtr DataConvert::ConvertValueDictToValueTuple(const ValuePtr &v) {
1862 MS_EXCEPTION_IF_NULL(v);
1863 const auto &dic_v = v->cast<ValueDictionaryPtr>();
1864 MS_EXCEPTION_IF_NULL(dic_v);
1865 std::vector<ValuePtr> v_list;
1866 (void)std::transform(dic_v->value().begin(), dic_v->value().end(), std::back_inserter(v_list),
1867 [](const std::pair<ValuePtr, ValuePtr> &elem) { return elem.second; });
1868 return std::make_shared<ValueTuple>(v_list);
1869 }
1870
ConvertMapTensor(const FrontendOpRunInfoPtr & op_run_info,const tensor::MapTensorPtr & map_tensor,const TopCellInfoPtr & top_cell,size_t index)1871 void DataConvert::ConvertMapTensor(const FrontendOpRunInfoPtr &op_run_info, const tensor::MapTensorPtr &map_tensor,
1872 const TopCellInfoPtr &top_cell, size_t index) {
1873 MS_EXCEPTION_IF_NULL(op_run_info);
1874 MS_EXCEPTION_IF_NULL(map_tensor);
1875 constexpr int input_num = 1;
1876 const auto input_names = op_run_info->op_grad_info->op_prim->GetAttr(kAttrInputNames);
1877 if (input_names == nullptr) {
1878 MS_LOG(DEBUG) << "input_names are nullptr";
1879 return;
1880 }
1881 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(map_tensor);
1882 const auto it = op_run_info->base_op_run_info.input_types.end();
1883 (void)op_run_info->base_op_run_info.input_types.insert(it, input_num, InputType::kParameter);
1884 if (op_run_info->requires_grad) {
1885 op_run_info->op_grad_info->input_value_grad_type[index] = Common::SetTensorGradInfo(map_tensor, top_cell);
1886 }
1887 }
1888
ConvertCSRTensorToTensorList(const FrontendOpRunInfoPtr & op_run_info,const tensor::CSRTensorPtr & csr_tensor,const TopCellInfoPtr & top_cell,size_t index)1889 void DataConvert::ConvertCSRTensorToTensorList(const FrontendOpRunInfoPtr &op_run_info,
1890 const tensor::CSRTensorPtr &csr_tensor, const TopCellInfoPtr &top_cell,
1891 size_t index) {
1892 MS_EXCEPTION_IF_NULL(op_run_info);
1893 MS_EXCEPTION_IF_NULL(csr_tensor);
1894 constexpr int input_num = 3;
1895 const auto input_names = op_run_info->op_grad_info->op_prim->GetAttr(kAttrInputNames);
1896 if (input_names == nullptr) {
1897 MS_LOG(DEBUG) << "input_names are nullptr";
1898 return;
1899 }
1900
1901 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(csr_tensor->GetIndptr());
1902 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(csr_tensor->GetIndices());
1903 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(csr_tensor->GetValues());
1904 const auto it = op_run_info->base_op_run_info.input_types.end();
1905 (void)op_run_info->base_op_run_info.input_types.insert(it, input_num, InputType::kInput);
1906 op_run_info->op_grad_info->op_prim->set_attr("is_csr", MakeValue(true));
1907 op_run_info->op_grad_info->op_prim->set_attr("dense_shape", MakeValue(csr_tensor->shape()));
1908 if (op_run_info->requires_grad) {
1909 op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kOpOutput;
1910 for (int i = 0; i < input_num; ++i) {
1911 auto iter = op_run_info->base_op_run_info.expanded_input_values.rbegin() + i;
1912 auto grad_type = Common::SetTensorGradInfo((*iter)->cast<tensor::BaseTensorPtr>(), top_cell);
1913 if (Common::IsParam(grad_type)) {
1914 op_run_info->op_grad_info->input_value_grad_type[index] = InputType::kParameter;
1915 }
1916 }
1917 }
1918 }
1919
ConvertValueTensorId(const ValuePtr & value,std::vector<std::string> * converted_tensor_id)1920 void DataConvert::ConvertValueTensorId(const ValuePtr &value, std::vector<std::string> *converted_tensor_id) {
1921 if (value->isa<tensor::BaseTensor>()) {
1922 (void)converted_tensor_id->emplace_back(value->cast<tensor::BaseTensorPtr>()->id());
1923 MS_LOG(DEBUG) << "Get top cell output tensor id " << converted_tensor_id->back();
1924 } else if (value->isa<ValueSequence>()) {
1925 const auto &seq = value->cast<ValueSequencePtr>();
1926 for (const auto &val : seq->value()) {
1927 ConvertValueTensorId(val, converted_tensor_id);
1928 }
1929 } else if (value->isa<ValueDictionary>()) {
1930 ConvertValueTensorId(ConvertValueDictToValueTuple(value), converted_tensor_id);
1931 }
1932 }
1933
ConvertTupleValueToTensor(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & value_seq,size_t index,const TopCellInfoPtr & top_cell)1934 void DataConvert::ConvertTupleValueToTensor(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &value_seq,
1935 size_t index, const TopCellInfoPtr &top_cell) {
1936 MS_EXCEPTION_IF_NULL(op_run_info);
1937 MS_EXCEPTION_IF_NULL(value_seq);
1938
1939 const auto &tuple_inputs = value_seq->value();
1940 if (tuple_inputs.empty()) {
1941 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(value_seq);
1942 (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1943 return;
1944 }
1945 if (tuple_inputs[0]->isa<tensor::BaseTensor>()) {
1946 PlantTensorTupleToVector(op_run_info, value_seq, index, top_cell);
1947 } else {
1948 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(value_seq);
1949 (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1950 }
1951 }
1952
MarkInputs(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,size_t index,const TopCellInfoPtr & top_cell)1953 void DataConvert::MarkInputs(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v, size_t index,
1954 const TopCellInfoPtr &top_cell) {
1955 MS_EXCEPTION_IF_NULL(op_run_info);
1956 MS_EXCEPTION_IF_NULL(v);
1957 tensor::BaseTensorPtr tensor_ptr = nullptr;
1958 InputType input_type = InputType::kInput;
1959 if (v->isa<tensor::BaseTensor>()) {
1960 tensor_ptr = v->cast<tensor::BaseTensorPtr>();
1961 if (tensor_ptr->is_parameter()) {
1962 input_type = InputType::kParameter;
1963 }
1964 if (op_run_info->requires_grad) {
1965 op_run_info->op_grad_info->input_value_grad_type[index] = Common::SetTensorGradInfo(tensor_ptr, top_cell);
1966 }
1967 } else if (v->isa<BoolImm>() || v->isa<FloatImm>() || v->isa<Type>() || v->isa<StringImm>() || v->isa<None>()) {
1968 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(v);
1969 (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1970 return;
1971 } else if (v->isa<IntegerImm>()) {
1972 if (op_run_info->base_op_run_info.op_name == prim::kPrimCSRReduceSum->name()) {
1973 int64_t input = v->cast<Int64ImmPtr>()->value();
1974 op_run_info->op_grad_info->op_prim->set_attr("axis", MakeValue(input));
1975 return;
1976 }
1977 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(v);
1978 (void)op_run_info->base_op_run_info.input_types.emplace_back(InputType::kConstant);
1979 return;
1980 } else if (v->isa<ValueSequence>()) {
1981 ConvertTupleValueToTensor(op_run_info, v->cast<ValueSequencePtr>(), index, top_cell);
1982 return;
1983 } else if (v->isa<tensor::MapTensor>()) {
1984 ConvertMapTensor(op_run_info, v->cast<tensor::MapTensorPtr>(), top_cell, index);
1985 return;
1986 } else if (v->isa<tensor::CSRTensor>()) {
1987 ConvertCSRTensorToTensorList(op_run_info, v->cast<tensor::CSRTensorPtr>(), top_cell, index);
1988 return;
1989 } else if (v->isa<Monad>()) {
1990 return;
1991 } else if (v->isa<parse::InterpretedObject>()) {
1992 MS_EXCEPTION(TypeError) << "Not support for " << v->ToString();
1993 } else {
1994 MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
1995 }
1996 MS_EXCEPTION_IF_NULL(tensor_ptr);
1997 (void)op_run_info->base_op_run_info.expanded_input_values.emplace_back(tensor_ptr);
1998 (void)op_run_info->base_op_run_info.input_types.emplace_back(input_type);
1999 }
2000
ReplaceReduceAxis(const FrontendOpRunInfoPtr & op_run_info)2001 void ReplaceReduceAxis(const FrontendOpRunInfoPtr &op_run_info) {
2002 MS_EXCEPTION_IF_NULL(op_run_info);
2003 if (!common::AnfAlgo::IsReduceOp(op_run_info->base_op_run_info.op_name)) {
2004 return;
2005 }
2006 const auto &inputs = op_run_info->base_op_run_info.expanded_input_values;
2007 constexpr size_t kReduceOpInputNum = 2;
2008 if (inputs.size() < kReduceOpInputNum) {
2009 MS_LOG(EXCEPTION) << "Invalid input tensor size " << inputs.size() << " of Op "
2010 << op_run_info->base_op_run_info.op_name;
2011 }
2012
2013 MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info);
2014 const auto &op_prim = op_run_info->op_grad_info->op_prim;
2015 MS_EXCEPTION_IF_NULL(op_prim);
2016 if (op_prim->HasAttr(kAttrSkipMode) && GetValue<bool>(op_prim->GetAttr(kAttrSkipMode))) {
2017 return;
2018 }
2019
2020 // 2nd input tensor is {} or nulltpr, means reduce all axis.
2021 bool reduce_all_axis = false;
2022 if (inputs[kIndex1]->isa<ValueSequence>()) {
2023 auto seq_size = inputs[1]->cast<ValueSequencePtr>()->size();
2024 reduce_all_axis = seq_size == 0;
2025 } else if (inputs[kIndex1]->isa<None>()) {
2026 reduce_all_axis = true;
2027 }
2028 if (reduce_all_axis) {
2029 auto size = inputs[0]->cast<tensor::BaseTensorPtr>()->shape().size();
2030 // For example, input 0 is Tensor(shape=[], value=1), the axis to reduce is 0.
2031 std::vector<ValuePtr> axis = {std::make_shared<Int64Imm>(0)};
2032 for (size_t i = 1; i < size; ++i) {
2033 axis.push_back(std::make_shared<Int64Imm>(static_cast<int64_t>(i)));
2034 }
2035 op_run_info->base_op_run_info.expanded_input_values[1] = std::make_shared<ValueTuple>(axis);
2036 }
2037 }
2038
GetInputTensor(const FrontendOpRunInfoPtr & op_run_info,const TopCellInfoPtr & top_cell)2039 void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell) {
2040 MS_EXCEPTION_IF_NULL(op_run_info);
2041
2042 (void)op_run_info->base_op_run_info.expanded_input_values.reserve(op_run_info->input_size);
2043 (void)op_run_info->base_op_run_info.input_types.reserve(op_run_info->input_size);
2044 // Get input tensors.
2045 op_run_info->op_grad_info->op_prim->BeginRecordAddAttr();
2046 for (size_t index = 0; index < op_run_info->input_size; ++index) {
2047 const ValuePtr &input_object = op_run_info->op_grad_info->input_value[index];
2048 // convert const input to attr
2049 if (RunOpConvertConstInputToAttr(op_run_info, input_object, index)) {
2050 continue;
2051 }
2052 // Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
2053 MarkInputs(op_run_info, input_object, index, top_cell);
2054 // -1 indicates input_object is not a dynInput
2055 if (!op_run_info->base_op_run_info.dyn_input_sizes.empty() && !input_object->isa<ValueSequence>()) {
2056 (void)op_run_info->base_op_run_info.dyn_input_sizes.emplace_back(-1);
2057 }
2058 }
2059 op_run_info->op_grad_info->op_prim->EndRecordAddAttr();
2060 ReplaceReduceAxis(op_run_info);
2061 AddDynInputsSizesAttr(op_run_info);
2062 }
2063
2064 namespace {
2065 const mindspore::HashSet<std::string> kGradBlackList{kMakeTupleOpName, kMakeListOpName,
2066 kTupleGetItemOpName, kStopGradientOpName,
2067 kUpdateStateOpName, kNPUAllocFloatStatusOpName,
2068 kNPUGetFloatStatusOpName, kNPUClearFloatStatusOpName};
2069
2070 mindspore::HashMap<std::string, pipeline::ResourcePtr> jit_call_graph_compile_cache_;
2071
CreateMakeTupleNode(const KernelGraphPtr & tape,const ValueSequencePtr & tuple,const abstract::AbstractSequencePtr & abs_seq,const SpecialType & type)2072 AnfNodePtr CreateMakeTupleNode(const KernelGraphPtr &tape, const ValueSequencePtr &tuple,
2073 const abstract::AbstractSequencePtr &abs_seq, const SpecialType &type) {
2074 AnfNodePtrList args{NewValueNode(prim::kPrimMakeTuple)};
2075 for (size_t i = 0; i < tuple->size(); ++i) {
2076 AnfNodePtr special_like_value = AutoGrad::BuildSpecialNode(tape, tuple->value()[i], abs_seq->elements()[i], type);
2077 (void)args.emplace_back(special_like_value);
2078 }
2079 auto special_like_value = tape->FuncGraph::NewCNode(args);
2080 special_like_value->set_abstract(abs_seq);
2081 return special_like_value;
2082 }
2083
CreateMakeDictNode(const KernelGraphPtr & tape,const ValueDictionaryPtr & v_dict,const abstract::AbstractDictionaryPtr & abs_dict,const SpecialType & type)2084 AnfNodePtr CreateMakeDictNode(const KernelGraphPtr &tape, const ValueDictionaryPtr &v_dict,
2085 const abstract::AbstractDictionaryPtr &abs_dict, const SpecialType &type) {
2086 MS_EXCEPTION_IF_NULL(tape);
2087 MS_EXCEPTION_IF_NULL(v_dict);
2088 MS_EXCEPTION_IF_NULL(abs_dict);
2089 AnfNodePtrList key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2090 AnfNodePtrList value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2091 abstract::AbstractBasePtrList local_key_abs_inputs;
2092 abstract::AbstractBasePtrList local_value_abs_inputs;
2093 for (size_t i = 0; i < v_dict->size(); ++i) {
2094 (void)key_inputs.emplace_back(
2095 Common::CreateValueNodeByValue(v_dict->value()[i].first, abs_dict->elements()[i].first));
2096 (void)local_key_abs_inputs.emplace_back(abs_dict->elements()[i].first);
2097 AnfNodePtr special_like_value =
2098 AutoGrad::BuildSpecialNode(tape, v_dict->value()[i].second, abs_dict->elements()[i].second, type);
2099 (void)value_inputs.emplace_back(special_like_value);
2100 (void)local_value_abs_inputs.emplace_back(abs_dict->elements()[i].second);
2101 }
2102 auto local_key_node = tape->NewCNode(key_inputs);
2103 local_key_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_key_abs_inputs));
2104 auto local_value_node = tape->NewCNode(value_inputs);
2105 local_value_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_value_abs_inputs));
2106 auto dict_node = tape->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2107 dict_node->set_abstract(abs_dict);
2108 return dict_node;
2109 }
2110
GetSparseTensorShapeNode(const ShapeVector & shape)2111 ValueNodePtr GetSparseTensorShapeNode(const ShapeVector &shape) {
2112 auto value_shape = NewValueNode(shape);
2113 std::vector<abstract::AbstractBasePtr> abstract_shape;
2114 (void)std::transform(
2115 shape.begin(), shape.end(), std::back_inserter(abstract_shape),
2116 [](auto shp) -> abstract::AbstractScalarPtr { return std::make_shared<abstract::AbstractScalar>(shp); });
2117 auto abs_shape = std::make_shared<abstract::AbstractTuple>(abstract_shape);
2118 value_shape->set_abstract(abs_shape);
2119 return value_shape;
2120 }
2121
WrapCOOTensor(const ValuePtr & coo_out,const ValuePtr & value)2122 ValuePtr WrapCOOTensor(const ValuePtr &coo_out, const ValuePtr &value) {
2123 MS_EXCEPTION_IF_NULL(coo_out);
2124 auto coo_tensor = coo_out->cast<tensor::COOTensorPtr>();
2125 MS_EXCEPTION_IF_NULL(coo_tensor);
2126 auto value_tensor = value->cast<tensor::TensorPtr>();
2127 MS_EXCEPTION_IF_NULL(value_tensor);
2128 auto indices_tensor = coo_tensor->GetIndices();
2129 auto shape_vector = coo_tensor->shape();
2130 return std::make_shared<tensor::COOTensor>(indices_tensor, value_tensor, shape_vector);
2131 }
2132
WrapCSRTensor(const ValuePtr & csr_out,const ValuePtr & value)2133 ValuePtr WrapCSRTensor(const ValuePtr &csr_out, const ValuePtr &value) {
2134 MS_EXCEPTION_IF_NULL(csr_out);
2135 auto csr_tensor = csr_out->cast<tensor::CSRTensorPtr>();
2136 MS_EXCEPTION_IF_NULL(csr_tensor);
2137 auto value_tensor = value->cast<tensor::TensorPtr>();
2138 MS_EXCEPTION_IF_NULL(value_tensor);
2139 auto indptr_tensor = csr_tensor->GetIndptr();
2140 auto indices_tensor = csr_tensor->GetIndices();
2141 auto shape_vector = csr_tensor->shape();
2142 return std::make_shared<tensor::CSRTensor>(indptr_tensor, indices_tensor, value_tensor, shape_vector);
2143 }
2144 } // namespace
2145
IsPrimNeedGrad(const PrimitivePtr & prim)2146 bool AutoGrad::IsPrimNeedGrad(const PrimitivePtr &prim) {
2147 MS_EXCEPTION_IF_NULL(prim);
2148 return kGradBlackList.find(prim->name()) == kGradBlackList.end();
2149 }
2150
NeedGrad(const std::vector<ValuePtr> & input_values)2151 bool AutoGrad::NeedGrad(const std::vector<ValuePtr> &input_values) {
2152 for (const ValuePtr &input_arg : input_values) {
2153 MS_EXCEPTION_IF_NULL(input_arg);
2154 if (input_arg->isa<tensor::BaseTensor>()) {
2155 tensor::BaseTensorPtr input_tensor = nullptr;
2156 input_tensor = input_arg->cast<tensor::BaseTensorPtr>();
2157 auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
2158 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
2159 if (auto_grad_meta_data->input_type() == InputType::kParameter && Common::IsParamRequiresGrad(input_tensor)) {
2160 return true;
2161 }
2162 auto variable = auto_grad_meta_data->variable();
2163 if (variable != nullptr) {
2164 return true;
2165 }
2166 } else if (input_arg->isa<ValueSequence>()) {
2167 auto value_seq = input_arg->cast<ValueSequencePtr>()->value();
2168 if (NeedGrad(value_seq)) {
2169 return true;
2170 }
2171 } else if (input_arg->isa<tensor::COOTensor>() || input_arg->isa<tensor::CSRTensor>()) {
2172 return true;
2173 }
2174 MS_LOG(DEBUG) << "Get value " << input_arg->ToString();
2175 }
2176 return false;
2177 }
2178
IsZerosLikeNode(const AnfNodePtr & node)2179 bool AutoGrad::IsZerosLikeNode(const AnfNodePtr &node) {
2180 MS_EXCEPTION_IF_NULL(node);
2181 if (!node->isa<CNode>()) {
2182 return false;
2183 }
2184 auto cnode = node->cast<CNodePtr>();
2185 if (IsPrimitiveCNode(cnode, prim::kPrimZerosLike)) {
2186 return true;
2187 }
2188 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
2189 return std::all_of(cnode->inputs().begin() + 1, cnode->inputs().end(),
2190 [](const auto &node) { return IsZerosLikeNode(node) == true; });
2191 }
2192 if (IsPrimitiveCNode(cnode, prim::kPrimMakeDict)) {
2193 return IsZerosLikeNode(cnode->input(kIndex2));
2194 }
2195 return false;
2196 }
2197
GetFakeZeroTensor()2198 ValuePtr AutoGrad::GetFakeZeroTensor() {
2199 static ValuePtr fake_v = std::make_shared<tensor::Tensor>(0);
2200 return fake_v;
2201 }
2202
BuildSpecialValueGrad(const ValuePtr & value,const tensor::BaseTensorPtr & grad,autograd::FuncBuilder * func_builder,const SpecialType & type)2203 ValuePtr AutoGrad::BuildSpecialValueGrad(const ValuePtr &value, const tensor::BaseTensorPtr &grad,
2204 autograd::FuncBuilder *func_builder, const SpecialType &type) {
2205 MS_EXCEPTION_IF_NULL(value);
2206 if (grad != nullptr) {
2207 return grad;
2208 }
2209 if (value->isa<tensor::BaseTensor>()) {
2210 return (type == SpecialType::kZerosLikeType ? func_builder->Zeros(value) : func_builder->Ones(value));
2211 }
2212 if (value->isa<ValueSequence>()) {
2213 ValuePtr zero_value = nullptr;
2214 auto v_seq = value->cast<ValueSequencePtr>();
2215 ValuePtrList v_list;
2216 for (const auto &item : v_seq->value()) {
2217 (void)v_list.emplace_back(BuildSpecialValueGrad(item, grad, func_builder, type));
2218 }
2219 return std::make_shared<ValueTuple>(v_list);
2220 }
2221 if (value->isa<Scalar>()) {
2222 auto fake_tensor = std::make_shared<tensor::Tensor>(0, value->type());
2223 return BuildSpecialValueGrad(fake_tensor, grad, func_builder, type);
2224 }
2225 if (value->isa<tensor::CSRTensor>()) {
2226 auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
2227 return WrapCSRTensor(csr_tensor, BuildSpecialValueGrad(csr_tensor->GetValues(), grad, func_builder, type));
2228 }
2229 if (value->isa<tensor::COOTensor>()) {
2230 auto coo_tensor = value->cast<tensor::COOTensorPtr>();
2231 return WrapCOOTensor(coo_tensor, BuildSpecialValueGrad(coo_tensor->GetValues(), grad, func_builder, type));
2232 }
2233 MS_LOG(INFO) << "For value " << value->ToString() << ", the type is not tensor or scalar";
2234 auto fake_tensor = std::make_shared<tensor::Tensor>(0, value->type());
2235 return BuildSpecialValueGrad(fake_tensor, grad, func_builder, type);
2236 }
2237
BuildSpecialNode(const KernelGraphPtr & tape,const ValuePtr & value,const abstract::AbstractBasePtr & abs,const SpecialType & type)2238 AnfNodePtr AutoGrad::BuildSpecialNode(const KernelGraphPtr &tape, const ValuePtr &value,
2239 const abstract::AbstractBasePtr &abs, const SpecialType &type) {
2240 MS_EXCEPTION_IF_NULL(value);
2241 if (value->isa<tensor::BaseTensor>()) {
2242 auto prim_node =
2243 (type == SpecialType::kZerosLikeType ? NewValueNode(std::make_shared<Primitive>(*prim::kPrimZerosLike))
2244 : NewValueNode(std::make_shared<Primitive>(*prim::kPrimOnesLike)));
2245 auto value_node = Common::CreateValueNodeByValue(value, abs);
2246 auto special_like_value = tape->FuncGraph::NewCNode({prim_node, value_node});
2247 special_like_value->set_abstract(value_node->abstract());
2248 return special_like_value;
2249 }
2250 if (value->isa<ValueSequence>()) {
2251 auto tuple = value->cast<ValueSequencePtr>();
2252 abstract::AbstractSequencePtr abs_seq;
2253 if (abs == nullptr) {
2254 abs_seq = Common::SetAbstractValueToAnyValue(value->ToAbstract())->cast<abstract::AbstractSequencePtr>();
2255 } else {
2256 abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2257 }
2258 return CreateMakeTupleNode(tape, tuple, abs_seq, type);
2259 }
2260 if (value->isa<Scalar>()) {
2261 auto fake_tensor = GetFakeZeroTensor();
2262 return BuildSpecialNode(tape, fake_tensor, nullptr, type);
2263 }
2264 if (value->isa<tensor::CSRTensor>()) {
2265 auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
2266 MS_EXCEPTION_IF_NULL(csr_tensor);
2267 auto data = csr_tensor->GetValues();
2268 return BuildSpecialNode(tape, data, nullptr, type);
2269 }
2270 if (value->isa<tensor::COOTensor>()) {
2271 auto coo_tensor = value->cast<tensor::COOTensorPtr>();
2272 MS_EXCEPTION_IF_NULL(coo_tensor);
2273 auto data = coo_tensor->GetValues();
2274 return BuildSpecialNode(tape, data, nullptr, type);
2275 }
2276 if (value->isa<ValueDictionary>()) {
2277 auto v_dict = value->cast<ValueDictionaryPtr>();
2278 abstract::AbstractDictionaryPtr abs_dict;
2279 if (abs == nullptr) {
2280 abs_dict = Common::SetAbstractValueToAnyValue(value->ToAbstract())->cast<abstract::AbstractDictionaryPtr>();
2281 } else {
2282 abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
2283 }
2284 return CreateMakeDictNode(tape, v_dict, abs_dict, type);
2285 }
2286 MS_LOG(INFO) << "For value " << value->ToString() << ", the type is not tensor or scalar";
2287 return BuildSpecialNode(tape, GetFakeZeroTensor(), nullptr, type);
2288 }
2289
BuildSparseTensorNode(const KernelGraphPtr & tape,const ValuePtr & sparse_value,const AnfNodePtr & dout_value_node)2290 AnfNodePtr AutoGrad::BuildSparseTensorNode(const KernelGraphPtr &tape, const ValuePtr &sparse_value,
2291 const AnfNodePtr &dout_value_node) {
2292 MS_EXCEPTION_IF_NULL(tape);
2293 MS_EXCEPTION_IF_NULL(sparse_value);
2294 if (sparse_value->isa<tensor::CSRTensor>()) {
2295 auto csr_tensor = sparse_value->cast<tensor::CSRTensorPtr>();
2296 MS_EXCEPTION_IF_NULL(csr_tensor);
2297 auto indptr_node = Common::CreateValueNodeByValue(csr_tensor->GetIndptr());
2298 auto indices_node = Common::CreateValueNodeByValue(csr_tensor->GetIndices());
2299 auto value_shape = GetSparseTensorShapeNode(csr_tensor->shape());
2300 auto special_like_csr_node = tape->FuncGraph::NewCNode(
2301 {NewValueNode(prim::kPrimMakeTuple), indptr_node, indices_node, dout_value_node, value_shape});
2302 special_like_csr_node->set_abstract(sparse_value->ToAbstract()->Broaden());
2303 return special_like_csr_node;
2304 }
2305 if (sparse_value->isa<tensor::COOTensor>()) {
2306 auto coo_tensor = sparse_value->cast<tensor::COOTensorPtr>();
2307 MS_EXCEPTION_IF_NULL(coo_tensor);
2308 auto indices_node = Common::CreateValueNodeByValue(coo_tensor->GetIndices());
2309 auto value_shape = GetSparseTensorShapeNode(coo_tensor->shape());
2310 auto special_like_coo_node =
2311 tape->FuncGraph::NewCNode({NewValueNode(prim::kPrimMakeTuple), indices_node, dout_value_node, value_shape});
2312 special_like_coo_node->set_abstract(sparse_value->ToAbstract()->Broaden());
2313 return special_like_coo_node;
2314 }
2315 MS_LOG(EXCEPTION) << "Get invalid sparse tensor";
2316 }
2317
SetGradMetaData(const ValuePtr & value,const VariablePtr & variable,const ParameterPtr & param)2318 void AutoGrad::SetGradMetaData(const ValuePtr &value, const VariablePtr &variable, const ParameterPtr ¶m) {
2319 if (value->isa<tensor::BaseTensor>()) {
2320 tensor::BaseTensorPtr tensor = nullptr;
2321 tensor = value->cast<tensor::BaseTensorPtr>();
2322 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
2323 if (auto_grad_meta_data == nullptr) {
2324 MS_LOG(DEBUG) << "tensor has no auto_grad_meta_data";
2325 auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
2326 tensor->set_auto_grad_meta_data(auto_grad_meta_data);
2327 }
2328 auto_grad_meta_data->set_variable(variable);
2329 if (param != nullptr) {
2330 auto_grad_meta_data->set_parameter(param);
2331 auto_grad_meta_data->set_input_type(InputType::kParameter);
2332 }
2333 } else if (value->isa<ValueSequence>()) {
2334 auto value_sequence = value->cast<ValueSequencePtr>();
2335 for (const auto &val : value_sequence->value()) {
2336 SetGradMetaData(val, variable);
2337 }
2338 } else if (value->isa<ValueDictionary>()) {
2339 auto value_dict = value->cast<ValueDictionaryPtr>();
2340 for (const auto &val : value_dict->value()) {
2341 SetGradMetaData(val.second, variable);
2342 }
2343 }
2344 }
2345
SetGradInfoForInputs(const ValuePtr & value,const VariablePtr & variable,const ParameterPtr & param)2346 void AutoGrad::SetGradInfoForInputs(const ValuePtr &value, const VariablePtr &variable, const ParameterPtr ¶m) {
2347 if (value->isa<tensor::BaseTensor>()) {
2348 const auto &input_tensor = value->cast<tensor::BaseTensorPtr>();
2349 const auto &auto_grad_meta_data = input_tensor->auto_grad_meta_data();
2350 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
2351 auto_grad_meta_data->set_variable(variable);
2352 auto_grad_meta_data->set_parameter(param);
2353 } else if (value->isa<tensor::COOTensor>()) {
2354 const auto &coo_tensor = value->cast<tensor::COOTensorPtr>();
2355 const auto &indices_tensor = coo_tensor->GetIndices();
2356 SetGradInfoForInputs(indices_tensor, variable, param);
2357 } else if (value->isa<tensor::CSRTensor>()) {
2358 const auto &csr_tensor = value->cast<tensor::CSRTensorPtr>();
2359 const auto &indices_tensor = csr_tensor->GetIndices();
2360 SetGradInfoForInputs(indices_tensor, variable, param);
2361 }
2362 }
2363
2364 // Create fake bprop
BuildFakeBpropCNode(const CNodePtr & cnode,std::vector<CNodePtr> * outputs)2365 void AutoGrad::BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
2366 auto prim = GetCNodePrimitive(cnode);
2367 if (prim == nullptr) {
2368 MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
2369 }
2370 size_t dout_index = cnode->size() - 1;
2371 const auto &dout = cnode->input(dout_index);
2372 const auto &dout_cnode = dout->cast<CNodePtr>();
2373 MS_EXCEPTION_IF_NULL(dout_cnode);
2374 // Size is same as op_arg size
2375 size_t input_size = cnode->size() - 2;
2376 for (size_t i = 1; i < input_size; ++i) {
2377 (void)outputs->emplace_back(dout_cnode);
2378 }
2379 }
2380
CreateGraphCallBack(const FuncGraphPtr & call_graph,const std::string & cache_key,const GraphCallCondition & graph_call_condition)2381 CallBackFn AutoGrad::CreateGraphCallBack(const FuncGraphPtr &call_graph, const std::string &cache_key,
2382 const GraphCallCondition &graph_call_condition) {
2383 // kFlagJitCallGraph is set true to avoid compilig call_graph whe compiling the main graph
2384 call_graph->set_flag(kFlagJitCallGraph, true);
2385 // call graph not inline to grad top
2386 call_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
2387 // Pynative bprop graph flag
2388 call_graph->set_flag(kFlagIsPynativeBpropGraph, true);
2389 // Run graph by single op will use this kFlagPyNativeBpropGraphWithBpropCut flag
2390 if (graph_call_condition.is_dynamic_shape_process_) {
2391 call_graph->set_flag(kFlagPyNativeBpropGraphWithBpropCut, false);
2392 if (!graph_call_condition.is_jit_graph_) {
2393 call_graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
2394 }
2395 }
2396 pipeline::ResourcePtr resource;
2397 constexpr auto kNeedCompile = "NeedCompile";
2398 const auto it = jit_call_graph_compile_cache_.find(cache_key);
2399 bool need_compile = (it == jit_call_graph_compile_cache_.end());
2400 if (need_compile) {
2401 resource = std::make_shared<pipeline::Resource>();
2402 resource->set_func_graph(call_graph);
2403 if (graph_call_condition.is_func_grad_) {
2404 auto manager = resource->manager();
2405 manager->AddFuncGraph(call_graph, false);
2406 (void)opt::EnvironConversion(resource);
2407 if (graph_call_condition.jit_out_has_dict_) {
2408 MS_LOG(DEBUG) << "Jit out is dict, need convert make dict to pyexecute";
2409 (void)mindspore::opt::RewriterAfterOptA(resource->func_graph(), resource);
2410 }
2411 }
2412 if (graph_call_condition.is_jit_graph_ || !graph_call_condition.is_dynamic_shape_process_) {
2413 (void)jit_call_graph_compile_cache_.emplace(cache_key, resource);
2414 }
2415 resource->SetResult(kNeedCompile, true);
2416 } else {
2417 resource = it->second;
2418 // If resource func graph not compile(not call run grad graph), but hit cache
2419 need_compile = resource->GetResult(kNeedCompile).cast<bool>();
2420 }
2421 MS_EXCEPTION_IF_NULL(resource);
2422 bool is_control_flow = graph_call_condition.is_control_flow_;
2423 auto fn = [resource, need_compile, is_control_flow, kNeedCompile](const VectorRef &arg_list) -> VectorRef {
2424 if (need_compile) {
2425 MS_LOG(DEBUG) << "Start emit action for graph " << resource->func_graph()->ToString();
2426 auto manager = resource->manager();
2427 manager->AddFuncGraph(resource->func_graph(), true);
2428 resource->SetBackendAsync([]() { return compile::CreateBackend(); });
2429 // kFlagJitCallGraph is set false to compile sub graph in control flow
2430 if (is_control_flow) {
2431 for (const auto &g : manager->func_graphs()) {
2432 g->set_flag(kFlagJitCallGraph, false);
2433 }
2434 }
2435 (void)TaskEmitAction(resource);
2436 (void)ExecuteAction(resource);
2437 resource->SetResult(kNeedCompile, false);
2438 }
2439 MS_LOG(DEBUG) << "Start execute action for graph " << resource->func_graph()->ToString();
2440 compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
2441 return utils::cast<VectorRef>((*run)(arg_list));
2442 };
2443 return fn;
2444 }
2445
BuildBpropCutPrim(const PrimitivePtr & prim,bool is_need_recompute)2446 PrimitivePyPtr AutoGrad::BuildBpropCutPrim(const PrimitivePtr &prim, bool is_need_recompute) {
2447 MS_EXCEPTION_IF_NULL(prim);
2448 auto prim_py = prim->cast<PrimitivePyPtr>();
2449 MS_EXCEPTION_IF_NULL(prim_py);
2450 auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
2451 bprop_cut->CopyHookFunction(prim_py);
2452 prim_py->AddBpropCutPrim(bprop_cut);
2453 if (prim->HasAttr("cell_id")) {
2454 auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
2455 if (!cell_id.empty()) {
2456 (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
2457 (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
2458 }
2459 }
2460 // Only custom op need add this attr, hook function not need.
2461 if (prim->HasAttr("custom_op_bprop")) {
2462 (void)bprop_cut->AddAttr("custom_op_bprop", MakeValue(true));
2463 }
2464 (void)bprop_cut->AddAttr("custom_op_name", MakeValue(prim->name()));
2465 if (is_need_recompute) {
2466 (void)bprop_cut->AddAttr("is_recompute", MakeValue(true));
2467 }
2468 return bprop_cut;
2469 }
2470
CheckRecomputeInputs(const GradParamPtr & grad_param)2471 void AutoGrad::CheckRecomputeInputs(const GradParamPtr &grad_param) {
2472 if (grad_param->op_grad_info->is_need_recompute) {
2473 for (const auto &input : grad_param->op_grad_info->input_value) {
2474 if (input->isa<ValueSequence>()) {
2475 const auto &seq = input->cast<ValueSequencePtr>();
2476 const auto val = seq->value();
2477 if (AutoGrad::NeedGrad(val)) {
2478 MS_LOG(EXCEPTION) << "For recompute cell, now we do not support calculate tensor's gradient from tuple. "
2479 "You need check your inputs of construct function from recompute cell, and not put "
2480 "tensors in tuple which need grad!";
2481 }
2482 }
2483 }
2484 }
2485 }
2486
ClearAutoGradStaticCache()2487 void AutoGrad::ClearAutoGradStaticCache() { jit_call_graph_compile_cache_.clear(); }
2488
IsRealOp(const AnfNodePtr & cnode)2489 bool GradCommon::IsRealOp(const AnfNodePtr &cnode) {
2490 MS_EXCEPTION_IF_NULL(cnode);
2491 const auto &prim = GetCNodePrimitive(cnode);
2492 if (prim == nullptr) {
2493 return false;
2494 }
2495 return kNotRealOP.find(prim->name()) == kNotRealOP.end();
2496 }
2497
SetForward(const AnfNodePtrList & node_list)2498 void GradCommon::SetForward(const AnfNodePtrList &node_list) {
2499 for (const auto &cn : node_list) {
2500 auto out = Common::CreatOutputTensorValueByAbstract(cn->abstract());
2501 const auto &c_node = cn->cast<CNodePtr>();
2502 MS_EXCEPTION_IF_NULL(c_node);
2503 c_node->set_forward(Common::CreateValueNodeByValue(out, cn->abstract()), "");
2504 }
2505 }
2506
GetUsedCNodeInBpropGraph(const CNodePtr & cnode,const mindspore::HashSet<size_t> & unused_inputs,AnfNodePtrList * node_list)2507 void GradCommon::GetUsedCNodeInBpropGraph(const CNodePtr &cnode, const mindspore::HashSet<size_t> &unused_inputs,
2508 AnfNodePtrList *node_list) {
2509 MS_EXCEPTION_IF_NULL(cnode);
2510 MS_EXCEPTION_IF_NULL(node_list);
2511 // Check input used in single op bprop graph. For example,
2512 // A = a * b;
2513 // B = A * c;
2514 // So, A can also replace by its output
2515 size_t input_num = cnode->size() - 1;
2516 for (size_t i = 0; i < input_num; ++i) {
2517 if (unused_inputs.find(i) == unused_inputs.end() && cnode->input(i + 1)->isa<CNode>()) {
2518 // Input used by bprop graph, and it is a cnode have produce real output
2519 const auto &input_c = cnode->input(i + 1)->cast<CNodePtr>();
2520 MS_EXCEPTION_IF_NULL(input_c);
2521 if (IsPrimitive(input_c, prim::kPrimMakeTuple)) {
2522 size_t tuple_input_num = input_c->size() - 1;
2523 for (size_t j = 0; j < tuple_input_num; ++j) {
2524 if (auto f_node = common::AnfAlgo::VisitKernel(input_c, j).first; f_node->isa<CNode>() && IsRealOp(f_node)) {
2525 MS_LOG(DEBUG) << "Get used input node " << f_node->DebugString();
2526 (void)node_list->emplace_back(f_node);
2527 }
2528 }
2529 } else {
2530 if (auto f_node = common::AnfAlgo::VisitKernel(input_c, 0).first; f_node->isa<CNode>() && IsRealOp(f_node)) {
2531 MS_LOG(DEBUG) << "Get used input node " << f_node->DebugString();
2532 (void)node_list->emplace_back(f_node);
2533 }
2534 }
2535 }
2536 }
2537 // Check output used in single op bprop graph
2538 if (unused_inputs.find(cnode->size() - 1) == unused_inputs.end()) {
2539 MS_LOG(DEBUG) << "Get used output node " << cnode->DebugString();
2540 (void)node_list->emplace_back(cnode);
2541 }
2542 }
2543 } // namespace PyNativeAlgo
2544
DispatchOp(const std::shared_ptr<runtime::AsyncTask> & task)2545 void DispatchOp(const std::shared_ptr<runtime::AsyncTask> &task) {
2546 static bool need_sync = runtime::OpExecutor::NeedSync();
2547 if (need_sync && !runtime::OpExecutor::GetInstance().async_for_graph()) {
2548 MS_LOG(INFO) << "PyBoost sync run frontend task";
2549 runtime::OpExecutor::GetInstance().WaitAll();
2550 task->Run();
2551 } else {
2552 runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
2553 runtime::Pipeline::Get().frontend_stage()->Push(task);
2554 }
2555 }
2556 } // namespace pynative
2557 } // namespace mindspore
2558