1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <unordered_map>
18 #include <functional>
19 #include <map>
20 #include "runtime/graph_scheduler/control_node_parser.h"
21 #include "mindspore/core/ops/sparse_tensor_ops.h"
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "runtime/graph_scheduler/actor/actor_common.h"
25 #include "runtime/device/device_address_utils.h"
26 #include "include/common/utils/convert_utils.h"
27 #include "abstract/utils.h"
28 #include "utils/ms_context.h"
29 #include "ir/tensor.h"
30 #include "abstract/abstract_function.h"
31 #include "include/common/debug/anf_ir_dump.h"
32
33 namespace mindspore {
34 namespace runtime {
35 namespace {
36 constexpr auto kDebugStrDepthTwo = 2;
37 // Check if node is a value node need to create a device tensor.
IsFrontValueNode(const KernelWithIndex & node_with_index)38 bool IsFrontValueNode(const KernelWithIndex &node_with_index) {
39 const auto &node = node_with_index.first;
40 MS_EXCEPTION_IF_NULL(node);
41 if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) {
42 return false;
43 }
44
45 return true;
46 }
47
48 // Fetch real input node in maketuple.
FetchRealInputNode(const KernelWithIndex & node_with_index)49 KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) {
50 const auto &node = node_with_index.first;
51 MS_EXCEPTION_IF_NULL(node);
52 if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
53 return node_with_index;
54 }
55
56 const auto &abstract = node->abstract();
57 MS_EXCEPTION_IF_NULL(abstract);
58 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
59 if (output_num <= node_with_index.second) {
60 MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid index:" << node_with_index.second
61 << "for tuple node:" << node->DebugString();
62 }
63
64 const auto &cnode = node->cast<CNodePtr>();
65 MS_EXCEPTION_IF_NULL(cnode);
66 const auto &inputs = cnode->inputs();
67 size_t real_index = node_with_index.second;
68 for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
69 MS_EXCEPTION_IF_NULL(inputs[i]);
70 const auto &sub_abstract = inputs[i]->abstract();
71 MS_EXCEPTION_IF_NULL(sub_abstract);
72 size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
73 // If it is not the output of node, need to subtract the number of inputs of it.
74 if (real_index >= tmp_index) {
75 real_index -= tmp_index;
76 continue;
77 }
78 return {inputs[i], real_index};
79 }
80 MS_LOG_WITH_NODE(EXCEPTION, node) << "Failed to get real output from node:" << node->DebugString()
81 << " index:" << node_with_index.second;
82 }
83
84 // Fetch all the output index in the sub-abstract of abstract.
FetchRealIndexByAbstract(const AbstractBasePtr & abstract,std::vector<size_t> * const indexes)85 std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::vector<size_t> *const indexes) {
86 MS_EXCEPTION_IF_NULL(abstract);
87 MS_EXCEPTION_IF_NULL(indexes);
88 AbstractBasePtr dst_abstract = abstract;
89 size_t pre_abstract_num = 0;
90 std::set<size_t> output_indexs;
91 if (indexes->empty()) {
92 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
93 for (size_t i = 0; i < output_num; ++i) {
94 (void)output_indexs.emplace(i);
95 }
96 return output_indexs;
97 }
98
99 size_t index = indexes->back();
100 indexes->pop_back();
101
102 // Fetch the dest abstract by index, and the abstracts num before the dest abstract.
103 if (abstract->isa<abstract::AbstractSequence>()) {
104 auto sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
105 MS_EXCEPTION_IF_NULL(sequence_abstract);
106 const auto &sub_abstracts = sequence_abstract->elements();
107 if (sub_abstracts.size() <= index) {
108 MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
109 }
110 for (size_t i = 0; i < index; ++i) {
111 pre_abstract_num += common::AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
112 }
113 dst_abstract = sub_abstracts[index];
114 } else {
115 if (index != 0) {
116 MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
117 }
118 }
119 MS_EXCEPTION_IF_NULL(dst_abstract);
120
121 // Fetch real output index.
122 auto tmp_indexs = FetchRealIndexByAbstract(dst_abstract, indexes);
123 for (auto tmp_index : tmp_indexs) {
124 (void)output_indexs.emplace(tmp_index + pre_abstract_num);
125 }
126 return output_indexs;
127 }
128
129 // Get all the real parameters corresponding to node.
FetchRealParameterByNode(const KernelWithIndex & node,std::set<KernelWithIndex> * const real_parameters,std::set<KernelWithIndex> * invalid_call_nodes,const mindspore::HashMap<AnfNodePtr,std::set<FuncGraphPtr>> & call_node_to_func_graphs)130 void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *const real_parameters,
131 std::set<KernelWithIndex> *invalid_call_nodes,
132 const mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr>> &call_node_to_func_graphs) {
133 MS_EXCEPTION_IF_NULL(node.first);
134 MS_EXCEPTION_IF_NULL(real_parameters);
135 MS_EXCEPTION_IF_NULL(invalid_call_nodes);
136 MS_LOG(DEBUG) << "Fetch real parameter by node:" << node.first->DebugString() << " index:" << node.second;
137 auto node_with_index = common::AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
138 MS_EXCEPTION_IF_NULL(node_with_index.first);
139 if (node_with_index.first->isa<ValueNode>() || node_with_index.first->isa<Parameter>()) {
140 // If node is a valuenode or parameter, the real parameter is itself.
141 MS_LOG(DEBUG) << "Add real parameter:" << node_with_index.first->DebugString()
142 << " index:" << node_with_index.second;
143 (void)real_parameters->emplace(node_with_index);
144 } else if (common::AnfAlgo::IsCallNode(node_with_index.first)) {
145 // If node is a call node, the real parameters are the outputs of funcgraph the node called.
146 if (invalid_call_nodes->find(node_with_index) != invalid_call_nodes->end()) {
147 return;
148 }
149 (void)invalid_call_nodes->emplace(node_with_index);
150 const auto &iter = call_node_to_func_graphs.find(node_with_index.first);
151 if (iter == call_node_to_func_graphs.end()) {
152 MS_LOG(DEBUG) << "Invalid call node:" << node_with_index.first->DebugString();
153 return;
154 }
155 const auto &func_graphs = iter->second;
156 for (const auto &func_graph : func_graphs) {
157 MS_EXCEPTION_IF_NULL(func_graph);
158 FetchRealParameterByNode({func_graph->output(), node_with_index.second}, real_parameters, invalid_call_nodes,
159 call_node_to_func_graphs);
160 }
161 } else if (common::AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) {
162 // If node is a maketuple node, the real parameters are its total inputs.
163 const auto &real_input = FetchRealInputNode(node_with_index);
164 MS_EXCEPTION_IF_NULL(real_input.first);
165 MS_LOG(DEBUG) << "Real input node:" << real_input.first->DebugString() << " index:" << real_input.second
166 << " for tuple node:" << node_with_index.first->DebugString() << " index:" << node_with_index.second;
167 FetchRealParameterByNode(real_input, real_parameters, invalid_call_nodes, call_node_to_func_graphs);
168 } else if (common::AnfAlgo::CheckPrimitiveType(node.first, prim::kPrimSwitch)) {
169 // If node is a switch node, the real parameters are its both true and false branches.
170 const auto cnode = node_with_index.first->cast<CNodePtr>();
171 MS_EXCEPTION_IF_NULL(cnode);
172 const auto inputs = cnode->inputs();
173 for (size_t i = kSwitchTrueBranchPos; i < inputs.size(); ++i) {
174 FetchRealParameterByNode({inputs[i], 0}, real_parameters, invalid_call_nodes, call_node_to_func_graphs);
175 }
176 } else if (common::AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitchLayer)) {
177 // If node is a switchlyaer node, the real parameters are its total branches.
178 const auto &switch_layer_cnode = node_with_index.first->cast<CNodePtr>();
179 MS_EXCEPTION_IF_NULL(switch_layer_cnode);
180 const auto &switch_layer_inputs = switch_layer_cnode->inputs();
181 if (switch_layer_inputs.size() != kSwitchLayerInputNum ||
182 (!common::AnfAlgo::CheckPrimitiveType(switch_layer_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple))) {
183 MS_LOG_WITH_NODE(EXCEPTION, switch_layer_cnode)
184 << "Invalid switch layer node:" << switch_layer_cnode->DebugString();
185 }
186 const auto &make_tuple_cnode = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>();
187 MS_EXCEPTION_IF_NULL(make_tuple_cnode);
188 const auto &make_tuple_inputs = make_tuple_cnode->inputs();
189 for (size_t i = kSwitchTrueBranchPos; i < make_tuple_inputs.size(); ++i) {
190 FetchRealParameterByNode({make_tuple_inputs[i], 0}, real_parameters, invalid_call_nodes,
191 call_node_to_func_graphs);
192 }
193 } else {
194 // If node is a kernel, the real parameter is itself.
195 MS_LOG(DEBUG) << "Add real parameter:" << node_with_index.first->DebugString()
196 << " index:" << node_with_index.second;
197 (void)real_parameters->emplace(node_with_index);
198 }
199 }
200
201 // Topologically sort all funcgraphs according to the function call relationship.
TopoSortForFuncGraph(const FuncGraphPtr & root,FuncGraphCallRelation * const edges)202 std::vector<FuncGraphPtr> TopoSortForFuncGraph(const FuncGraphPtr &root, FuncGraphCallRelation *const edges) {
203 MS_EXCEPTION_IF_NULL(root);
204 MS_EXCEPTION_IF_NULL(edges);
205 MS_EXCEPTION_IF_NULL(root->manager());
206 std::set<FuncGraphPtr> nodes;
207 (void)nodes.emplace(root);
208
209 FuncGraphSet subs = root->manager()->func_graphs();
210 for (auto sub : subs) {
211 if (sub != root) {
212 (void)nodes.emplace(sub);
213 }
214 }
215
216 std::queue<FuncGraphPtr> que;
217 for (const auto &node : nodes) {
218 if (edges->find(node) == edges->end()) {
219 que.push(node);
220 }
221 }
222
223 std::vector<FuncGraphPtr> result;
224 while (!que.empty()) {
225 const auto node = que.front();
226 que.pop();
227 (void)result.emplace_back(node);
228 for (auto iter = edges->begin(); iter != edges->end();) {
229 auto &sub_edges = iter->second;
230 for (auto sub_iter = sub_edges.begin(); sub_iter != sub_edges.end();) {
231 if (sub_iter->find(node) != sub_iter->end()) {
232 sub_iter = sub_edges.erase(sub_iter);
233 } else {
234 ++sub_iter;
235 }
236 }
237 if (sub_edges.empty()) {
238 que.push(iter->first);
239 iter = edges->erase(iter);
240 } else {
241 ++iter;
242 }
243 }
244 }
245
246 return result;
247 }
248
FetchTypeIdByNode(const AnfNodePtr & node,size_t index)249 TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
250 MS_EXCEPTION_IF_NULL(node);
251 TypeId type_id = kTypeUnknown;
252 if (node->isa<ValueNode>() && node->abstract() != nullptr) {
253 // For valuenode, fetch type from abstract.
254 const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
255 MS_EXCEPTION_IF_NULL(abs);
256 const auto &type = abs->BuildType();
257 MS_EXCEPTION_IF_NULL(type);
258 if (type->isa<TensorType>()) {
259 const auto &tensor_type = type->cast<TensorTypePtr>();
260 MS_EXCEPTION_IF_NULL(tensor_type);
261 const auto &element = tensor_type->element();
262 MS_EXCEPTION_IF_NULL(element);
263 type_id = element->type_id();
264 } else if (common::AnfAlgo::IsDynamicSequence(node)) {
265 const auto &sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
266 MS_EXCEPTION_IF_NULL(sequence_abs);
267 if (sequence_abs->dynamic_len_element_abs() == nullptr) {
268 type_id = type->type_id();
269 } else {
270 if (sequence_abs->dynamic_len_element_abs()->isa<abstract::AbstractTensor>()) {
271 const auto &tensor_abs = sequence_abs->dynamic_len_element_abs()->cast<abstract::AbstractTensorPtr>();
272 MS_EXCEPTION_IF_NULL(tensor_abs);
273 MS_EXCEPTION_IF_NULL(tensor_abs->element());
274 const auto &tensor_element_type = tensor_abs->element()->BuildType();
275 MS_EXCEPTION_IF_NULL(tensor_element_type);
276 return tensor_element_type->type_id();
277 }
278 const auto &element_type = sequence_abs->dynamic_len_element_abs()->BuildType();
279 MS_EXCEPTION_IF_NULL(element_type);
280 type_id = element_type->type_id();
281 }
282 } else {
283 type_id = type->type_id();
284 }
285 } else {
286 type_id = common::AnfAlgo::GetOutputInferDataType(node, index);
287 }
288 return type_id;
289 }
290
FetchOutputSizeByValue(const ValuePtr & value)291 size_t FetchOutputSizeByValue(const ValuePtr &value) {
292 MS_EXCEPTION_IF_NULL(value);
293 if (value->isa<Scalar>()) {
294 return GetTypeByte(value->type());
295 } else if (value->isa<tensor::Tensor>()) {
296 const auto &tensor = value->cast<tensor::TensorPtr>();
297 MS_EXCEPTION_IF_NULL(tensor);
298 return tensor->Size();
299 } else if (value->isa<ValueSequence>()) {
300 const auto &value_sequence = value->cast<ValueSequencePtr>();
301 MS_EXCEPTION_IF_NULL(value_sequence);
302 if (value_sequence->size() == 0) {
303 return 0;
304 }
305 size_t size = 0;
306 for (const auto &sub_value : value_sequence->value()) {
307 MS_EXCEPTION_IF_NULL(sub_value);
308 size += FetchOutputSizeByValue(sub_value);
309 }
310 return size;
311 } else {
312 MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
313 }
314 }
315
FetchOutputSizeByNode(const AnfNodePtr & node,size_t index,TypeId type_id)316 size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_id) {
317 MS_EXCEPTION_IF_NULL(node);
318 size_t size = GetTypeByte(TypeIdToType(type_id));
319 if (node->isa<ValueNode>() && node->abstract() != nullptr) {
320 const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
321 MS_EXCEPTION_IF_NULL(abs);
322 const auto &shape_ptr = abs->BuildShape();
323 MS_EXCEPTION_IF_NULL(shape_ptr);
324 if (shape_ptr->isa<abstract::Shape>()) {
325 const auto &shapes = shape_ptr->cast<abstract::ShapePtr>()->shape();
326 size = std::accumulate(shapes.begin(), shapes.end(), size, std::multiplies<int64_t>());
327 } else if (shape_ptr->isa<abstract::DynamicSequenceShape>()) {
328 const auto &value_node = node->cast<ValueNodePtr>();
329 MS_EXCEPTION_IF_NULL(value_node);
330 const auto &value = value_node->value();
331 MS_EXCEPTION_IF_NULL(value);
332 size = FetchOutputSizeByValue(value);
333 MS_LOG(INFO) << "Abstract;" << abs->ToString() << " for node:" << node->DebugString() << " index:" << index
334 << " shape:" << shape_ptr->ToString() << " size:" << size;
335 } else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractScalar>()) {
336 MS_LOG(DEBUG) << "For scalar, the output shape is 1.";
337 } else {
338 MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid abstract;" << abs->ToString() << " for node:" << node->DebugString()
339 << " index:" << index << " shape:" << shape_ptr->ToString();
340 }
341 } else {
342 size = AnfAlgo::GetOutputTensorMemSize(node, index);
343 }
344 return size;
345 }
346
347 // Create a device tensor for the front node.
348 // Get the output format and select kernel build info from the backend node corresponding to the front node to
349 // create the device address.
CreateDeviceTensorForValueNode(const KernelWithIndex & front_node_with_index,const AnfNodePtr & backend_node,const DeviceContext * device_context)350 void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node,
351 const DeviceContext *device_context) {
352 MS_EXCEPTION_IF_NULL(backend_node);
353 MS_EXCEPTION_IF_NULL(device_context);
354 const auto &front_node = front_node_with_index.first;
355 MS_EXCEPTION_IF_NULL(front_node);
356
357 const auto &node_value = front_node->cast<ValueNodePtr>()->value();
358 MS_EXCEPTION_IF_NULL(node_value);
359 if (node_value->isa<FuncGraph>() || node_value->isa<Primitive>() ||
360 (node_value->isa<ValueSequence>() && node_value->cast<ValueSequencePtr>()->size() == 0)) {
361 return;
362 }
363
364 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
365 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
366 if (output_type_id == kTypeUnknown) {
367 output_type_id = common::AnfAlgo::GetOutputInferDataType(backend_node, 0);
368 }
369 if (front_node->abstract() != nullptr && front_node->abstract()->isa<abstract::AbstractSequence>() &&
370 front_node->abstract()->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
371 tensor_size = FetchOutputSizeByNode(front_node, front_node_with_index.second, output_type_id);
372 }
373 CreateBuildInfoForFrontNode(front_node_with_index, backend_node);
374 device::DeviceAddressPtr address = nullptr;
375 if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
376 // If is_forward_output, get address from tensor
377 auto tensor = node_value->cast<TensorPtr>();
378 MS_EXCEPTION_IF_NULL(tensor);
379 address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
380 } else {
381 // Create device tensor.
382 std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
383
384 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
385 {backend_node, 0}, nullptr, tensor_size, output_format, output_type_id, ShapeVector(),
386 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
387 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(backend_node));
388 address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
389 }
390 MS_EXCEPTION_IF_NULL(address);
391 MS_LOG(DEBUG) << "Create address for front node:" << front_node->DebugString()
392 << " backend node:" << backend_node->DebugString() << " index:" << front_node_with_index.second
393 << " addr:" << address << " size:" << tensor_size;
394 AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
395 UpdateRefCount(address.get(), true);
396 }
397
398 // Create a device tensor for front node.
399 // When the condition input of the switch and switchlayer or the output of a subgraph is a parameter or value node,
400 // there is no corresponding backend node for this parameter, so a device tensor needs to be created for it.
CreateDeviceTensorForFrontNode(const KernelWithIndex & front_node_with_index,const DeviceContext * device_context)401 void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) {
402 MS_EXCEPTION_IF_NULL(device_context);
403 const auto &node = front_node_with_index.first;
404
405 MS_EXCEPTION_IF_NULL(node);
406 MS_LOG(DEBUG) << "Start create device tensor for front node:" << front_node_with_index.first->DebugString()
407 << " index:" << front_node_with_index.second;
408
409 // Create kernel info for front node.
410 if (node->kernel_info() == nullptr) {
411 auto kernel_info = std::make_shared<device::KernelInfo>();
412 MS_EXCEPTION_IF_NULL(kernel_info);
413 std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
414 MS_EXCEPTION_IF_NULL(builder);
415 kernel_info->set_select_kernel_build_info(builder->Build());
416 node->set_kernel_info(kernel_info);
417 }
418
419 // Set format.
420 const auto &kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
421 MS_EXCEPTION_IF_NULL(kernel_info);
422 const auto &builder = kernel_info->GetMutableSelectKernelBuildInfo();
423 MS_EXCEPTION_IF_NULL(builder);
424
425 if (node->isa<ValueNode>()) {
426 const auto &node_value = node->cast<ValueNodePtr>()->value();
427 MS_EXCEPTION_IF_NULL(node_value);
428 if (node_value->isa<ValueSequence>() && node_value->cast<ValueSequencePtr>()->size() == 0) {
429 return;
430 }
431 }
432
433 if (builder->GetAllOutputFormats().size() > front_node_with_index.second) {
434 builder->SetOutputFormat(kOpFormat_DEFAULT, front_node_with_index.second);
435 } else {
436 auto formats = builder->GetAllOutputFormats();
437 for (size_t i = 0; i <= front_node_with_index.second - builder->GetAllOutputFormats().size(); ++i) {
438 (void)formats.emplace_back(kOpFormat_DEFAULT);
439 }
440 builder->SetOutputsFormat(formats);
441 }
442
443 // Set type.
444 TypeId type_id = FetchTypeIdByNode(node, front_node_with_index.second);
445 if (builder->GetAllOutputDeviceTypes().size() > front_node_with_index.second) {
446 builder->SetOutputDeviceType(type_id, front_node_with_index.second);
447 } else {
448 auto types = builder->GetAllOutputDeviceTypes();
449 for (size_t i = 0; i <= front_node_with_index.second - builder->GetAllOutputDeviceTypes().size(); ++i) {
450 (void)types.emplace_back(type_id);
451 }
452 builder->SetOutputsDeviceType(types);
453 }
454
455 const auto &abstract = AnfAlgo::GetNodeAbstractByIndex(front_node_with_index.first, front_node_with_index.second);
456 bool is_map_parameter = abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>();
457 if (is_map_parameter) {
458 DeviceAddressUtils::CreateDeviceAddressByMapTensorNode(device_context, front_node_with_index.first,
459 front_node_with_index.second);
460 UpdateRefCount(AnfAlgo::GetMutableOutputAddr(front_node_with_index.first, front_node_with_index.second).get(),
461 true);
462 return;
463 }
464
465 // Fetch mem size by shape, the shape is first obtained from the abstract to deal with the scenario where
466 // the value node is a multi-level tuple.
467 size_t size = FetchOutputSizeByNode(node, front_node_with_index.second, type_id);
468 device::DeviceAddressPtr address = nullptr;
469 if (node->isa<ValueNode>()) {
470 const auto &node_value = node->cast<ValueNodePtr>()->value();
471 MS_EXCEPTION_IF_NULL(node_value);
472 if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
473 // If is_forward_output, get address from tensor
474 auto tensor = node_value->cast<TensorPtr>();
475 MS_EXCEPTION_IF_NULL(tensor);
476 address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
477 } else {
478 // Create device tensor.
479 const auto &sub_abstract = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), front_node_with_index.second);
480 MS_EXCEPTION_IF_NULL(sub_abstract);
481 const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
482 sub_abstract->BuildShape(), sub_abstract->BuildType(), sub_abstract->BuildValue(), nullptr, size,
483 kOpFormat_DEFAULT, type_id, ShapeVector(), device_context->device_context_key().device_name_,
484 device_context->device_context_key().device_id_);
485 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(node));
486 address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
487 }
488 } else {
489 // Create device tensor.
490 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
491 {node, front_node_with_index.second}, nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector(),
492 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
493 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(node));
494 address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
495 }
496 MS_EXCEPTION_IF_NULL(address);
497 MS_LOG(INFO) << "Create address for node that has no corresponding backend node:"
498 << common::AnfAlgo::GetNodeDebugString(node) << " addr:" << address << " size:" << size
499 << ", type id:" << type_id;
500 AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
501 UpdateRefCount(address.get(), true);
502 }
503
504 // Fetch all funcgraph by a seed graph, if a calls b, b calls c, and c calls a, return a set of a, b, c.
FetchAllExecutionFunction(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * const checked_funcgraphs,const std::unordered_map<FuncGraphPtr,std::set<FuncGraphPtr>> & call_relation)505 void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *const checked_funcgraphs,
506 const std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> &call_relation) {
507 MS_EXCEPTION_IF_NULL(func_graph);
508 MS_EXCEPTION_IF_NULL(checked_funcgraphs);
509 if (checked_funcgraphs->find(func_graph) != checked_funcgraphs->end()) {
510 return;
511 }
512 (void)checked_funcgraphs->emplace(func_graph);
513 auto iter = call_relation.find(func_graph);
514 if (iter == call_relation.end()) {
515 return;
516 }
517
518 for (const auto &called_func_graph : iter->second) {
519 MS_EXCEPTION_IF_NULL(called_func_graph);
520 FetchAllExecutionFunction(called_func_graph, checked_funcgraphs, call_relation);
521 }
522 }
523
IsValidMonadNode(const AnfNodePtr & node)524 bool IsValidMonadNode(const AnfNodePtr &node) {
525 MS_EXCEPTION_IF_NULL(node);
526 return node->isa<ValueNode>() || node->isa<Parameter>() || common::AnfAlgo::IsCallNode(node);
527 }
528
529 // Fetch all inputs of node.
FetchInputNodeByNode(const AnfNodePtr & node)530 std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
531 MS_EXCEPTION_IF_NULL(node);
532 if (HasAbstractMonad(node)) {
533 const auto &real_node_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0);
534 const auto &real_node = real_node_with_index.first;
535 MS_EXCEPTION_IF_NULL(real_node);
536 if (IsValidMonadNode(real_node)) {
537 return {real_node_with_index};
538 }
539 MS_LOG_WITH_NODE(EXCEPTION, real_node) << "Invalid monad node:" << real_node->DebugString();
540 }
541
542 // The node is divided into the following types:
543 // 1. depend and load.
544 const auto &node_with_index =
545 common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
546 auto real_node = node_with_index.first;
547 size_t real_index = node_with_index.second;
548 MS_EXCEPTION_IF_NULL(real_node);
549 std::vector<KernelWithIndex> results;
550
551 // 2. Tuple node.
552 const PrimitiveSet expand_prims{prim::kPrimMakeTuple};
553 // The MakeTuple/MakeSparse node need expand and recurse.
554 if (IsOneOfPrimitiveCNode(real_node, expand_prims)) {
555 const auto &cnode = real_node->cast<CNodePtr>();
556 MS_EXCEPTION_IF_NULL(cnode);
557 const auto &inputs = cnode->inputs();
558 for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
559 const auto &sub_results = FetchInputNodeByNode(inputs[i]);
560 (void)results.insert(results.end(), sub_results.begin(), sub_results.end());
561 }
562 return results;
563 }
564
565 // 3. One output node.
566 const auto &abstract = real_node->abstract();
567 if (abstract == nullptr) {
568 MS_LOG(DEBUG) << "Empty abstract for node:" << real_node->DebugString();
569 (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(real_node, real_index));
570 return results;
571 }
572
573 // 4 Other.
574 if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
575 if (real_node->cast<CNodePtr>()->HasAttr(kAttrReplaceRealKernelInBackend) && real_node->abstract() != nullptr) {
576 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(real_node->abstract());
577 MS_LOG(INFO) << "Fetch an tuple get item with repalce flag:" << real_node->DebugString()
578 << " output num:" << output_num;
579 for (size_t i = 0; i < output_num; ++i) {
580 (void)results.emplace_back(real_node, i);
581 }
582 return results;
583 }
584 std::vector<size_t> index_stack;
585 auto get_item_src_node = common::AnfAlgo::GetTupleIndexes(real_node, &index_stack);
586 MS_EXCEPTION_IF_NULL(get_item_src_node);
587 if (index_stack.empty()) {
588 const auto &sub_results = FetchInputNodeByNode(get_item_src_node);
589 (void)results.insert(results.end(), sub_results.begin(), sub_results.end());
590 return results;
591 }
592 auto get_item_src_abstract = get_item_src_node->abstract();
593 MS_EXCEPTION_IF_NULL(get_item_src_abstract);
594 auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack);
595 (void)std::transform(indexes.begin(), indexes.end(), std::back_inserter(results),
596 [&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
597 return results;
598 }
599
600 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
601 for (size_t i = 0; i < output_num; ++i) {
602 (void)results.emplace_back(real_node, i);
603 }
604 return results;
605 }
606
607 // Add formal parameter and real parameter into realationship map.
AddFormalToRealParameter(const AnfNodePtr & formal_parameter,const AnfNodePtr & real_parameter,const CallNodeToFuncGraph & call_node_to_func_graphs,FormalToRealParameter * const formal_to_real_parameters)608 void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodePtr &real_parameter,
609 const CallNodeToFuncGraph &call_node_to_func_graphs,
610 FormalToRealParameter *const formal_to_real_parameters) {
611 MS_EXCEPTION_IF_NULL(formal_parameter);
612 MS_EXCEPTION_IF_NULL(real_parameter);
613 MS_EXCEPTION_IF_NULL(formal_to_real_parameters);
614 auto abstract = formal_parameter->abstract();
615 if (abstract == nullptr) {
616 MS_LOG_WITH_NODE(EXCEPTION, formal_parameter) << "Empty abstract for parameter:" << formal_parameter->DebugString();
617 }
618 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
619
620 for (size_t i = 0; i < output_num; ++i) {
621 std::set<KernelWithIndex> real_parameters;
622 std::set<KernelWithIndex> invalid_call_nodes;
623 FetchRealParameterByNode({real_parameter, i}, &real_parameters, &invalid_call_nodes, call_node_to_func_graphs);
624 if (real_parameters.empty()) {
625 MS_LOG(DEBUG) << "Failed to find real parameter for formal parameter:" << real_parameter->DebugString();
626 continue;
627 }
628
629 for (const auto ¶meter : real_parameters) {
630 MS_EXCEPTION_IF_NULL(parameter.first);
631 MS_LOG(DEBUG) << "Add formal parameter:" << formal_parameter->DebugString() << " index:" << i
632 << " to real parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
633 }
634 (*formal_to_real_parameters)[{formal_parameter, i}].insert(real_parameters.begin(), real_parameters.end());
635 }
636 }
637
638 // Recursively traverse the input to confirm whether there is an input of recursive call.
IsFirstControlNode(const AnfNodePtr & node,std::set<AnfNodePtr> * checked_nodes,const std::set<AnfNodePtr> & unrecursion_call_nodes)639 bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
640 const std::set<AnfNodePtr> &unrecursion_call_nodes) {
641 MS_EXCEPTION_IF_NULL(node);
642 MS_EXCEPTION_IF_NULL(checked_nodes);
643 if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
644 return true;
645 }
646 (void)checked_nodes->emplace(node);
647
648 const auto &cnode = node->cast<CNodePtr>();
649 MS_EXCEPTION_IF_NULL(cnode);
650 const auto &inputs = cnode->inputs();
651 for (const auto &input : inputs) {
652 MS_EXCEPTION_IF_NULL(input);
653 if ((common::AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) ||
654 (!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) {
655 return false;
656 }
657 }
658 return true;
659 }
660
661 // Check if src_node depends on dst_node.
IsTopoDependNode(const AnfNodePtr & src_node,const AnfNodePtr & dst_node,std::set<AnfNodePtr> * checked_node)662 bool IsTopoDependNode(const AnfNodePtr &src_node, const AnfNodePtr &dst_node, std::set<AnfNodePtr> *checked_node) {
663 MS_EXCEPTION_IF_NULL(src_node);
664 MS_EXCEPTION_IF_NULL(dst_node);
665 MS_EXCEPTION_IF_NULL(checked_node);
666 if (src_node == dst_node) {
667 return true;
668 }
669 if (!src_node->isa<CNode>() || checked_node->find(src_node) != checked_node->end()) {
670 return false;
671 }
672
673 (void)checked_node->emplace(src_node);
674 const auto &cnode = src_node->cast<CNodePtr>();
675 MS_EXCEPTION_IF_NULL(cnode);
676 const auto &inputs = cnode->inputs();
677 for (const auto &input : inputs) {
678 MS_EXCEPTION_IF_NULL(input);
679 if (IsTopoDependNode(input, dst_node, checked_node)) {
680 return true;
681 }
682 }
683 return false;
684 }
685
IsValidBackendParameter(const AnfNodePtr & node)686 bool IsValidBackendParameter(const AnfNodePtr &node) {
687 if (node == nullptr) {
688 return false;
689 }
690 if (node->abstract() == nullptr) {
691 return true;
692 }
693 if (node->abstract()->isa<abstract::AbstractAny>()) {
694 return false;
695 }
696 const auto &shape = node->abstract()->BuildShape();
697 if (shape == nullptr || shape->IsDynamic()) {
698 return false;
699 }
700 return true;
701 }
702 } // namespace
CreateBuildInfoForFrontNode(const KernelWithIndex & front_node_with_index,const AnfNodePtr & backend_node)703 void CreateBuildInfoForFrontNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node) {
704 MS_EXCEPTION_IF_NULL(front_node_with_index.first);
705 MS_EXCEPTION_IF_NULL(backend_node);
706 const auto &front_node = front_node_with_index.first;
707 if (front_node->kernel_info() == nullptr) {
708 auto kernel_info = std::make_shared<device::KernelInfo>();
709 MS_EXCEPTION_IF_NULL(kernel_info);
710 front_node->set_kernel_info(kernel_info);
711 std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
712 MS_EXCEPTION_IF_NULL(builder);
713 kernel_info->set_select_kernel_build_info(builder->Build());
714 kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsKernelObjectType(
715 {kernel::KernelObjectType::TUPLE_UNFOLD});
716 }
717
718 // Set build info to front node.
719 auto backend_kernel_info = static_cast<device::KernelInfo *>(backend_node->kernel_info());
720 MS_EXCEPTION_IF_NULL(backend_kernel_info);
721 auto backend_build_info = backend_kernel_info->GetMutableSelectKernelBuildInfo();
722 MS_EXCEPTION_IF_NULL(backend_build_info);
723
724 auto front_kernel_info = static_cast<device::KernelInfo *>(front_node->kernel_info());
725 MS_EXCEPTION_IF_NULL(front_kernel_info);
726 auto front_build_info = front_kernel_info->GetMutableSelectKernelBuildInfo();
727 MS_EXCEPTION_IF_NULL(front_build_info);
728 // Set output format and device data type.
729 if (front_build_info->GetAllOutputFormats().size() > front_node_with_index.second) {
730 front_build_info->SetOutputFormat(backend_build_info->GetOutputFormat(0), front_node_with_index.second);
731 front_build_info->SetOutputDeviceType(backend_build_info->GetOutputDeviceType(0), front_node_with_index.second);
732 } else {
733 auto formats = front_build_info->GetAllOutputFormats();
734 auto types = front_build_info->GetAllOutputDeviceTypes();
735 for (size_t i = 0; i <= front_node_with_index.second - front_build_info->GetAllOutputFormats().size(); ++i) {
736 (void)formats.emplace_back(backend_build_info->GetOutputFormat(0));
737 (void)types.emplace_back(backend_build_info->GetOutputDeviceType(0));
738 }
739 front_build_info->SetOutputsFormat(formats);
740 front_build_info->SetOutputsDeviceType(types);
741 }
742 }
743
IsInvalidPartial(const AnfNodePtr & node)744 bool IsInvalidPartial(const AnfNodePtr &node) {
745 MS_EXCEPTION_IF_NULL(node);
746 if (!node->isa<CNode>()) {
747 return false;
748 }
749
750 const auto &cnode = node->cast<CNodePtr>();
751 MS_EXCEPTION_IF_NULL(cnode);
752 const auto &inputs = cnode->inputs();
753 if (inputs.size() <= kPartialFuncGraphPos) {
754 return false;
755 }
756
757 if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
758 return false;
759 }
760 if (IsDeadNode(inputs[kPartialFuncGraphPos])) {
761 return true;
762 }
763 return false;
764 }
765
FetchRealNodeByGetItem(const KernelWithIndex & node_with_index)766 KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
767 MS_EXCEPTION_IF_NULL(node_with_index.first);
768 std::vector<size_t> index_stack{node_with_index.second};
769
770 const auto &get_item_src_node = common::AnfAlgo::GetTupleIndexes(node_with_index.first, &index_stack);
771 MS_EXCEPTION_IF_NULL(get_item_src_node);
772 const auto &get_item_src_abstract = get_item_src_node->abstract();
773 MS_EXCEPTION_IF_NULL(get_item_src_abstract);
774 auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack);
775 if (indexes.empty()) {
776 MS_LOG_WITH_NODE(EXCEPTION, get_item_src_node) << "Failed to find index for node:" << get_item_src_node;
777 }
778 if (indexes.size() > 1) {
779 MS_LOG(DEBUG) << "Output size:" << indexes.size() << " for node:" << get_item_src_node->DebugString()
780 << " more than 1";
781 }
782 return {get_item_src_node, *(indexes.begin())};
783 }
784
IsCsrNode(const AnfNodePtr & node)785 bool IsCsrNode(const AnfNodePtr &node) {
786 MS_EXCEPTION_IF_NULL(node);
787 if (!node->isa<CNode>()) {
788 return false;
789 }
790 return common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndptr) ||
791 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndices) ||
792 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetValues) ||
793 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetDenseShape);
794 }
795
IsCooNode(const AnfNodePtr & node)796 bool IsCooNode(const AnfNodePtr &node) {
797 MS_EXCEPTION_IF_NULL(node);
798 if (!node->isa<CNode>()) {
799 return false;
800 }
801 return common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCOOTensorGetIndices) ||
802 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCOOTensorGetValues) ||
803 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCOOTensorGetDenseShape);
804 }
805
GetFrontNodeByKernelGraph(const AnfNodePtr & backend_node,const KernelGraph * const graph)806 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph) {
807 MS_EXCEPTION_IF_NULL(backend_node);
808 MS_EXCEPTION_IF_NULL(graph);
809 const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
810 if (front_node != nullptr) {
811 MS_LOG(DEBUG) << "Front node:" << front_node->DebugString() << " index:0"
812 << " for backend node:" << backend_node->DebugString();
813 return {front_node, 0};
814 }
815 const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
816 if (front_node_with_index.first != nullptr) {
817 MS_LOG(DEBUG) << "Internal front node:" << front_node_with_index.first->DebugString()
818 << " index:" << front_node_with_index.second << " for backend node:" << backend_node->DebugString();
819 return front_node_with_index;
820 }
821 const auto &front_tuple_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(backend_node);
822 if (front_tuple_node_with_index.first == nullptr) {
823 MS_LOG_WITH_NODE(EXCEPTION, backend_node)
824 << "Cannot find front node for backend node:" << backend_node->DebugString() << " in graph:" << graph->ToString();
825 }
826 MS_LOG(DEBUG) << "Tuple front node:" << front_tuple_node_with_index.first->DebugString()
827 << " index:" << front_tuple_node_with_index.second;
828 return front_tuple_node_with_index;
829 }
830
FetchInputNodeByCNode(const AnfNodePtr & node)831 std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
832 MS_EXCEPTION_IF_NULL(node);
833 MS_LOG(DEBUG) << "Fetch input node for:" << node->DebugString();
834 if (!node->isa<CNode>()) {
835 MS_LOG(DEBUG) << "Empty input node for:" << node->DebugString();
836 return {};
837 }
838
839 std::vector<KernelWithIndex> results;
840 // The first input of normal cnode is the primitive of node, and the real input starts from the second input,
841 // but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial.
842 size_t input_start_pos = kCNodeInputStartPos;
843 if (common::AnfAlgo::IsCallNode(node)) {
844 input_start_pos = 0;
845 }
846 const auto &cnode = node->cast<CNodePtr>();
847 MS_EXCEPTION_IF_NULL(cnode);
848 const auto inputs = cnode->inputs();
849
850 // The first branch of the input of the switch node is the true branch, and the second is the false branch.
851 // But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input
852 // of the switch node needs to exchange the positions of the two branches. So deal separately.
853 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
854 if (inputs.size() != kSwitchInputNum) {
855 MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid switch node:" << node->DebugString();
856 }
857 (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0));
858 (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0));
859 (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0));
860 return results;
861 }
862
863 for (size_t i = input_start_pos; i < inputs.size(); ++i) {
864 MS_EXCEPTION_IF_NULL(inputs[i]);
865 const auto &sub_results = FetchInputNodeByNode(inputs[i]);
866 (void)results.insert(results.end(), sub_results.begin(), sub_results.end());
867 }
868 return results;
869 }
870
IsPartialInput(const AnfNodePtr & node)871 bool IsPartialInput(const AnfNodePtr &node) {
872 MS_EXCEPTION_IF_NULL(node);
873 const auto &abstract = node->abstract();
874 if (abstract != nullptr) {
875 if (abstract->isa<abstract::AbstractFunction>()) {
876 return true;
877 }
878 return false;
879 }
880
881 if (!node->isa<CNode>()) {
882 return false;
883 }
884
885 // If the abstract is empty and the node is a cnode, check its true branch.
886 const auto &cnode = node->cast<CNodePtr>();
887 MS_EXCEPTION_IF_NULL(cnode);
888
889 const auto &inputs = cnode->inputs();
890 if (inputs.size() < kSwitchTrueBranchIndex + 1) {
891 MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid switch node:" << node->DebugString();
892 }
893 const auto &branch_node = inputs[kSwitchTrueBranchIndex];
894 MS_EXCEPTION_IF_NULL(branch_node);
895 const auto &branch_abstract = branch_node->abstract();
896 // If abstract is empty, the default is true.
897 if (branch_abstract == nullptr) {
898 MS_LOG(DEBUG) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
899 return true;
900 }
901
902 if (branch_abstract->isa<abstract::AbstractFunction>()) {
903 return true;
904 } else if (branch_abstract->isa<abstract::AbstractSequence>()) {
905 // In switch layer, the true branch input is a make tuple.
906 auto sequence_abstract = branch_abstract->cast<abstract::AbstractSequencePtr>();
907 MS_EXCEPTION_IF_NULL(sequence_abstract);
908 const auto &sub_abstracts = sequence_abstract->elements();
909 if (sub_abstracts.empty() || sub_abstracts[0] == nullptr) {
910 MS_LOG(DEBUG) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
911 return true;
912 }
913 if (sub_abstracts[0]->isa<abstract::AbstractFunction>()) {
914 return true;
915 }
916 }
917 return false;
918 }
919
920 // Fetch the depend nodes according to the monad node.
FetchRealDependNodeByAutoMonad(const AnfNodePtr & node,std::set<AnfNodePtr> * const depend_nodes)921 void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *const depend_nodes) {
922 // Find the real input node, include the monad node and make tuple node.
923 const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
924 prim::kPrimMakeTuple};
925 const auto &node_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_types);
926 auto real_node = node_with_index.first;
927 MS_EXCEPTION_IF_NULL(real_node);
928 if (!real_node->isa<CNode>()) {
929 return;
930 }
931
932 const auto &real_cnode = real_node->cast<CNodePtr>();
933 MS_EXCEPTION_IF_NULL(real_cnode);
934 const auto &real_inputs = real_cnode->inputs();
935
936 // Make tuple node needs to be expanded.
937 if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
938 for (size_t i = 1; i < real_inputs.size(); ++i) {
939 MS_EXCEPTION_IF_NULL(real_inputs[i]);
940 FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
941 }
942 return;
943 }
944
945 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
946 prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
947 if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimDepend) ||
948 common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimLoad)) {
949 FetchRealDependNodeByAutoMonad(real_inputs[kDependAttachNodeIndex], depend_nodes);
950 // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
951 // node in this scene.
952 if (IsOneOfPrimitiveCNode(real_inputs[kRealInputIndexInDepend], recursion_prims)) {
953 FetchRealDependNodeByAutoMonad(real_inputs[kRealInputIndexInDepend], depend_nodes);
954 }
955 } else if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimUpdateState)) {
956 for (size_t i = kUpdateStateRealInput; i < real_inputs.size(); ++i) {
957 FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
958 }
959 } else {
960 MS_EXCEPTION_IF_NULL(depend_nodes);
961 (void)depend_nodes->emplace(real_node);
962 }
963 }
964
965 // Get all the depend nodes of node in side effect.
FetchAllMonadNodeByNode(const AnfNodePtr & node)966 std::vector<AnfNodePtr> FetchAllMonadNodeByNode(const AnfNodePtr &node) {
967 MS_EXCEPTION_IF_NULL(node);
968 if (!node->isa<CNode>()) {
969 return {};
970 }
971 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) ||
972 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) ||
973 common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad)) {
974 return {node};
975 }
976
977 std::vector<AnfNodePtr> results;
978 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
979 const auto &cnode = node->cast<CNodePtr>();
980 MS_EXCEPTION_IF_NULL(cnode);
981 for (auto &weak_input : cnode->weak_inputs()) {
982 auto input = weak_input.lock();
983 MS_EXCEPTION_IF_NULL(input);
984 const auto &result = FetchAllMonadNodeByNode(input);
985 (void)results.insert(results.end(), result.begin(), result.end());
986 }
987 }
988 return results;
989 }
990
Parse(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const FuncGraphPtr & root_graph,const FuncGraphToKernelGraphGroup & func_graph_to_kernel_graphs)991 void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
992 const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
993 const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
994 if (graphs.size() != device_contexts.size()) {
995 MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size()
996 << " device context num:" << device_contexts.size();
997 }
998
999 if (control_nodes.size() <= 1) {
1000 MS_LOG(DEBUG) << "Control node parser is not inited.";
1001 return;
1002 }
1003 MS_LOG(INFO) << "Control node parse start.";
1004
1005 // Fetch default device context.
1006 auto context_ptr = MsContext::GetInstance();
1007 MS_EXCEPTION_IF_NULL(context_ptr);
1008 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1009 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1010 DeviceContext *default_context = nullptr;
1011 if (device_contexts.empty()) {
1012 default_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
1013 } else {
1014 default_context = device_contexts[0];
1015 }
1016 MS_EXCEPTION_IF_NULL(default_context);
1017
1018 KernelGraphToDeviceContext kernel_graph_to_device_contexts;
1019 for (size_t i = 0; i < graphs.size(); ++i) {
1020 kernel_graph_to_device_contexts[graphs[i]] = device_contexts[i];
1021 }
1022
1023 for (const auto &control_node : control_nodes) {
1024 MS_EXCEPTION_IF_NULL(control_node);
1025 MS_LOG(DEBUG) << "Print control node:" << control_node->DebugString();
1026 }
1027
1028 is_inited_ = true;
1029
1030 root_func_graph_ = root_graph;
1031
1032 root_graph_parameters_ = root_graph->parameters();
1033
1034 func_graph_to_kernel_graph_groups_ = func_graph_to_kernel_graphs;
1035 for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
1036 for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
1037 for (const auto &kernel_graph : kernel_graph_group) {
1038 MS_EXCEPTION_IF_NULL(func_graph_to_kernel_graph_groups.first);
1039 MS_EXCEPTION_IF_NULL(kernel_graph);
1040 MS_LOG(DEBUG) << "Funcgraph to kernel graph, func:" << func_graph_to_kernel_graph_groups.first->ToString()
1041 << " kernel_graph:" << kernel_graph->ToString();
1042 }
1043 }
1044 }
1045
1046 CreateBranchIDForCallNode(control_nodes);
1047
1048 ParseFrontNodeToKernelGraph(graphs);
1049
1050 ParseCallNodeToFuncGraph(control_nodes);
1051
1052 ParseUnRecursionCallNode();
1053
1054 InsertDependForParallelCall(control_nodes);
1055
1056 ParseKernelGraphGroup(kernel_graph_to_device_contexts);
1057
1058 ParseNodeLevel(control_nodes);
1059
1060 ParseNeedStackControlNode(control_nodes);
1061
1062 ParseFormalToRealParameter(control_nodes);
1063
1064 ParseFrontToBackendParameter(graphs, device_contexts);
1065
1066 CreateDeviceTensorForRootGraphParameter(default_context);
1067
1068 ParseFrontToBackendKernel(graphs, device_contexts);
1069
1070 ParseDeviceContext(control_nodes, graphs, device_contexts, default_context, func_graph_to_kernel_graphs);
1071
1072 FetchFrontValueNode(control_nodes, default_context);
1073
1074 ParseControlNodeParameter(control_nodes);
1075
1076 ParseFirstControlNodeAndKernelGraphForFuncGraph(control_nodes);
1077
1078 ParseDynamicLenFormalParameter(control_nodes);
1079 MS_LOG(INFO) << "Control node parse end.";
1080 }
1081
1082 namespace {
GetArgumentIndexForDynamicLenParameter(const abstract::AbstractBasePtr & argument_abs,size_t argument_index,const abstract::AbstractBasePtr & parameter_abs,mindspore::HashMap<size_t,size_t> * indexes)1083 void GetArgumentIndexForDynamicLenParameter(const abstract::AbstractBasePtr &argument_abs, size_t argument_index,
1084 const abstract::AbstractBasePtr ¶meter_abs,
1085 mindspore::HashMap<size_t, size_t> *indexes) {
1086 if (argument_abs == nullptr || parameter_abs == nullptr) {
1087 return;
1088 }
1089 MS_EXCEPTION_IF_NULL(indexes);
1090 if ((!argument_abs->isa<abstract::AbstractSequence>()) || (!parameter_abs->isa<abstract::AbstractSequence>())) {
1091 return;
1092 }
1093 const auto &arg_seq_abs = argument_abs->cast<abstract::AbstractSequencePtr>();
1094 const auto ¶_seq_abs = parameter_abs->cast<abstract::AbstractSequencePtr>();
1095 MS_EXCEPTION_IF_NULL(arg_seq_abs);
1096 MS_EXCEPTION_IF_NULL(para_seq_abs);
1097 if (arg_seq_abs->dynamic_len() && para_seq_abs->dynamic_len()) {
1098 return;
1099 }
1100 if ((!arg_seq_abs->dynamic_len()) && para_seq_abs->dynamic_len()) {
1101 MS_LOG(DEBUG) << "Add argument index:" << argument_index << " size:" << arg_seq_abs->size();
1102 (*indexes)[argument_index] = arg_seq_abs->size();
1103 return;
1104 }
1105 if (arg_seq_abs->dynamic_len() || para_seq_abs->dynamic_len() || arg_seq_abs->size() != para_seq_abs->size()) {
1106 MS_LOG(EXCEPTION) << "Invalid dynamic len flag for argument abstract:" << arg_seq_abs->ToString()
1107 << " parameter abstract:" << para_seq_abs->ToString();
1108 }
1109 size_t start_index = argument_index;
1110 for (size_t i = 0; i < arg_seq_abs->size(); ++i) {
1111 GetArgumentIndexForDynamicLenParameter(arg_seq_abs->elements()[i], start_index, para_seq_abs->elements()[i],
1112 indexes);
1113 start_index += common::AnfAlgo::GetOutputNumByAbstract(arg_seq_abs->elements()[i]);
1114 }
1115 }
1116 } // namespace
1117
ParseDynamicLenFormalParameterByCallNode(const AnfNodePtr & node)1118 void ControlNodeParser::ParseDynamicLenFormalParameterByCallNode(const AnfNodePtr &node) {
1119 MS_EXCEPTION_IF_NULL(node);
1120 const auto &cnode = node->cast<CNodePtr>();
1121 MS_EXCEPTION_IF_NULL(cnode);
1122 const auto &func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
1123 if (func_graphs.empty()) {
1124 MS_LOG(EXCEPTION) << "Get func_graph from abstract failed.";
1125 }
1126 mindspore::HashMap<size_t, size_t> sequence_indexes;
1127 for (auto func_graph : func_graphs) {
1128 MS_EXCEPTION_IF_NULL(func_graph);
1129 // Check the consistency of return outputs and call outputs.
1130 MS_EXCEPTION_IF_NULL(func_graph->return_node());
1131 mindspore::HashMap<size_t, size_t> return_sequence_indexes;
1132 GetArgumentIndexForDynamicLenParameter(func_graph->return_node()->abstract(), 0, node->abstract(),
1133 &return_sequence_indexes);
1134 if (!return_sequence_indexes.empty()) {
1135 return_to_call_with_dynamic_sequence_index_[func_graph->return_node()][node] = return_sequence_indexes;
1136 }
1137 // Check the consistency of arguments and parameters.
1138 if (cnode->inputs().empty()) {
1139 MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid cnode:" << cnode->DebugString();
1140 }
1141 size_t args_num = cnode->size() - 1;
1142 size_t para_num = func_graph->parameters().size();
1143 MS_LOG(DEBUG) << "for call node:" << cnode->DebugString() << " arg size:" << args_num << " para size:" << para_num;
1144 if (args_num > para_num) {
1145 MS_LOG(EXCEPTION) << "Invalid args num:" << args_num << " for funcgraph:" << func_graph->ToString()
1146 << " parameters num:" << func_graph->parameters().size();
1147 }
1148 size_t start_index = 1;
1149 for (size_t i = 0; i < args_num; ++i) {
1150 MS_EXCEPTION_IF_NULL(cnode->input(i + 1));
1151 MS_EXCEPTION_IF_NULL((func_graph->parameters())[i + para_num - args_num]);
1152 MS_LOG(DEBUG) << "Check formal parameter:" << cnode->input(i + 1)->DebugString()
1153 << " real node:" << (func_graph->parameters())[i + para_num - args_num]->DebugString();
1154 GetArgumentIndexForDynamicLenParameter(cnode->input(i + 1)->abstract(), start_index,
1155 (func_graph->parameters())[i + para_num - args_num]->abstract(),
1156 &sequence_indexes);
1157 start_index += common::AnfAlgo::GetOutputNumByAbstract(cnode->input(i + 1)->abstract());
1158 }
1159 if (!sequence_indexes.empty()) {
1160 for (const auto &pair : sequence_indexes) {
1161 MS_LOG(DEBUG) << "Add dynamic len formal parameter for call node:" << node->DebugString()
1162 << " funcgraph:" << func_graph->ToString() << " argument index:" << pair.first
1163 << " size:" << pair.second;
1164 }
1165 control_node_to_funcgraph_with_dynamic_sequence_index_[node][func_graph.get()] = sequence_indexes;
1166 }
1167 }
1168 }
1169
ParseDynamicLenFormalParameterByPartial(const AnfNodePtr & node)1170 void ControlNodeParser::ParseDynamicLenFormalParameterByPartial(const AnfNodePtr &node) {
1171 MS_EXCEPTION_IF_NULL(node);
1172 const auto &cnode = node->cast<CNodePtr>();
1173 MS_EXCEPTION_IF_NULL(cnode);
1174 size_t input_num = cnode->size();
1175 if (input_num <= kPartialFuncGraphPos || cnode->input(kPartialFuncGraphPos) == nullptr ||
1176 (!cnode->input(kPartialFuncGraphPos)->isa<ValueNode>())) {
1177 MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid partial node:" << node->DebugString();
1178 }
1179 const auto &func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kPartialFuncGraphPos));
1180 if (func_graph == nullptr) {
1181 MS_LOG(DEBUG) << "Failed to get funcgraph in partial node:" << node->DebugString();
1182 return;
1183 }
1184 if (func_graph->parameters().size() < input_num - kPartialInputStartPos) {
1185 MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid args num:" << input_num - kPartialInputStartPos
1186 << " in partial node:" << cnode->DebugString()
1187 << " for fungraph:" << func_graph->ToString()
1188 << " parameter num:" << func_graph->parameters().size();
1189 }
1190 size_t start_index = 1;
1191 mindspore::HashMap<size_t, size_t> sequence_indexes;
1192 for (size_t i = kPartialInputStartPos; i < input_num; ++i) {
1193 MS_EXCEPTION_IF_NULL(cnode->input(i));
1194 MS_EXCEPTION_IF_NULL(func_graph->parameters()[i - kPartialInputStartPos]);
1195 GetArgumentIndexForDynamicLenParameter(cnode->input(i)->abstract(), start_index,
1196 func_graph->parameters()[i - kPartialInputStartPos]->abstract(),
1197 &sequence_indexes);
1198 start_index += common::AnfAlgo::GetOutputNumByAbstract(cnode->input(i)->abstract());
1199 }
1200 if (!sequence_indexes.empty()) {
1201 mindspore::HashMap<size_t, size_t> new_sequence_indexes;
1202 for (const auto &index_pair : sequence_indexes) {
1203 new_sequence_indexes[index_pair.first + 1] = index_pair.second;
1204 }
1205 control_node_to_funcgraph_with_dynamic_sequence_index_[node][func_graph.get()] = new_sequence_indexes;
1206 }
1207 }
1208
ParseDynamicLenFormalParameter(const std::vector<AnfNodePtr> & control_nodes)1209 void ControlNodeParser::ParseDynamicLenFormalParameter(const std::vector<AnfNodePtr> &control_nodes) {
1210 for (const auto &node : control_nodes) {
1211 MS_EXCEPTION_IF_NULL(node);
1212 if (common::AnfAlgo::IsCallNode(node)) {
1213 ParseDynamicLenFormalParameterByCallNode(node);
1214 } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
1215 ParseDynamicLenFormalParameterByPartial(node);
1216 }
1217 }
1218 for (const auto &node_to_func_with_index : control_node_to_funcgraph_with_dynamic_sequence_index_) {
1219 const auto &node = node_to_func_with_index.first;
1220 MS_EXCEPTION_IF_NULL(node);
1221 for (const auto &func_with_index : node_to_func_with_index.second) {
1222 const auto &func_graph = func_with_index.first;
1223 MS_EXCEPTION_IF_NULL(func_graph);
1224 for (const auto &indexes : func_with_index.second) {
1225 MS_LOG(DEBUG) << "Node:" << node->DebugString() << " func_graph:" << func_graph->ToString()
1226 << " start index:" << indexes.first << " size:" << indexes.second;
1227 }
1228 }
1229 }
1230 for (const auto &node_to_call_with_index : return_to_call_with_dynamic_sequence_index_) {
1231 const auto &node = node_to_call_with_index.first;
1232 MS_EXCEPTION_IF_NULL(node);
1233 for (const auto &call_with_index : node_to_call_with_index.second) {
1234 const auto &call = call_with_index.first;
1235 MS_EXCEPTION_IF_NULL(call);
1236 for (const auto &indexes : call_with_index.second) {
1237 MS_LOG(DEBUG) << "Node:" << node->DebugString() << " call node:" << call->DebugString()
1238 << " start index:" << indexes.first << " size:" << indexes.second;
1239 }
1240 }
1241 }
1242 }
1243
1244 // Fetch all the funcgraph recursively that the call node will call.
FetchAllCalledFuncGraph(const AnfNodePtr & call_node,std::set<FuncGraphPtr> * called_graphs,const CallNodeToFuncGraph & call_node_to_func_graphs,const FuncGraphToCallNode & func_graph_to_call_nodes)1245 void FetchAllCalledFuncGraph(const AnfNodePtr &call_node, std::set<FuncGraphPtr> *called_graphs,
1246 const CallNodeToFuncGraph &call_node_to_func_graphs,
1247 const FuncGraphToCallNode &func_graph_to_call_nodes) {
1248 MS_EXCEPTION_IF_NULL(call_node);
1249 MS_EXCEPTION_IF_NULL(called_graphs);
1250 const auto &call_iter = call_node_to_func_graphs.find(call_node);
1251 if (call_iter == call_node_to_func_graphs.end()) {
1252 return;
1253 }
1254 for (const auto &func_graph : call_iter->second) {
1255 MS_EXCEPTION_IF_NULL(func_graph);
1256 if (called_graphs->find(func_graph) != called_graphs->end()) {
1257 continue;
1258 }
1259 (void)called_graphs->emplace(func_graph);
1260 const auto &graph_iter = func_graph_to_call_nodes.find(func_graph);
1261 if (graph_iter == func_graph_to_call_nodes.end()) {
1262 continue;
1263 }
1264
1265 // Fetch the funcgraph recursively.
1266 for (const auto &node : graph_iter->second) {
1267 FetchAllCalledFuncGraph(node, called_graphs, call_node_to_func_graphs, func_graph_to_call_nodes);
1268 }
1269 }
1270 }
1271
CreateTensorForValue(const ValuePtr & value)1272 tensor::TensorPtr ControlNodeParser::CreateTensorForValue(const ValuePtr &value) {
1273 MS_EXCEPTION_IF_NULL(value);
1274 tensor::TensorPtr tensor = nullptr;
1275 if (value->isa<Monad>()) {
1276 tensor = std::make_shared<tensor::Tensor>(int8_t('U'), TypeIdToType(kNumberTypeInt8));
1277 } else if (value->isa<Scalar>()) {
1278 const auto scalar_value = value->cast<ScalarPtr>();
1279 MS_EXCEPTION_IF_NULL(scalar_value);
1280 tensor = ScalarToTensor(scalar_value);
1281 } else {
1282 MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
1283 }
1284 control_node_tensors_.emplace_back(tensor);
1285 return tensor;
1286 }
1287
IsParallelCallRecursionGraph(const AnfNodePtr & call_node1,const AnfNodePtr & call_node2,const FuncGraphToCallNode & func_graph_to_call_nodes)1288 bool ControlNodeParser::IsParallelCallRecursionGraph(const AnfNodePtr &call_node1, const AnfNodePtr &call_node2,
1289 const FuncGraphToCallNode &func_graph_to_call_nodes) {
1290 // Fetch all funcgraphs the two call nodes will call both.
1291 std::set<FuncGraphPtr> called_graphs_1;
1292 FetchAllCalledFuncGraph(call_node1, &called_graphs_1, call_node_to_func_graphs_, func_graph_to_call_nodes);
1293 std::set<FuncGraphPtr> called_graphs_2;
1294 FetchAllCalledFuncGraph(call_node2, &called_graphs_2, call_node_to_func_graphs_, func_graph_to_call_nodes);
1295 std::vector<FuncGraphPtr> common_called_graphs;
1296 (void)std::set_intersection(called_graphs_1.begin(), called_graphs_1.end(), called_graphs_2.begin(),
1297 called_graphs_2.end(), std::back_inserter(common_called_graphs));
1298
1299 // Check for recursive calls in funcgraph.
1300 for (const auto &func_graph : common_called_graphs) {
1301 MS_EXCEPTION_IF_NULL(func_graph);
1302 const auto &iter = func_graph_to_call_nodes.find(func_graph);
1303 if (iter == func_graph_to_call_nodes.end()) {
1304 continue;
1305 }
1306 for (const auto &call_node : iter->second) {
1307 MS_EXCEPTION_IF_NULL(call_node);
1308 if (IsRecursionCallNode(call_node)) {
1309 MS_LOG(INFO) << "Call node:" << call_node1->DebugString() << " and:" << call_node2->DebugString()
1310 << " would call the same recursion in graph:" << func_graph
1311 << " which has a recursion call:" << call_node->DebugString();
1312 return true;
1313 }
1314 }
1315 }
1316 return false;
1317 }
1318
InsertDependForParallelCall(const std::vector<AnfNodePtr> & control_nodes)1319 void ControlNodeParser::InsertDependForParallelCall(const std::vector<AnfNodePtr> &control_nodes) {
1320 MS_LOG(INFO) << "InsertDependForParallelCall start";
1321 std::vector<AnfNodePtr> call_nodes;
1322 for (const auto &control_node : control_nodes) {
1323 MS_EXCEPTION_IF_NULL(control_node);
1324 if (!common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
1325 if (common::AnfAlgo::IsCallNode(control_node)) {
1326 // Fetch all the call nodes in the same graph.
1327 (void)call_nodes.emplace_back(control_node);
1328 }
1329 continue;
1330 }
1331
1332 // Check whether there is a topology relationship between call nodes.
1333 for (size_t i = 0; i < call_nodes.size(); ++i) {
1334 for (size_t j = 0; j < i; ++j) {
1335 std::set<AnfNodePtr> checked_nodes;
1336 if ((!IsParallelCallRecursionGraph(call_nodes[i], call_nodes[j], func_graph_to_call_nodes_)) ||
1337 IsTopoDependNode(call_nodes[i], call_nodes[j], &checked_nodes)) {
1338 continue;
1339 }
1340 // If there is no topological relationship between call nodes, and the same recursive graph will be called
1341 // at the same time, then a depend node needs to be inserted between call nodes.
1342 auto func_graph = call_nodes[i]->func_graph();
1343 MS_EXCEPTION_IF_NULL(func_graph);
1344 auto cnode = call_nodes[i]->cast<CNodePtr>();
1345 MS_EXCEPTION_IF_NULL(cnode);
1346 const auto &inputs = cnode->inputs();
1347 MS_EXCEPTION_IF_NULL(inputs[0]);
1348
1349 // Create a depend node.
1350 std::vector<AnfNodePtr> depend_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
1351 cnode->input(0), call_nodes[j]};
1352 auto new_depend = func_graph->NewCNode(depend_inputs);
1353 MS_EXCEPTION_IF_NULL(new_depend);
1354 new_depend->set_abstract(cnode->input(0)->abstract());
1355
1356 // Set depend node to call input.
1357 std::vector<AnfNodePtr> new_call_inputs{new_depend};
1358 for (size_t k = 1; k < inputs.size(); ++k) {
1359 (void)new_call_inputs.emplace_back(inputs[k]);
1360 }
1361 cnode->set_inputs(new_call_inputs);
1362 MS_LOG(INFO) << "Add depend node:" << new_depend->DebugString()
1363 << " for call node:" << call_nodes[i]->DebugString() << " and:" << call_nodes[j]->DebugString();
1364 }
1365 }
1366 call_nodes.clear();
1367 }
1368 MS_LOG(INFO) << "InsertDependForParallelCall end";
1369 }
1370
IsControlFlowDataArrow(const KernelGraphPtr & graph,const AnfNodePtr & backend_node)1371 bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) {
1372 MS_EXCEPTION_IF_NULL(graph);
1373 // Has no control flow node.
1374 if (!IsInited()) {
1375 return false;
1376 }
1377
1378 MS_EXCEPTION_IF_NULL(backend_node);
1379 if (!backend_node->isa<Parameter>()) {
1380 return false;
1381 }
1382 auto parameter_node = backend_node->cast<ParameterPtr>();
1383 MS_EXCEPTION_IF_NULL(parameter_node);
1384
1385 // Parameter input should be linked to its entrance actor.
1386 auto front_node = graph->GetFrontAnfByBackendAnf(backend_node);
1387 auto internal_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
1388 front_node = (front_node != nullptr ? front_node : internal_node_with_index.first);
1389 if (front_node == nullptr) {
1390 auto front_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(backend_node);
1391 front_node = front_node_with_index.first;
1392 }
1393 MS_EXCEPTION_IF_NULL(front_node);
1394 const auto &real_front_node = common::AnfAlgo::VisitKernelWithReturnType(front_node, 0).first;
1395 if (real_front_node != nullptr && real_front_node->isa<ValueNode>() && (!HasAbstractMonad(real_front_node))) {
1396 // If the real front node is a value node, we have two situations:
1397 // 1. if the value in value node is a tensor, it should be set into device tensor store by graph scheduler;
1398 // 2. if the value is a monad state, it should be converted to control arrow, which should link by control
1399 // node scheduler.
1400 MS_LOG(DEBUG) << "Front node:" << real_front_node->DebugString()
1401 << " of backend node:" << backend_node->DebugString() << " is a valuenode.";
1402 return false;
1403 }
1404
1405 // If parameter is a weight node in root funcgraph, it should be set to kernel actor directly.
1406 if (IsRootGraphPersistentDeviceTensor(front_node)) {
1407 MS_LOG(DEBUG) << "backend node:" << backend_node->DebugString()
1408 << " front node:" << (front_node == nullptr ? "null" : front_node->DebugString());
1409 return false;
1410 }
1411
1412 // If the input front node and graph not in same graph group, the input arrow should be link to the exit actor
1413 // of the graph.
1414 if (!IsSameKernelGraphGroup(front_node, graph)) {
1415 return true;
1416 }
1417
1418 // If the graph has a call input, all of its inputs in the graph should be linked to its stack actor.
1419 if (IsCallInputKernelGraph(graph.get())) {
1420 // If the input come from a kernel graph belong the same group, it should be linked by internal parameter.
1421 if (front_node != nullptr && (IsSameKernelGraphGroup(front_node, graph) || front_node->isa<ValueNode>())) {
1422 return false;
1423 }
1424 return true;
1425 }
1426
1427 return (front_node != nullptr && front_node->isa<Parameter>());
1428 }
1429
IsRootGraphPersistentDeviceTensor(const AnfNodePtr & node)1430 bool ControlNodeParser::IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node) {
1431 MS_EXCEPTION_IF_NULL(node);
1432 if (!IsPersistentDeviceTensor(node)) {
1433 return false;
1434 }
1435
1436 // No control flow.
1437 if (!is_inited_) {
1438 return true;
1439 }
1440
1441 // Maybe the load node, need fetch the real parameter node.
1442 auto real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
1443 MS_EXCEPTION_IF_NULL(real_node);
1444 return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), real_node) != root_graph_parameters_.end();
1445 }
1446
IsNeedStackControlNode(const AnfNodePtr & node)1447 bool ControlNodeParser::IsNeedStackControlNode(const AnfNodePtr &node) {
1448 MS_EXCEPTION_IF_NULL(node);
1449 if (!(node->isa<CNode>())) {
1450 return false;
1451 }
1452
1453 return need_stack_control_nodes_.find(node) != need_stack_control_nodes_.end();
1454 }
1455
IsRecursionCallNode(const AnfNodePtr & node)1456 bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
1457 MS_EXCEPTION_IF_NULL(node);
1458 if (!common::AnfAlgo::IsCallNode(node)) {
1459 return false;
1460 }
1461 return unrecursion_call_nodes_.find(node) == unrecursion_call_nodes_.end();
1462 }
1463
IsRecursionKernelGraph(const KernelGraphPtr & graph)1464 bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
1465 MS_EXCEPTION_IF_NULL(graph);
1466 auto group_info_iter = kernel_graphs_to_group_info_.find(graph);
1467 if (group_info_iter == kernel_graphs_to_group_info_.end()) {
1468 MS_LOG(EXCEPTION) << "Invalid kernel graph:" << graph->ToString();
1469 }
1470 MS_EXCEPTION_IF_NULL(group_info_iter->second);
1471 if (!group_info_iter->second->need_stack_) {
1472 return false;
1473 }
1474 for (const auto &front_input_node : group_info_iter->second->front_input_nodes_) {
1475 const auto &node = front_input_node.first.first;
1476 MS_EXCEPTION_IF_NULL(node);
1477 if (IsRecursionCallNode(node)) {
1478 return true;
1479 }
1480 }
1481 return false;
1482 }
1483
IsSameKernelGraphGroup(const AnfNodePtr & node,const KernelGraphPtr & graph)1484 bool ControlNodeParser::IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph) {
1485 MS_EXCEPTION_IF_NULL(node);
1486 MS_EXCEPTION_IF_NULL(graph);
1487 if (!node->isa<CNode>()) {
1488 MS_LOG(DEBUG) << "Not a cnode:" << node->DebugString();
1489 return false;
1490 }
1491
1492 const auto node_graph = FetchKernelGraphByFrontNode(node);
1493 if (node_graph == nullptr) {
1494 MS_LOG(DEBUG) << "Fail to get kernel graph for cnode:" << node->DebugString();
1495 return false;
1496 }
1497 MS_LOG(DEBUG) << "Get kernel graph:" << node_graph->ToString() << " for cnode:" << node->DebugString()
1498 << " compare to graph:" << graph->ToString();
1499 const auto iter1 = kernel_graphs_to_group_info_.find(node_graph);
1500 const auto iter2 = kernel_graphs_to_group_info_.find(graph);
1501
1502 return iter1 != kernel_graphs_to_group_info_.end() && iter2 != kernel_graphs_to_group_info_.end() &&
1503 iter1->second == iter2->second;
1504 }
1505
ParseDeviceContext(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & kernel_graphs,const std::vector<DeviceContext * > & device_contexts,DeviceContext * default_context,const FuncGraphToKernelGraphGroup & func_graph_to_kernel_graphs)1506 void ControlNodeParser::ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
1507 const std::vector<KernelGraphPtr> &kernel_graphs,
1508 const std::vector<DeviceContext *> &device_contexts,
1509 DeviceContext *default_context,
1510 const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
1511 MS_EXCEPTION_IF_NULL(default_context);
1512 ParseDeviceContextForFuncGraph(kernel_graphs, device_contexts, default_context, func_graph_to_kernel_graphs);
1513 ParseDeviceContextForReturnNode(default_context);
1514 ParseDeviceContextForCallNode(control_nodes);
1515 ParseDeviceContextForPartialNode(control_nodes);
1516 }
1517
ParseDeviceContextForFuncGraph(const std::vector<KernelGraphPtr> & kernel_graphs,const std::vector<DeviceContext * > & device_contexts,DeviceContext * default_context,const FuncGraphToKernelGraphGroup & func_graph_to_kernel_graphs)1518 void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<KernelGraphPtr> &kernel_graphs,
1519 const std::vector<DeviceContext *> &device_contexts,
1520 DeviceContext *default_context,
1521 const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
1522 MS_EXCEPTION_IF_NULL(default_context);
1523 if (device_contexts.size() != kernel_graphs.size()) {
1524 MS_LOG(EXCEPTION) << "Invalid device context size:" << device_contexts.size()
1525 << " graph size:" << kernel_graphs.size();
1526 }
1527 mindspore::HashMap<KernelGraphPtr, DeviceContext *> kernel_graph_to_device_context;
1528 for (size_t i = 0; i < kernel_graphs.size(); ++i) {
1529 kernel_graph_to_device_context[kernel_graphs[i]] = device_contexts[i];
1530 }
1531
1532 // Collect the device context type of the parameter in the kernel graph as the type of the real parameters.
1533 for (const auto &func_graph_to_kernel_graph : func_graph_to_kernel_graphs) {
1534 const auto &func_graph = func_graph_to_kernel_graph.first;
1535 MS_EXCEPTION_IF_NULL(func_graph);
1536 std::vector<KernelWithIndex> front_parameters;
1537 for (const auto ¶meter : func_graph->parameters()) {
1538 const auto &abstract = parameter->abstract();
1539 MS_EXCEPTION_IF_NULL(abstract);
1540 for (size_t i = 0; i < common::AnfAlgo::GetOutputNumByAbstract(abstract); ++i) {
1541 (void)front_parameters.emplace_back(parameter, i);
1542 }
1543 }
1544 std::vector<const DeviceContext *> parameter_device_contexts(front_parameters.size(), default_context);
1545 std::map<KernelWithIndex, DeviceContext *> front_parameter_to_device_context;
1546
1547 for (const auto &kernel_graph_group : func_graph_to_kernel_graph.second) {
1548 for (const auto &kernel_graph : kernel_graph_group) {
1549 MS_EXCEPTION_IF_NULL(kernel_graph);
1550 const auto &backend_parameters = kernel_graph->parameters();
1551
1552 for (const auto &backend_parameter : backend_parameters) {
1553 auto front_parameter = KernelWithIndex(kernel_graph->GetFrontAnfByBackendAnf(backend_parameter), 0);
1554 if (front_parameter.first == nullptr) {
1555 front_parameter = kernel_graph->GetElementInTupleBackendFrontIndexMap(backend_parameter);
1556 }
1557 if (front_parameter.first != nullptr && front_parameter.first->isa<Parameter>()) {
1558 front_parameter_to_device_context[front_parameter] = kernel_graph_to_device_context[kernel_graph];
1559 }
1560 }
1561 }
1562 }
1563
1564 for (size_t i = 0; i < front_parameters.size(); ++i) {
1565 const auto &front_parameter = front_parameters[i];
1566 const auto &iter = front_parameter_to_device_context.find(front_parameter);
1567 if (iter != front_parameter_to_device_context.end()) {
1568 parameter_device_contexts[i] = iter->second;
1569 }
1570 }
1571 func_graph_to_device_contexts_[func_graph] = parameter_device_contexts;
1572 }
1573
1574 // If there is no kernel in funcgraph, the parameter uses the default device context type.
1575 MS_EXCEPTION_IF_NULL(root_func_graph_);
1576 MS_EXCEPTION_IF_NULL(root_func_graph_->manager());
1577 FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
1578 for (auto sub_graph : sub_graphs) {
1579 MS_EXCEPTION_IF_NULL(sub_graph);
1580 if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
1581 size_t output_num = 0;
1582 for (const auto ¶meter : sub_graph->parameters()) {
1583 const auto &abstract = parameter->abstract();
1584 MS_EXCEPTION_IF_NULL(abstract);
1585 output_num += common::AnfAlgo::GetOutputNumByAbstract(abstract);
1586 }
1587 func_graph_to_device_contexts_[sub_graph] = std::vector<const DeviceContext *>(output_num, default_context);
1588 }
1589 }
1590 }
1591
ParseDeviceContextForPartialNode(const std::vector<AnfNodePtr> & control_nodes)1592 void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNodePtr> &control_nodes) {
1593 for (const auto &control_node : control_nodes) {
1594 if (!common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
1595 continue;
1596 }
1597
1598 MS_EXCEPTION_IF_NULL(control_node);
1599 const auto &cnode = control_node->cast<CNodePtr>();
1600 MS_EXCEPTION_IF_NULL(cnode);
1601 const auto &inputs = cnode->inputs();
1602 if (inputs.size() <= kPartialFuncGraphPos) {
1603 MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid input size for partial node:" << cnode->DebugString();
1604 }
1605 auto &func_node = inputs[kPartialFuncGraphPos];
1606 // Ignore if the node is 'Partial(DeadNode,)'.
1607 if (IsDeadNode(func_node)) {
1608 MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
1609 continue;
1610 }
1611 // Fetch the funcgraph in partial node.
1612 const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
1613 if (func_graph == nullptr) {
1614 MS_LOG_WITH_NODE(EXCEPTION, func_node)
1615 << "Invalid funcgraph node:" << func_node->DebugString() << " for partial node:" << cnode->DebugString();
1616 }
1617
1618 // Fetch the device contexts for the formal parameters in the funcgraph of partial node.
1619 auto iter = func_graph_to_device_contexts_.find(func_graph);
1620 if (iter == func_graph_to_device_contexts_.end()) {
1621 MS_LOG(EXCEPTION) << "Failed to get device contexts for funcgraph:" << func_graph->ToString();
1622 }
1623
1624 size_t input_num = 0;
1625 for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
1626 MS_EXCEPTION_IF_NULL(inputs[i]);
1627 const auto &abstract = inputs[i]->abstract();
1628 MS_EXCEPTION_IF_NULL(abstract);
1629 input_num += common::AnfAlgo::GetOutputNumByAbstract(abstract);
1630 }
1631 if (input_num > iter->second.size()) {
1632 MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid input num:" << input_num
1633 << " for funcgraph:" << func_graph->ToString()
1634 << " device context size:" << iter->second.size()
1635 << " for partial node:" << cnode->DebugString();
1636 }
1637
1638 // Get the device contexts for the real parameters.
1639 std::vector<const DeviceContext *> device_contexts;
1640 // In partial node, the first input is always a partial, maybe a funcgraph or a partial node, so we need
1641 // to insert an empty device context for it.
1642 (void)device_contexts.emplace_back(nullptr);
1643 for (size_t i = 0; i < input_num; ++i) {
1644 MS_EXCEPTION_IF_NULL(iter->second[i]);
1645 (void)device_contexts.emplace_back(iter->second[i]);
1646 }
1647 control_node_to_device_contexts_[control_node] = device_contexts;
1648 }
1649 }
1650
CollectDeviceContextByDynamicLen(const CNodePtr & cnode,const FuncGraphPtr & func_graph,const std::vector<const DeviceContext * > & parameter_contexts,std::vector<const DeviceContext * > * arg_context)1651 void CollectDeviceContextByDynamicLen(const CNodePtr &cnode, const FuncGraphPtr &func_graph,
1652 const std::vector<const DeviceContext *> ¶meter_contexts,
1653 std::vector<const DeviceContext *> *arg_context) {
1654 MS_EXCEPTION_IF_NULL(cnode);
1655 MS_EXCEPTION_IF_NULL(func_graph);
1656 MS_EXCEPTION_IF_NULL(arg_context);
1657 size_t para_num = func_graph->parameters().size();
1658 size_t arg_num = cnode->size() - 1;
1659 if (arg_num > para_num) {
1660 MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid arg size:" << arg_num << " parameter size:" << para_num
1661 << "for call node:" << cnode->DebugString()
1662 << " funcgraph:" << func_graph->ToString();
1663 }
1664 if (para_num != parameter_contexts.size()) {
1665 MS_LOG(EXCEPTION) << "Invalid parameter context size:" << parameter_contexts.size()
1666 << " parameter size:" << para_num;
1667 }
1668 for (size_t i = para_num - arg_num; i < para_num; ++i) {
1669 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(cnode->input(i + 1)->abstract());
1670 for (size_t j = 0; j < output_num; ++j) {
1671 arg_context->emplace_back(parameter_contexts[0]);
1672 }
1673 }
1674 }
1675
ParseDeviceContextForCallNode(const std::vector<AnfNodePtr> & control_nodes)1676 void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodePtr> &control_nodes) {
1677 for (const auto &control_node : control_nodes) {
1678 MS_EXCEPTION_IF_NULL(control_node);
1679 if (!common::AnfAlgo::IsCallNode(control_node)) {
1680 continue;
1681 }
1682
1683 // Fetch the device contexts of the funcgraph the node called.
1684 const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
1685 if (func_graphs.empty()) {
1686 MS_LOG_WITH_NODE(EXCEPTION, control_node)
1687 << "Failed to get funcgraph by call node:" << control_node->DebugString();
1688 }
1689 const auto &func_graph = *(func_graphs.begin());
1690 MS_EXCEPTION_IF_NULL(func_graph);
1691 auto iter = func_graph_to_device_contexts_.find(func_graph);
1692 if (iter == func_graph_to_device_contexts_.end()) {
1693 MS_LOG(EXCEPTION) << "Failed to get device contexts for funcgraph:" << func_graph->ToString();
1694 }
1695
1696 std::vector<const DeviceContext *> device_contexts;
1697 // In call node, the first input is always a partial, maybe a funcgraph or a partial node, so we need
1698 // to insert an empty device context for it.
1699 (void)device_contexts.emplace_back(nullptr);
1700 const auto &cnode = control_node->cast<CNodePtr>();
1701 MS_EXCEPTION_IF_NULL(cnode);
1702 const auto &inputs = cnode->inputs();
1703 size_t call_input_num = 0;
1704 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1705 MS_EXCEPTION_IF_NULL(inputs[i]);
1706 const auto &abstract = inputs[i]->abstract();
1707 MS_EXCEPTION_IF_NULL(abstract);
1708 call_input_num += common::AnfAlgo::GetOutputNumByAbstract(abstract);
1709 }
1710
1711 if (call_input_num > iter->second.size()) {
1712 MS_LOG(INFO) << "Call input size:" << call_input_num << " context size:" << iter->second.size() << "for funcgraph"
1713 << func_graph->ToString() << " for call node:" << cnode->DebugString();
1714 CollectDeviceContextByDynamicLen(cnode, func_graph, iter->second, &device_contexts);
1715 control_node_to_device_contexts_[control_node] = device_contexts;
1716 continue;
1717 }
1718
1719 // Fetch the device contexts for the real parameters on the call node.
1720 for (size_t i = iter->second.size() - call_input_num; i < iter->second.size(); ++i) {
1721 MS_EXCEPTION_IF_NULL(iter->second[i]);
1722 (void)device_contexts.emplace_back(iter->second[i]);
1723 }
1724 control_node_to_device_contexts_[control_node] = device_contexts;
1725 }
1726 }
1727
FetchDeviceContextByNode(const std::vector<KernelWithIndex> & output_nodes,std::vector<const DeviceContext * > * return_device_contexts,const FuncGraphPtr & func_graph,const DeviceContext * default_context)1728 void ControlNodeParser::FetchDeviceContextByNode(const std::vector<KernelWithIndex> &output_nodes,
1729 std::vector<const DeviceContext *> *return_device_contexts,
1730 const FuncGraphPtr &func_graph, const DeviceContext *default_context) {
1731 MS_EXCEPTION_IF_NULL(return_device_contexts);
1732 for (const auto &output_node : output_nodes) {
1733 MS_EXCEPTION_IF_NULL(output_node.first);
1734 if (output_node.first->isa<Parameter>()) {
1735 // If the output is parameter, get the device context type from the formal parameter.
1736 const auto &iter = find(func_graph->parameters().begin(), func_graph->parameters().end(), output_node.first);
1737 if (iter == func_graph->parameters().end()) {
1738 MS_LOG_WITH_NODE(EXCEPTION, output_node.first)
1739 << "Invalid parameter:" << output_node.first->DebugString() << " for func_graph:" << func_graph->ToString();
1740 }
1741 const auto &func_graph_iter = func_graph_to_device_contexts_.find(func_graph);
1742 if (func_graph_iter == func_graph_to_device_contexts_.end()) {
1743 MS_LOG(EXCEPTION) << "Cannot find device context for funcgraph:" << func_graph->ToString();
1744 }
1745 size_t index = LongToSize(iter - func_graph->parameters().begin());
1746 MS_EXCEPTION_IF_NULL(func_graph_iter->second[index]);
1747 (void)return_device_contexts->emplace_back(func_graph_iter->second[index]);
1748 } else if (output_node.first->isa<ValueNode>()) {
1749 // If the output is parameter, used the default context type.
1750 (void)return_device_contexts->emplace_back(default_context);
1751 } else if (common::AnfAlgo::IsCallNode(output_node.first)) {
1752 // If the output is call node, get the device context type by the output of funcgraph.
1753 const auto &func_graphs = call_node_to_func_graphs_[output_node.first];
1754 std::vector<const DeviceContext *> call_device_contexts;
1755 for (const auto &graph : func_graphs) {
1756 MS_EXCEPTION_IF_NULL(graph);
1757 const auto &node = graph->return_node();
1758 MS_EXCEPTION_IF_NULL(node);
1759 const auto &iter = control_node_to_device_contexts_.find(node);
1760 if (iter != control_node_to_device_contexts_.end()) {
1761 call_device_contexts = iter->second;
1762 break;
1763 }
1764 }
1765 // Since funcgraph has been topo-sorted according to the calling relationship, when there is a call node in
1766 // the output, the output type of the funcgraph called by it should have been determined, if not, an exception
1767 // will be thrown.
1768 if (call_device_contexts.empty() || call_device_contexts.size() <= output_node.second) {
1769 MS_LOG(DEBUG) << "Cannot find device context for call node:" << output_node.first->DebugString()
1770 << " device contexts size:" << call_device_contexts.size() << " index:" << output_node.second;
1771 (void)return_device_contexts->emplace_back(default_context);
1772 } else {
1773 MS_EXCEPTION_IF_NULL(call_device_contexts[output_node.second]);
1774 (void)return_device_contexts->emplace_back(call_device_contexts[output_node.second]);
1775 }
1776 } else if (common::AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial) ||
1777 common::AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimSwitch)) {
1778 (void)return_device_contexts->emplace_back(default_context);
1779 } else if (output_node.first->isa<CNode>()) {
1780 // If the output is a cnode, get the device context type by the kernel.
1781 const auto &iter = front_to_backend_kernels_.find(output_node);
1782 if (iter == front_to_backend_kernels_.end()) {
1783 MS_LOG(DEBUG) << "Cannot find backend kernel for cnode:" << output_node.first->DebugString();
1784 (void)return_device_contexts->emplace_back(default_context);
1785 continue;
1786 }
1787 MS_EXCEPTION_IF_NULL(iter->second.second);
1788 (void)return_device_contexts->emplace_back(iter->second.second);
1789 } else {
1790 MS_LOG_WITH_NODE(EXCEPTION, output_node.first) << "Invalid node for return:" << output_node.first->DebugString();
1791 }
1792 }
1793 }
1794
ParseDeviceContextForReturnNode(const DeviceContext * default_context)1795 void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *default_context) {
1796 MS_EXCEPTION_IF_NULL(default_context);
1797 // Collect the call realationship between funcgraphs.
1798 FuncGraphCallRelation func_graph_call_relation;
1799 for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
1800 const auto &call_node = call_node_to_func_graphs.first;
1801 MS_EXCEPTION_IF_NULL(call_node);
1802 const auto &func_graph = call_node->func_graph();
1803 MS_EXCEPTION_IF_NULL(func_graph);
1804 (void)func_graph_call_relation[func_graph].emplace_back(call_node_to_func_graphs.second);
1805 }
1806
1807 // Topologically sort all funcgraphs according to the function call relationship.
1808 const auto &topo_sort_func_graphs = TopoSortForFuncGraph(root_func_graph_, &func_graph_call_relation);
1809
1810 // Deduces the device context type of funcgraph outputs according to the topological order.
1811 for (const auto &func_graph : topo_sort_func_graphs) {
1812 MS_EXCEPTION_IF_NULL(func_graph);
1813 const auto &return_node = func_graph->return_node();
1814 MS_EXCEPTION_IF_NULL(return_node);
1815 const auto &cnode = return_node->cast<CNodePtr>();
1816 MS_EXCEPTION_IF_NULL(cnode);
1817 const auto &inputs = cnode->inputs();
1818 if (inputs.size() <= kReturnInputPos) {
1819 MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid return node:" << cnode->DebugString();
1820 }
1821 const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
1822 std::vector<const DeviceContext *> return_device_contexts;
1823
1824 FetchDeviceContextByNode(output_nodes, &return_device_contexts, func_graph, default_context);
1825 control_node_to_device_contexts_[return_node] = return_device_contexts;
1826 }
1827 }
1828
ParseFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> & graphs)1829 void ControlNodeParser::ParseFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
1830 for (const auto &graph : graphs) {
1831 MS_EXCEPTION_IF_NULL(graph);
1832 if (graph->execution_order().empty()) {
1833 continue;
1834 }
1835 const auto &front_to_backend_nodes = graph->front_backend_anf_map();
1836 for (const auto &front_to_backend_node : front_to_backend_nodes) {
1837 MS_LOG(DEBUG) << "Add front node:" << front_to_backend_node.first->DebugString()
1838 << " for kernel graph:" << graph->ToString();
1839 front_node_to_kernel_graph_[front_to_backend_node.first] = graph;
1840 }
1841 }
1842 }
1843
FetchBranchIDByCallNode(const AnfNodePtr & call_node)1844 int ControlNodeParser::FetchBranchIDByCallNode(const AnfNodePtr &call_node) {
1845 MS_EXCEPTION_IF_NULL(call_node);
1846
1847 if (call_node_to_branch_id_.find(call_node) == call_node_to_branch_id_.end()) {
1848 MS_LOG_WITH_NODE(EXCEPTION, call_node) << "Invalid branch id for call_node:" << call_node->DebugString();
1849 }
1850 return call_node_to_branch_id_[call_node];
1851 }
1852
FetchKernelGraphByFrontNode(const AnfNodePtr & kernel)1853 KernelGraphPtr ControlNodeParser::FetchKernelGraphByFrontNode(const AnfNodePtr &kernel) {
1854 const auto &iter = front_node_to_kernel_graph_.find(kernel);
1855 if (iter == front_node_to_kernel_graph_.end()) {
1856 return nullptr;
1857 }
1858 return iter->second;
1859 }
1860
IsCallInputKernelGraph(KernelGraph * const graph)1861 bool ControlNodeParser::IsCallInputKernelGraph(KernelGraph *const graph) {
1862 if (call_input_kernel_graphs_.find(graph) == call_input_kernel_graphs_.end()) {
1863 return false;
1864 }
1865 return true;
1866 }
1867
IsCallInputKernelGraphGroup(const std::string & group_name)1868 bool ControlNodeParser::IsCallInputKernelGraphGroup(const std::string &group_name) {
1869 for (const auto &graph_group : kernel_graph_group_infos_) {
1870 MS_EXCEPTION_IF_NULL(graph_group);
1871 if (group_name.find(graph_group->group_name_) != std ::string::npos) {
1872 return graph_group->need_stack_;
1873 }
1874 }
1875 MS_LOG(EXCEPTION) << "Invalid kernel graph group name:" << group_name;
1876 }
1877
FetchBackendNodeByFrontNode(const KernelWithIndex & node_with_index)1878 KernelWithIndex ControlNodeParser::FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index) {
1879 const auto &iter = front_to_backend_kernels_.find(node_with_index);
1880 if (iter != front_to_backend_kernels_.end()) {
1881 return iter->second.first;
1882 }
1883 return {};
1884 }
1885
FetchFuncGraphByKernelGraph(const KernelGraph * const graph)1886 FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *const graph) {
1887 for (const auto &func_graph_to_kernel_graphs : func_graph_to_kernel_graph_groups_) {
1888 const auto &kernel_graph_groups = func_graph_to_kernel_graphs.second;
1889 if (std::any_of(kernel_graph_groups.begin(), kernel_graph_groups.end(), [graph](const auto &kernel_graph_group) {
1890 return std::any_of(kernel_graph_group.begin(), kernel_graph_group.end(),
1891 [graph](const auto &kernel_graph) { return kernel_graph.get() == graph; });
1892 })) {
1893 return func_graph_to_kernel_graphs.first;
1894 }
1895 }
1896 return nullptr;
1897 }
1898
FetchBackendParameterWithContextByFrontParameter(const KernelWithIndex & front_parameter_with_index)1899 NodeWithIndexToContext ControlNodeParser::FetchBackendParameterWithContextByFrontParameter(
1900 const KernelWithIndex &front_parameter_with_index) {
1901 MS_EXCEPTION_IF_NULL(front_parameter_with_index.first);
1902 const auto &iter = front_to_backend_parameters_.find(front_parameter_with_index);
1903 if (iter == front_to_backend_parameters_.end()) {
1904 return {};
1905 }
1906
1907 for (const auto &node_with_index_to_context : iter->second) {
1908 const auto &node = node_with_index_to_context.first.first;
1909 MS_EXCEPTION_IF_NULL(node);
1910 const auto &abstract =
1911 AnfAlgo::GetNodeAbstractByIndex(front_parameter_with_index.first, front_parameter_with_index.second);
1912 bool is_map_parameter = abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>();
1913 if (AnfAlgo::GetOutputTensorMemSize(node, node_with_index_to_context.first.second) != 0 || is_map_parameter) {
1914 return node_with_index_to_context;
1915 }
1916 MS_LOG(DEBUG) << "Backend node:" << node->DebugString()
1917 << " for front node:" << front_parameter_with_index.first->DebugString()
1918 << " index:" << front_parameter_with_index.second << " output size is 0.";
1919 }
1920 return {};
1921 }
1922
CreateDeviceTensors(const std::vector<AnfNodePtr> & control_nodes,const DeviceContext * const default_context)1923 void ControlNodeParser::CreateDeviceTensors(const std::vector<AnfNodePtr> &control_nodes,
1924 const DeviceContext *const default_context) {
1925 for (const auto &control_node : control_nodes) {
1926 MS_EXCEPTION_IF_NULL(control_node);
1927 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
1928 common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
1929 auto input_with_indexs = FetchInputNodeByCNode(control_node);
1930 for (size_t i = 0; i < input_with_indexs.size(); ++i) {
1931 MS_EXCEPTION_IF_NULL(input_with_indexs[i].first);
1932 if (IsFrontValueNode(input_with_indexs[i])) {
1933 CreateDeviceTensorForFrontNode(input_with_indexs[i], default_context);
1934 (void)front_value_nodes_.emplace(input_with_indexs[i], default_context);
1935 }
1936 }
1937 continue;
1938 }
1939
1940 if ((!common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) &&
1941 (!common::AnfAlgo::IsCallNode(control_node))) {
1942 continue;
1943 }
1944
1945 auto input_with_indexs = FetchInputNodeByCNode(control_node);
1946 auto iter = control_node_to_device_contexts_.find(control_node);
1947 if (iter == control_node_to_device_contexts_.end() || iter->second.size() < input_with_indexs.size()) {
1948 MS_LOG_WITH_NODE(EXCEPTION, control_node)
1949 << "Invalid device context for control node:" << control_node->DebugString()
1950 << " need:" << input_with_indexs.size() << " current:"
1951 << (iter == control_node_to_device_contexts_.end() ? "null" : std::to_string(iter->second.size()));
1952 }
1953 for (size_t i = 0; i < input_with_indexs.size(); ++i) {
1954 const auto &input_with_index = input_with_indexs[i];
1955 if (IsFrontValueNode(input_with_index) &&
1956 front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
1957 MS_EXCEPTION_IF_NULL(input_with_index.first);
1958 MS_LOG(DEBUG) << "Create device tensor for value node:" << input_with_index.first->DebugString()
1959 << " index:" << i << " in control node:" << control_node->DebugString();
1960 const auto &node_with_index_with_context = FetchBackendParameterWithContextByFrontParameter(input_with_index);
1961 const auto &backend_node = node_with_index_with_context.first.first;
1962 if (IsValidBackendParameter(backend_node)) {
1963 CreateDeviceTensorForValueNode(input_with_index, backend_node, node_with_index_with_context.second);
1964 (void)front_value_nodes_.emplace(input_with_index, node_with_index_with_context.second);
1965 } else {
1966 CreateDeviceTensorForFrontNode(input_with_index, default_context);
1967 (void)front_value_nodes_.emplace(input_with_index, default_context);
1968 }
1969 }
1970 }
1971 }
1972 }
1973
FetchFrontValueNode(const std::vector<AnfNodePtr> & control_nodes,const DeviceContext * const default_context)1974 void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
1975 const DeviceContext *const default_context) {
1976 MS_EXCEPTION_IF_NULL(default_context);
1977
1978 for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
1979 for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
1980 if (!IsFrontValueNode(real_parameter_with_index)) {
1981 continue;
1982 }
1983
1984 const auto &node_with_index_to_context =
1985 FetchBackendParameterWithContextByFrontParameter(real_parameter_with_index);
1986 const auto &backend_node = node_with_index_to_context.first.first;
1987 if (IsValidBackendParameter(backend_node)) {
1988 (void)front_value_nodes_.emplace(real_parameter_with_index, node_with_index_to_context.second);
1989 CreateDeviceTensorForValueNode(real_parameter_with_index, backend_node, node_with_index_to_context.second);
1990 } else {
1991 (void)front_value_nodes_.emplace(real_parameter_with_index, default_context);
1992 CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context);
1993 }
1994 }
1995 }
1996
1997 // Create device tensors for those value nodes which direct return by a return node.
1998 CreateDeviceTensors(control_nodes, default_context);
1999 for (const auto &front_node : front_value_nodes_) {
2000 MS_EXCEPTION_IF_NULL(front_node.first.first);
2001 MS_LOG(DEBUG) << "Print front value node:" << front_node.first.first->DebugString()
2002 << " addr:" << front_node.first.first << " index:" << front_node.first.second;
2003 }
2004 }
2005
ParseFormalToRealParameter(const std::vector<AnfNodePtr> & control_nodes)2006 void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) {
2007 FormalToRealParameter formal_to_real_parameters;
2008
2009 // The actual parameters of the function are divided into two parts:
2010 // 1. Input of partial node.
2011 // 2. Input of call node.
2012 for (const auto &node : control_nodes) {
2013 MS_EXCEPTION_IF_NULL(node);
2014 if (common::AnfAlgo::IsCallNode(node)) {
2015 const auto &cnode = node->cast<CNodePtr>();
2016 MS_EXCEPTION_IF_NULL(cnode);
2017 const auto &inputs = cnode->inputs();
2018 const auto &func_graphs = FetchFuncGraphbyCallNode(node);
2019 for (const auto &func_graph : func_graphs) {
2020 MS_EXCEPTION_IF_NULL(func_graph);
2021 const auto ¶meters = func_graph->parameters();
2022 for (int i = SizeToInt(inputs.size()) - 1, j = SizeToInt(parameters.size()) - 1; i >= 1 && j >= 0; --i, --j) {
2023 MS_EXCEPTION_IF_NULL(inputs[IntToSize(i)]);
2024 MS_EXCEPTION_IF_NULL(parameters[IntToSize(j)]);
2025 AddFormalToRealParameter(parameters[IntToSize(j)], inputs[IntToSize(i)], call_node_to_func_graphs_,
2026 &formal_to_real_parameters);
2027 }
2028 }
2029 } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
2030 const auto &cnode = node->cast<CNodePtr>();
2031 MS_EXCEPTION_IF_NULL(cnode);
2032 const auto &inputs = cnode->inputs();
2033 if (inputs.size() <= kPartialFuncGraphPos) {
2034 MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid input size for partial node:" << node->DebugString();
2035 }
2036 auto &func_node = inputs[kPartialFuncGraphPos];
2037 MS_EXCEPTION_IF_NULL(func_node);
2038 // Ignore if the node is 'Partial(DeadNode,)'.
2039 if (IsDeadNode(func_node)) {
2040 MS_LOG(DEBUG) << "Ignore partial dead node:" << node->DebugString();
2041 continue;
2042 }
2043 const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
2044 if (func_graph == nullptr) {
2045 MS_LOG_WITH_NODE(EXCEPTION, node)
2046 << "Invalid funcgraph node:" << func_node->DebugString() << " for partial node:" << node->DebugString();
2047 }
2048 const auto ¶meters = func_graph->parameters();
2049 if (inputs.size() - kPartialInputStartPos > parameters.size()) {
2050 MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size()
2051 << " formal parameter size:" << parameters.size();
2052 }
2053 for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
2054 MS_EXCEPTION_IF_NULL(inputs[i]);
2055 MS_EXCEPTION_IF_NULL(parameters[i - kPartialInputStartPos]);
2056 AddFormalToRealParameter(parameters[i - kPartialInputStartPos], inputs[i], call_node_to_func_graphs_,
2057 &formal_to_real_parameters);
2058 }
2059 }
2060 }
2061
2062 // When the real parameter is also a parameter, the corresponding actual parameter needs to be obtained recursively.
2063 for (const auto &formal_to_real_parameter : formal_to_real_parameters) {
2064 const auto &formal_parameter = formal_to_real_parameter.first;
2065 const auto &real_parameters = formal_to_real_parameter.second;
2066 std::set<KernelWithIndex> total_real_parameters = real_parameters;
2067 for (const auto &real_parameter : real_parameters) {
2068 MS_EXCEPTION_IF_NULL(real_parameter.first);
2069 if (real_parameter.first->isa<Parameter>()) {
2070 std::set<KernelWithIndex> invalid_real_parameter{formal_parameter};
2071 ParseAllRealParameterByFormalParameter(real_parameter, formal_to_real_parameters, &total_real_parameters,
2072 &invalid_real_parameter);
2073 (void)real_to_formal_parameters_[real_parameter].emplace(formal_parameter);
2074 } else {
2075 (void)total_real_parameters.emplace(real_parameter);
2076 }
2077 }
2078 std::swap(formal_to_real_parameters_[formal_parameter], total_real_parameters);
2079 }
2080
2081 for (const auto &formal_to_real : formal_to_real_parameters_) {
2082 for (const auto &real_parameter : formal_to_real.second) {
2083 MS_EXCEPTION_IF_NULL(formal_to_real.first.first);
2084 MS_EXCEPTION_IF_NULL(real_parameter.first);
2085 MS_LOG(DEBUG) << "Print formal to real node, formal:" << formal_to_real.first.first->DebugString()
2086 << " real:" << real_parameter.first->DebugString() << " index:" << real_parameter.second;
2087 }
2088 }
2089 }
2090
ParseAllRealParameterByFormalParameter(const KernelWithIndex & formal_parameter,const FormalToRealParameter & formal_to_real_parameters,std::set<KernelWithIndex> * const total_real_parameters,std::set<KernelWithIndex> * invalid_real_parameter)2091 void ControlNodeParser::ParseAllRealParameterByFormalParameter(const KernelWithIndex &formal_parameter,
2092 const FormalToRealParameter &formal_to_real_parameters,
2093 std::set<KernelWithIndex> *const total_real_parameters,
2094 std::set<KernelWithIndex> *invalid_real_parameter) {
2095 MS_EXCEPTION_IF_NULL(formal_parameter.first);
2096 MS_EXCEPTION_IF_NULL(total_real_parameters);
2097 MS_EXCEPTION_IF_NULL(invalid_real_parameter);
2098 if (invalid_real_parameter->find(formal_parameter) != invalid_real_parameter->end()) {
2099 return;
2100 }
2101 (void)invalid_real_parameter->emplace(formal_parameter);
2102
2103 // Get all the actual parameters corresponding to parameter recursively.
2104 const auto &dst_iter = formal_to_real_parameters_.find(formal_parameter);
2105 if (dst_iter != formal_to_real_parameters_.end()) {
2106 total_real_parameters->insert(dst_iter->second.begin(), dst_iter->second.end());
2107 return;
2108 }
2109 const auto &src_iter = formal_to_real_parameters.find(formal_parameter);
2110 if (src_iter == formal_to_real_parameters.end()) {
2111 const auto &func_graph = formal_parameter.first->func_graph();
2112 MS_EXCEPTION_IF_NULL(func_graph);
2113 if (func_graph == root_func_graph_) {
2114 return;
2115 }
2116 MS_LOG(DEBUG) << "Invalid formal parameter:" << formal_parameter.first->DebugString()
2117 << ", maybe there is no call node for funcgraph:"
2118 << (formal_parameter.first->func_graph() == nullptr
2119 ? "null"
2120 : formal_parameter.first->func_graph()->ToString());
2121 return;
2122 }
2123 const auto &real_parameters = src_iter->second;
2124 for (const auto &real_parameter : real_parameters) {
2125 MS_EXCEPTION_IF_NULL(real_parameter.first);
2126 (void)total_real_parameters->emplace(real_parameter);
2127 if (real_parameter.first->isa<Parameter>()) {
2128 ParseAllRealParameterByFormalParameter(real_parameter, formal_to_real_parameters, total_real_parameters,
2129 invalid_real_parameter);
2130 }
2131 }
2132 }
2133
ParseControlNodeParameter(const std::vector<AnfNodePtr> & control_nodes)2134 void ControlNodeParser::ParseControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes) {
2135 for (const auto &control_node : control_nodes) {
2136 MS_EXCEPTION_IF_NULL(control_node);
2137 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
2138 break;
2139 }
2140
2141 const auto &inputs = FetchInputNodeByCNode(control_node);
2142 for (size_t i = 0; i < inputs.size(); ++i) {
2143 MS_EXCEPTION_IF_NULL(inputs[i].first);
2144 MS_LOG(DEBUG) << "Control node:" << control_node->DebugString()
2145 << " input node:" << inputs[i].first->DebugString() << " index:" << inputs[i].second;
2146 if (inputs[i].first->isa<Parameter>()) {
2147 MS_LOG(DEBUG) << "Control node:" << control_node->DebugString()
2148 << " input parameter:" << inputs[i].first->DebugString() << " index:" << inputs[i].second;
2149 (void)control_node_parameters_.emplace_back(inputs[i]);
2150 // Set Dynamic shape flag for parameter.
2151 const auto ¶meter = inputs[i].first->cast<ParameterPtr>();
2152 MS_EXCEPTION_IF_NULL(parameter);
2153 const auto &base_shape = parameter->Shape();
2154 if (base_shape == nullptr) {
2155 continue;
2156 }
2157 if ((base_shape->isa<abstract::Shape>() && base_shape->IsDynamic()) ||
2158 base_shape->isa<abstract::DynamicSequenceShape>()) {
2159 MS_LOG(INFO) << "Set dynamic shape flag to parameter:" << parameter->DebugString();
2160 parameter->set_has_dynamic_shape(true);
2161 }
2162 }
2163 }
2164 }
2165 }
2166
CreateBranchIDForCallNode(const std::vector<AnfNodePtr> & control_nodes)2167 void ControlNodeParser::CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes) {
2168 int branch_id = kMainBranchID;
2169
2170 for (const auto &control_node : control_nodes) {
2171 // Root funcgraph does not need to create a gather actor.
2172 if (common::AnfAlgo::IsCallNode(control_node)) {
2173 call_node_to_branch_id_[control_node] = ++branch_id;
2174 MS_LOG(DEBUG) << "control node:" << control_node->DebugString()
2175 << " branch id:" << call_node_to_branch_id_[control_node];
2176 }
2177 }
2178 }
2179
ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)2180 void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
2181 const std::vector<DeviceContext *> &device_contexts) {
2182 if (graphs.size() != device_contexts.size()) {
2183 MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
2184 }
2185
2186 // Fetch the mapping relationship between front parameters and backend parameters in the kernel graphs.
2187 for (size_t i = 0; i < graphs.size(); ++i) {
2188 const auto &graph = graphs[i];
2189 auto device_context = device_contexts[i];
2190 MS_EXCEPTION_IF_NULL(graph);
2191 MS_EXCEPTION_IF_NULL(device_context);
2192 for (const auto ¶meter : graph->input_nodes()) {
2193 MS_EXCEPTION_IF_NULL(parameter);
2194 const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
2195 const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
2196 const auto &front_tuple_parameter_with_index = graph->GetElementInTupleBackendFrontIndexMap(parameter);
2197 if (front_node == nullptr && front_node_with_index.first == nullptr &&
2198 front_tuple_parameter_with_index.first == nullptr) {
2199 MS_LOG_WITH_NODE(EXCEPTION, parameter)
2200 << "Invalid backend parameter:" << parameter->DebugString() << " for kernel graph:" << graph->ToString();
2201 }
2202
2203 if (front_node_with_index.first != nullptr) {
2204 std::set<KernelWithIndex> real_parameters;
2205 std::set<KernelWithIndex> invalid_call_nodes;
2206 FetchRealParameterByNode(front_node_with_index, &real_parameters, &invalid_call_nodes,
2207 call_node_to_func_graphs_);
2208 for (const auto &real_parameter : real_parameters) {
2209 MS_EXCEPTION_IF_NULL(real_parameter.first);
2210 if (real_parameter.first->isa<Parameter>() || real_parameter.first->isa<ValueNode>()) {
2211 (void)front_to_backend_parameters_[real_parameter].emplace(KernelWithIndex(parameter, 0), device_context);
2212 MS_LOG(DEBUG) << "Add front node:" << real_parameter.first->DebugString()
2213 << " index:" << real_parameter.second
2214 << " for backend parameter:" << parameter->DebugString();
2215 }
2216 }
2217 } else if (front_tuple_parameter_with_index.first != nullptr) {
2218 (void)front_to_backend_parameters_[front_tuple_parameter_with_index].emplace(KernelWithIndex(parameter, 0),
2219 device_context);
2220 } else {
2221 (void)front_to_backend_parameters_[{front_node, 0}].emplace(KernelWithIndex(parameter, 0), device_context);
2222 }
2223 }
2224 }
2225
2226 // Get the corresponding backend node for the real parameter according to the relationship between real
2227 // parameter and formal parameter.
2228 for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
2229 const auto &front_parameter = front_to_backend_parameters.first;
2230 const auto &backend_parameters = front_to_backend_parameters.second;
2231 const auto &iter = formal_to_real_parameters_.find(front_parameter);
2232 if (iter != formal_to_real_parameters_.end()) {
2233 for (const auto &real_parameter_with_index : iter->second) {
2234 const auto &real_parameter = real_parameter_with_index.first;
2235 MS_EXCEPTION_IF_NULL(real_parameter);
2236 if (real_parameter->isa<Parameter>()) {
2237 front_to_backend_parameters_[real_parameter_with_index].insert(backend_parameters.begin(),
2238 backend_parameters.end());
2239 }
2240 }
2241 }
2242 }
2243 for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
2244 for (const auto &backend_parameter : front_to_backend_parameters.second) {
2245 MS_EXCEPTION_IF_NULL(front_to_backend_parameters.first.first);
2246 MS_EXCEPTION_IF_NULL(backend_parameter.first.first);
2247 MS_LOG(DEBUG) << "Print front to backend parameter, front:"
2248 << front_to_backend_parameters.first.first->DebugString()
2249 << " index:" << front_to_backend_parameters.first.second
2250 << " backend:" << backend_parameter.first.first->DebugString()
2251 << " index:" << backend_parameter.first.second << " node addr:" << backend_parameter.first.first;
2252 }
2253 }
2254 }
2255
ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> & control_nodes)2256 void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
2257 for (const auto &control_node : control_nodes) {
2258 MS_EXCEPTION_IF_NULL(control_node);
2259 if (!common::AnfAlgo::IsCallNode(control_node)) {
2260 continue;
2261 }
2262
2263 const auto &belong_func_graph = control_node->func_graph();
2264 MS_EXCEPTION_IF_NULL(belong_func_graph);
2265 (void)func_graph_to_call_nodes_[belong_func_graph].emplace(control_node);
2266
2267 const auto &cnode = control_node->cast<CNodePtr>();
2268 MS_EXCEPTION_IF_NULL(cnode);
2269 const auto &func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
2270 if (func_graphs.empty()) {
2271 MS_LOG(EXCEPTION) << "Get func graphs from abstract failed.";
2272 }
2273 for (auto func_graph : func_graphs) {
2274 (void)call_node_to_func_graphs_[control_node].emplace(func_graph);
2275 }
2276 }
2277 }
2278
FetchFuncGraphbyCallNode(const AnfNodePtr & control_node)2279 const std::set<FuncGraphPtr> &ControlNodeParser::FetchFuncGraphbyCallNode(const AnfNodePtr &control_node) {
2280 MS_EXCEPTION_IF_NULL(control_node);
2281 const auto &iter = call_node_to_func_graphs_.find(control_node);
2282 if (iter == call_node_to_func_graphs_.end()) {
2283 MS_LOG_WITH_NODE(EXCEPTION, control_node) << "Invalid call node:" << control_node->DebugString();
2284 }
2285 return iter->second;
2286 }
2287
ParseFrontToBackendKernel(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)2288 void ControlNodeParser::ParseFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
2289 const std::vector<DeviceContext *> &device_contexts) {
2290 for (size_t i = 0; i < graphs.size(); ++i) {
2291 const auto &graph = graphs[i];
2292 const auto &device_context = device_contexts[i];
2293 MS_EXCEPTION_IF_NULL(graph);
2294 auto execution_order = graph->execution_order();
2295 for (auto &kernel : execution_order) {
2296 auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
2297 if (front_node != nullptr) {
2298 for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) {
2299 front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context};
2300 MS_LOG(DEBUG) << "Add front to backend kernel, front:" << common::AnfAlgo::GetNodeDebugString(front_node)
2301 << "index:" << j << " addr:" << front_node
2302 << " second:" << common::AnfAlgo::GetNodeDebugString(kernel) << "index:" << j
2303 << " addr:" << kernel;
2304 }
2305 }
2306 }
2307
2308 for (const auto &output_pair : graph->front_node_to_graph_output_map()) {
2309 MS_EXCEPTION_IF_NULL(output_pair.second.first);
2310 if (output_pair.second.first->isa<CNode>()) {
2311 front_to_backend_kernels_[output_pair.first] = {output_pair.second, device_context};
2312 }
2313 }
2314 }
2315 for (const auto &front_to_backend_kernels : front_to_backend_kernels_) {
2316 MS_EXCEPTION_IF_NULL(front_to_backend_kernels.first.first);
2317 MS_EXCEPTION_IF_NULL(front_to_backend_kernels.second.first.first);
2318 MS_LOG(DEBUG) << "Print front to backend kernel, front node:" << front_to_backend_kernels.first.first->DebugString()
2319 << " front index:" << front_to_backend_kernels.first.second
2320 << " backend node:" << front_to_backend_kernels.second.first.first->DebugString()
2321 << " backend index:" << front_to_backend_kernels.second.first.second;
2322 }
2323 }
2324
ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> & control_nodes)2325 void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
2326 for (const auto &control_node : control_nodes) {
2327 MS_EXCEPTION_IF_NULL(control_node);
2328 const auto &func_graph = control_node->func_graph();
2329 MS_EXCEPTION_IF_NULL(func_graph);
2330 // In the funcgraph with recursive call node, the call node is marked as level1, and the entrance actor is
2331 // notified to send data after the call node execute ends. At this time, it is necessary to ensure that the
2332 // data of all actors in the graph has been processed, so all control nodes of level0 need link control arrow
2333 // to entrance actor.
2334 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
2335 auto iter = node_to_level_.find(control_node);
2336 if (iter != node_to_level_.end() && iter->second == 0 && (!IsPartialInput(control_node))) {
2337 (void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
2338 }
2339 }
2340
2341 std::set<AnfNodePtr> checked_nodes;
2342 if (((common::AnfAlgo::IsCallNode(control_node) &&
2343 unrecursion_call_nodes_.find(control_node) == unrecursion_call_nodes_.end()) ||
2344 common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) &&
2345 IsFirstControlNode(control_node, &checked_nodes, unrecursion_call_nodes_)) {
2346 (void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
2347 MS_LOG(DEBUG) << "Add first control node:" << control_node->DebugString()
2348 << " for funcgraph:" << func_graph->ToString();
2349 if (!common::AnfAlgo::IsCallNode(control_node)) {
2350 continue;
2351 }
2352
2353 // If there is a recursive call node in the funcgraph, the kernel graph of the topo sort before the call node
2354 // needs to be executed before the call recursion, that is, the kernel graph whose level is less than the call
2355 // node needs to link a control arrow to the corresponding entry actor.
2356 // Fetch the level of control node.
2357 const auto &level_iter = node_to_level_.find(control_node);
2358 if (level_iter == node_to_level_.end()) {
2359 MS_LOG(DEBUG) << "Failed to get level for call node:" << control_node->DebugString();
2360 continue;
2361 }
2362
2363 // Fetch all of the kernel graph group info whose level less than the control node.
2364 const auto &graph_group_iter = func_graph_to_kernel_graph_groups_.find(func_graph);
2365 if (graph_group_iter == func_graph_to_kernel_graph_groups_.end()) {
2366 continue;
2367 }
2368 for (const auto &kernel_graphs : graph_group_iter->second) {
2369 // Fetch one graph from the group.
2370 KernelGraphPtr dst_graph = nullptr;
2371 for (const auto &graph : kernel_graphs) {
2372 MS_EXCEPTION_IF_NULL(graph);
2373 if (graph->execution_order().empty()) {
2374 continue;
2375 }
2376 dst_graph = graph;
2377 break;
2378 }
2379 if (dst_graph == nullptr) {
2380 continue;
2381 }
2382
2383 // Fetch the group info.
2384 const auto &group_info_iter = kernel_graphs_to_group_info_.find(dst_graph);
2385 if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2386 MS_LOG(EXCEPTION) << "Failed to get group info for kernel_graph:" << dst_graph->ToString();
2387 }
2388 MS_EXCEPTION_IF_NULL(group_info_iter->second);
2389 if (group_info_iter->second->level_ < level_iter->second) {
2390 MS_LOG(DEBUG) << "Kernel graph group;" << group_info_iter->second->group_name_
2391 << " need link control to entrance of funcgraph:" << func_graph->ToString();
2392 (void)func_graph_to_first_kernel_graphs_[func_graph].emplace(group_info_iter->second);
2393 }
2394 }
2395 }
2396 }
2397 }
2398
ParseUnRecursionCallNode()2399 void ControlNodeParser::ParseUnRecursionCallNode() {
2400 std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> func_graph_call_relation;
2401 // Collect the call relationship between funcgraphs.
2402 for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
2403 const auto &call_node = call_node_to_func_graphs.first;
2404 MS_EXCEPTION_IF_NULL(call_node);
2405 const auto &func_graph = call_node->func_graph();
2406 MS_EXCEPTION_IF_NULL(func_graph);
2407 func_graph_call_relation[func_graph].insert(call_node_to_func_graphs.second.begin(),
2408 call_node_to_func_graphs.second.end());
2409 }
2410
2411 for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
2412 const auto &call_node = call_node_to_func_graphs.first;
2413 MS_EXCEPTION_IF_NULL(call_node);
2414 const auto &dest_func_graph = call_node->func_graph();
2415 MS_EXCEPTION_IF_NULL(dest_func_graph);
2416 std::set<FuncGraphPtr> exexution_func_graphs;
2417 for (const auto &func_graph : call_node_to_func_graphs.second) {
2418 FetchAllExecutionFunction(func_graph, &exexution_func_graphs, func_graph_call_relation);
2419 }
2420 if (exexution_func_graphs.find(dest_func_graph) == exexution_func_graphs.end()) {
2421 (void)unrecursion_call_nodes_.emplace(call_node);
2422 MS_LOG(DEBUG) << "Add unrecursion call control node:" << call_node->DebugString();
2423 }
2424 }
2425 }
2426
IsCallNodeNeedStack(const AnfNodePtr & node)2427 bool ControlNodeParser::IsCallNodeNeedStack(const AnfNodePtr &node) {
2428 MS_EXCEPTION_IF_NULL(node);
2429 const auto &cnode = node->cast<CNodePtr>();
2430 MS_EXCEPTION_IF_NULL(cnode);
2431 const auto &inputs = cnode->inputs();
2432 std::set<AnfNodePtr> depend_nodes;
2433
2434 // Fetch all the side effect inputs of call node.
2435 for (const auto &input : inputs) {
2436 MS_EXCEPTION_IF_NULL(input);
2437 std::vector<AnfNodePtr> monad_nodes = FetchAllMonadNodeByNode(input);
2438 for (const auto &monad_node : monad_nodes) {
2439 FetchRealDependNodeByAutoMonad(monad_node, &depend_nodes);
2440 }
2441 }
2442
2443 // Fetch all the data inputs of call node.
2444 auto input_with_indexs = FetchInputNodeByCNode(node);
2445 (void)std::for_each(
2446 input_with_indexs.begin(), input_with_indexs.end(),
2447 [&depend_nodes](const auto &input_with_index) { (void)depend_nodes.emplace(input_with_index.first); });
2448
2449 // Check if the call node need a stack.
2450 for (const auto &depend_node : depend_nodes) {
2451 MS_EXCEPTION_IF_NULL(depend_node);
2452 // If the call node has call or recursion graph input, a stack created for the call node is required.
2453 if (!common::AnfAlgo::IsCallNode(depend_node)) {
2454 if (!depend_node->isa<CNode>()) {
2455 continue;
2456 }
2457 const auto &graph = FetchKernelGraphByFrontNode(depend_node);
2458 if (graph == nullptr || (!IsRecursionKernelGraph(graph))) {
2459 continue;
2460 }
2461 }
2462 return true;
2463 }
2464 return false;
2465 }
2466
ParseNeedStackControlNode(const std::vector<AnfNodePtr> & control_nodes)2467 void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes) {
2468 for (const auto &control_node : control_nodes) {
2469 MS_EXCEPTION_IF_NULL(control_node);
2470 if (common::AnfAlgo::IsCallNode(control_node) && IsCallNodeNeedStack(control_node)) {
2471 (void)need_stack_control_nodes_.emplace(control_node);
2472 MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
2473 }
2474 }
2475
2476 for (const auto &control_node : control_nodes) {
2477 MS_EXCEPTION_IF_NULL(control_node);
2478 if (IsInvalidPartial(control_node)) {
2479 continue;
2480 }
2481
2482 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
2483 auto input_with_indexs = FetchInputNodeByCNode(control_node);
2484 size_t call_input_num = 0;
2485 for (auto input_with_index : input_with_indexs) {
2486 if (common::AnfAlgo::IsCallNode(input_with_index.first)) {
2487 ++call_input_num;
2488 }
2489 }
2490
2491 const auto &cnode = control_node->cast<CNodePtr>();
2492 MS_EXCEPTION_IF_NULL(cnode);
2493 const auto &inputs = cnode->inputs();
2494 if (inputs.size() <= kReturnInputPos) {
2495 MS_LOG_WITH_NODE(EXCEPTION, control_node) << "Invalid return node:" << control_node->DebugString();
2496 }
2497
2498 if ((!IsInputInSameLevel(control_node)) ||
2499 (call_input_num != 0 && (common::AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend)))) {
2500 (void)need_stack_control_nodes_.emplace(control_node);
2501 MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
2502 }
2503 } else if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) ||
2504 common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
2505 common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
2506 if (!IsInputInSameLevel(control_node)) {
2507 (void)need_stack_control_nodes_.emplace(control_node);
2508 MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
2509 }
2510 }
2511 }
2512 }
2513
CollectEffectiveInputByGraph(const KernelGraphPtr & graph,const DeviceContext * const device_context,KernelGraphGroupInfo * const kernel_graph_group_info)2514 void CollectEffectiveInputByGraph(const KernelGraphPtr &graph, const DeviceContext *const device_context,
2515 KernelGraphGroupInfo *const kernel_graph_group_info) {
2516 MS_EXCEPTION_IF_NULL(graph);
2517 MS_EXCEPTION_IF_NULL(device_context);
2518 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2519
2520 const auto &outputs = kernel_graph_group_info->front_output_nodes_;
2521 const auto &monad_outputs = kernel_graph_group_info->monad_outputs_;
2522 const auto &real_parameters = graph->input_nodes();
2523 for (const auto ¶meter : real_parameters) {
2524 MS_EXCEPTION_IF_NULL(parameter);
2525 auto front_node_with_index = GetFrontNodeByKernelGraph(parameter, graph.get());
2526 MS_EXCEPTION_IF_NULL(front_node_with_index.first);
2527 // If input come from the output of kernel graph belong the same group, it should not be collected in
2528 // the group inputs.
2529 if (HasAbstractMonad(front_node_with_index.first) || HasAbstractMonad(parameter) ||
2530 outputs.find(front_node_with_index) != outputs.end() || front_node_with_index.first->isa<ValueNode>()) {
2531 // The monad input is used to link the control arrow of the graph. If it comes from other graphs in the same
2532 // group, it is not used as the monad input of the group.
2533 if ((HasAbstractMonad(front_node_with_index.first) || HasAbstractMonad(parameter)) &&
2534 monad_outputs.find(front_node_with_index) == monad_outputs.end()) {
2535 (void)kernel_graph_group_info->monad_inputs_.emplace(front_node_with_index.first);
2536 MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
2537 << " add front monad input node:" << front_node_with_index.first->DebugString();
2538 }
2539 continue;
2540 }
2541 if (common::AnfAlgo::IsCallNode(front_node_with_index.first)) {
2542 kernel_graph_group_info->need_stack_ = true;
2543 }
2544 MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
2545 << " add front input node:" << front_node_with_index.first->DebugString()
2546 << " index:" << front_node_with_index.second << " backend node:" << parameter->DebugString()
2547 << " index:0";
2548 kernel_graph_group_info->front_input_nodes_[front_node_with_index] = device_context;
2549 }
2550 }
2551
CollectEffectiveOutputByGraph(const KernelGraphPtr & graph,DeviceContext * const device_context,FrontToBackendKernelWithContext * const outputs,std::set<KernelWithIndex> * monad_outputs)2552 void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *const device_context,
2553 FrontToBackendKernelWithContext *const outputs,
2554 std::set<KernelWithIndex> *monad_outputs) {
2555 MS_EXCEPTION_IF_NULL(graph);
2556 MS_EXCEPTION_IF_NULL(device_context);
2557 MS_EXCEPTION_IF_NULL(outputs);
2558 MS_EXCEPTION_IF_NULL(monad_outputs);
2559
2560 for (const auto &front_to_backend : graph->front_node_to_graph_output_map()) {
2561 MS_EXCEPTION_IF_NULL(front_to_backend.first.first);
2562 MS_EXCEPTION_IF_NULL(front_to_backend.second.first);
2563 if (HasAbstractMonad(front_to_backend.second.first) || HasAbstractMonad(front_to_backend.first.first) ||
2564 front_to_backend.second.first->isa<Parameter>() ||
2565 common::AnfAlgo::CheckPrimitiveType(front_to_backend.first.first, prim::kPrimPartial) ||
2566 front_to_backend.first.first->isa<ValueNode>()) {
2567 if (HasAbstractMonad(front_to_backend.first.first) || HasAbstractMonad(front_to_backend.second.first)) {
2568 MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString() << " add monad output node:"
2569 << (front_to_backend.first.first != nullptr ? front_to_backend.first.first->DebugString()
2570 : "null")
2571 << " index:" << front_to_backend.first.second;
2572 (void)monad_outputs->emplace(front_to_backend.first);
2573 }
2574 continue;
2575 }
2576
2577 // Skip the function input.
2578 const auto &abstract = front_to_backend.first.first->abstract();
2579 MS_EXCEPTION_IF_NULL(abstract);
2580 const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, front_to_backend.first.second);
2581 MS_EXCEPTION_IF_NULL(real_abstract);
2582 if (real_abstract->isa<abstract::AbstractFunction>()) {
2583 continue;
2584 }
2585
2586 MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
2587 << " add front output node:" << front_to_backend.first.first->DebugString()
2588 << " index:" << front_to_backend.first.second
2589 << " backend node:" << front_to_backend.second.first->DebugString()
2590 << " full name:" << front_to_backend.second.first->fullname_with_scope()
2591 << " index:" << front_to_backend.second.second;
2592 (*outputs)[front_to_backend.first] = {front_to_backend.second, device_context};
2593 }
2594 }
2595
ParseKernelGraphGroup(const KernelGraphToDeviceContext & kernel_graph_to_device_contexts)2596 void ControlNodeParser::ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts) {
2597 for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
2598 for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
2599 if (kernel_graph_group.empty()) {
2600 continue;
2601 }
2602
2603 KernelGraphGroupInfoPtr kernel_graph_group_info = std::make_shared<KernelGraphGroupInfo>();
2604 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2605 for (const auto &kernel_graph : kernel_graph_group) {
2606 MS_EXCEPTION_IF_NULL(kernel_graph);
2607 if (kernel_graph->execution_order().empty()) {
2608 continue;
2609 }
2610 auto iter = kernel_graph_to_device_contexts.find(kernel_graph);
2611 if (iter == kernel_graph_to_device_contexts.end()) {
2612 MS_LOG(EXCEPTION) << "Failed to find device context for kernel graph:" << kernel_graph->ToString();
2613 }
2614 // Collect kernel graphs in group.
2615 (void)kernel_graph_group_info->graphs_.emplace(kernel_graph);
2616
2617 // Collect inputs in group.
2618 CollectEffectiveInputByGraph(kernel_graph, iter->second, kernel_graph_group_info.get());
2619
2620 // Collect outputs in group.
2621 CollectEffectiveOutputByGraph(kernel_graph, iter->second, &(kernel_graph_group_info->front_output_nodes_),
2622 &(kernel_graph_group_info->monad_outputs_));
2623
2624 kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info;
2625 }
2626 kernel_graph_group_info->group_name_ = "kernel_graph";
2627 for (const auto &graph : kernel_graph_group_info->graphs_) {
2628 if (kernel_graph_group_info->need_stack_) {
2629 MS_LOG(DEBUG) << "Add call input kernel graph:" << graph->ToString();
2630 (void)call_input_kernel_graphs_.emplace(graph.get());
2631 }
2632 kernel_graph_group_info->group_name_ += ("_" + std::to_string(graph->graph_id()));
2633 }
2634 MS_LOG(DEBUG) << "Add kernel graph info for group:" << kernel_graph_group_info->group_name_;
2635 (void)kernel_graph_group_infos_.emplace(kernel_graph_group_info);
2636 }
2637 }
2638 }
2639
ParseControlNodeLevel(const AnfNodePtr & node,std::set<AnfNodePtr> * checked_nodes)2640 size_t ControlNodeParser::ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes) {
2641 MS_EXCEPTION_IF_NULL(node);
2642 MS_EXCEPTION_IF_NULL(checked_nodes);
2643 if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
2644 return 0;
2645 }
2646 (void)checked_nodes->emplace(node);
2647
2648 auto iter = node_to_level_.find(node);
2649 if (iter != node_to_level_.end()) {
2650 return iter->second;
2651 }
2652
2653 size_t level = 0;
2654 const auto &kernel_graph = FetchKernelGraphByFrontNode(node);
2655 if (kernel_graph == nullptr) {
2656 // If the kernel graph is not found, it means that the input does not come from the kernel graph, then
2657 // just continue to traverse the input.
2658 const auto &cnode = node->cast<CNodePtr>();
2659 MS_EXCEPTION_IF_NULL(cnode);
2660 const auto &inputs = cnode->inputs();
2661 for (const auto &input : inputs) {
2662 size_t tmp_level = ParseControlNodeLevel(input, checked_nodes);
2663 level = (tmp_level > level ? tmp_level : level);
2664 }
2665 return level;
2666 }
2667
2668 // If the input comes from the kernel graph, you need to check all the graph's input, not just the node's input.
2669 auto group_info_iter = kernel_graphs_to_group_info_.find(kernel_graph);
2670 if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2671 MS_LOG(EXCEPTION) << "Failed to get kernel graph group info for graph:" << kernel_graph->ToString();
2672 }
2673 MS_EXCEPTION_IF_NULL(group_info_iter->second);
2674 const auto &inputs = group_info_iter->second->front_input_nodes_;
2675 for (const auto &input : inputs) {
2676 const auto &input_node = input.first.first;
2677 size_t tmp_level = ParseControlNodeLevel(input_node, checked_nodes);
2678 level = (tmp_level > level ? tmp_level : level);
2679 }
2680 return level;
2681 }
2682
2683 namespace {
GetRealOutputNode(const KernelWithIndex & front_pair,const KernelWithIndex & backend_pair)2684 AnfNodePtr GetRealOutputNode(const KernelWithIndex &front_pair, const KernelWithIndex &backend_pair) {
2685 if (front_pair.first == nullptr || backend_pair.first == nullptr) {
2686 return nullptr;
2687 }
2688 if (common::AnfAlgo::CheckPrimitiveType(backend_pair.first, prim::kPrimLoad) &&
2689 common::AnfAlgo::CheckPrimitiveType(front_pair.first, prim::kPrimLoad)) {
2690 const auto &backend_cnode = backend_pair.first->cast<CNodePtr>();
2691 const auto &front_cnode = front_pair.first->cast<CNodePtr>();
2692 MS_EXCEPTION_IF_NULL(backend_cnode);
2693 MS_EXCEPTION_IF_NULL(front_cnode);
2694 if (backend_cnode->inputs().size() > 1 && backend_cnode->input(1) != nullptr &&
2695 backend_cnode->input(1)->isa<CNode>() && front_cnode->inputs().size() > 1 && front_cnode->input(1) != nullptr &&
2696 front_cnode->input(1)->isa<CNode>()) {
2697 return front_cnode->input(1);
2698 }
2699 }
2700 return nullptr;
2701 }
2702 } // namespace
2703
ParseNodeLevel(const std::vector<AnfNodePtr> & control_nodes)2704 void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
2705 size_t level = 0;
2706 // 1. Parse levels of control nodes.
2707 for (const auto &control_node : control_nodes) {
2708 MS_EXCEPTION_IF_NULL(control_node);
2709 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
2710 node_to_level_[control_node] = level;
2711 MS_LOG(DEBUG) << "Add level:" << level << " for node:" << control_node->DebugString();
2712 level = 0;
2713 const auto &func_graph = control_node->func_graph();
2714 MS_EXCEPTION_IF_NULL(func_graph);
2715 const auto ¶meters = func_graph->parameters();
2716 for (const auto ¶meter : parameters) {
2717 MS_EXCEPTION_IF_NULL(parameter);
2718 MS_LOG(DEBUG) << "Add level:" << level << " for node:" << parameter->DebugString();
2719 node_to_level_[parameter] = level;
2720 }
2721 continue;
2722 } else if (IsRecursionCallNode(control_node)) {
2723 ++level;
2724 MS_LOG(DEBUG) << "Add level:" << level << " for node:" << control_node->DebugString();
2725 node_to_level_[control_node] = level;
2726 } else {
2727 std::set<AnfNodePtr> checked_nodes;
2728 node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes);
2729 MS_LOG(DEBUG) << "Add level:" << node_to_level_[control_node] << " for node:" << control_node->DebugString();
2730 }
2731 }
2732
2733 // 2. Parse the levels of kernel graph outputs.
2734 for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
2735 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2736 level = 0;
2737 for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
2738 const auto &input_node = front_input_node.first.first;
2739 auto iter = node_to_level_.find(input_node);
2740 if (iter != node_to_level_.end() && level < iter->second) {
2741 level = iter->second;
2742 }
2743 }
2744 for (const auto &front_output_node : kernel_graph_group_info->front_output_nodes_) {
2745 MS_EXCEPTION_IF_NULL(front_output_node.second.first.first);
2746 if (front_output_node.second.first.first->isa<Parameter>()) {
2747 continue;
2748 }
2749 const auto &output_node = front_output_node.first.first;
2750 MS_EXCEPTION_IF_NULL(output_node);
2751 MS_LOG(DEBUG) << "Add level:" << level << " for node:" << output_node->DebugString();
2752 node_to_level_[output_node] = level;
2753 const auto &real_output_node = GetRealOutputNode(front_output_node.first, front_output_node.second.first);
2754 if (real_output_node != nullptr && node_to_level_.find(real_output_node) == node_to_level_.end()) {
2755 node_to_level_[real_output_node] = level;
2756 }
2757 }
2758 }
2759
2760 // Parse the levels of kernel graph groups.
2761 for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
2762 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2763 size_t max_level = 0;
2764 for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
2765 const auto &input_node = front_input_node.first.first;
2766 MS_EXCEPTION_IF_NULL(input_node);
2767 auto iter = node_to_level_.find(input_node);
2768 if (iter == node_to_level_.end()) {
2769 MS_LOG_WITH_NODE(EXCEPTION, input_node) << "Failed to get input node:" << input_node->DebugString()
2770 << " for kernel graph:" << kernel_graph_group_info->group_name_;
2771 }
2772 max_level = (max_level > iter->second ? max_level : iter->second);
2773 }
2774 if (max_level > 0) {
2775 kernel_graph_group_info->need_stack_ = true;
2776 kernel_graph_group_info->level_ = max_level;
2777 for (const auto &kernel_graph : kernel_graph_group_info->graphs_) {
2778 (void)call_input_kernel_graphs_.emplace(kernel_graph.get());
2779 }
2780 }
2781 MS_LOG(DEBUG) << "Kernel graph group:" << kernel_graph_group_info->group_name_
2782 << " need stack:" << kernel_graph_group_info->need_stack_
2783 << " level:" << kernel_graph_group_info->level_;
2784 }
2785 }
2786
IsInputInSameLevel(const AnfNodePtr & node)2787 bool ControlNodeParser::IsInputInSameLevel(const AnfNodePtr &node) {
2788 MS_EXCEPTION_IF_NULL(node);
2789 if (!node->isa<CNode>()) {
2790 return true;
2791 }
2792
2793 auto input_with_indexes = FetchInputNodeByCNode(node);
2794 size_t level = SIZE_MAX;
2795 for (const auto &input_with_index : input_with_indexes) {
2796 auto input_node = input_with_index.first;
2797 MS_EXCEPTION_IF_NULL(input_node);
2798 if (input_node->isa<ValueNode>()) {
2799 continue;
2800 }
2801 auto iter = node_to_level_.find(input_node);
2802 if (iter == node_to_level_.end()) {
2803 MS_LOG_WITH_NODE(EXCEPTION, node) << "Failed to find input:" << input_node->DebugString()
2804 << " for node:" << node->DebugString() << " in graph output map.";
2805 }
2806 if (level == SIZE_MAX) {
2807 level = iter->second;
2808 continue;
2809 }
2810 if (level != iter->second) {
2811 return false;
2812 }
2813 }
2814 return true;
2815 }
2816
CreateDeviceTensorForRootGraphParameter(DeviceContext * const default_context)2817 void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context) {
2818 MS_EXCEPTION_IF_NULL(default_context);
2819 for (const auto ¶meter : root_graph_parameters_) {
2820 MS_EXCEPTION_IF_NULL(parameter);
2821 const auto &abstract = parameter->abstract();
2822 MS_EXCEPTION_IF_NULL(abstract);
2823 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
2824 for (size_t i = 0; i < output_num; ++i) {
2825 KernelWithIndex parameter_with_index(parameter, i);
2826 if (front_to_backend_parameters_.find(parameter_with_index) == front_to_backend_parameters_.end()) {
2827 MS_LOG(DEBUG) << "Create device tensor for root graph parameter:" << parameter->DebugString();
2828 CreateDeviceTensorForFrontNode(parameter_with_index, default_context);
2829 (void)front_to_backend_parameters_[parameter_with_index].emplace(parameter_with_index, default_context);
2830 }
2831 }
2832 }
2833 }
2834
FetchGroupNameByKernelGraph(const KernelGraphPtr & graph)2835 std::string ControlNodeParser::FetchGroupNameByKernelGraph(const KernelGraphPtr &graph) {
2836 MS_EXCEPTION_IF_NULL(graph);
2837 auto group_info_iter = kernel_graphs_to_group_info_.find(graph);
2838 if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2839 MS_LOG(EXCEPTION) << "Failed to get kernel graph group info for graph:" << graph->ToString();
2840 }
2841 MS_EXCEPTION_IF_NULL(group_info_iter->second);
2842 return group_info_iter->second->group_name_;
2843 }
2844
FetchBackendOutputByKernelGraph(const KernelGraphPtr & graph,const KernelWithIndex & front_node_with_index)2845 KernelWithIndex ControlNodeParser::FetchBackendOutputByKernelGraph(const KernelGraphPtr &graph,
2846 const KernelWithIndex &front_node_with_index) {
2847 MS_EXCEPTION_IF_NULL(graph);
2848 auto group_info_iter = kernel_graphs_to_group_info_.find(graph);
2849 if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2850 MS_LOG(WARNING) << "Failed to get kernel graph group info for graph:" << graph->ToString();
2851 return {nullptr, 0};
2852 }
2853 MS_EXCEPTION_IF_NULL(group_info_iter->second);
2854 const auto &output_iter = group_info_iter->second->front_output_nodes_.find(front_node_with_index);
2855 if (output_iter != group_info_iter->second->front_output_nodes_.end()) {
2856 return output_iter->second.first;
2857 }
2858 const auto &backend_iter = std::find_if(
2859 group_info_iter->second->front_output_nodes_.begin(), group_info_iter->second->front_output_nodes_.end(),
2860 [front_node_with_index](const auto &pair) {
2861 return front_node_with_index == common::AnfAlgo::VisitKernelWithReturnType(pair.first.first, pair.first.second);
2862 });
2863 if (backend_iter == group_info_iter->second->front_output_nodes_.end()) {
2864 return {nullptr, 0};
2865 }
2866 return common::AnfAlgo::VisitKernelWithReturnType(backend_iter->second.first.first,
2867 backend_iter->second.first.second);
2868 }
2869
PrintParseInfo()2870 void ControlNodeParser::PrintParseInfo() {
2871 for (const auto &group : kernel_graph_group_infos_) {
2872 MS_EXCEPTION_IF_NULL(group);
2873 for (const auto &input_pair : group->front_input_nodes_) {
2874 if (input_pair.first.first != nullptr) {
2875 MS_LOG(WARNING) << "Kernel graph group:" << group->group_name_
2876 << " input node:" << input_pair.first.first->fullname_with_scope()
2877 << " debug string:" << input_pair.first.first->DebugString(kDebugStrDepthTwo)
2878 << " index:" << input_pair.first.second;
2879 }
2880 }
2881 for (const auto &output_pair : group->front_output_nodes_) {
2882 if (output_pair.first.first != nullptr && output_pair.second.first.first != nullptr) {
2883 MS_LOG(WARNING) << "Kernel graph group:" << group->group_name_
2884 << " output node:" << output_pair.first.first->fullname_with_scope()
2885 << " debug string:" << output_pair.first.first->DebugString(kDebugStrDepthTwo)
2886 << " index:" << output_pair.first.second
2887 << " backend node:" << output_pair.second.first.first->fullname_with_scope()
2888 << " debug string:" << output_pair.second.first.first->DebugString(kDebugStrDepthTwo)
2889 << " index:" << output_pair.second.first.second;
2890 }
2891 }
2892 }
2893 for (const auto &f_to_b : front_to_backend_kernels_) {
2894 if (f_to_b.first.first != nullptr && f_to_b.second.first.first != nullptr) {
2895 MS_LOG(WARNING) << "Front to backend map front node:" << f_to_b.first.first->fullname_with_scope()
2896 << " debug string:" << f_to_b.first.first->DebugString(kDebugStrDepthTwo)
2897 << " index:" << f_to_b.first.second
2898 << " backend node:" << f_to_b.second.first.first->fullname_with_scope()
2899 << " debug string:" << f_to_b.second.first.first->DebugString(kDebugStrDepthTwo)
2900 << " index:" << f_to_b.second.first.second;
2901 }
2902 }
2903 for (const auto &pair : front_node_to_kernel_graph_) {
2904 if (pair.first != nullptr && pair.second == nullptr) {
2905 MS_LOG(WARNING) << "Front node:" << pair.first->fullname_with_scope()
2906 << " debug string:" << pair.first->DebugString(kDebugStrDepthTwo)
2907 << " to kernel graph:" << pair.second->ToString();
2908 }
2909 }
2910 }
2911 } // namespace runtime
2912 } // namespace mindspore
2913