• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <unordered_map>
18 #include <functional>
19 #include <map>
20 #include "runtime/graph_scheduler/control_node_parser.h"
21 #include "mindspore/core/ops/sparse_tensor_ops.h"
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "runtime/graph_scheduler/actor/actor_common.h"
25 #include "runtime/device/device_address_utils.h"
26 #include "include/common/utils/convert_utils.h"
27 #include "abstract/utils.h"
28 #include "utils/ms_context.h"
29 #include "ir/tensor.h"
30 #include "abstract/abstract_function.h"
31 #include "include/common/debug/anf_ir_dump.h"
32 
33 namespace mindspore {
34 namespace runtime {
35 namespace {
36 constexpr auto kDebugStrDepthTwo = 2;
37 // Check if node is a value node need to create a device tensor.
IsFrontValueNode(const KernelWithIndex & node_with_index)38 bool IsFrontValueNode(const KernelWithIndex &node_with_index) {
39   const auto &node = node_with_index.first;
40   MS_EXCEPTION_IF_NULL(node);
41   if (!node->isa<ValueNode>() || IsValueNode<FuncGraph>(node) || IsValueNode<Primitive>(node)) {
42     return false;
43   }
44 
45   return true;
46 }
47 
48 // Fetch real input node in maketuple.
FetchRealInputNode(const KernelWithIndex & node_with_index)49 KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) {
50   const auto &node = node_with_index.first;
51   MS_EXCEPTION_IF_NULL(node);
52   if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
53     return node_with_index;
54   }
55 
56   const auto &abstract = node->abstract();
57   MS_EXCEPTION_IF_NULL(abstract);
58   size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
59   if (output_num <= node_with_index.second) {
60     MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid index:" << node_with_index.second
61                                       << "for tuple node:" << node->DebugString();
62   }
63 
64   const auto &cnode = node->cast<CNodePtr>();
65   MS_EXCEPTION_IF_NULL(cnode);
66   const auto &inputs = cnode->inputs();
67   size_t real_index = node_with_index.second;
68   for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
69     MS_EXCEPTION_IF_NULL(inputs[i]);
70     const auto &sub_abstract = inputs[i]->abstract();
71     MS_EXCEPTION_IF_NULL(sub_abstract);
72     size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
73     // If it is not the output of node, need to subtract the number of inputs of it.
74     if (real_index >= tmp_index) {
75       real_index -= tmp_index;
76       continue;
77     }
78     return {inputs[i], real_index};
79   }
80   MS_LOG_WITH_NODE(EXCEPTION, node) << "Failed to get real output from node:" << node->DebugString()
81                                     << " index:" << node_with_index.second;
82 }
83 
84 // Fetch all the output index in the sub-abstract of abstract.
FetchRealIndexByAbstract(const AbstractBasePtr & abstract,std::vector<size_t> * const indexes)85 std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::vector<size_t> *const indexes) {
86   MS_EXCEPTION_IF_NULL(abstract);
87   MS_EXCEPTION_IF_NULL(indexes);
88   AbstractBasePtr dst_abstract = abstract;
89   size_t pre_abstract_num = 0;
90   std::set<size_t> output_indexs;
91   if (indexes->empty()) {
92     size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
93     for (size_t i = 0; i < output_num; ++i) {
94       (void)output_indexs.emplace(i);
95     }
96     return output_indexs;
97   }
98 
99   size_t index = indexes->back();
100   indexes->pop_back();
101 
102   // Fetch the dest abstract by index, and the abstracts num before the dest abstract.
103   if (abstract->isa<abstract::AbstractSequence>()) {
104     auto sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
105     MS_EXCEPTION_IF_NULL(sequence_abstract);
106     const auto &sub_abstracts = sequence_abstract->elements();
107     if (sub_abstracts.size() <= index) {
108       MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
109     }
110     for (size_t i = 0; i < index; ++i) {
111       pre_abstract_num += common::AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]);
112     }
113     dst_abstract = sub_abstracts[index];
114   } else {
115     if (index != 0) {
116       MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
117     }
118   }
119   MS_EXCEPTION_IF_NULL(dst_abstract);
120 
121   // Fetch real output index.
122   auto tmp_indexs = FetchRealIndexByAbstract(dst_abstract, indexes);
123   for (auto tmp_index : tmp_indexs) {
124     (void)output_indexs.emplace(tmp_index + pre_abstract_num);
125   }
126   return output_indexs;
127 }
128 
129 // Get all the real parameters corresponding to node.
FetchRealParameterByNode(const KernelWithIndex & node,std::set<KernelWithIndex> * const real_parameters,std::set<KernelWithIndex> * invalid_call_nodes,const mindspore::HashMap<AnfNodePtr,std::set<FuncGraphPtr>> & call_node_to_func_graphs)130 void FetchRealParameterByNode(const KernelWithIndex &node, std::set<KernelWithIndex> *const real_parameters,
131                               std::set<KernelWithIndex> *invalid_call_nodes,
132                               const mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr>> &call_node_to_func_graphs) {
133   MS_EXCEPTION_IF_NULL(node.first);
134   MS_EXCEPTION_IF_NULL(real_parameters);
135   MS_EXCEPTION_IF_NULL(invalid_call_nodes);
136   MS_LOG(DEBUG) << "Fetch real parameter by node:" << node.first->DebugString() << " index:" << node.second;
137   auto node_with_index = common::AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
138   MS_EXCEPTION_IF_NULL(node_with_index.first);
139   if (node_with_index.first->isa<ValueNode>() || node_with_index.first->isa<Parameter>()) {
140     // If node is a valuenode or parameter, the real parameter is itself.
141     MS_LOG(DEBUG) << "Add real parameter:" << node_with_index.first->DebugString()
142                   << " index:" << node_with_index.second;
143     (void)real_parameters->emplace(node_with_index);
144   } else if (common::AnfAlgo::IsCallNode(node_with_index.first)) {
145     // If node is a call node, the real parameters are the outputs of funcgraph the node called.
146     if (invalid_call_nodes->find(node_with_index) != invalid_call_nodes->end()) {
147       return;
148     }
149     (void)invalid_call_nodes->emplace(node_with_index);
150     const auto &iter = call_node_to_func_graphs.find(node_with_index.first);
151     if (iter == call_node_to_func_graphs.end()) {
152       MS_LOG(DEBUG) << "Invalid call node:" << node_with_index.first->DebugString();
153       return;
154     }
155     const auto &func_graphs = iter->second;
156     for (const auto &func_graph : func_graphs) {
157       MS_EXCEPTION_IF_NULL(func_graph);
158       FetchRealParameterByNode({func_graph->output(), node_with_index.second}, real_parameters, invalid_call_nodes,
159                                call_node_to_func_graphs);
160     }
161   } else if (common::AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) {
162     // If node is a maketuple node, the real parameters are its total inputs.
163     const auto &real_input = FetchRealInputNode(node_with_index);
164     MS_EXCEPTION_IF_NULL(real_input.first);
165     MS_LOG(DEBUG) << "Real input node:" << real_input.first->DebugString() << " index:" << real_input.second
166                   << " for tuple node:" << node_with_index.first->DebugString() << " index:" << node_with_index.second;
167     FetchRealParameterByNode(real_input, real_parameters, invalid_call_nodes, call_node_to_func_graphs);
168   } else if (common::AnfAlgo::CheckPrimitiveType(node.first, prim::kPrimSwitch)) {
169     // If node is a switch node, the real parameters are its both true and false branches.
170     const auto cnode = node_with_index.first->cast<CNodePtr>();
171     MS_EXCEPTION_IF_NULL(cnode);
172     const auto inputs = cnode->inputs();
173     for (size_t i = kSwitchTrueBranchPos; i < inputs.size(); ++i) {
174       FetchRealParameterByNode({inputs[i], 0}, real_parameters, invalid_call_nodes, call_node_to_func_graphs);
175     }
176   } else if (common::AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitchLayer)) {
177     // If node is a switchlyaer node, the real parameters are its total branches.
178     const auto &switch_layer_cnode = node_with_index.first->cast<CNodePtr>();
179     MS_EXCEPTION_IF_NULL(switch_layer_cnode);
180     const auto &switch_layer_inputs = switch_layer_cnode->inputs();
181     if (switch_layer_inputs.size() != kSwitchLayerInputNum ||
182         (!common::AnfAlgo::CheckPrimitiveType(switch_layer_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple))) {
183       MS_LOG_WITH_NODE(EXCEPTION, switch_layer_cnode)
184         << "Invalid switch layer node:" << switch_layer_cnode->DebugString();
185     }
186     const auto &make_tuple_cnode = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>();
187     MS_EXCEPTION_IF_NULL(make_tuple_cnode);
188     const auto &make_tuple_inputs = make_tuple_cnode->inputs();
189     for (size_t i = kSwitchTrueBranchPos; i < make_tuple_inputs.size(); ++i) {
190       FetchRealParameterByNode({make_tuple_inputs[i], 0}, real_parameters, invalid_call_nodes,
191                                call_node_to_func_graphs);
192     }
193   } else {
194     // If node is a kernel, the real parameter is itself.
195     MS_LOG(DEBUG) << "Add real parameter:" << node_with_index.first->DebugString()
196                   << " index:" << node_with_index.second;
197     (void)real_parameters->emplace(node_with_index);
198   }
199 }
200 
201 // Topologically sort all funcgraphs according to the function call relationship.
TopoSortForFuncGraph(const FuncGraphPtr & root,FuncGraphCallRelation * const edges)202 std::vector<FuncGraphPtr> TopoSortForFuncGraph(const FuncGraphPtr &root, FuncGraphCallRelation *const edges) {
203   MS_EXCEPTION_IF_NULL(root);
204   MS_EXCEPTION_IF_NULL(edges);
205   MS_EXCEPTION_IF_NULL(root->manager());
206   std::set<FuncGraphPtr> nodes;
207   (void)nodes.emplace(root);
208 
209   FuncGraphSet subs = root->manager()->func_graphs();
210   for (auto sub : subs) {
211     if (sub != root) {
212       (void)nodes.emplace(sub);
213     }
214   }
215 
216   std::queue<FuncGraphPtr> que;
217   for (const auto &node : nodes) {
218     if (edges->find(node) == edges->end()) {
219       que.push(node);
220     }
221   }
222 
223   std::vector<FuncGraphPtr> result;
224   while (!que.empty()) {
225     const auto node = que.front();
226     que.pop();
227     (void)result.emplace_back(node);
228     for (auto iter = edges->begin(); iter != edges->end();) {
229       auto &sub_edges = iter->second;
230       for (auto sub_iter = sub_edges.begin(); sub_iter != sub_edges.end();) {
231         if (sub_iter->find(node) != sub_iter->end()) {
232           sub_iter = sub_edges.erase(sub_iter);
233         } else {
234           ++sub_iter;
235         }
236       }
237       if (sub_edges.empty()) {
238         que.push(iter->first);
239         iter = edges->erase(iter);
240       } else {
241         ++iter;
242       }
243     }
244   }
245 
246   return result;
247 }
248 
FetchTypeIdByNode(const AnfNodePtr & node,size_t index)249 TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
250   MS_EXCEPTION_IF_NULL(node);
251   TypeId type_id = kTypeUnknown;
252   if (node->isa<ValueNode>() && node->abstract() != nullptr) {
253     // For valuenode, fetch type from abstract.
254     const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
255     MS_EXCEPTION_IF_NULL(abs);
256     const auto &type = abs->BuildType();
257     MS_EXCEPTION_IF_NULL(type);
258     if (type->isa<TensorType>()) {
259       const auto &tensor_type = type->cast<TensorTypePtr>();
260       MS_EXCEPTION_IF_NULL(tensor_type);
261       const auto &element = tensor_type->element();
262       MS_EXCEPTION_IF_NULL(element);
263       type_id = element->type_id();
264     } else if (common::AnfAlgo::IsDynamicSequence(node)) {
265       const auto &sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
266       MS_EXCEPTION_IF_NULL(sequence_abs);
267       if (sequence_abs->dynamic_len_element_abs() == nullptr) {
268         type_id = type->type_id();
269       } else {
270         if (sequence_abs->dynamic_len_element_abs()->isa<abstract::AbstractTensor>()) {
271           const auto &tensor_abs = sequence_abs->dynamic_len_element_abs()->cast<abstract::AbstractTensorPtr>();
272           MS_EXCEPTION_IF_NULL(tensor_abs);
273           MS_EXCEPTION_IF_NULL(tensor_abs->element());
274           const auto &tensor_element_type = tensor_abs->element()->BuildType();
275           MS_EXCEPTION_IF_NULL(tensor_element_type);
276           return tensor_element_type->type_id();
277         }
278         const auto &element_type = sequence_abs->dynamic_len_element_abs()->BuildType();
279         MS_EXCEPTION_IF_NULL(element_type);
280         type_id = element_type->type_id();
281       }
282     } else {
283       type_id = type->type_id();
284     }
285   } else {
286     type_id = common::AnfAlgo::GetOutputInferDataType(node, index);
287   }
288   return type_id;
289 }
290 
FetchOutputSizeByValue(const ValuePtr & value)291 size_t FetchOutputSizeByValue(const ValuePtr &value) {
292   MS_EXCEPTION_IF_NULL(value);
293   if (value->isa<Scalar>()) {
294     return GetTypeByte(value->type());
295   } else if (value->isa<tensor::Tensor>()) {
296     const auto &tensor = value->cast<tensor::TensorPtr>();
297     MS_EXCEPTION_IF_NULL(tensor);
298     return tensor->Size();
299   } else if (value->isa<ValueSequence>()) {
300     const auto &value_sequence = value->cast<ValueSequencePtr>();
301     MS_EXCEPTION_IF_NULL(value_sequence);
302     if (value_sequence->size() == 0) {
303       return 0;
304     }
305     size_t size = 0;
306     for (const auto &sub_value : value_sequence->value()) {
307       MS_EXCEPTION_IF_NULL(sub_value);
308       size += FetchOutputSizeByValue(sub_value);
309     }
310     return size;
311   } else {
312     MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
313   }
314 }
315 
FetchOutputSizeByNode(const AnfNodePtr & node,size_t index,TypeId type_id)316 size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_id) {
317   MS_EXCEPTION_IF_NULL(node);
318   size_t size = GetTypeByte(TypeIdToType(type_id));
319   if (node->isa<ValueNode>() && node->abstract() != nullptr) {
320     const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
321     MS_EXCEPTION_IF_NULL(abs);
322     const auto &shape_ptr = abs->BuildShape();
323     MS_EXCEPTION_IF_NULL(shape_ptr);
324     if (shape_ptr->isa<abstract::Shape>()) {
325       const auto &shapes = shape_ptr->cast<abstract::ShapePtr>()->shape();
326       size = std::accumulate(shapes.begin(), shapes.end(), size, std::multiplies<int64_t>());
327     } else if (shape_ptr->isa<abstract::DynamicSequenceShape>()) {
328       const auto &value_node = node->cast<ValueNodePtr>();
329       MS_EXCEPTION_IF_NULL(value_node);
330       const auto &value = value_node->value();
331       MS_EXCEPTION_IF_NULL(value);
332       size = FetchOutputSizeByValue(value);
333       MS_LOG(INFO) << "Abstract;" << abs->ToString() << " for node:" << node->DebugString() << " index:" << index
334                    << " shape:" << shape_ptr->ToString() << " size:" << size;
335     } else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractScalar>()) {
336       MS_LOG(DEBUG) << "For scalar, the output shape is 1.";
337     } else {
338       MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid abstract;" << abs->ToString() << " for node:" << node->DebugString()
339                                         << " index:" << index << " shape:" << shape_ptr->ToString();
340     }
341   } else {
342     size = AnfAlgo::GetOutputTensorMemSize(node, index);
343   }
344   return size;
345 }
346 
347 // Create a device tensor for the front node.
348 // Get the output format and select kernel build info from the backend node corresponding to the front node to
349 // create the device address.
CreateDeviceTensorForValueNode(const KernelWithIndex & front_node_with_index,const AnfNodePtr & backend_node,const DeviceContext * device_context)350 void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node,
351                                     const DeviceContext *device_context) {
352   MS_EXCEPTION_IF_NULL(backend_node);
353   MS_EXCEPTION_IF_NULL(device_context);
354   const auto &front_node = front_node_with_index.first;
355   MS_EXCEPTION_IF_NULL(front_node);
356 
357   const auto &node_value = front_node->cast<ValueNodePtr>()->value();
358   MS_EXCEPTION_IF_NULL(node_value);
359   if (node_value->isa<FuncGraph>() || node_value->isa<Primitive>() ||
360       (node_value->isa<ValueSequence>() && node_value->cast<ValueSequencePtr>()->size() == 0)) {
361     return;
362   }
363 
364   size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
365   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
366   if (output_type_id == kTypeUnknown) {
367     output_type_id = common::AnfAlgo::GetOutputInferDataType(backend_node, 0);
368   }
369   if (front_node->abstract() != nullptr && front_node->abstract()->isa<abstract::AbstractSequence>() &&
370       front_node->abstract()->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
371     tensor_size = FetchOutputSizeByNode(front_node, front_node_with_index.second, output_type_id);
372   }
373   CreateBuildInfoForFrontNode(front_node_with_index, backend_node);
374   device::DeviceAddressPtr address = nullptr;
375   if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
376     // If is_forward_output, get address from tensor
377     auto tensor = node_value->cast<TensorPtr>();
378     MS_EXCEPTION_IF_NULL(tensor);
379     address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
380   } else {
381     // Create device tensor.
382     std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
383 
384     const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
385       {backend_node, 0}, nullptr, tensor_size, output_format, output_type_id, ShapeVector(),
386       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
387     kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(backend_node));
388     address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
389   }
390   MS_EXCEPTION_IF_NULL(address);
391   MS_LOG(DEBUG) << "Create address for front node:" << front_node->DebugString()
392                 << " backend node:" << backend_node->DebugString() << " index:" << front_node_with_index.second
393                 << " addr:" << address << " size:" << tensor_size;
394   AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get());
395   UpdateRefCount(address.get(), true);
396 }
397 
398 // Create a device tensor for front node.
399 // When the condition input of the switch and switchlayer or the output of a subgraph is a parameter or value node,
400 // there is no corresponding backend node for this parameter, so a device tensor needs to be created for it.
CreateDeviceTensorForFrontNode(const KernelWithIndex & front_node_with_index,const DeviceContext * device_context)401 void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index, const DeviceContext *device_context) {
402   MS_EXCEPTION_IF_NULL(device_context);
403   const auto &node = front_node_with_index.first;
404 
405   MS_EXCEPTION_IF_NULL(node);
406   MS_LOG(DEBUG) << "Start create device tensor for front node:" << front_node_with_index.first->DebugString()
407                 << " index:" << front_node_with_index.second;
408 
409   // Create kernel info for front node.
410   if (node->kernel_info() == nullptr) {
411     auto kernel_info = std::make_shared<device::KernelInfo>();
412     MS_EXCEPTION_IF_NULL(kernel_info);
413     std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
414     MS_EXCEPTION_IF_NULL(builder);
415     kernel_info->set_select_kernel_build_info(builder->Build());
416     node->set_kernel_info(kernel_info);
417   }
418 
419   // Set format.
420   const auto &kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
421   MS_EXCEPTION_IF_NULL(kernel_info);
422   const auto &builder = kernel_info->GetMutableSelectKernelBuildInfo();
423   MS_EXCEPTION_IF_NULL(builder);
424 
425   if (node->isa<ValueNode>()) {
426     const auto &node_value = node->cast<ValueNodePtr>()->value();
427     MS_EXCEPTION_IF_NULL(node_value);
428     if (node_value->isa<ValueSequence>() && node_value->cast<ValueSequencePtr>()->size() == 0) {
429       return;
430     }
431   }
432 
433   if (builder->GetAllOutputFormats().size() > front_node_with_index.second) {
434     builder->SetOutputFormat(kOpFormat_DEFAULT, front_node_with_index.second);
435   } else {
436     auto formats = builder->GetAllOutputFormats();
437     for (size_t i = 0; i <= front_node_with_index.second - builder->GetAllOutputFormats().size(); ++i) {
438       (void)formats.emplace_back(kOpFormat_DEFAULT);
439     }
440     builder->SetOutputsFormat(formats);
441   }
442 
443   // Set type.
444   TypeId type_id = FetchTypeIdByNode(node, front_node_with_index.second);
445   if (builder->GetAllOutputDeviceTypes().size() > front_node_with_index.second) {
446     builder->SetOutputDeviceType(type_id, front_node_with_index.second);
447   } else {
448     auto types = builder->GetAllOutputDeviceTypes();
449     for (size_t i = 0; i <= front_node_with_index.second - builder->GetAllOutputDeviceTypes().size(); ++i) {
450       (void)types.emplace_back(type_id);
451     }
452     builder->SetOutputsDeviceType(types);
453   }
454 
455   const auto &abstract = AnfAlgo::GetNodeAbstractByIndex(front_node_with_index.first, front_node_with_index.second);
456   bool is_map_parameter = abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>();
457   if (is_map_parameter) {
458     DeviceAddressUtils::CreateDeviceAddressByMapTensorNode(device_context, front_node_with_index.first,
459                                                            front_node_with_index.second);
460     UpdateRefCount(AnfAlgo::GetMutableOutputAddr(front_node_with_index.first, front_node_with_index.second).get(),
461                    true);
462     return;
463   }
464 
465   // Fetch mem size by shape, the shape is first obtained from the abstract to deal with the scenario where
466   // the value node is a multi-level tuple.
467   size_t size = FetchOutputSizeByNode(node, front_node_with_index.second, type_id);
468   device::DeviceAddressPtr address = nullptr;
469   if (node->isa<ValueNode>()) {
470     const auto &node_value = node->cast<ValueNodePtr>()->value();
471     MS_EXCEPTION_IF_NULL(node_value);
472     if (node_value->isa<tensor::Tensor>() && node_value->cast<TensorPtr>()->is_forward_output()) {
473       // If is_forward_output, get address from tensor
474       auto tensor = node_value->cast<TensorPtr>();
475       MS_EXCEPTION_IF_NULL(tensor);
476       address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
477     } else {
478       // Create device tensor.
479       const auto &sub_abstract = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), front_node_with_index.second);
480       MS_EXCEPTION_IF_NULL(sub_abstract);
481       const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
482         sub_abstract->BuildShape(), sub_abstract->BuildType(), sub_abstract->BuildValue(), nullptr, size,
483         kOpFormat_DEFAULT, type_id, ShapeVector(), device_context->device_context_key().device_name_,
484         device_context->device_context_key().device_id_);
485       kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(node));
486       address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
487     }
488   } else {
489     // Create device tensor.
490     const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
491       {node, front_node_with_index.second}, nullptr, size, kOpFormat_DEFAULT, type_id, ShapeVector(),
492       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
493     kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(node));
494     address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
495   }
496   MS_EXCEPTION_IF_NULL(address);
497   MS_LOG(INFO) << "Create address for node that has no corresponding backend node:"
498                << common::AnfAlgo::GetNodeDebugString(node) << " addr:" << address << " size:" << size
499                << ", type id:" << type_id;
500   AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
501   UpdateRefCount(address.get(), true);
502 }
503 
504 // Fetch all funcgraph by a seed graph, if a calls b, b calls c, and c calls a, return a set of a, b, c.
FetchAllExecutionFunction(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * const checked_funcgraphs,const std::unordered_map<FuncGraphPtr,std::set<FuncGraphPtr>> & call_relation)505 void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *const checked_funcgraphs,
506                                const std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> &call_relation) {
507   MS_EXCEPTION_IF_NULL(func_graph);
508   MS_EXCEPTION_IF_NULL(checked_funcgraphs);
509   if (checked_funcgraphs->find(func_graph) != checked_funcgraphs->end()) {
510     return;
511   }
512   (void)checked_funcgraphs->emplace(func_graph);
513   auto iter = call_relation.find(func_graph);
514   if (iter == call_relation.end()) {
515     return;
516   }
517 
518   for (const auto &called_func_graph : iter->second) {
519     MS_EXCEPTION_IF_NULL(called_func_graph);
520     FetchAllExecutionFunction(called_func_graph, checked_funcgraphs, call_relation);
521   }
522 }
523 
IsValidMonadNode(const AnfNodePtr & node)524 bool IsValidMonadNode(const AnfNodePtr &node) {
525   MS_EXCEPTION_IF_NULL(node);
526   return node->isa<ValueNode>() || node->isa<Parameter>() || common::AnfAlgo::IsCallNode(node);
527 }
528 
529 // Fetch all inputs of node.
FetchInputNodeByNode(const AnfNodePtr & node)530 std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
531   MS_EXCEPTION_IF_NULL(node);
532   if (HasAbstractMonad(node)) {
533     const auto &real_node_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0);
534     const auto &real_node = real_node_with_index.first;
535     MS_EXCEPTION_IF_NULL(real_node);
536     if (IsValidMonadNode(real_node)) {
537       return {real_node_with_index};
538     }
539     MS_LOG_WITH_NODE(EXCEPTION, real_node) << "Invalid monad node:" << real_node->DebugString();
540   }
541 
542   // The node is divided into the following types:
543   // 1. depend and load.
544   const auto &node_with_index =
545     common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
546   auto real_node = node_with_index.first;
547   size_t real_index = node_with_index.second;
548   MS_EXCEPTION_IF_NULL(real_node);
549   std::vector<KernelWithIndex> results;
550 
551   // 2. Tuple node.
552   const PrimitiveSet expand_prims{prim::kPrimMakeTuple};
553   // The MakeTuple/MakeSparse node need expand and recurse.
554   if (IsOneOfPrimitiveCNode(real_node, expand_prims)) {
555     const auto &cnode = real_node->cast<CNodePtr>();
556     MS_EXCEPTION_IF_NULL(cnode);
557     const auto &inputs = cnode->inputs();
558     for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
559       const auto &sub_results = FetchInputNodeByNode(inputs[i]);
560       (void)results.insert(results.end(), sub_results.begin(), sub_results.end());
561     }
562     return results;
563   }
564 
565   // 3. One output node.
566   const auto &abstract = real_node->abstract();
567   if (abstract == nullptr) {
568     MS_LOG(DEBUG) << "Empty abstract for node:" << real_node->DebugString();
569     (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(real_node, real_index));
570     return results;
571   }
572 
573   // 4 Other.
574   if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
575     if (real_node->cast<CNodePtr>()->HasAttr(kAttrReplaceRealKernelInBackend) && real_node->abstract() != nullptr) {
576       size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(real_node->abstract());
577       MS_LOG(INFO) << "Fetch an tuple get item with repalce flag:" << real_node->DebugString()
578                    << " output num:" << output_num;
579       for (size_t i = 0; i < output_num; ++i) {
580         (void)results.emplace_back(real_node, i);
581       }
582       return results;
583     }
584     std::vector<size_t> index_stack;
585     auto get_item_src_node = common::AnfAlgo::GetTupleIndexes(real_node, &index_stack);
586     MS_EXCEPTION_IF_NULL(get_item_src_node);
587     if (index_stack.empty()) {
588       const auto &sub_results = FetchInputNodeByNode(get_item_src_node);
589       (void)results.insert(results.end(), sub_results.begin(), sub_results.end());
590       return results;
591     }
592     auto get_item_src_abstract = get_item_src_node->abstract();
593     MS_EXCEPTION_IF_NULL(get_item_src_abstract);
594     auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack);
595     (void)std::transform(indexes.begin(), indexes.end(), std::back_inserter(results),
596                          [&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
597     return results;
598   }
599 
600   size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
601   for (size_t i = 0; i < output_num; ++i) {
602     (void)results.emplace_back(real_node, i);
603   }
604   return results;
605 }
606 
607 // Add formal parameter and real parameter into realationship map.
AddFormalToRealParameter(const AnfNodePtr & formal_parameter,const AnfNodePtr & real_parameter,const CallNodeToFuncGraph & call_node_to_func_graphs,FormalToRealParameter * const formal_to_real_parameters)608 void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodePtr &real_parameter,
609                               const CallNodeToFuncGraph &call_node_to_func_graphs,
610                               FormalToRealParameter *const formal_to_real_parameters) {
611   MS_EXCEPTION_IF_NULL(formal_parameter);
612   MS_EXCEPTION_IF_NULL(real_parameter);
613   MS_EXCEPTION_IF_NULL(formal_to_real_parameters);
614   auto abstract = formal_parameter->abstract();
615   if (abstract == nullptr) {
616     MS_LOG_WITH_NODE(EXCEPTION, formal_parameter) << "Empty abstract for parameter:" << formal_parameter->DebugString();
617   }
618   size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
619 
620   for (size_t i = 0; i < output_num; ++i) {
621     std::set<KernelWithIndex> real_parameters;
622     std::set<KernelWithIndex> invalid_call_nodes;
623     FetchRealParameterByNode({real_parameter, i}, &real_parameters, &invalid_call_nodes, call_node_to_func_graphs);
624     if (real_parameters.empty()) {
625       MS_LOG(DEBUG) << "Failed to find real parameter for formal parameter:" << real_parameter->DebugString();
626       continue;
627     }
628 
629     for (const auto &parameter : real_parameters) {
630       MS_EXCEPTION_IF_NULL(parameter.first);
631       MS_LOG(DEBUG) << "Add formal parameter:" << formal_parameter->DebugString() << " index:" << i
632                     << " to real parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
633     }
634     (*formal_to_real_parameters)[{formal_parameter, i}].insert(real_parameters.begin(), real_parameters.end());
635   }
636 }
637 
638 // Recursively traverse the input to confirm whether there is an input of recursive call.
IsFirstControlNode(const AnfNodePtr & node,std::set<AnfNodePtr> * checked_nodes,const std::set<AnfNodePtr> & unrecursion_call_nodes)639 bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes,
640                         const std::set<AnfNodePtr> &unrecursion_call_nodes) {
641   MS_EXCEPTION_IF_NULL(node);
642   MS_EXCEPTION_IF_NULL(checked_nodes);
643   if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
644     return true;
645   }
646   (void)checked_nodes->emplace(node);
647 
648   const auto &cnode = node->cast<CNodePtr>();
649   MS_EXCEPTION_IF_NULL(cnode);
650   const auto &inputs = cnode->inputs();
651   for (const auto &input : inputs) {
652     MS_EXCEPTION_IF_NULL(input);
653     if ((common::AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) ||
654         (!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) {
655       return false;
656     }
657   }
658   return true;
659 }
660 
661 // Check if src_node depends on dst_node.
IsTopoDependNode(const AnfNodePtr & src_node,const AnfNodePtr & dst_node,std::set<AnfNodePtr> * checked_node)662 bool IsTopoDependNode(const AnfNodePtr &src_node, const AnfNodePtr &dst_node, std::set<AnfNodePtr> *checked_node) {
663   MS_EXCEPTION_IF_NULL(src_node);
664   MS_EXCEPTION_IF_NULL(dst_node);
665   MS_EXCEPTION_IF_NULL(checked_node);
666   if (src_node == dst_node) {
667     return true;
668   }
669   if (!src_node->isa<CNode>() || checked_node->find(src_node) != checked_node->end()) {
670     return false;
671   }
672 
673   (void)checked_node->emplace(src_node);
674   const auto &cnode = src_node->cast<CNodePtr>();
675   MS_EXCEPTION_IF_NULL(cnode);
676   const auto &inputs = cnode->inputs();
677   for (const auto &input : inputs) {
678     MS_EXCEPTION_IF_NULL(input);
679     if (IsTopoDependNode(input, dst_node, checked_node)) {
680       return true;
681     }
682   }
683   return false;
684 }
685 
IsValidBackendParameter(const AnfNodePtr & node)686 bool IsValidBackendParameter(const AnfNodePtr &node) {
687   if (node == nullptr) {
688     return false;
689   }
690   if (node->abstract() == nullptr) {
691     return true;
692   }
693   if (node->abstract()->isa<abstract::AbstractAny>()) {
694     return false;
695   }
696   const auto &shape = node->abstract()->BuildShape();
697   if (shape == nullptr || shape->IsDynamic()) {
698     return false;
699   }
700   return true;
701 }
702 }  // namespace
CreateBuildInfoForFrontNode(const KernelWithIndex & front_node_with_index,const AnfNodePtr & backend_node)703 void CreateBuildInfoForFrontNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node) {
704   MS_EXCEPTION_IF_NULL(front_node_with_index.first);
705   MS_EXCEPTION_IF_NULL(backend_node);
706   const auto &front_node = front_node_with_index.first;
707   if (front_node->kernel_info() == nullptr) {
708     auto kernel_info = std::make_shared<device::KernelInfo>();
709     MS_EXCEPTION_IF_NULL(kernel_info);
710     front_node->set_kernel_info(kernel_info);
711     std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
712     MS_EXCEPTION_IF_NULL(builder);
713     kernel_info->set_select_kernel_build_info(builder->Build());
714     kernel_info->GetMutableSelectKernelBuildInfo()->SetOutputsKernelObjectType(
715       {kernel::KernelObjectType::TUPLE_UNFOLD});
716   }
717 
718   // Set build info to front node.
719   auto backend_kernel_info = static_cast<device::KernelInfo *>(backend_node->kernel_info());
720   MS_EXCEPTION_IF_NULL(backend_kernel_info);
721   auto backend_build_info = backend_kernel_info->GetMutableSelectKernelBuildInfo();
722   MS_EXCEPTION_IF_NULL(backend_build_info);
723 
724   auto front_kernel_info = static_cast<device::KernelInfo *>(front_node->kernel_info());
725   MS_EXCEPTION_IF_NULL(front_kernel_info);
726   auto front_build_info = front_kernel_info->GetMutableSelectKernelBuildInfo();
727   MS_EXCEPTION_IF_NULL(front_build_info);
728   // Set output format and device data type.
729   if (front_build_info->GetAllOutputFormats().size() > front_node_with_index.second) {
730     front_build_info->SetOutputFormat(backend_build_info->GetOutputFormat(0), front_node_with_index.second);
731     front_build_info->SetOutputDeviceType(backend_build_info->GetOutputDeviceType(0), front_node_with_index.second);
732   } else {
733     auto formats = front_build_info->GetAllOutputFormats();
734     auto types = front_build_info->GetAllOutputDeviceTypes();
735     for (size_t i = 0; i <= front_node_with_index.second - front_build_info->GetAllOutputFormats().size(); ++i) {
736       (void)formats.emplace_back(backend_build_info->GetOutputFormat(0));
737       (void)types.emplace_back(backend_build_info->GetOutputDeviceType(0));
738     }
739     front_build_info->SetOutputsFormat(formats);
740     front_build_info->SetOutputsDeviceType(types);
741   }
742 }
743 
IsInvalidPartial(const AnfNodePtr & node)744 bool IsInvalidPartial(const AnfNodePtr &node) {
745   MS_EXCEPTION_IF_NULL(node);
746   if (!node->isa<CNode>()) {
747     return false;
748   }
749 
750   const auto &cnode = node->cast<CNodePtr>();
751   MS_EXCEPTION_IF_NULL(cnode);
752   const auto &inputs = cnode->inputs();
753   if (inputs.size() <= kPartialFuncGraphPos) {
754     return false;
755   }
756 
757   if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
758     return false;
759   }
760   if (IsDeadNode(inputs[kPartialFuncGraphPos])) {
761     return true;
762   }
763   return false;
764 }
765 
FetchRealNodeByGetItem(const KernelWithIndex & node_with_index)766 KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
767   MS_EXCEPTION_IF_NULL(node_with_index.first);
768   std::vector<size_t> index_stack{node_with_index.second};
769 
770   const auto &get_item_src_node = common::AnfAlgo::GetTupleIndexes(node_with_index.first, &index_stack);
771   MS_EXCEPTION_IF_NULL(get_item_src_node);
772   const auto &get_item_src_abstract = get_item_src_node->abstract();
773   MS_EXCEPTION_IF_NULL(get_item_src_abstract);
774   auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack);
775   if (indexes.empty()) {
776     MS_LOG_WITH_NODE(EXCEPTION, get_item_src_node) << "Failed to find index for node:" << get_item_src_node;
777   }
778   if (indexes.size() > 1) {
779     MS_LOG(DEBUG) << "Output size:" << indexes.size() << " for node:" << get_item_src_node->DebugString()
780                   << " more than 1";
781   }
782   return {get_item_src_node, *(indexes.begin())};
783 }
784 
IsCsrNode(const AnfNodePtr & node)785 bool IsCsrNode(const AnfNodePtr &node) {
786   MS_EXCEPTION_IF_NULL(node);
787   if (!node->isa<CNode>()) {
788     return false;
789   }
790   return common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndptr) ||
791          common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetIndices) ||
792          common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetValues) ||
793          common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCSRTensorGetDenseShape);
794 }
795 
IsCooNode(const AnfNodePtr & node)796 bool IsCooNode(const AnfNodePtr &node) {
797   MS_EXCEPTION_IF_NULL(node);
798   if (!node->isa<CNode>()) {
799     return false;
800   }
801   return common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCOOTensorGetIndices) ||
802          common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCOOTensorGetValues) ||
803          common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCOOTensorGetDenseShape);
804 }
805 
GetFrontNodeByKernelGraph(const AnfNodePtr & backend_node,const KernelGraph * const graph)806 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph) {
807   MS_EXCEPTION_IF_NULL(backend_node);
808   MS_EXCEPTION_IF_NULL(graph);
809   const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
810   if (front_node != nullptr) {
811     MS_LOG(DEBUG) << "Front node:" << front_node->DebugString() << " index:0"
812                   << " for backend node:" << backend_node->DebugString();
813     return {front_node, 0};
814   }
815   const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
816   if (front_node_with_index.first != nullptr) {
817     MS_LOG(DEBUG) << "Internal front node:" << front_node_with_index.first->DebugString()
818                   << " index:" << front_node_with_index.second << " for backend node:" << backend_node->DebugString();
819     return front_node_with_index;
820   }
821   const auto &front_tuple_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(backend_node);
822   if (front_tuple_node_with_index.first == nullptr) {
823     MS_LOG_WITH_NODE(EXCEPTION, backend_node)
824       << "Cannot find front node for backend node:" << backend_node->DebugString() << " in graph:" << graph->ToString();
825   }
826   MS_LOG(DEBUG) << "Tuple front node:" << front_tuple_node_with_index.first->DebugString()
827                 << " index:" << front_tuple_node_with_index.second;
828   return front_tuple_node_with_index;
829 }
830 
FetchInputNodeByCNode(const AnfNodePtr & node)831 std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
832   MS_EXCEPTION_IF_NULL(node);
833   MS_LOG(DEBUG) << "Fetch input node for:" << node->DebugString();
834   if (!node->isa<CNode>()) {
835     MS_LOG(DEBUG) << "Empty input node for:" << node->DebugString();
836     return {};
837   }
838 
839   std::vector<KernelWithIndex> results;
840   // The first input of normal cnode is the primitive of node, and the real input starts from the second input,
841   // but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial.
842   size_t input_start_pos = kCNodeInputStartPos;
843   if (common::AnfAlgo::IsCallNode(node)) {
844     input_start_pos = 0;
845   }
846   const auto &cnode = node->cast<CNodePtr>();
847   MS_EXCEPTION_IF_NULL(cnode);
848   const auto inputs = cnode->inputs();
849 
850   // The first branch of the input of the switch node is the true branch, and the second is the false branch.
851   // But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input
852   // of the switch node needs to exchange the positions of the two branches. So deal separately.
853   if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
854     if (inputs.size() != kSwitchInputNum) {
855       MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid switch node:" << node->DebugString();
856     }
857     (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0));
858     (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0));
859     (void)results.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0));
860     return results;
861   }
862 
863   for (size_t i = input_start_pos; i < inputs.size(); ++i) {
864     MS_EXCEPTION_IF_NULL(inputs[i]);
865     const auto &sub_results = FetchInputNodeByNode(inputs[i]);
866     (void)results.insert(results.end(), sub_results.begin(), sub_results.end());
867   }
868   return results;
869 }
870 
IsPartialInput(const AnfNodePtr & node)871 bool IsPartialInput(const AnfNodePtr &node) {
872   MS_EXCEPTION_IF_NULL(node);
873   const auto &abstract = node->abstract();
874   if (abstract != nullptr) {
875     if (abstract->isa<abstract::AbstractFunction>()) {
876       return true;
877     }
878     return false;
879   }
880 
881   if (!node->isa<CNode>()) {
882     return false;
883   }
884 
885   // If the abstract is empty and the node is a cnode, check its true branch.
886   const auto &cnode = node->cast<CNodePtr>();
887   MS_EXCEPTION_IF_NULL(cnode);
888 
889   const auto &inputs = cnode->inputs();
890   if (inputs.size() < kSwitchTrueBranchIndex + 1) {
891     MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid switch node:" << node->DebugString();
892   }
893   const auto &branch_node = inputs[kSwitchTrueBranchIndex];
894   MS_EXCEPTION_IF_NULL(branch_node);
895   const auto &branch_abstract = branch_node->abstract();
896   // If abstract is empty, the default is true.
897   if (branch_abstract == nullptr) {
898     MS_LOG(DEBUG) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
899     return true;
900   }
901 
902   if (branch_abstract->isa<abstract::AbstractFunction>()) {
903     return true;
904   } else if (branch_abstract->isa<abstract::AbstractSequence>()) {
905     // In switch layer, the true branch input is a make tuple.
906     auto sequence_abstract = branch_abstract->cast<abstract::AbstractSequencePtr>();
907     MS_EXCEPTION_IF_NULL(sequence_abstract);
908     const auto &sub_abstracts = sequence_abstract->elements();
909     if (sub_abstracts.empty() || sub_abstracts[0] == nullptr) {
910       MS_LOG(DEBUG) << "Failed to get abstract by true branch input of switch node:" << node->DebugString();
911       return true;
912     }
913     if (sub_abstracts[0]->isa<abstract::AbstractFunction>()) {
914       return true;
915     }
916   }
917   return false;
918 }
919 
920 // Fetch the depend nodes according to the monad node.
FetchRealDependNodeByAutoMonad(const AnfNodePtr & node,std::set<AnfNodePtr> * const depend_nodes)921 void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *const depend_nodes) {
922   // Find the real input node, include the monad node and make tuple node.
923   const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
924                                                   prim::kPrimMakeTuple};
925   const auto &node_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_types);
926   auto real_node = node_with_index.first;
927   MS_EXCEPTION_IF_NULL(real_node);
928   if (!real_node->isa<CNode>()) {
929     return;
930   }
931 
932   const auto &real_cnode = real_node->cast<CNodePtr>();
933   MS_EXCEPTION_IF_NULL(real_cnode);
934   const auto &real_inputs = real_cnode->inputs();
935 
936   // Make tuple node needs to be expanded.
937   if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
938     for (size_t i = 1; i < real_inputs.size(); ++i) {
939       MS_EXCEPTION_IF_NULL(real_inputs[i]);
940       FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
941     }
942     return;
943   }
944 
945   const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
946     prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
947   if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimDepend) ||
948       common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimLoad)) {
949     FetchRealDependNodeByAutoMonad(real_inputs[kDependAttachNodeIndex], depend_nodes);
950     // The real input may be this scene:  depend/load --> load/depend, so need add the control arrow for real input
951     // node in this scene.
952     if (IsOneOfPrimitiveCNode(real_inputs[kRealInputIndexInDepend], recursion_prims)) {
953       FetchRealDependNodeByAutoMonad(real_inputs[kRealInputIndexInDepend], depend_nodes);
954     }
955   } else if (common::AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimUpdateState)) {
956     for (size_t i = kUpdateStateRealInput; i < real_inputs.size(); ++i) {
957       FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
958     }
959   } else {
960     MS_EXCEPTION_IF_NULL(depend_nodes);
961     (void)depend_nodes->emplace(real_node);
962   }
963 }
964 
965 // Get all the depend nodes of node in side effect.
FetchAllMonadNodeByNode(const AnfNodePtr & node)966 std::vector<AnfNodePtr> FetchAllMonadNodeByNode(const AnfNodePtr &node) {
967   MS_EXCEPTION_IF_NULL(node);
968   if (!node->isa<CNode>()) {
969     return {};
970   }
971   if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) ||
972       common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) ||
973       common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad)) {
974     return {node};
975   }
976 
977   std::vector<AnfNodePtr> results;
978   if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
979     const auto &cnode = node->cast<CNodePtr>();
980     MS_EXCEPTION_IF_NULL(cnode);
981     for (auto &weak_input : cnode->weak_inputs()) {
982       auto input = weak_input.lock();
983       MS_EXCEPTION_IF_NULL(input);
984       const auto &result = FetchAllMonadNodeByNode(input);
985       (void)results.insert(results.end(), result.begin(), result.end());
986     }
987   }
988   return results;
989 }
990 
Parse(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const FuncGraphPtr & root_graph,const FuncGraphToKernelGraphGroup & func_graph_to_kernel_graphs)991 void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
992                               const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
993                               const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
994   if (graphs.size() != device_contexts.size()) {
995     MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size()
996                       << " device context num:" << device_contexts.size();
997   }
998 
999   if (control_nodes.size() <= 1) {
1000     MS_LOG(DEBUG) << "Control node parser is not inited.";
1001     return;
1002   }
1003   MS_LOG(INFO) << "Control node parse start.";
1004 
1005   // Fetch default device context.
1006   auto context_ptr = MsContext::GetInstance();
1007   MS_EXCEPTION_IF_NULL(context_ptr);
1008   std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1009   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1010   DeviceContext *default_context = nullptr;
1011   if (device_contexts.empty()) {
1012     default_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
1013   } else {
1014     default_context = device_contexts[0];
1015   }
1016   MS_EXCEPTION_IF_NULL(default_context);
1017 
1018   KernelGraphToDeviceContext kernel_graph_to_device_contexts;
1019   for (size_t i = 0; i < graphs.size(); ++i) {
1020     kernel_graph_to_device_contexts[graphs[i]] = device_contexts[i];
1021   }
1022 
1023   for (const auto &control_node : control_nodes) {
1024     MS_EXCEPTION_IF_NULL(control_node);
1025     MS_LOG(DEBUG) << "Print control node:" << control_node->DebugString();
1026   }
1027 
1028   is_inited_ = true;
1029 
1030   root_func_graph_ = root_graph;
1031 
1032   root_graph_parameters_ = root_graph->parameters();
1033 
1034   func_graph_to_kernel_graph_groups_ = func_graph_to_kernel_graphs;
1035   for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
1036     for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
1037       for (const auto &kernel_graph : kernel_graph_group) {
1038         MS_EXCEPTION_IF_NULL(func_graph_to_kernel_graph_groups.first);
1039         MS_EXCEPTION_IF_NULL(kernel_graph);
1040         MS_LOG(DEBUG) << "Funcgraph to kernel graph, func:" << func_graph_to_kernel_graph_groups.first->ToString()
1041                       << " kernel_graph:" << kernel_graph->ToString();
1042       }
1043     }
1044   }
1045 
1046   CreateBranchIDForCallNode(control_nodes);
1047 
1048   ParseFrontNodeToKernelGraph(graphs);
1049 
1050   ParseCallNodeToFuncGraph(control_nodes);
1051 
1052   ParseUnRecursionCallNode();
1053 
1054   InsertDependForParallelCall(control_nodes);
1055 
1056   ParseKernelGraphGroup(kernel_graph_to_device_contexts);
1057 
1058   ParseNodeLevel(control_nodes);
1059 
1060   ParseNeedStackControlNode(control_nodes);
1061 
1062   ParseFormalToRealParameter(control_nodes);
1063 
1064   ParseFrontToBackendParameter(graphs, device_contexts);
1065 
1066   CreateDeviceTensorForRootGraphParameter(default_context);
1067 
1068   ParseFrontToBackendKernel(graphs, device_contexts);
1069 
1070   ParseDeviceContext(control_nodes, graphs, device_contexts, default_context, func_graph_to_kernel_graphs);
1071 
1072   FetchFrontValueNode(control_nodes, default_context);
1073 
1074   ParseControlNodeParameter(control_nodes);
1075 
1076   ParseFirstControlNodeAndKernelGraphForFuncGraph(control_nodes);
1077 
1078   ParseDynamicLenFormalParameter(control_nodes);
1079   MS_LOG(INFO) << "Control node parse end.";
1080 }
1081 
1082 namespace {
GetArgumentIndexForDynamicLenParameter(const abstract::AbstractBasePtr & argument_abs,size_t argument_index,const abstract::AbstractBasePtr & parameter_abs,mindspore::HashMap<size_t,size_t> * indexes)1083 void GetArgumentIndexForDynamicLenParameter(const abstract::AbstractBasePtr &argument_abs, size_t argument_index,
1084                                             const abstract::AbstractBasePtr &parameter_abs,
1085                                             mindspore::HashMap<size_t, size_t> *indexes) {
1086   if (argument_abs == nullptr || parameter_abs == nullptr) {
1087     return;
1088   }
1089   MS_EXCEPTION_IF_NULL(indexes);
1090   if ((!argument_abs->isa<abstract::AbstractSequence>()) || (!parameter_abs->isa<abstract::AbstractSequence>())) {
1091     return;
1092   }
1093   const auto &arg_seq_abs = argument_abs->cast<abstract::AbstractSequencePtr>();
1094   const auto &para_seq_abs = parameter_abs->cast<abstract::AbstractSequencePtr>();
1095   MS_EXCEPTION_IF_NULL(arg_seq_abs);
1096   MS_EXCEPTION_IF_NULL(para_seq_abs);
1097   if (arg_seq_abs->dynamic_len() && para_seq_abs->dynamic_len()) {
1098     return;
1099   }
1100   if ((!arg_seq_abs->dynamic_len()) && para_seq_abs->dynamic_len()) {
1101     MS_LOG(DEBUG) << "Add argument index:" << argument_index << " size:" << arg_seq_abs->size();
1102     (*indexes)[argument_index] = arg_seq_abs->size();
1103     return;
1104   }
1105   if (arg_seq_abs->dynamic_len() || para_seq_abs->dynamic_len() || arg_seq_abs->size() != para_seq_abs->size()) {
1106     MS_LOG(EXCEPTION) << "Invalid dynamic len flag for argument abstract:" << arg_seq_abs->ToString()
1107                       << " parameter abstract:" << para_seq_abs->ToString();
1108   }
1109   size_t start_index = argument_index;
1110   for (size_t i = 0; i < arg_seq_abs->size(); ++i) {
1111     GetArgumentIndexForDynamicLenParameter(arg_seq_abs->elements()[i], start_index, para_seq_abs->elements()[i],
1112                                            indexes);
1113     start_index += common::AnfAlgo::GetOutputNumByAbstract(arg_seq_abs->elements()[i]);
1114   }
1115 }
1116 }  // namespace
1117 
ParseDynamicLenFormalParameterByCallNode(const AnfNodePtr & node)1118 void ControlNodeParser::ParseDynamicLenFormalParameterByCallNode(const AnfNodePtr &node) {
1119   MS_EXCEPTION_IF_NULL(node);
1120   const auto &cnode = node->cast<CNodePtr>();
1121   MS_EXCEPTION_IF_NULL(cnode);
1122   const auto &func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
1123   if (func_graphs.empty()) {
1124     MS_LOG(EXCEPTION) << "Get func_graph from abstract failed.";
1125   }
1126   mindspore::HashMap<size_t, size_t> sequence_indexes;
1127   for (auto func_graph : func_graphs) {
1128     MS_EXCEPTION_IF_NULL(func_graph);
1129     // Check the consistency of return outputs and call outputs.
1130     MS_EXCEPTION_IF_NULL(func_graph->return_node());
1131     mindspore::HashMap<size_t, size_t> return_sequence_indexes;
1132     GetArgumentIndexForDynamicLenParameter(func_graph->return_node()->abstract(), 0, node->abstract(),
1133                                            &return_sequence_indexes);
1134     if (!return_sequence_indexes.empty()) {
1135       return_to_call_with_dynamic_sequence_index_[func_graph->return_node()][node] = return_sequence_indexes;
1136     }
1137     // Check the consistency of arguments and parameters.
1138     if (cnode->inputs().empty()) {
1139       MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid cnode:" << cnode->DebugString();
1140     }
1141     size_t args_num = cnode->size() - 1;
1142     size_t para_num = func_graph->parameters().size();
1143     MS_LOG(DEBUG) << "for call node:" << cnode->DebugString() << " arg size:" << args_num << " para size:" << para_num;
1144     if (args_num > para_num) {
1145       MS_LOG(EXCEPTION) << "Invalid args num:" << args_num << " for funcgraph:" << func_graph->ToString()
1146                         << " parameters num:" << func_graph->parameters().size();
1147     }
1148     size_t start_index = 1;
1149     for (size_t i = 0; i < args_num; ++i) {
1150       MS_EXCEPTION_IF_NULL(cnode->input(i + 1));
1151       MS_EXCEPTION_IF_NULL((func_graph->parameters())[i + para_num - args_num]);
1152       MS_LOG(DEBUG) << "Check formal parameter:" << cnode->input(i + 1)->DebugString()
1153                     << " real node:" << (func_graph->parameters())[i + para_num - args_num]->DebugString();
1154       GetArgumentIndexForDynamicLenParameter(cnode->input(i + 1)->abstract(), start_index,
1155                                              (func_graph->parameters())[i + para_num - args_num]->abstract(),
1156                                              &sequence_indexes);
1157       start_index += common::AnfAlgo::GetOutputNumByAbstract(cnode->input(i + 1)->abstract());
1158     }
1159     if (!sequence_indexes.empty()) {
1160       for (const auto &pair : sequence_indexes) {
1161         MS_LOG(DEBUG) << "Add dynamic len formal parameter for call node:" << node->DebugString()
1162                       << " funcgraph:" << func_graph->ToString() << " argument index:" << pair.first
1163                       << " size:" << pair.second;
1164       }
1165       control_node_to_funcgraph_with_dynamic_sequence_index_[node][func_graph.get()] = sequence_indexes;
1166     }
1167   }
1168 }
1169 
ParseDynamicLenFormalParameterByPartial(const AnfNodePtr & node)1170 void ControlNodeParser::ParseDynamicLenFormalParameterByPartial(const AnfNodePtr &node) {
1171   MS_EXCEPTION_IF_NULL(node);
1172   const auto &cnode = node->cast<CNodePtr>();
1173   MS_EXCEPTION_IF_NULL(cnode);
1174   size_t input_num = cnode->size();
1175   if (input_num <= kPartialFuncGraphPos || cnode->input(kPartialFuncGraphPos) == nullptr ||
1176       (!cnode->input(kPartialFuncGraphPos)->isa<ValueNode>())) {
1177     MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid partial node:" << node->DebugString();
1178   }
1179   const auto &func_graph = GetValueNode<FuncGraphPtr>(cnode->input(kPartialFuncGraphPos));
1180   if (func_graph == nullptr) {
1181     MS_LOG(DEBUG) << "Failed to get funcgraph in partial node:" << node->DebugString();
1182     return;
1183   }
1184   if (func_graph->parameters().size() < input_num - kPartialInputStartPos) {
1185     MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid args num:" << input_num - kPartialInputStartPos
1186                                        << " in partial node:" << cnode->DebugString()
1187                                        << " for fungraph:" << func_graph->ToString()
1188                                        << " parameter num:" << func_graph->parameters().size();
1189   }
1190   size_t start_index = 1;
1191   mindspore::HashMap<size_t, size_t> sequence_indexes;
1192   for (size_t i = kPartialInputStartPos; i < input_num; ++i) {
1193     MS_EXCEPTION_IF_NULL(cnode->input(i));
1194     MS_EXCEPTION_IF_NULL(func_graph->parameters()[i - kPartialInputStartPos]);
1195     GetArgumentIndexForDynamicLenParameter(cnode->input(i)->abstract(), start_index,
1196                                            func_graph->parameters()[i - kPartialInputStartPos]->abstract(),
1197                                            &sequence_indexes);
1198     start_index += common::AnfAlgo::GetOutputNumByAbstract(cnode->input(i)->abstract());
1199   }
1200   if (!sequence_indexes.empty()) {
1201     mindspore::HashMap<size_t, size_t> new_sequence_indexes;
1202     for (const auto &index_pair : sequence_indexes) {
1203       new_sequence_indexes[index_pair.first + 1] = index_pair.second;
1204     }
1205     control_node_to_funcgraph_with_dynamic_sequence_index_[node][func_graph.get()] = new_sequence_indexes;
1206   }
1207 }
1208 
ParseDynamicLenFormalParameter(const std::vector<AnfNodePtr> & control_nodes)1209 void ControlNodeParser::ParseDynamicLenFormalParameter(const std::vector<AnfNodePtr> &control_nodes) {
1210   for (const auto &node : control_nodes) {
1211     MS_EXCEPTION_IF_NULL(node);
1212     if (common::AnfAlgo::IsCallNode(node)) {
1213       ParseDynamicLenFormalParameterByCallNode(node);
1214     } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
1215       ParseDynamicLenFormalParameterByPartial(node);
1216     }
1217   }
1218   for (const auto &node_to_func_with_index : control_node_to_funcgraph_with_dynamic_sequence_index_) {
1219     const auto &node = node_to_func_with_index.first;
1220     MS_EXCEPTION_IF_NULL(node);
1221     for (const auto &func_with_index : node_to_func_with_index.second) {
1222       const auto &func_graph = func_with_index.first;
1223       MS_EXCEPTION_IF_NULL(func_graph);
1224       for (const auto &indexes : func_with_index.second) {
1225         MS_LOG(DEBUG) << "Node:" << node->DebugString() << " func_graph:" << func_graph->ToString()
1226                       << " start index:" << indexes.first << " size:" << indexes.second;
1227       }
1228     }
1229   }
1230   for (const auto &node_to_call_with_index : return_to_call_with_dynamic_sequence_index_) {
1231     const auto &node = node_to_call_with_index.first;
1232     MS_EXCEPTION_IF_NULL(node);
1233     for (const auto &call_with_index : node_to_call_with_index.second) {
1234       const auto &call = call_with_index.first;
1235       MS_EXCEPTION_IF_NULL(call);
1236       for (const auto &indexes : call_with_index.second) {
1237         MS_LOG(DEBUG) << "Node:" << node->DebugString() << " call node:" << call->DebugString()
1238                       << " start index:" << indexes.first << " size:" << indexes.second;
1239       }
1240     }
1241   }
1242 }
1243 
1244 // Fetch all the funcgraph recursively that the call node will call.
FetchAllCalledFuncGraph(const AnfNodePtr & call_node,std::set<FuncGraphPtr> * called_graphs,const CallNodeToFuncGraph & call_node_to_func_graphs,const FuncGraphToCallNode & func_graph_to_call_nodes)1245 void FetchAllCalledFuncGraph(const AnfNodePtr &call_node, std::set<FuncGraphPtr> *called_graphs,
1246                              const CallNodeToFuncGraph &call_node_to_func_graphs,
1247                              const FuncGraphToCallNode &func_graph_to_call_nodes) {
1248   MS_EXCEPTION_IF_NULL(call_node);
1249   MS_EXCEPTION_IF_NULL(called_graphs);
1250   const auto &call_iter = call_node_to_func_graphs.find(call_node);
1251   if (call_iter == call_node_to_func_graphs.end()) {
1252     return;
1253   }
1254   for (const auto &func_graph : call_iter->second) {
1255     MS_EXCEPTION_IF_NULL(func_graph);
1256     if (called_graphs->find(func_graph) != called_graphs->end()) {
1257       continue;
1258     }
1259     (void)called_graphs->emplace(func_graph);
1260     const auto &graph_iter = func_graph_to_call_nodes.find(func_graph);
1261     if (graph_iter == func_graph_to_call_nodes.end()) {
1262       continue;
1263     }
1264 
1265     // Fetch the funcgraph recursively.
1266     for (const auto &node : graph_iter->second) {
1267       FetchAllCalledFuncGraph(node, called_graphs, call_node_to_func_graphs, func_graph_to_call_nodes);
1268     }
1269   }
1270 }
1271 
CreateTensorForValue(const ValuePtr & value)1272 tensor::TensorPtr ControlNodeParser::CreateTensorForValue(const ValuePtr &value) {
1273   MS_EXCEPTION_IF_NULL(value);
1274   tensor::TensorPtr tensor = nullptr;
1275   if (value->isa<Monad>()) {
1276     tensor = std::make_shared<tensor::Tensor>(int8_t('U'), TypeIdToType(kNumberTypeInt8));
1277   } else if (value->isa<Scalar>()) {
1278     const auto scalar_value = value->cast<ScalarPtr>();
1279     MS_EXCEPTION_IF_NULL(scalar_value);
1280     tensor = ScalarToTensor(scalar_value);
1281   } else {
1282     MS_LOG(EXCEPTION) << "Invalid value:" << value->ToString();
1283   }
1284   control_node_tensors_.emplace_back(tensor);
1285   return tensor;
1286 }
1287 
IsParallelCallRecursionGraph(const AnfNodePtr & call_node1,const AnfNodePtr & call_node2,const FuncGraphToCallNode & func_graph_to_call_nodes)1288 bool ControlNodeParser::IsParallelCallRecursionGraph(const AnfNodePtr &call_node1, const AnfNodePtr &call_node2,
1289                                                      const FuncGraphToCallNode &func_graph_to_call_nodes) {
1290   // Fetch all funcgraphs the two call nodes will call both.
1291   std::set<FuncGraphPtr> called_graphs_1;
1292   FetchAllCalledFuncGraph(call_node1, &called_graphs_1, call_node_to_func_graphs_, func_graph_to_call_nodes);
1293   std::set<FuncGraphPtr> called_graphs_2;
1294   FetchAllCalledFuncGraph(call_node2, &called_graphs_2, call_node_to_func_graphs_, func_graph_to_call_nodes);
1295   std::vector<FuncGraphPtr> common_called_graphs;
1296   (void)std::set_intersection(called_graphs_1.begin(), called_graphs_1.end(), called_graphs_2.begin(),
1297                               called_graphs_2.end(), std::back_inserter(common_called_graphs));
1298 
1299   // Check for recursive calls in funcgraph.
1300   for (const auto &func_graph : common_called_graphs) {
1301     MS_EXCEPTION_IF_NULL(func_graph);
1302     const auto &iter = func_graph_to_call_nodes.find(func_graph);
1303     if (iter == func_graph_to_call_nodes.end()) {
1304       continue;
1305     }
1306     for (const auto &call_node : iter->second) {
1307       MS_EXCEPTION_IF_NULL(call_node);
1308       if (IsRecursionCallNode(call_node)) {
1309         MS_LOG(INFO) << "Call node:" << call_node1->DebugString() << " and:" << call_node2->DebugString()
1310                      << " would call the same recursion in graph:" << func_graph
1311                      << " which has a recursion call:" << call_node->DebugString();
1312         return true;
1313       }
1314     }
1315   }
1316   return false;
1317 }
1318 
InsertDependForParallelCall(const std::vector<AnfNodePtr> & control_nodes)1319 void ControlNodeParser::InsertDependForParallelCall(const std::vector<AnfNodePtr> &control_nodes) {
1320   MS_LOG(INFO) << "InsertDependForParallelCall start";
1321   std::vector<AnfNodePtr> call_nodes;
1322   for (const auto &control_node : control_nodes) {
1323     MS_EXCEPTION_IF_NULL(control_node);
1324     if (!common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
1325       if (common::AnfAlgo::IsCallNode(control_node)) {
1326         // Fetch all the call nodes in the same graph.
1327         (void)call_nodes.emplace_back(control_node);
1328       }
1329       continue;
1330     }
1331 
1332     // Check whether there is a topology relationship between call nodes.
1333     for (size_t i = 0; i < call_nodes.size(); ++i) {
1334       for (size_t j = 0; j < i; ++j) {
1335         std::set<AnfNodePtr> checked_nodes;
1336         if ((!IsParallelCallRecursionGraph(call_nodes[i], call_nodes[j], func_graph_to_call_nodes_)) ||
1337             IsTopoDependNode(call_nodes[i], call_nodes[j], &checked_nodes)) {
1338           continue;
1339         }
1340         // If there is no topological relationship between call nodes, and the same recursive graph will be called
1341         // at the same time, then a depend node needs to be inserted between call nodes.
1342         auto func_graph = call_nodes[i]->func_graph();
1343         MS_EXCEPTION_IF_NULL(func_graph);
1344         auto cnode = call_nodes[i]->cast<CNodePtr>();
1345         MS_EXCEPTION_IF_NULL(cnode);
1346         const auto &inputs = cnode->inputs();
1347         MS_EXCEPTION_IF_NULL(inputs[0]);
1348 
1349         // Create a depend node.
1350         std::vector<AnfNodePtr> depend_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
1351                                                  cnode->input(0), call_nodes[j]};
1352         auto new_depend = func_graph->NewCNode(depend_inputs);
1353         MS_EXCEPTION_IF_NULL(new_depend);
1354         new_depend->set_abstract(cnode->input(0)->abstract());
1355 
1356         // Set depend node to call input.
1357         std::vector<AnfNodePtr> new_call_inputs{new_depend};
1358         for (size_t k = 1; k < inputs.size(); ++k) {
1359           (void)new_call_inputs.emplace_back(inputs[k]);
1360         }
1361         cnode->set_inputs(new_call_inputs);
1362         MS_LOG(INFO) << "Add depend node:" << new_depend->DebugString()
1363                      << " for call node:" << call_nodes[i]->DebugString() << " and:" << call_nodes[j]->DebugString();
1364       }
1365     }
1366     call_nodes.clear();
1367   }
1368   MS_LOG(INFO) << "InsertDependForParallelCall end";
1369 }
1370 
IsControlFlowDataArrow(const KernelGraphPtr & graph,const AnfNodePtr & backend_node)1371 bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) {
1372   MS_EXCEPTION_IF_NULL(graph);
1373   // Has no control flow node.
1374   if (!IsInited()) {
1375     return false;
1376   }
1377 
1378   MS_EXCEPTION_IF_NULL(backend_node);
1379   if (!backend_node->isa<Parameter>()) {
1380     return false;
1381   }
1382   auto parameter_node = backend_node->cast<ParameterPtr>();
1383   MS_EXCEPTION_IF_NULL(parameter_node);
1384 
1385   // Parameter input should be linked to its entrance actor.
1386   auto front_node = graph->GetFrontAnfByBackendAnf(backend_node);
1387   auto internal_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
1388   front_node = (front_node != nullptr ? front_node : internal_node_with_index.first);
1389   if (front_node == nullptr) {
1390     auto front_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(backend_node);
1391     front_node = front_node_with_index.first;
1392   }
1393   MS_EXCEPTION_IF_NULL(front_node);
1394   const auto &real_front_node = common::AnfAlgo::VisitKernelWithReturnType(front_node, 0).first;
1395   if (real_front_node != nullptr && real_front_node->isa<ValueNode>() && (!HasAbstractMonad(real_front_node))) {
1396     // If the real front node is a value node, we have two situations:
1397     // 1. if the value in value node is a tensor, it should be set into device tensor store by graph scheduler;
1398     // 2. if the value is a monad state, it should be converted to control arrow, which should link by control
1399     //    node scheduler.
1400     MS_LOG(DEBUG) << "Front node:" << real_front_node->DebugString()
1401                   << " of backend node:" << backend_node->DebugString() << " is a valuenode.";
1402     return false;
1403   }
1404 
1405   // If parameter is a weight node in root funcgraph, it should be set to kernel actor directly.
1406   if (IsRootGraphPersistentDeviceTensor(front_node)) {
1407     MS_LOG(DEBUG) << "backend node:" << backend_node->DebugString()
1408                   << " front node:" << (front_node == nullptr ? "null" : front_node->DebugString());
1409     return false;
1410   }
1411 
1412   // If the input front node and graph not in same graph group, the input arrow should be link to the exit actor
1413   // of the graph.
1414   if (!IsSameKernelGraphGroup(front_node, graph)) {
1415     return true;
1416   }
1417 
1418   // If the graph has a call input, all of its inputs in the graph should be linked to its stack actor.
1419   if (IsCallInputKernelGraph(graph.get())) {
1420     // If the input come from a kernel graph belong the same group, it should be linked by internal parameter.
1421     if (front_node != nullptr && (IsSameKernelGraphGroup(front_node, graph) || front_node->isa<ValueNode>())) {
1422       return false;
1423     }
1424     return true;
1425   }
1426 
1427   return (front_node != nullptr && front_node->isa<Parameter>());
1428 }
1429 
IsRootGraphPersistentDeviceTensor(const AnfNodePtr & node)1430 bool ControlNodeParser::IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node) {
1431   MS_EXCEPTION_IF_NULL(node);
1432   if (!IsPersistentDeviceTensor(node)) {
1433     return false;
1434   }
1435 
1436   // No control flow.
1437   if (!is_inited_) {
1438     return true;
1439   }
1440 
1441   // Maybe the load node, need fetch the real parameter node.
1442   auto real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
1443   MS_EXCEPTION_IF_NULL(real_node);
1444   return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), real_node) != root_graph_parameters_.end();
1445 }
1446 
IsNeedStackControlNode(const AnfNodePtr & node)1447 bool ControlNodeParser::IsNeedStackControlNode(const AnfNodePtr &node) {
1448   MS_EXCEPTION_IF_NULL(node);
1449   if (!(node->isa<CNode>())) {
1450     return false;
1451   }
1452 
1453   return need_stack_control_nodes_.find(node) != need_stack_control_nodes_.end();
1454 }
1455 
IsRecursionCallNode(const AnfNodePtr & node)1456 bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
1457   MS_EXCEPTION_IF_NULL(node);
1458   if (!common::AnfAlgo::IsCallNode(node)) {
1459     return false;
1460   }
1461   return unrecursion_call_nodes_.find(node) == unrecursion_call_nodes_.end();
1462 }
1463 
IsRecursionKernelGraph(const KernelGraphPtr & graph)1464 bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) {
1465   MS_EXCEPTION_IF_NULL(graph);
1466   auto group_info_iter = kernel_graphs_to_group_info_.find(graph);
1467   if (group_info_iter == kernel_graphs_to_group_info_.end()) {
1468     MS_LOG(EXCEPTION) << "Invalid kernel graph:" << graph->ToString();
1469   }
1470   MS_EXCEPTION_IF_NULL(group_info_iter->second);
1471   if (!group_info_iter->second->need_stack_) {
1472     return false;
1473   }
1474   for (const auto &front_input_node : group_info_iter->second->front_input_nodes_) {
1475     const auto &node = front_input_node.first.first;
1476     MS_EXCEPTION_IF_NULL(node);
1477     if (IsRecursionCallNode(node)) {
1478       return true;
1479     }
1480   }
1481   return false;
1482 }
1483 
IsSameKernelGraphGroup(const AnfNodePtr & node,const KernelGraphPtr & graph)1484 bool ControlNodeParser::IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph) {
1485   MS_EXCEPTION_IF_NULL(node);
1486   MS_EXCEPTION_IF_NULL(graph);
1487   if (!node->isa<CNode>()) {
1488     MS_LOG(DEBUG) << "Not a cnode:" << node->DebugString();
1489     return false;
1490   }
1491 
1492   const auto node_graph = FetchKernelGraphByFrontNode(node);
1493   if (node_graph == nullptr) {
1494     MS_LOG(DEBUG) << "Fail to get kernel graph for cnode:" << node->DebugString();
1495     return false;
1496   }
1497   MS_LOG(DEBUG) << "Get kernel graph:" << node_graph->ToString() << " for cnode:" << node->DebugString()
1498                 << " compare to graph:" << graph->ToString();
1499   const auto iter1 = kernel_graphs_to_group_info_.find(node_graph);
1500   const auto iter2 = kernel_graphs_to_group_info_.find(graph);
1501 
1502   return iter1 != kernel_graphs_to_group_info_.end() && iter2 != kernel_graphs_to_group_info_.end() &&
1503          iter1->second == iter2->second;
1504 }
1505 
ParseDeviceContext(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & kernel_graphs,const std::vector<DeviceContext * > & device_contexts,DeviceContext * default_context,const FuncGraphToKernelGraphGroup & func_graph_to_kernel_graphs)1506 void ControlNodeParser::ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
1507                                            const std::vector<KernelGraphPtr> &kernel_graphs,
1508                                            const std::vector<DeviceContext *> &device_contexts,
1509                                            DeviceContext *default_context,
1510                                            const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
1511   MS_EXCEPTION_IF_NULL(default_context);
1512   ParseDeviceContextForFuncGraph(kernel_graphs, device_contexts, default_context, func_graph_to_kernel_graphs);
1513   ParseDeviceContextForReturnNode(default_context);
1514   ParseDeviceContextForCallNode(control_nodes);
1515   ParseDeviceContextForPartialNode(control_nodes);
1516 }
1517 
ParseDeviceContextForFuncGraph(const std::vector<KernelGraphPtr> & kernel_graphs,const std::vector<DeviceContext * > & device_contexts,DeviceContext * default_context,const FuncGraphToKernelGraphGroup & func_graph_to_kernel_graphs)1518 void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<KernelGraphPtr> &kernel_graphs,
1519                                                        const std::vector<DeviceContext *> &device_contexts,
1520                                                        DeviceContext *default_context,
1521                                                        const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs) {
1522   MS_EXCEPTION_IF_NULL(default_context);
1523   if (device_contexts.size() != kernel_graphs.size()) {
1524     MS_LOG(EXCEPTION) << "Invalid device context size:" << device_contexts.size()
1525                       << " graph size:" << kernel_graphs.size();
1526   }
1527   mindspore::HashMap<KernelGraphPtr, DeviceContext *> kernel_graph_to_device_context;
1528   for (size_t i = 0; i < kernel_graphs.size(); ++i) {
1529     kernel_graph_to_device_context[kernel_graphs[i]] = device_contexts[i];
1530   }
1531 
1532   // Collect the device context type of the parameter in the kernel graph as the type of the real parameters.
1533   for (const auto &func_graph_to_kernel_graph : func_graph_to_kernel_graphs) {
1534     const auto &func_graph = func_graph_to_kernel_graph.first;
1535     MS_EXCEPTION_IF_NULL(func_graph);
1536     std::vector<KernelWithIndex> front_parameters;
1537     for (const auto &parameter : func_graph->parameters()) {
1538       const auto &abstract = parameter->abstract();
1539       MS_EXCEPTION_IF_NULL(abstract);
1540       for (size_t i = 0; i < common::AnfAlgo::GetOutputNumByAbstract(abstract); ++i) {
1541         (void)front_parameters.emplace_back(parameter, i);
1542       }
1543     }
1544     std::vector<const DeviceContext *> parameter_device_contexts(front_parameters.size(), default_context);
1545     std::map<KernelWithIndex, DeviceContext *> front_parameter_to_device_context;
1546 
1547     for (const auto &kernel_graph_group : func_graph_to_kernel_graph.second) {
1548       for (const auto &kernel_graph : kernel_graph_group) {
1549         MS_EXCEPTION_IF_NULL(kernel_graph);
1550         const auto &backend_parameters = kernel_graph->parameters();
1551 
1552         for (const auto &backend_parameter : backend_parameters) {
1553           auto front_parameter = KernelWithIndex(kernel_graph->GetFrontAnfByBackendAnf(backend_parameter), 0);
1554           if (front_parameter.first == nullptr) {
1555             front_parameter = kernel_graph->GetElementInTupleBackendFrontIndexMap(backend_parameter);
1556           }
1557           if (front_parameter.first != nullptr && front_parameter.first->isa<Parameter>()) {
1558             front_parameter_to_device_context[front_parameter] = kernel_graph_to_device_context[kernel_graph];
1559           }
1560         }
1561       }
1562     }
1563 
1564     for (size_t i = 0; i < front_parameters.size(); ++i) {
1565       const auto &front_parameter = front_parameters[i];
1566       const auto &iter = front_parameter_to_device_context.find(front_parameter);
1567       if (iter != front_parameter_to_device_context.end()) {
1568         parameter_device_contexts[i] = iter->second;
1569       }
1570     }
1571     func_graph_to_device_contexts_[func_graph] = parameter_device_contexts;
1572   }
1573 
1574   // If there is no kernel in funcgraph, the parameter uses the default device context type.
1575   MS_EXCEPTION_IF_NULL(root_func_graph_);
1576   MS_EXCEPTION_IF_NULL(root_func_graph_->manager());
1577   FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
1578   for (auto sub_graph : sub_graphs) {
1579     MS_EXCEPTION_IF_NULL(sub_graph);
1580     if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
1581       size_t output_num = 0;
1582       for (const auto &parameter : sub_graph->parameters()) {
1583         const auto &abstract = parameter->abstract();
1584         MS_EXCEPTION_IF_NULL(abstract);
1585         output_num += common::AnfAlgo::GetOutputNumByAbstract(abstract);
1586       }
1587       func_graph_to_device_contexts_[sub_graph] = std::vector<const DeviceContext *>(output_num, default_context);
1588     }
1589   }
1590 }
1591 
ParseDeviceContextForPartialNode(const std::vector<AnfNodePtr> & control_nodes)1592 void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNodePtr> &control_nodes) {
1593   for (const auto &control_node : control_nodes) {
1594     if (!common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
1595       continue;
1596     }
1597 
1598     MS_EXCEPTION_IF_NULL(control_node);
1599     const auto &cnode = control_node->cast<CNodePtr>();
1600     MS_EXCEPTION_IF_NULL(cnode);
1601     const auto &inputs = cnode->inputs();
1602     if (inputs.size() <= kPartialFuncGraphPos) {
1603       MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid input size for partial node:" << cnode->DebugString();
1604     }
1605     auto &func_node = inputs[kPartialFuncGraphPos];
1606     // Ignore if the node is 'Partial(DeadNode,)'.
1607     if (IsDeadNode(func_node)) {
1608       MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
1609       continue;
1610     }
1611     // Fetch the funcgraph in partial node.
1612     const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
1613     if (func_graph == nullptr) {
1614       MS_LOG_WITH_NODE(EXCEPTION, func_node)
1615         << "Invalid funcgraph node:" << func_node->DebugString() << " for partial node:" << cnode->DebugString();
1616     }
1617 
1618     // Fetch the device contexts for the formal parameters in the funcgraph of partial node.
1619     auto iter = func_graph_to_device_contexts_.find(func_graph);
1620     if (iter == func_graph_to_device_contexts_.end()) {
1621       MS_LOG(EXCEPTION) << "Failed to get device contexts for funcgraph:" << func_graph->ToString();
1622     }
1623 
1624     size_t input_num = 0;
1625     for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
1626       MS_EXCEPTION_IF_NULL(inputs[i]);
1627       const auto &abstract = inputs[i]->abstract();
1628       MS_EXCEPTION_IF_NULL(abstract);
1629       input_num += common::AnfAlgo::GetOutputNumByAbstract(abstract);
1630     }
1631     if (input_num > iter->second.size()) {
1632       MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid input num:" << input_num
1633                                          << " for funcgraph:" << func_graph->ToString()
1634                                          << " device context size:" << iter->second.size()
1635                                          << " for partial node:" << cnode->DebugString();
1636     }
1637 
1638     // Get the device contexts for the real parameters.
1639     std::vector<const DeviceContext *> device_contexts;
1640     // In partial node, the first input is always a partial, maybe a funcgraph or a partial node, so we need
1641     // to insert an empty device context for it.
1642     (void)device_contexts.emplace_back(nullptr);
1643     for (size_t i = 0; i < input_num; ++i) {
1644       MS_EXCEPTION_IF_NULL(iter->second[i]);
1645       (void)device_contexts.emplace_back(iter->second[i]);
1646     }
1647     control_node_to_device_contexts_[control_node] = device_contexts;
1648   }
1649 }
1650 
CollectDeviceContextByDynamicLen(const CNodePtr & cnode,const FuncGraphPtr & func_graph,const std::vector<const DeviceContext * > & parameter_contexts,std::vector<const DeviceContext * > * arg_context)1651 void CollectDeviceContextByDynamicLen(const CNodePtr &cnode, const FuncGraphPtr &func_graph,
1652                                       const std::vector<const DeviceContext *> &parameter_contexts,
1653                                       std::vector<const DeviceContext *> *arg_context) {
1654   MS_EXCEPTION_IF_NULL(cnode);
1655   MS_EXCEPTION_IF_NULL(func_graph);
1656   MS_EXCEPTION_IF_NULL(arg_context);
1657   size_t para_num = func_graph->parameters().size();
1658   size_t arg_num = cnode->size() - 1;
1659   if (arg_num > para_num) {
1660     MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid arg size:" << arg_num << " parameter size:" << para_num
1661                                        << "for call node:" << cnode->DebugString()
1662                                        << " funcgraph:" << func_graph->ToString();
1663   }
1664   if (para_num != parameter_contexts.size()) {
1665     MS_LOG(EXCEPTION) << "Invalid parameter context size:" << parameter_contexts.size()
1666                       << " parameter size:" << para_num;
1667   }
1668   for (size_t i = para_num - arg_num; i < para_num; ++i) {
1669     size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(cnode->input(i + 1)->abstract());
1670     for (size_t j = 0; j < output_num; ++j) {
1671       arg_context->emplace_back(parameter_contexts[0]);
1672     }
1673   }
1674 }
1675 
ParseDeviceContextForCallNode(const std::vector<AnfNodePtr> & control_nodes)1676 void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodePtr> &control_nodes) {
1677   for (const auto &control_node : control_nodes) {
1678     MS_EXCEPTION_IF_NULL(control_node);
1679     if (!common::AnfAlgo::IsCallNode(control_node)) {
1680       continue;
1681     }
1682 
1683     // Fetch the device contexts of the funcgraph the node called.
1684     const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
1685     if (func_graphs.empty()) {
1686       MS_LOG_WITH_NODE(EXCEPTION, control_node)
1687         << "Failed to get funcgraph by call node:" << control_node->DebugString();
1688     }
1689     const auto &func_graph = *(func_graphs.begin());
1690     MS_EXCEPTION_IF_NULL(func_graph);
1691     auto iter = func_graph_to_device_contexts_.find(func_graph);
1692     if (iter == func_graph_to_device_contexts_.end()) {
1693       MS_LOG(EXCEPTION) << "Failed to get device contexts for funcgraph:" << func_graph->ToString();
1694     }
1695 
1696     std::vector<const DeviceContext *> device_contexts;
1697     // In call node, the first input is always a partial, maybe a funcgraph or a partial node, so we need
1698     // to insert an empty device context for it.
1699     (void)device_contexts.emplace_back(nullptr);
1700     const auto &cnode = control_node->cast<CNodePtr>();
1701     MS_EXCEPTION_IF_NULL(cnode);
1702     const auto &inputs = cnode->inputs();
1703     size_t call_input_num = 0;
1704     for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1705       MS_EXCEPTION_IF_NULL(inputs[i]);
1706       const auto &abstract = inputs[i]->abstract();
1707       MS_EXCEPTION_IF_NULL(abstract);
1708       call_input_num += common::AnfAlgo::GetOutputNumByAbstract(abstract);
1709     }
1710 
1711     if (call_input_num > iter->second.size()) {
1712       MS_LOG(INFO) << "Call input size:" << call_input_num << " context size:" << iter->second.size() << "for funcgraph"
1713                    << func_graph->ToString() << " for call node:" << cnode->DebugString();
1714       CollectDeviceContextByDynamicLen(cnode, func_graph, iter->second, &device_contexts);
1715       control_node_to_device_contexts_[control_node] = device_contexts;
1716       continue;
1717     }
1718 
1719     // Fetch the device contexts for the real parameters on the call node.
1720     for (size_t i = iter->second.size() - call_input_num; i < iter->second.size(); ++i) {
1721       MS_EXCEPTION_IF_NULL(iter->second[i]);
1722       (void)device_contexts.emplace_back(iter->second[i]);
1723     }
1724     control_node_to_device_contexts_[control_node] = device_contexts;
1725   }
1726 }
1727 
FetchDeviceContextByNode(const std::vector<KernelWithIndex> & output_nodes,std::vector<const DeviceContext * > * return_device_contexts,const FuncGraphPtr & func_graph,const DeviceContext * default_context)1728 void ControlNodeParser::FetchDeviceContextByNode(const std::vector<KernelWithIndex> &output_nodes,
1729                                                  std::vector<const DeviceContext *> *return_device_contexts,
1730                                                  const FuncGraphPtr &func_graph, const DeviceContext *default_context) {
1731   MS_EXCEPTION_IF_NULL(return_device_contexts);
1732   for (const auto &output_node : output_nodes) {
1733     MS_EXCEPTION_IF_NULL(output_node.first);
1734     if (output_node.first->isa<Parameter>()) {
1735       // If the output is parameter, get the device context type from the formal parameter.
1736       const auto &iter = find(func_graph->parameters().begin(), func_graph->parameters().end(), output_node.first);
1737       if (iter == func_graph->parameters().end()) {
1738         MS_LOG_WITH_NODE(EXCEPTION, output_node.first)
1739           << "Invalid parameter:" << output_node.first->DebugString() << " for func_graph:" << func_graph->ToString();
1740       }
1741       const auto &func_graph_iter = func_graph_to_device_contexts_.find(func_graph);
1742       if (func_graph_iter == func_graph_to_device_contexts_.end()) {
1743         MS_LOG(EXCEPTION) << "Cannot find device context for funcgraph:" << func_graph->ToString();
1744       }
1745       size_t index = LongToSize(iter - func_graph->parameters().begin());
1746       MS_EXCEPTION_IF_NULL(func_graph_iter->second[index]);
1747       (void)return_device_contexts->emplace_back(func_graph_iter->second[index]);
1748     } else if (output_node.first->isa<ValueNode>()) {
1749       // If the output is parameter, used the default context type.
1750       (void)return_device_contexts->emplace_back(default_context);
1751     } else if (common::AnfAlgo::IsCallNode(output_node.first)) {
1752       // If the output is call node, get the device context type by the output of funcgraph.
1753       const auto &func_graphs = call_node_to_func_graphs_[output_node.first];
1754       std::vector<const DeviceContext *> call_device_contexts;
1755       for (const auto &graph : func_graphs) {
1756         MS_EXCEPTION_IF_NULL(graph);
1757         const auto &node = graph->return_node();
1758         MS_EXCEPTION_IF_NULL(node);
1759         const auto &iter = control_node_to_device_contexts_.find(node);
1760         if (iter != control_node_to_device_contexts_.end()) {
1761           call_device_contexts = iter->second;
1762           break;
1763         }
1764       }
1765       // Since funcgraph has been topo-sorted according to the calling relationship, when there is a call node in
1766       // the output, the output type of the funcgraph called by it should have been determined, if not, an exception
1767       // will be thrown.
1768       if (call_device_contexts.empty() || call_device_contexts.size() <= output_node.second) {
1769         MS_LOG(DEBUG) << "Cannot find device context for call node:" << output_node.first->DebugString()
1770                       << " device contexts size:" << call_device_contexts.size() << " index:" << output_node.second;
1771         (void)return_device_contexts->emplace_back(default_context);
1772       } else {
1773         MS_EXCEPTION_IF_NULL(call_device_contexts[output_node.second]);
1774         (void)return_device_contexts->emplace_back(call_device_contexts[output_node.second]);
1775       }
1776     } else if (common::AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial) ||
1777                common::AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimSwitch)) {
1778       (void)return_device_contexts->emplace_back(default_context);
1779     } else if (output_node.first->isa<CNode>()) {
1780       // If the output is a cnode, get the device context type by the kernel.
1781       const auto &iter = front_to_backend_kernels_.find(output_node);
1782       if (iter == front_to_backend_kernels_.end()) {
1783         MS_LOG(DEBUG) << "Cannot find backend kernel for cnode:" << output_node.first->DebugString();
1784         (void)return_device_contexts->emplace_back(default_context);
1785         continue;
1786       }
1787       MS_EXCEPTION_IF_NULL(iter->second.second);
1788       (void)return_device_contexts->emplace_back(iter->second.second);
1789     } else {
1790       MS_LOG_WITH_NODE(EXCEPTION, output_node.first) << "Invalid node for return:" << output_node.first->DebugString();
1791     }
1792   }
1793 }
1794 
ParseDeviceContextForReturnNode(const DeviceContext * default_context)1795 void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *default_context) {
1796   MS_EXCEPTION_IF_NULL(default_context);
1797   // Collect the call realationship between funcgraphs.
1798   FuncGraphCallRelation func_graph_call_relation;
1799   for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
1800     const auto &call_node = call_node_to_func_graphs.first;
1801     MS_EXCEPTION_IF_NULL(call_node);
1802     const auto &func_graph = call_node->func_graph();
1803     MS_EXCEPTION_IF_NULL(func_graph);
1804     (void)func_graph_call_relation[func_graph].emplace_back(call_node_to_func_graphs.second);
1805   }
1806 
1807   // Topologically sort all funcgraphs according to the function call relationship.
1808   const auto &topo_sort_func_graphs = TopoSortForFuncGraph(root_func_graph_, &func_graph_call_relation);
1809 
1810   // Deduces the device context type of funcgraph outputs according to the topological order.
1811   for (const auto &func_graph : topo_sort_func_graphs) {
1812     MS_EXCEPTION_IF_NULL(func_graph);
1813     const auto &return_node = func_graph->return_node();
1814     MS_EXCEPTION_IF_NULL(return_node);
1815     const auto &cnode = return_node->cast<CNodePtr>();
1816     MS_EXCEPTION_IF_NULL(cnode);
1817     const auto &inputs = cnode->inputs();
1818     if (inputs.size() <= kReturnInputPos) {
1819       MS_LOG_WITH_NODE(EXCEPTION, cnode) << "Invalid return node:" << cnode->DebugString();
1820     }
1821     const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
1822     std::vector<const DeviceContext *> return_device_contexts;
1823 
1824     FetchDeviceContextByNode(output_nodes, &return_device_contexts, func_graph, default_context);
1825     control_node_to_device_contexts_[return_node] = return_device_contexts;
1826   }
1827 }
1828 
ParseFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> & graphs)1829 void ControlNodeParser::ParseFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
1830   for (const auto &graph : graphs) {
1831     MS_EXCEPTION_IF_NULL(graph);
1832     if (graph->execution_order().empty()) {
1833       continue;
1834     }
1835     const auto &front_to_backend_nodes = graph->front_backend_anf_map();
1836     for (const auto &front_to_backend_node : front_to_backend_nodes) {
1837       MS_LOG(DEBUG) << "Add front node:" << front_to_backend_node.first->DebugString()
1838                     << " for kernel graph:" << graph->ToString();
1839       front_node_to_kernel_graph_[front_to_backend_node.first] = graph;
1840     }
1841   }
1842 }
1843 
FetchBranchIDByCallNode(const AnfNodePtr & call_node)1844 int ControlNodeParser::FetchBranchIDByCallNode(const AnfNodePtr &call_node) {
1845   MS_EXCEPTION_IF_NULL(call_node);
1846 
1847   if (call_node_to_branch_id_.find(call_node) == call_node_to_branch_id_.end()) {
1848     MS_LOG_WITH_NODE(EXCEPTION, call_node) << "Invalid branch id for call_node:" << call_node->DebugString();
1849   }
1850   return call_node_to_branch_id_[call_node];
1851 }
1852 
FetchKernelGraphByFrontNode(const AnfNodePtr & kernel)1853 KernelGraphPtr ControlNodeParser::FetchKernelGraphByFrontNode(const AnfNodePtr &kernel) {
1854   const auto &iter = front_node_to_kernel_graph_.find(kernel);
1855   if (iter == front_node_to_kernel_graph_.end()) {
1856     return nullptr;
1857   }
1858   return iter->second;
1859 }
1860 
IsCallInputKernelGraph(KernelGraph * const graph)1861 bool ControlNodeParser::IsCallInputKernelGraph(KernelGraph *const graph) {
1862   if (call_input_kernel_graphs_.find(graph) == call_input_kernel_graphs_.end()) {
1863     return false;
1864   }
1865   return true;
1866 }
1867 
IsCallInputKernelGraphGroup(const std::string & group_name)1868 bool ControlNodeParser::IsCallInputKernelGraphGroup(const std::string &group_name) {
1869   for (const auto &graph_group : kernel_graph_group_infos_) {
1870     MS_EXCEPTION_IF_NULL(graph_group);
1871     if (group_name.find(graph_group->group_name_) != std ::string::npos) {
1872       return graph_group->need_stack_;
1873     }
1874   }
1875   MS_LOG(EXCEPTION) << "Invalid kernel graph group name:" << group_name;
1876 }
1877 
FetchBackendNodeByFrontNode(const KernelWithIndex & node_with_index)1878 KernelWithIndex ControlNodeParser::FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index) {
1879   const auto &iter = front_to_backend_kernels_.find(node_with_index);
1880   if (iter != front_to_backend_kernels_.end()) {
1881     return iter->second.first;
1882   }
1883   return {};
1884 }
1885 
FetchFuncGraphByKernelGraph(const KernelGraph * const graph)1886 FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *const graph) {
1887   for (const auto &func_graph_to_kernel_graphs : func_graph_to_kernel_graph_groups_) {
1888     const auto &kernel_graph_groups = func_graph_to_kernel_graphs.second;
1889     if (std::any_of(kernel_graph_groups.begin(), kernel_graph_groups.end(), [graph](const auto &kernel_graph_group) {
1890           return std::any_of(kernel_graph_group.begin(), kernel_graph_group.end(),
1891                              [graph](const auto &kernel_graph) { return kernel_graph.get() == graph; });
1892         })) {
1893       return func_graph_to_kernel_graphs.first;
1894     }
1895   }
1896   return nullptr;
1897 }
1898 
FetchBackendParameterWithContextByFrontParameter(const KernelWithIndex & front_parameter_with_index)1899 NodeWithIndexToContext ControlNodeParser::FetchBackendParameterWithContextByFrontParameter(
1900   const KernelWithIndex &front_parameter_with_index) {
1901   MS_EXCEPTION_IF_NULL(front_parameter_with_index.first);
1902   const auto &iter = front_to_backend_parameters_.find(front_parameter_with_index);
1903   if (iter == front_to_backend_parameters_.end()) {
1904     return {};
1905   }
1906 
1907   for (const auto &node_with_index_to_context : iter->second) {
1908     const auto &node = node_with_index_to_context.first.first;
1909     MS_EXCEPTION_IF_NULL(node);
1910     const auto &abstract =
1911       AnfAlgo::GetNodeAbstractByIndex(front_parameter_with_index.first, front_parameter_with_index.second);
1912     bool is_map_parameter = abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>();
1913     if (AnfAlgo::GetOutputTensorMemSize(node, node_with_index_to_context.first.second) != 0 || is_map_parameter) {
1914       return node_with_index_to_context;
1915     }
1916     MS_LOG(DEBUG) << "Backend node:" << node->DebugString()
1917                   << " for front node:" << front_parameter_with_index.first->DebugString()
1918                   << " index:" << front_parameter_with_index.second << " output size is 0.";
1919   }
1920   return {};
1921 }
1922 
CreateDeviceTensors(const std::vector<AnfNodePtr> & control_nodes,const DeviceContext * const default_context)1923 void ControlNodeParser::CreateDeviceTensors(const std::vector<AnfNodePtr> &control_nodes,
1924                                             const DeviceContext *const default_context) {
1925   for (const auto &control_node : control_nodes) {
1926     MS_EXCEPTION_IF_NULL(control_node);
1927     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
1928         common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
1929       auto input_with_indexs = FetchInputNodeByCNode(control_node);
1930       for (size_t i = 0; i < input_with_indexs.size(); ++i) {
1931         MS_EXCEPTION_IF_NULL(input_with_indexs[i].first);
1932         if (IsFrontValueNode(input_with_indexs[i])) {
1933           CreateDeviceTensorForFrontNode(input_with_indexs[i], default_context);
1934           (void)front_value_nodes_.emplace(input_with_indexs[i], default_context);
1935         }
1936       }
1937       continue;
1938     }
1939 
1940     if ((!common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) &&
1941         (!common::AnfAlgo::IsCallNode(control_node))) {
1942       continue;
1943     }
1944 
1945     auto input_with_indexs = FetchInputNodeByCNode(control_node);
1946     auto iter = control_node_to_device_contexts_.find(control_node);
1947     if (iter == control_node_to_device_contexts_.end() || iter->second.size() < input_with_indexs.size()) {
1948       MS_LOG_WITH_NODE(EXCEPTION, control_node)
1949         << "Invalid device context for control node:" << control_node->DebugString()
1950         << " need:" << input_with_indexs.size() << " current:"
1951         << (iter == control_node_to_device_contexts_.end() ? "null" : std::to_string(iter->second.size()));
1952     }
1953     for (size_t i = 0; i < input_with_indexs.size(); ++i) {
1954       const auto &input_with_index = input_with_indexs[i];
1955       if (IsFrontValueNode(input_with_index) &&
1956           front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
1957         MS_EXCEPTION_IF_NULL(input_with_index.first);
1958         MS_LOG(DEBUG) << "Create device tensor for value node:" << input_with_index.first->DebugString()
1959                       << " index:" << i << " in control node:" << control_node->DebugString();
1960         const auto &node_with_index_with_context = FetchBackendParameterWithContextByFrontParameter(input_with_index);
1961         const auto &backend_node = node_with_index_with_context.first.first;
1962         if (IsValidBackendParameter(backend_node)) {
1963           CreateDeviceTensorForValueNode(input_with_index, backend_node, node_with_index_with_context.second);
1964           (void)front_value_nodes_.emplace(input_with_index, node_with_index_with_context.second);
1965         } else {
1966           CreateDeviceTensorForFrontNode(input_with_index, default_context);
1967           (void)front_value_nodes_.emplace(input_with_index, default_context);
1968         }
1969       }
1970     }
1971   }
1972 }
1973 
FetchFrontValueNode(const std::vector<AnfNodePtr> & control_nodes,const DeviceContext * const default_context)1974 void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
1975                                             const DeviceContext *const default_context) {
1976   MS_EXCEPTION_IF_NULL(default_context);
1977 
1978   for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
1979     for (const auto &real_parameter_with_index : formal_to_real_parameter.second) {
1980       if (!IsFrontValueNode(real_parameter_with_index)) {
1981         continue;
1982       }
1983 
1984       const auto &node_with_index_to_context =
1985         FetchBackendParameterWithContextByFrontParameter(real_parameter_with_index);
1986       const auto &backend_node = node_with_index_to_context.first.first;
1987       if (IsValidBackendParameter(backend_node)) {
1988         (void)front_value_nodes_.emplace(real_parameter_with_index, node_with_index_to_context.second);
1989         CreateDeviceTensorForValueNode(real_parameter_with_index, backend_node, node_with_index_to_context.second);
1990       } else {
1991         (void)front_value_nodes_.emplace(real_parameter_with_index, default_context);
1992         CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context);
1993       }
1994     }
1995   }
1996 
1997   // Create device tensors for those value nodes which direct return by a return node.
1998   CreateDeviceTensors(control_nodes, default_context);
1999   for (const auto &front_node : front_value_nodes_) {
2000     MS_EXCEPTION_IF_NULL(front_node.first.first);
2001     MS_LOG(DEBUG) << "Print front value node:" << front_node.first.first->DebugString()
2002                   << " addr:" << front_node.first.first << " index:" << front_node.first.second;
2003   }
2004 }
2005 
ParseFormalToRealParameter(const std::vector<AnfNodePtr> & control_nodes)2006 void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) {
2007   FormalToRealParameter formal_to_real_parameters;
2008 
2009   // The actual parameters of the function are divided into two parts:
2010   // 1. Input of partial node.
2011   // 2. Input of call node.
2012   for (const auto &node : control_nodes) {
2013     MS_EXCEPTION_IF_NULL(node);
2014     if (common::AnfAlgo::IsCallNode(node)) {
2015       const auto &cnode = node->cast<CNodePtr>();
2016       MS_EXCEPTION_IF_NULL(cnode);
2017       const auto &inputs = cnode->inputs();
2018       const auto &func_graphs = FetchFuncGraphbyCallNode(node);
2019       for (const auto &func_graph : func_graphs) {
2020         MS_EXCEPTION_IF_NULL(func_graph);
2021         const auto &parameters = func_graph->parameters();
2022         for (int i = SizeToInt(inputs.size()) - 1, j = SizeToInt(parameters.size()) - 1; i >= 1 && j >= 0; --i, --j) {
2023           MS_EXCEPTION_IF_NULL(inputs[IntToSize(i)]);
2024           MS_EXCEPTION_IF_NULL(parameters[IntToSize(j)]);
2025           AddFormalToRealParameter(parameters[IntToSize(j)], inputs[IntToSize(i)], call_node_to_func_graphs_,
2026                                    &formal_to_real_parameters);
2027         }
2028       }
2029     } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
2030       const auto &cnode = node->cast<CNodePtr>();
2031       MS_EXCEPTION_IF_NULL(cnode);
2032       const auto &inputs = cnode->inputs();
2033       if (inputs.size() <= kPartialFuncGraphPos) {
2034         MS_LOG_WITH_NODE(EXCEPTION, node) << "Invalid input size for partial node:" << node->DebugString();
2035       }
2036       auto &func_node = inputs[kPartialFuncGraphPos];
2037       MS_EXCEPTION_IF_NULL(func_node);
2038       // Ignore if the node is 'Partial(DeadNode,)'.
2039       if (IsDeadNode(func_node)) {
2040         MS_LOG(DEBUG) << "Ignore partial dead node:" << node->DebugString();
2041         continue;
2042       }
2043       const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
2044       if (func_graph == nullptr) {
2045         MS_LOG_WITH_NODE(EXCEPTION, node)
2046           << "Invalid funcgraph node:" << func_node->DebugString() << " for partial node:" << node->DebugString();
2047       }
2048       const auto &parameters = func_graph->parameters();
2049       if (inputs.size() - kPartialInputStartPos > parameters.size()) {
2050         MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size()
2051                           << " formal parameter size:" << parameters.size();
2052       }
2053       for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
2054         MS_EXCEPTION_IF_NULL(inputs[i]);
2055         MS_EXCEPTION_IF_NULL(parameters[i - kPartialInputStartPos]);
2056         AddFormalToRealParameter(parameters[i - kPartialInputStartPos], inputs[i], call_node_to_func_graphs_,
2057                                  &formal_to_real_parameters);
2058       }
2059     }
2060   }
2061 
2062   // When the real parameter is also a parameter, the corresponding actual parameter needs to be obtained recursively.
2063   for (const auto &formal_to_real_parameter : formal_to_real_parameters) {
2064     const auto &formal_parameter = formal_to_real_parameter.first;
2065     const auto &real_parameters = formal_to_real_parameter.second;
2066     std::set<KernelWithIndex> total_real_parameters = real_parameters;
2067     for (const auto &real_parameter : real_parameters) {
2068       MS_EXCEPTION_IF_NULL(real_parameter.first);
2069       if (real_parameter.first->isa<Parameter>()) {
2070         std::set<KernelWithIndex> invalid_real_parameter{formal_parameter};
2071         ParseAllRealParameterByFormalParameter(real_parameter, formal_to_real_parameters, &total_real_parameters,
2072                                                &invalid_real_parameter);
2073         (void)real_to_formal_parameters_[real_parameter].emplace(formal_parameter);
2074       } else {
2075         (void)total_real_parameters.emplace(real_parameter);
2076       }
2077     }
2078     std::swap(formal_to_real_parameters_[formal_parameter], total_real_parameters);
2079   }
2080 
2081   for (const auto &formal_to_real : formal_to_real_parameters_) {
2082     for (const auto &real_parameter : formal_to_real.second) {
2083       MS_EXCEPTION_IF_NULL(formal_to_real.first.first);
2084       MS_EXCEPTION_IF_NULL(real_parameter.first);
2085       MS_LOG(DEBUG) << "Print formal to real node, formal:" << formal_to_real.first.first->DebugString()
2086                     << " real:" << real_parameter.first->DebugString() << " index:" << real_parameter.second;
2087     }
2088   }
2089 }
2090 
ParseAllRealParameterByFormalParameter(const KernelWithIndex & formal_parameter,const FormalToRealParameter & formal_to_real_parameters,std::set<KernelWithIndex> * const total_real_parameters,std::set<KernelWithIndex> * invalid_real_parameter)2091 void ControlNodeParser::ParseAllRealParameterByFormalParameter(const KernelWithIndex &formal_parameter,
2092                                                                const FormalToRealParameter &formal_to_real_parameters,
2093                                                                std::set<KernelWithIndex> *const total_real_parameters,
2094                                                                std::set<KernelWithIndex> *invalid_real_parameter) {
2095   MS_EXCEPTION_IF_NULL(formal_parameter.first);
2096   MS_EXCEPTION_IF_NULL(total_real_parameters);
2097   MS_EXCEPTION_IF_NULL(invalid_real_parameter);
2098   if (invalid_real_parameter->find(formal_parameter) != invalid_real_parameter->end()) {
2099     return;
2100   }
2101   (void)invalid_real_parameter->emplace(formal_parameter);
2102 
2103   // Get all the actual parameters corresponding to parameter recursively.
2104   const auto &dst_iter = formal_to_real_parameters_.find(formal_parameter);
2105   if (dst_iter != formal_to_real_parameters_.end()) {
2106     total_real_parameters->insert(dst_iter->second.begin(), dst_iter->second.end());
2107     return;
2108   }
2109   const auto &src_iter = formal_to_real_parameters.find(formal_parameter);
2110   if (src_iter == formal_to_real_parameters.end()) {
2111     const auto &func_graph = formal_parameter.first->func_graph();
2112     MS_EXCEPTION_IF_NULL(func_graph);
2113     if (func_graph == root_func_graph_) {
2114       return;
2115     }
2116     MS_LOG(DEBUG) << "Invalid formal parameter:" << formal_parameter.first->DebugString()
2117                   << ", maybe there is no call node for funcgraph:"
2118                   << (formal_parameter.first->func_graph() == nullptr
2119                         ? "null"
2120                         : formal_parameter.first->func_graph()->ToString());
2121     return;
2122   }
2123   const auto &real_parameters = src_iter->second;
2124   for (const auto &real_parameter : real_parameters) {
2125     MS_EXCEPTION_IF_NULL(real_parameter.first);
2126     (void)total_real_parameters->emplace(real_parameter);
2127     if (real_parameter.first->isa<Parameter>()) {
2128       ParseAllRealParameterByFormalParameter(real_parameter, formal_to_real_parameters, total_real_parameters,
2129                                              invalid_real_parameter);
2130     }
2131   }
2132 }
2133 
ParseControlNodeParameter(const std::vector<AnfNodePtr> & control_nodes)2134 void ControlNodeParser::ParseControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes) {
2135   for (const auto &control_node : control_nodes) {
2136     MS_EXCEPTION_IF_NULL(control_node);
2137     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
2138       break;
2139     }
2140 
2141     const auto &inputs = FetchInputNodeByCNode(control_node);
2142     for (size_t i = 0; i < inputs.size(); ++i) {
2143       MS_EXCEPTION_IF_NULL(inputs[i].first);
2144       MS_LOG(DEBUG) << "Control node:" << control_node->DebugString()
2145                     << " input node:" << inputs[i].first->DebugString() << " index:" << inputs[i].second;
2146       if (inputs[i].first->isa<Parameter>()) {
2147         MS_LOG(DEBUG) << "Control node:" << control_node->DebugString()
2148                       << " input parameter:" << inputs[i].first->DebugString() << " index:" << inputs[i].second;
2149         (void)control_node_parameters_.emplace_back(inputs[i]);
2150         // Set Dynamic shape flag for parameter.
2151         const auto &parameter = inputs[i].first->cast<ParameterPtr>();
2152         MS_EXCEPTION_IF_NULL(parameter);
2153         const auto &base_shape = parameter->Shape();
2154         if (base_shape == nullptr) {
2155           continue;
2156         }
2157         if ((base_shape->isa<abstract::Shape>() && base_shape->IsDynamic()) ||
2158             base_shape->isa<abstract::DynamicSequenceShape>()) {
2159           MS_LOG(INFO) << "Set dynamic shape flag to parameter:" << parameter->DebugString();
2160           parameter->set_has_dynamic_shape(true);
2161         }
2162       }
2163     }
2164   }
2165 }
2166 
CreateBranchIDForCallNode(const std::vector<AnfNodePtr> & control_nodes)2167 void ControlNodeParser::CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes) {
2168   int branch_id = kMainBranchID;
2169 
2170   for (const auto &control_node : control_nodes) {
2171     // Root funcgraph does not need to create a gather actor.
2172     if (common::AnfAlgo::IsCallNode(control_node)) {
2173       call_node_to_branch_id_[control_node] = ++branch_id;
2174       MS_LOG(DEBUG) << "control node:" << control_node->DebugString()
2175                     << " branch id:" << call_node_to_branch_id_[control_node];
2176     }
2177   }
2178 }
2179 
ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)2180 void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
2181                                                      const std::vector<DeviceContext *> &device_contexts) {
2182   if (graphs.size() != device_contexts.size()) {
2183     MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
2184   }
2185 
2186   // Fetch the mapping relationship between front parameters and backend parameters in the kernel graphs.
2187   for (size_t i = 0; i < graphs.size(); ++i) {
2188     const auto &graph = graphs[i];
2189     auto device_context = device_contexts[i];
2190     MS_EXCEPTION_IF_NULL(graph);
2191     MS_EXCEPTION_IF_NULL(device_context);
2192     for (const auto &parameter : graph->input_nodes()) {
2193       MS_EXCEPTION_IF_NULL(parameter);
2194       const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter);
2195       const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
2196       const auto &front_tuple_parameter_with_index = graph->GetElementInTupleBackendFrontIndexMap(parameter);
2197       if (front_node == nullptr && front_node_with_index.first == nullptr &&
2198           front_tuple_parameter_with_index.first == nullptr) {
2199         MS_LOG_WITH_NODE(EXCEPTION, parameter)
2200           << "Invalid backend parameter:" << parameter->DebugString() << " for kernel graph:" << graph->ToString();
2201       }
2202 
2203       if (front_node_with_index.first != nullptr) {
2204         std::set<KernelWithIndex> real_parameters;
2205         std::set<KernelWithIndex> invalid_call_nodes;
2206         FetchRealParameterByNode(front_node_with_index, &real_parameters, &invalid_call_nodes,
2207                                  call_node_to_func_graphs_);
2208         for (const auto &real_parameter : real_parameters) {
2209           MS_EXCEPTION_IF_NULL(real_parameter.first);
2210           if (real_parameter.first->isa<Parameter>() || real_parameter.first->isa<ValueNode>()) {
2211             (void)front_to_backend_parameters_[real_parameter].emplace(KernelWithIndex(parameter, 0), device_context);
2212             MS_LOG(DEBUG) << "Add front node:" << real_parameter.first->DebugString()
2213                           << " index:" << real_parameter.second
2214                           << " for backend parameter:" << parameter->DebugString();
2215           }
2216         }
2217       } else if (front_tuple_parameter_with_index.first != nullptr) {
2218         (void)front_to_backend_parameters_[front_tuple_parameter_with_index].emplace(KernelWithIndex(parameter, 0),
2219                                                                                      device_context);
2220       } else {
2221         (void)front_to_backend_parameters_[{front_node, 0}].emplace(KernelWithIndex(parameter, 0), device_context);
2222       }
2223     }
2224   }
2225 
2226   // Get the corresponding backend node for the real parameter according to the relationship between real
2227   // parameter and formal parameter.
2228   for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
2229     const auto &front_parameter = front_to_backend_parameters.first;
2230     const auto &backend_parameters = front_to_backend_parameters.second;
2231     const auto &iter = formal_to_real_parameters_.find(front_parameter);
2232     if (iter != formal_to_real_parameters_.end()) {
2233       for (const auto &real_parameter_with_index : iter->second) {
2234         const auto &real_parameter = real_parameter_with_index.first;
2235         MS_EXCEPTION_IF_NULL(real_parameter);
2236         if (real_parameter->isa<Parameter>()) {
2237           front_to_backend_parameters_[real_parameter_with_index].insert(backend_parameters.begin(),
2238                                                                          backend_parameters.end());
2239         }
2240       }
2241     }
2242   }
2243   for (const auto &front_to_backend_parameters : front_to_backend_parameters_) {
2244     for (const auto &backend_parameter : front_to_backend_parameters.second) {
2245       MS_EXCEPTION_IF_NULL(front_to_backend_parameters.first.first);
2246       MS_EXCEPTION_IF_NULL(backend_parameter.first.first);
2247       MS_LOG(DEBUG) << "Print front to backend parameter, front:"
2248                     << front_to_backend_parameters.first.first->DebugString()
2249                     << " index:" << front_to_backend_parameters.first.second
2250                     << " backend:" << backend_parameter.first.first->DebugString()
2251                     << " index:" << backend_parameter.first.second << " node addr:" << backend_parameter.first.first;
2252     }
2253   }
2254 }
2255 
ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> & control_nodes)2256 void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
2257   for (const auto &control_node : control_nodes) {
2258     MS_EXCEPTION_IF_NULL(control_node);
2259     if (!common::AnfAlgo::IsCallNode(control_node)) {
2260       continue;
2261     }
2262 
2263     const auto &belong_func_graph = control_node->func_graph();
2264     MS_EXCEPTION_IF_NULL(belong_func_graph);
2265     (void)func_graph_to_call_nodes_[belong_func_graph].emplace(control_node);
2266 
2267     const auto &cnode = control_node->cast<CNodePtr>();
2268     MS_EXCEPTION_IF_NULL(cnode);
2269     const auto &func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
2270     if (func_graphs.empty()) {
2271       MS_LOG(EXCEPTION) << "Get func graphs from abstract failed.";
2272     }
2273     for (auto func_graph : func_graphs) {
2274       (void)call_node_to_func_graphs_[control_node].emplace(func_graph);
2275     }
2276   }
2277 }
2278 
FetchFuncGraphbyCallNode(const AnfNodePtr & control_node)2279 const std::set<FuncGraphPtr> &ControlNodeParser::FetchFuncGraphbyCallNode(const AnfNodePtr &control_node) {
2280   MS_EXCEPTION_IF_NULL(control_node);
2281   const auto &iter = call_node_to_func_graphs_.find(control_node);
2282   if (iter == call_node_to_func_graphs_.end()) {
2283     MS_LOG_WITH_NODE(EXCEPTION, control_node) << "Invalid call node:" << control_node->DebugString();
2284   }
2285   return iter->second;
2286 }
2287 
ParseFrontToBackendKernel(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)2288 void ControlNodeParser::ParseFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
2289                                                   const std::vector<DeviceContext *> &device_contexts) {
2290   for (size_t i = 0; i < graphs.size(); ++i) {
2291     const auto &graph = graphs[i];
2292     const auto &device_context = device_contexts[i];
2293     MS_EXCEPTION_IF_NULL(graph);
2294     auto execution_order = graph->execution_order();
2295     for (auto &kernel : execution_order) {
2296       auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
2297       if (front_node != nullptr) {
2298         for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) {
2299           front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context};
2300           MS_LOG(DEBUG) << "Add front to backend kernel, front:" << common::AnfAlgo::GetNodeDebugString(front_node)
2301                         << "index:" << j << " addr:" << front_node
2302                         << " second:" << common::AnfAlgo::GetNodeDebugString(kernel) << "index:" << j
2303                         << " addr:" << kernel;
2304         }
2305       }
2306     }
2307 
2308     for (const auto &output_pair : graph->front_node_to_graph_output_map()) {
2309       MS_EXCEPTION_IF_NULL(output_pair.second.first);
2310       if (output_pair.second.first->isa<CNode>()) {
2311         front_to_backend_kernels_[output_pair.first] = {output_pair.second, device_context};
2312       }
2313     }
2314   }
2315   for (const auto &front_to_backend_kernels : front_to_backend_kernels_) {
2316     MS_EXCEPTION_IF_NULL(front_to_backend_kernels.first.first);
2317     MS_EXCEPTION_IF_NULL(front_to_backend_kernels.second.first.first);
2318     MS_LOG(DEBUG) << "Print front to backend kernel, front node:" << front_to_backend_kernels.first.first->DebugString()
2319                   << " front index:" << front_to_backend_kernels.first.second
2320                   << " backend node:" << front_to_backend_kernels.second.first.first->DebugString()
2321                   << " backend index:" << front_to_backend_kernels.second.first.second;
2322   }
2323 }
2324 
ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> & control_nodes)2325 void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
2326   for (const auto &control_node : control_nodes) {
2327     MS_EXCEPTION_IF_NULL(control_node);
2328     const auto &func_graph = control_node->func_graph();
2329     MS_EXCEPTION_IF_NULL(func_graph);
2330     // In the funcgraph with recursive call node, the call node is marked as level1, and the entrance actor is
2331     // notified to send data after the call node execute ends. At this time, it is necessary to ensure that the
2332     // data of all actors in the graph has been processed, so all control nodes of level0 need link control arrow
2333     // to entrance actor.
2334     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
2335       auto iter = node_to_level_.find(control_node);
2336       if (iter != node_to_level_.end() && iter->second == 0 && (!IsPartialInput(control_node))) {
2337         (void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
2338       }
2339     }
2340 
2341     std::set<AnfNodePtr> checked_nodes;
2342     if (((common::AnfAlgo::IsCallNode(control_node) &&
2343           unrecursion_call_nodes_.find(control_node) == unrecursion_call_nodes_.end()) ||
2344          common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) &&
2345         IsFirstControlNode(control_node, &checked_nodes, unrecursion_call_nodes_)) {
2346       (void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
2347       MS_LOG(DEBUG) << "Add first control node:" << control_node->DebugString()
2348                     << " for funcgraph:" << func_graph->ToString();
2349       if (!common::AnfAlgo::IsCallNode(control_node)) {
2350         continue;
2351       }
2352 
2353       // If there is a recursive call node in the funcgraph, the kernel graph of the topo sort before the call node
2354       // needs to be executed before the call recursion, that is, the kernel graph whose level is less than the call
2355       // node needs to link a control arrow to the corresponding entry actor.
2356       // Fetch the level of control node.
2357       const auto &level_iter = node_to_level_.find(control_node);
2358       if (level_iter == node_to_level_.end()) {
2359         MS_LOG(DEBUG) << "Failed to get level for call node:" << control_node->DebugString();
2360         continue;
2361       }
2362 
2363       // Fetch all of the kernel graph group info whose level less than the control node.
2364       const auto &graph_group_iter = func_graph_to_kernel_graph_groups_.find(func_graph);
2365       if (graph_group_iter == func_graph_to_kernel_graph_groups_.end()) {
2366         continue;
2367       }
2368       for (const auto &kernel_graphs : graph_group_iter->second) {
2369         // Fetch one graph from the group.
2370         KernelGraphPtr dst_graph = nullptr;
2371         for (const auto &graph : kernel_graphs) {
2372           MS_EXCEPTION_IF_NULL(graph);
2373           if (graph->execution_order().empty()) {
2374             continue;
2375           }
2376           dst_graph = graph;
2377           break;
2378         }
2379         if (dst_graph == nullptr) {
2380           continue;
2381         }
2382 
2383         // Fetch the group info.
2384         const auto &group_info_iter = kernel_graphs_to_group_info_.find(dst_graph);
2385         if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2386           MS_LOG(EXCEPTION) << "Failed to get group info for kernel_graph:" << dst_graph->ToString();
2387         }
2388         MS_EXCEPTION_IF_NULL(group_info_iter->second);
2389         if (group_info_iter->second->level_ < level_iter->second) {
2390           MS_LOG(DEBUG) << "Kernel graph group;" << group_info_iter->second->group_name_
2391                         << " need link control to entrance of funcgraph:" << func_graph->ToString();
2392           (void)func_graph_to_first_kernel_graphs_[func_graph].emplace(group_info_iter->second);
2393         }
2394       }
2395     }
2396   }
2397 }
2398 
ParseUnRecursionCallNode()2399 void ControlNodeParser::ParseUnRecursionCallNode() {
2400   std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> func_graph_call_relation;
2401   // Collect the call relationship between funcgraphs.
2402   for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
2403     const auto &call_node = call_node_to_func_graphs.first;
2404     MS_EXCEPTION_IF_NULL(call_node);
2405     const auto &func_graph = call_node->func_graph();
2406     MS_EXCEPTION_IF_NULL(func_graph);
2407     func_graph_call_relation[func_graph].insert(call_node_to_func_graphs.second.begin(),
2408                                                 call_node_to_func_graphs.second.end());
2409   }
2410 
2411   for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
2412     const auto &call_node = call_node_to_func_graphs.first;
2413     MS_EXCEPTION_IF_NULL(call_node);
2414     const auto &dest_func_graph = call_node->func_graph();
2415     MS_EXCEPTION_IF_NULL(dest_func_graph);
2416     std::set<FuncGraphPtr> exexution_func_graphs;
2417     for (const auto &func_graph : call_node_to_func_graphs.second) {
2418       FetchAllExecutionFunction(func_graph, &exexution_func_graphs, func_graph_call_relation);
2419     }
2420     if (exexution_func_graphs.find(dest_func_graph) == exexution_func_graphs.end()) {
2421       (void)unrecursion_call_nodes_.emplace(call_node);
2422       MS_LOG(DEBUG) << "Add unrecursion call control node:" << call_node->DebugString();
2423     }
2424   }
2425 }
2426 
IsCallNodeNeedStack(const AnfNodePtr & node)2427 bool ControlNodeParser::IsCallNodeNeedStack(const AnfNodePtr &node) {
2428   MS_EXCEPTION_IF_NULL(node);
2429   const auto &cnode = node->cast<CNodePtr>();
2430   MS_EXCEPTION_IF_NULL(cnode);
2431   const auto &inputs = cnode->inputs();
2432   std::set<AnfNodePtr> depend_nodes;
2433 
2434   // Fetch all the side effect inputs of call node.
2435   for (const auto &input : inputs) {
2436     MS_EXCEPTION_IF_NULL(input);
2437     std::vector<AnfNodePtr> monad_nodes = FetchAllMonadNodeByNode(input);
2438     for (const auto &monad_node : monad_nodes) {
2439       FetchRealDependNodeByAutoMonad(monad_node, &depend_nodes);
2440     }
2441   }
2442 
2443   // Fetch all the data inputs of call node.
2444   auto input_with_indexs = FetchInputNodeByCNode(node);
2445   (void)std::for_each(
2446     input_with_indexs.begin(), input_with_indexs.end(),
2447     [&depend_nodes](const auto &input_with_index) { (void)depend_nodes.emplace(input_with_index.first); });
2448 
2449   // Check if the call node need a stack.
2450   for (const auto &depend_node : depend_nodes) {
2451     MS_EXCEPTION_IF_NULL(depend_node);
2452     // If the call node has call or recursion graph input, a stack created for the call node is required.
2453     if (!common::AnfAlgo::IsCallNode(depend_node)) {
2454       if (!depend_node->isa<CNode>()) {
2455         continue;
2456       }
2457       const auto &graph = FetchKernelGraphByFrontNode(depend_node);
2458       if (graph == nullptr || (!IsRecursionKernelGraph(graph))) {
2459         continue;
2460       }
2461     }
2462     return true;
2463   }
2464   return false;
2465 }
2466 
ParseNeedStackControlNode(const std::vector<AnfNodePtr> & control_nodes)2467 void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes) {
2468   for (const auto &control_node : control_nodes) {
2469     MS_EXCEPTION_IF_NULL(control_node);
2470     if (common::AnfAlgo::IsCallNode(control_node) && IsCallNodeNeedStack(control_node)) {
2471       (void)need_stack_control_nodes_.emplace(control_node);
2472       MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
2473     }
2474   }
2475 
2476   for (const auto &control_node : control_nodes) {
2477     MS_EXCEPTION_IF_NULL(control_node);
2478     if (IsInvalidPartial(control_node)) {
2479       continue;
2480     }
2481 
2482     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
2483       auto input_with_indexs = FetchInputNodeByCNode(control_node);
2484       size_t call_input_num = 0;
2485       for (auto input_with_index : input_with_indexs) {
2486         if (common::AnfAlgo::IsCallNode(input_with_index.first)) {
2487           ++call_input_num;
2488         }
2489       }
2490 
2491       const auto &cnode = control_node->cast<CNodePtr>();
2492       MS_EXCEPTION_IF_NULL(cnode);
2493       const auto &inputs = cnode->inputs();
2494       if (inputs.size() <= kReturnInputPos) {
2495         MS_LOG_WITH_NODE(EXCEPTION, control_node) << "Invalid return node:" << control_node->DebugString();
2496       }
2497 
2498       if ((!IsInputInSameLevel(control_node)) ||
2499           (call_input_num != 0 && (common::AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend)))) {
2500         (void)need_stack_control_nodes_.emplace(control_node);
2501         MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
2502       }
2503     } else if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) ||
2504                common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
2505                common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
2506       if (!IsInputInSameLevel(control_node)) {
2507         (void)need_stack_control_nodes_.emplace(control_node);
2508         MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString();
2509       }
2510     }
2511   }
2512 }
2513 
CollectEffectiveInputByGraph(const KernelGraphPtr & graph,const DeviceContext * const device_context,KernelGraphGroupInfo * const kernel_graph_group_info)2514 void CollectEffectiveInputByGraph(const KernelGraphPtr &graph, const DeviceContext *const device_context,
2515                                   KernelGraphGroupInfo *const kernel_graph_group_info) {
2516   MS_EXCEPTION_IF_NULL(graph);
2517   MS_EXCEPTION_IF_NULL(device_context);
2518   MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2519 
2520   const auto &outputs = kernel_graph_group_info->front_output_nodes_;
2521   const auto &monad_outputs = kernel_graph_group_info->monad_outputs_;
2522   const auto &real_parameters = graph->input_nodes();
2523   for (const auto &parameter : real_parameters) {
2524     MS_EXCEPTION_IF_NULL(parameter);
2525     auto front_node_with_index = GetFrontNodeByKernelGraph(parameter, graph.get());
2526     MS_EXCEPTION_IF_NULL(front_node_with_index.first);
2527     // If input come from the output of kernel graph belong the same group, it should not be collected in
2528     // the group inputs.
2529     if (HasAbstractMonad(front_node_with_index.first) || HasAbstractMonad(parameter) ||
2530         outputs.find(front_node_with_index) != outputs.end() || front_node_with_index.first->isa<ValueNode>()) {
2531       // The monad input is used to link the control arrow of the graph. If it comes from other graphs in the same
2532       // group, it is not used as the monad input of the group.
2533       if ((HasAbstractMonad(front_node_with_index.first) || HasAbstractMonad(parameter)) &&
2534           monad_outputs.find(front_node_with_index) == monad_outputs.end()) {
2535         (void)kernel_graph_group_info->monad_inputs_.emplace(front_node_with_index.first);
2536         MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
2537                       << " add front monad input node:" << front_node_with_index.first->DebugString();
2538       }
2539       continue;
2540     }
2541     if (common::AnfAlgo::IsCallNode(front_node_with_index.first)) {
2542       kernel_graph_group_info->need_stack_ = true;
2543     }
2544     MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
2545                   << " add front input node:" << front_node_with_index.first->DebugString()
2546                   << " index:" << front_node_with_index.second << " backend node:" << parameter->DebugString()
2547                   << " index:0";
2548     kernel_graph_group_info->front_input_nodes_[front_node_with_index] = device_context;
2549   }
2550 }
2551 
CollectEffectiveOutputByGraph(const KernelGraphPtr & graph,DeviceContext * const device_context,FrontToBackendKernelWithContext * const outputs,std::set<KernelWithIndex> * monad_outputs)2552 void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *const device_context,
2553                                    FrontToBackendKernelWithContext *const outputs,
2554                                    std::set<KernelWithIndex> *monad_outputs) {
2555   MS_EXCEPTION_IF_NULL(graph);
2556   MS_EXCEPTION_IF_NULL(device_context);
2557   MS_EXCEPTION_IF_NULL(outputs);
2558   MS_EXCEPTION_IF_NULL(monad_outputs);
2559 
2560   for (const auto &front_to_backend : graph->front_node_to_graph_output_map()) {
2561     MS_EXCEPTION_IF_NULL(front_to_backend.first.first);
2562     MS_EXCEPTION_IF_NULL(front_to_backend.second.first);
2563     if (HasAbstractMonad(front_to_backend.second.first) || HasAbstractMonad(front_to_backend.first.first) ||
2564         front_to_backend.second.first->isa<Parameter>() ||
2565         common::AnfAlgo::CheckPrimitiveType(front_to_backend.first.first, prim::kPrimPartial) ||
2566         front_to_backend.first.first->isa<ValueNode>()) {
2567       if (HasAbstractMonad(front_to_backend.first.first) || HasAbstractMonad(front_to_backend.second.first)) {
2568         MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString() << " add monad output node:"
2569                       << (front_to_backend.first.first != nullptr ? front_to_backend.first.first->DebugString()
2570                                                                   : "null")
2571                       << " index:" << front_to_backend.first.second;
2572         (void)monad_outputs->emplace(front_to_backend.first);
2573       }
2574       continue;
2575     }
2576 
2577     // Skip the function input.
2578     const auto &abstract = front_to_backend.first.first->abstract();
2579     MS_EXCEPTION_IF_NULL(abstract);
2580     const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, front_to_backend.first.second);
2581     MS_EXCEPTION_IF_NULL(real_abstract);
2582     if (real_abstract->isa<abstract::AbstractFunction>()) {
2583       continue;
2584     }
2585 
2586     MS_LOG(DEBUG) << "Kernel graph:" << graph->ToString()
2587                   << " add front output node:" << front_to_backend.first.first->DebugString()
2588                   << " index:" << front_to_backend.first.second
2589                   << " backend node:" << front_to_backend.second.first->DebugString()
2590                   << " full name:" << front_to_backend.second.first->fullname_with_scope()
2591                   << " index:" << front_to_backend.second.second;
2592     (*outputs)[front_to_backend.first] = {front_to_backend.second, device_context};
2593   }
2594 }
2595 
ParseKernelGraphGroup(const KernelGraphToDeviceContext & kernel_graph_to_device_contexts)2596 void ControlNodeParser::ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts) {
2597   for (const auto &func_graph_to_kernel_graph_groups : func_graph_to_kernel_graph_groups_) {
2598     for (const auto &kernel_graph_group : func_graph_to_kernel_graph_groups.second) {
2599       if (kernel_graph_group.empty()) {
2600         continue;
2601       }
2602 
2603       KernelGraphGroupInfoPtr kernel_graph_group_info = std::make_shared<KernelGraphGroupInfo>();
2604       MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2605       for (const auto &kernel_graph : kernel_graph_group) {
2606         MS_EXCEPTION_IF_NULL(kernel_graph);
2607         if (kernel_graph->execution_order().empty()) {
2608           continue;
2609         }
2610         auto iter = kernel_graph_to_device_contexts.find(kernel_graph);
2611         if (iter == kernel_graph_to_device_contexts.end()) {
2612           MS_LOG(EXCEPTION) << "Failed to find device context for kernel graph:" << kernel_graph->ToString();
2613         }
2614         // Collect kernel graphs in group.
2615         (void)kernel_graph_group_info->graphs_.emplace(kernel_graph);
2616 
2617         // Collect inputs in group.
2618         CollectEffectiveInputByGraph(kernel_graph, iter->second, kernel_graph_group_info.get());
2619 
2620         // Collect outputs in group.
2621         CollectEffectiveOutputByGraph(kernel_graph, iter->second, &(kernel_graph_group_info->front_output_nodes_),
2622                                       &(kernel_graph_group_info->monad_outputs_));
2623 
2624         kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info;
2625       }
2626       kernel_graph_group_info->group_name_ = "kernel_graph";
2627       for (const auto &graph : kernel_graph_group_info->graphs_) {
2628         if (kernel_graph_group_info->need_stack_) {
2629           MS_LOG(DEBUG) << "Add call input kernel graph:" << graph->ToString();
2630           (void)call_input_kernel_graphs_.emplace(graph.get());
2631         }
2632         kernel_graph_group_info->group_name_ += ("_" + std::to_string(graph->graph_id()));
2633       }
2634       MS_LOG(DEBUG) << "Add kernel graph info for group:" << kernel_graph_group_info->group_name_;
2635       (void)kernel_graph_group_infos_.emplace(kernel_graph_group_info);
2636     }
2637   }
2638 }
2639 
ParseControlNodeLevel(const AnfNodePtr & node,std::set<AnfNodePtr> * checked_nodes)2640 size_t ControlNodeParser::ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes) {
2641   MS_EXCEPTION_IF_NULL(node);
2642   MS_EXCEPTION_IF_NULL(checked_nodes);
2643   if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) {
2644     return 0;
2645   }
2646   (void)checked_nodes->emplace(node);
2647 
2648   auto iter = node_to_level_.find(node);
2649   if (iter != node_to_level_.end()) {
2650     return iter->second;
2651   }
2652 
2653   size_t level = 0;
2654   const auto &kernel_graph = FetchKernelGraphByFrontNode(node);
2655   if (kernel_graph == nullptr) {
2656     // If the kernel graph is not found, it means that the input does not come from the kernel graph, then
2657     // just continue to traverse the input.
2658     const auto &cnode = node->cast<CNodePtr>();
2659     MS_EXCEPTION_IF_NULL(cnode);
2660     const auto &inputs = cnode->inputs();
2661     for (const auto &input : inputs) {
2662       size_t tmp_level = ParseControlNodeLevel(input, checked_nodes);
2663       level = (tmp_level > level ? tmp_level : level);
2664     }
2665     return level;
2666   }
2667 
2668   // If the input comes from the kernel graph, you need to check all the graph's input, not just the node's input.
2669   auto group_info_iter = kernel_graphs_to_group_info_.find(kernel_graph);
2670   if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2671     MS_LOG(EXCEPTION) << "Failed to get kernel graph group info for graph:" << kernel_graph->ToString();
2672   }
2673   MS_EXCEPTION_IF_NULL(group_info_iter->second);
2674   const auto &inputs = group_info_iter->second->front_input_nodes_;
2675   for (const auto &input : inputs) {
2676     const auto &input_node = input.first.first;
2677     size_t tmp_level = ParseControlNodeLevel(input_node, checked_nodes);
2678     level = (tmp_level > level ? tmp_level : level);
2679   }
2680   return level;
2681 }
2682 
2683 namespace {
GetRealOutputNode(const KernelWithIndex & front_pair,const KernelWithIndex & backend_pair)2684 AnfNodePtr GetRealOutputNode(const KernelWithIndex &front_pair, const KernelWithIndex &backend_pair) {
2685   if (front_pair.first == nullptr || backend_pair.first == nullptr) {
2686     return nullptr;
2687   }
2688   if (common::AnfAlgo::CheckPrimitiveType(backend_pair.first, prim::kPrimLoad) &&
2689       common::AnfAlgo::CheckPrimitiveType(front_pair.first, prim::kPrimLoad)) {
2690     const auto &backend_cnode = backend_pair.first->cast<CNodePtr>();
2691     const auto &front_cnode = front_pair.first->cast<CNodePtr>();
2692     MS_EXCEPTION_IF_NULL(backend_cnode);
2693     MS_EXCEPTION_IF_NULL(front_cnode);
2694     if (backend_cnode->inputs().size() > 1 && backend_cnode->input(1) != nullptr &&
2695         backend_cnode->input(1)->isa<CNode>() && front_cnode->inputs().size() > 1 && front_cnode->input(1) != nullptr &&
2696         front_cnode->input(1)->isa<CNode>()) {
2697       return front_cnode->input(1);
2698     }
2699   }
2700   return nullptr;
2701 }
2702 }  // namespace
2703 
ParseNodeLevel(const std::vector<AnfNodePtr> & control_nodes)2704 void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) {
2705   size_t level = 0;
2706   // 1. Parse levels of control nodes.
2707   for (const auto &control_node : control_nodes) {
2708     MS_EXCEPTION_IF_NULL(control_node);
2709     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
2710       node_to_level_[control_node] = level;
2711       MS_LOG(DEBUG) << "Add level:" << level << " for node:" << control_node->DebugString();
2712       level = 0;
2713       const auto &func_graph = control_node->func_graph();
2714       MS_EXCEPTION_IF_NULL(func_graph);
2715       const auto &parameters = func_graph->parameters();
2716       for (const auto &parameter : parameters) {
2717         MS_EXCEPTION_IF_NULL(parameter);
2718         MS_LOG(DEBUG) << "Add level:" << level << " for node:" << parameter->DebugString();
2719         node_to_level_[parameter] = level;
2720       }
2721       continue;
2722     } else if (IsRecursionCallNode(control_node)) {
2723       ++level;
2724       MS_LOG(DEBUG) << "Add level:" << level << " for node:" << control_node->DebugString();
2725       node_to_level_[control_node] = level;
2726     } else {
2727       std::set<AnfNodePtr> checked_nodes;
2728       node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes);
2729       MS_LOG(DEBUG) << "Add level:" << node_to_level_[control_node] << " for node:" << control_node->DebugString();
2730     }
2731   }
2732 
2733   // 2. Parse the levels of kernel graph outputs.
2734   for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
2735     MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2736     level = 0;
2737     for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
2738       const auto &input_node = front_input_node.first.first;
2739       auto iter = node_to_level_.find(input_node);
2740       if (iter != node_to_level_.end() && level < iter->second) {
2741         level = iter->second;
2742       }
2743     }
2744     for (const auto &front_output_node : kernel_graph_group_info->front_output_nodes_) {
2745       MS_EXCEPTION_IF_NULL(front_output_node.second.first.first);
2746       if (front_output_node.second.first.first->isa<Parameter>()) {
2747         continue;
2748       }
2749       const auto &output_node = front_output_node.first.first;
2750       MS_EXCEPTION_IF_NULL(output_node);
2751       MS_LOG(DEBUG) << "Add level:" << level << " for node:" << output_node->DebugString();
2752       node_to_level_[output_node] = level;
2753       const auto &real_output_node = GetRealOutputNode(front_output_node.first, front_output_node.second.first);
2754       if (real_output_node != nullptr && node_to_level_.find(real_output_node) == node_to_level_.end()) {
2755         node_to_level_[real_output_node] = level;
2756       }
2757     }
2758   }
2759 
2760   // Parse the levels of kernel graph groups.
2761   for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) {
2762     MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2763     size_t max_level = 0;
2764     for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) {
2765       const auto &input_node = front_input_node.first.first;
2766       MS_EXCEPTION_IF_NULL(input_node);
2767       auto iter = node_to_level_.find(input_node);
2768       if (iter == node_to_level_.end()) {
2769         MS_LOG_WITH_NODE(EXCEPTION, input_node) << "Failed to get input node:" << input_node->DebugString()
2770                                                 << " for kernel graph:" << kernel_graph_group_info->group_name_;
2771       }
2772       max_level = (max_level > iter->second ? max_level : iter->second);
2773     }
2774     if (max_level > 0) {
2775       kernel_graph_group_info->need_stack_ = true;
2776       kernel_graph_group_info->level_ = max_level;
2777       for (const auto &kernel_graph : kernel_graph_group_info->graphs_) {
2778         (void)call_input_kernel_graphs_.emplace(kernel_graph.get());
2779       }
2780     }
2781     MS_LOG(DEBUG) << "Kernel graph group:" << kernel_graph_group_info->group_name_
2782                   << " need stack:" << kernel_graph_group_info->need_stack_
2783                   << " level:" << kernel_graph_group_info->level_;
2784   }
2785 }
2786 
IsInputInSameLevel(const AnfNodePtr & node)2787 bool ControlNodeParser::IsInputInSameLevel(const AnfNodePtr &node) {
2788   MS_EXCEPTION_IF_NULL(node);
2789   if (!node->isa<CNode>()) {
2790     return true;
2791   }
2792 
2793   auto input_with_indexes = FetchInputNodeByCNode(node);
2794   size_t level = SIZE_MAX;
2795   for (const auto &input_with_index : input_with_indexes) {
2796     auto input_node = input_with_index.first;
2797     MS_EXCEPTION_IF_NULL(input_node);
2798     if (input_node->isa<ValueNode>()) {
2799       continue;
2800     }
2801     auto iter = node_to_level_.find(input_node);
2802     if (iter == node_to_level_.end()) {
2803       MS_LOG_WITH_NODE(EXCEPTION, node) << "Failed to find input:" << input_node->DebugString()
2804                                         << " for node:" << node->DebugString() << " in graph output map.";
2805     }
2806     if (level == SIZE_MAX) {
2807       level = iter->second;
2808       continue;
2809     }
2810     if (level != iter->second) {
2811       return false;
2812     }
2813   }
2814   return true;
2815 }
2816 
CreateDeviceTensorForRootGraphParameter(DeviceContext * const default_context)2817 void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context) {
2818   MS_EXCEPTION_IF_NULL(default_context);
2819   for (const auto &parameter : root_graph_parameters_) {
2820     MS_EXCEPTION_IF_NULL(parameter);
2821     const auto &abstract = parameter->abstract();
2822     MS_EXCEPTION_IF_NULL(abstract);
2823     size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
2824     for (size_t i = 0; i < output_num; ++i) {
2825       KernelWithIndex parameter_with_index(parameter, i);
2826       if (front_to_backend_parameters_.find(parameter_with_index) == front_to_backend_parameters_.end()) {
2827         MS_LOG(DEBUG) << "Create device tensor for root graph parameter:" << parameter->DebugString();
2828         CreateDeviceTensorForFrontNode(parameter_with_index, default_context);
2829         (void)front_to_backend_parameters_[parameter_with_index].emplace(parameter_with_index, default_context);
2830       }
2831     }
2832   }
2833 }
2834 
FetchGroupNameByKernelGraph(const KernelGraphPtr & graph)2835 std::string ControlNodeParser::FetchGroupNameByKernelGraph(const KernelGraphPtr &graph) {
2836   MS_EXCEPTION_IF_NULL(graph);
2837   auto group_info_iter = kernel_graphs_to_group_info_.find(graph);
2838   if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2839     MS_LOG(EXCEPTION) << "Failed to get kernel graph group info for graph:" << graph->ToString();
2840   }
2841   MS_EXCEPTION_IF_NULL(group_info_iter->second);
2842   return group_info_iter->second->group_name_;
2843 }
2844 
FetchBackendOutputByKernelGraph(const KernelGraphPtr & graph,const KernelWithIndex & front_node_with_index)2845 KernelWithIndex ControlNodeParser::FetchBackendOutputByKernelGraph(const KernelGraphPtr &graph,
2846                                                                    const KernelWithIndex &front_node_with_index) {
2847   MS_EXCEPTION_IF_NULL(graph);
2848   auto group_info_iter = kernel_graphs_to_group_info_.find(graph);
2849   if (group_info_iter == kernel_graphs_to_group_info_.end()) {
2850     MS_LOG(WARNING) << "Failed to get kernel graph group info for graph:" << graph->ToString();
2851     return {nullptr, 0};
2852   }
2853   MS_EXCEPTION_IF_NULL(group_info_iter->second);
2854   const auto &output_iter = group_info_iter->second->front_output_nodes_.find(front_node_with_index);
2855   if (output_iter != group_info_iter->second->front_output_nodes_.end()) {
2856     return output_iter->second.first;
2857   }
2858   const auto &backend_iter = std::find_if(
2859     group_info_iter->second->front_output_nodes_.begin(), group_info_iter->second->front_output_nodes_.end(),
2860     [front_node_with_index](const auto &pair) {
2861       return front_node_with_index == common::AnfAlgo::VisitKernelWithReturnType(pair.first.first, pair.first.second);
2862     });
2863   if (backend_iter == group_info_iter->second->front_output_nodes_.end()) {
2864     return {nullptr, 0};
2865   }
2866   return common::AnfAlgo::VisitKernelWithReturnType(backend_iter->second.first.first,
2867                                                     backend_iter->second.first.second);
2868 }
2869 
PrintParseInfo()2870 void ControlNodeParser::PrintParseInfo() {
2871   for (const auto &group : kernel_graph_group_infos_) {
2872     MS_EXCEPTION_IF_NULL(group);
2873     for (const auto &input_pair : group->front_input_nodes_) {
2874       if (input_pair.first.first != nullptr) {
2875         MS_LOG(WARNING) << "Kernel graph group:" << group->group_name_
2876                         << " input node:" << input_pair.first.first->fullname_with_scope()
2877                         << " debug string:" << input_pair.first.first->DebugString(kDebugStrDepthTwo)
2878                         << " index:" << input_pair.first.second;
2879       }
2880     }
2881     for (const auto &output_pair : group->front_output_nodes_) {
2882       if (output_pair.first.first != nullptr && output_pair.second.first.first != nullptr) {
2883         MS_LOG(WARNING) << "Kernel graph group:" << group->group_name_
2884                         << " output node:" << output_pair.first.first->fullname_with_scope()
2885                         << " debug string:" << output_pair.first.first->DebugString(kDebugStrDepthTwo)
2886                         << " index:" << output_pair.first.second
2887                         << " backend node:" << output_pair.second.first.first->fullname_with_scope()
2888                         << " debug string:" << output_pair.second.first.first->DebugString(kDebugStrDepthTwo)
2889                         << " index:" << output_pair.second.first.second;
2890       }
2891     }
2892   }
2893   for (const auto &f_to_b : front_to_backend_kernels_) {
2894     if (f_to_b.first.first != nullptr && f_to_b.second.first.first != nullptr) {
2895       MS_LOG(WARNING) << "Front to backend map front node:" << f_to_b.first.first->fullname_with_scope()
2896                       << " debug string:" << f_to_b.first.first->DebugString(kDebugStrDepthTwo)
2897                       << " index:" << f_to_b.first.second
2898                       << " backend node:" << f_to_b.second.first.first->fullname_with_scope()
2899                       << " debug string:" << f_to_b.second.first.first->DebugString(kDebugStrDepthTwo)
2900                       << " index:" << f_to_b.second.first.second;
2901     }
2902   }
2903   for (const auto &pair : front_node_to_kernel_graph_) {
2904     if (pair.first != nullptr && pair.second == nullptr) {
2905       MS_LOG(WARNING) << "Front node:" << pair.first->fullname_with_scope()
2906                       << " debug string:" << pair.first->DebugString(kDebugStrDepthTwo)
2907                       << " to kernel graph:" << pair.second->ToString();
2908     }
2909   }
2910 }
2911 }  // namespace runtime
2912 }  // namespace mindspore
2913