• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/framework/control_node_parser.h"
18 #include "runtime/framework/actor/switch_actor.h"
19 #include "runtime/framework/actor/gather_actor.h"
20 #include "abstract/utils.h"
21 #include "ir/tensor.h"
22 
23 namespace mindspore {
24 namespace runtime {
25 namespace {
26 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
27 // Fetch all the weight parameters related to node. It runs like this:
28 // if we have a map like {{a, {b, c}}, {b, {d, e}}}, final we will get {{a, {b, c, d, e}}, {b, {c, d}}}.
FetchWeightbyHostParameter(const AnfNodePtr & node,std::vector<AnfNodePtr> * dest_nodes,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & front_to_front_weight)29 void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *dest_nodes,
30                                 const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_weight) {
31   if (find((*dest_nodes).begin(), (*dest_nodes).end(), node) != (*dest_nodes).end()) {
32     return;
33   }
34   (void)((*dest_nodes).emplace_back(node));
35   if (front_to_front_weight.find(node) == front_to_front_weight.end()) {
36     return;
37   }
38 
39   const auto weight_nodes = front_to_front_weight.at(node);
40   for (const auto weight_node : weight_nodes) {
41     FetchWeightbyHostParameter(weight_node, dest_nodes, front_to_front_weight);
42   }
43 }
44 
45 // Check whether the input is a valid parameter.
CheckValidFuncGraphInput(const AnfNodePtr & node)46 bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
47   if (HasAbstractMonad(node)) {
48     return false;
49   } else if (node->isa<Parameter>()) {
50     return !HasAbstractRef(node);
51   }
52   return true;
53 }
54 
55 // Get the funcgraph in partial node.
GetFuncGraphFromPartial(const AnfNodePtr & node)56 FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
57   const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
58   return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
59 }
60 
61 // Get the relationship between funcgraph and parameters in the switch node.
FetchParameterBySwitchNode(const AnfNodePtr & switch_node,FuncGraphToParameter * graph_to_real_parameters)62 void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParameter *graph_to_real_parameters) {
63   const auto &switch_cnode = switch_node->cast<CNodePtr>();
64   const auto &switch_inputs = switch_cnode->inputs();
65   if (switch_inputs.size() != kSwitchInputNum) {
66     MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(switch_node);
67   }
68 
69   for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
70     const auto &partial_node = switch_inputs[i];
71     if (IsValueNode<FuncGraph>(partial_node)) {
72       continue;
73     }
74     const auto &func_graph = GetFuncGraphFromPartial(partial_node);
75     std::vector<AnfNodePtr> parameters;
76     const auto &partial_inputs = partial_node->cast<CNodePtr>()->inputs();
77     for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
78       if (CheckValidFuncGraphInput(partial_inputs[j])) {
79         (void)parameters.emplace_back(partial_inputs[j]);
80       }
81     }
82     (void)((*graph_to_real_parameters)[func_graph].emplace_back(parameters));
83   }
84 }
85 
86 // Get the corresponding relationship between funcgraph and parameters in the switch layer node.
FetchParameterBySwitchLayerNode(const AnfNodePtr & switch_layer_node,const std::vector<AnfNodePtr> & call_inputs,FuncGraphToParameter * graph_to_real_parameters)87 void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const std::vector<AnfNodePtr> &call_inputs,
88                                      FuncGraphToParameter *graph_to_real_parameters) {
89   const auto &switch_layer_cnode = switch_layer_node->cast<CNodePtr>();
90   const auto &switch_layer_inputs = switch_layer_cnode->inputs();
91 
92   if (switch_layer_inputs.size() != kSwitchLayerInputNum) {
93     MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(switch_layer_node);
94   }
95 
96   auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
97 
98   // Get the parameter corresponding to each funcgraph in make tuple.
99   for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
100     if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
101       // Tuple branch is a partial node.
102       const auto &func_graph = GetFuncGraphFromPartial(tuple_inputs[i]);
103       std::vector<AnfNodePtr> parameters;
104       const auto &partial_inputs = tuple_inputs[i]->cast<CNodePtr>()->inputs();
105 
106       // Get inputs in partial node.
107       for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
108         if (CheckValidFuncGraphInput(partial_inputs[j])) {
109           (void)parameters.emplace_back(partial_inputs[j]);
110         }
111       }
112 
113       // Get inputs in call node.
114       for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) {
115         if (CheckValidFuncGraphInput(call_inputs[j])) {
116           (void)parameters.emplace_back(call_inputs[j]);
117         }
118       }
119       (void)((*graph_to_real_parameters)[func_graph].emplace_back(parameters));
120     } else if (tuple_inputs[i]->isa<ValueNode>() && IsValueNode<FuncGraph>(tuple_inputs[i])) {
121       // Tuple branch is a call node.
122       const auto &func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
123       std::vector<AnfNodePtr> parameters;
124 
125       // Get inputs in call node.
126       for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) {
127         if (CheckValidFuncGraphInput(call_inputs[j])) {
128           (void)parameters.emplace_back(call_inputs[j]);
129         }
130       }
131 
132       (void)(*graph_to_real_parameters)[func_graph].emplace_back(parameters);
133     }
134   }
135 }
136 
137 // Create a device tensor for the front node.
138 // Get the output format and select kernel build info from the backend node corresponding to the front node to
139 // create the device address.
CreateDeviceTensorForValueNode(const AnfNodePtr & front_node,const AnfNodePtr & backend_node,const DeviceContext * device_context)140 void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
141                                     const DeviceContext *device_context) {
142   MS_EXCEPTION_IF_NULL(device_context);
143 
144   const auto &node_value = front_node->cast<ValueNodePtr>()->value();
145   if (!node_value->isa<tensor::Tensor>()) {
146     return;
147   }
148 
149   size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
150   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
151   if (output_type_id == kTypeUnknown) {
152     output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0);
153   }
154 
155   if (front_node->kernel_info() == nullptr) {
156     front_node->set_kernel_info(std::make_shared<device::KernelInfo>());
157   }
158 
159   // Get the select kernel build info.
160   auto kernel_info = dynamic_cast<device::KernelInfo *>(backend_node->kernel_info());
161   MS_EXCEPTION_IF_NULL(kernel_info);
162   auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
163   MS_EXCEPTION_IF_NULL(build_info);
164   AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
165 
166   // Create device tensor.
167   std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
168   device::DeviceAddressPtr address =
169     device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
170   MS_EXCEPTION_IF_NULL(address);
171   AnfAlgo::SetOutputAddr(address, 0, front_node.get());
172 }
173 
174 // Create a device tensor for front parameter.
175 // When the condition input of the switch and switchlayer or the output of a subgraph is a parameter, there is no
176 // corresponding backend node for this parameter, so a device tensor needs to be created for it.
CreateDeviceTensorForFrontParameter(const AnfNodePtr & node,const DeviceContext * device_context)177 void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceContext *device_context) {
178   MS_EXCEPTION_IF_NULL(device_context);
179 
180   TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0);
181 
182   if (node->kernel_info() == nullptr) {
183     auto kernel_info = std::make_shared<device::KernelInfo>();
184     std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
185     builder->SetOutputsFormat({kOpFormat_DEFAULT});
186     builder->SetOutputsDeviceType({type_id});
187     kernel_info->set_select_kernel_build_info(builder->Build());
188     node->set_kernel_info(kernel_info);
189   }
190   size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0);
191 
192   // Create device tensor.
193   device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
194   MS_EXCEPTION_IF_NULL(address);
195   AnfAlgo::SetOutputAddr(address, 0, node.get());
196 }
197 
198 // Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
199 // backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
200 // front_node.
FetchBackendNodeByFrontNode(const AnfNodePtr & front_node,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & real_to_formal_front_parameters,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & formal_to_real_front_parameters,const std::unordered_map<AnfNodePtr,std::pair<AnfNodePtr,DeviceContext * >> & front_to_backend_parameter,std::set<AnfNodePtr> * invalid_node)201 std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
202   const AnfNodePtr &front_node,
203   const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &real_to_formal_front_parameters,
204   const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &formal_to_real_front_parameters,
205   const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
206   std::set<AnfNodePtr> *invalid_node) {
207   // Check whether the front_node has been looked for.
208   if ((*invalid_node).find(front_node) != (*invalid_node).end()) {
209     return std::pair<AnfNodePtr, DeviceContext *>();
210   }
211   (void)(*invalid_node).insert(front_node);
212 
213   const auto front_to_backend_iter = front_to_backend_parameter.find(front_node);
214   if (front_to_backend_iter != front_to_backend_parameter.end()) {
215     return front_to_backend_iter->second;
216   }
217 
218   const auto &real_to_formal_iter = real_to_formal_front_parameters.find(front_node);
219   if (real_to_formal_iter == real_to_formal_front_parameters.end()) {
220     return std::pair<AnfNodePtr, DeviceContext *>();
221   }
222   for (const auto &next_node : real_to_formal_iter->second) {
223     auto banckend_node =
224       FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters,
225                                   front_to_backend_parameter, invalid_node);
226     if (banckend_node.first != nullptr) {
227       return banckend_node;
228     }
229   }
230 
231   const auto &formal_to_real_iter = formal_to_real_front_parameters.find(front_node);
232   if (formal_to_real_iter == formal_to_real_front_parameters.end()) {
233     return std::pair<AnfNodePtr, DeviceContext *>();
234   }
235   for (const auto &next_node : formal_to_real_iter->second) {
236     auto banckend_node =
237       FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters,
238                                   front_to_backend_parameter, invalid_node);
239     if (banckend_node.first != nullptr) {
240       return banckend_node;
241     }
242   }
243   return std::pair<AnfNodePtr, DeviceContext *>();
244 }
245 
246 // Fetch all backend input nodes by parameter for gather actor.
FetchInputNodeByParameter(const AnfNodePtr & parameter,const std::vector<AnfNodePtr> & host_ds_parameters,std::set<AnfNodePtr> * invalid_inputs,const FuncGraphToParameter & graph_to_real_parameters)247 std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr &parameter,
248                                                   const std::vector<AnfNodePtr> &host_ds_parameters,
249                                                   std::set<AnfNodePtr> *invalid_inputs,
250                                                   const FuncGraphToParameter &graph_to_real_parameters) {
251   std::vector<AnfNodePtr> input_nodes;
252 
253   // If the node has been collected, skip it.
254   if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) {
255     return input_nodes;
256   }
257 
258   // Record the node which has been collected.
259   (void)(*invalid_inputs).insert(parameter);
260 
261   // If the parameter node is a parameter of host data source actor, return it.
262   if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) {
263     (void)input_nodes.emplace_back(parameter);
264     return input_nodes;
265   }
266 
267   // Check the parameter which send to its funcgraph.
268   const auto &func_graph = parameter->func_graph();
269   if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) {
270     return input_nodes;
271   }
272 
273   std::vector<AnfNodePtr> self_inputs;
274   for (const auto &input : func_graph->get_inputs()) {
275     // Monad input need not send to funcgraph.
276     if (HasAbstractMonad(input) || HasAbstractRef(input)) {
277       continue;
278     }
279     (void)self_inputs.emplace_back(input);
280   }
281 
282   const auto iter = find(self_inputs.begin(), self_inputs.end(), parameter);
283   if (iter == self_inputs.end()) {
284     MS_LOG(EXCEPTION) << "Cannot find parameter node:" << AnfAlgo::GetNodeDebugString(parameter);
285   }
286   size_t pos = iter - self_inputs.begin();
287 
288   for (const auto parameters : graph_to_real_parameters.at(func_graph)) {
289     if (parameters.size() != self_inputs.size()) {
290       MS_LOG(EXCEPTION) << "Invalid input num:" << parameters.size() << " and:" << self_inputs.size()
291                         << " for func_graph:" << func_graph->ToString();
292     }
293     const auto input = parameters[pos];
294     if (input->isa<CNode>()) {
295       (void)input_nodes.emplace_back(input);
296     } else if (input->isa<Parameter>()) {
297       // If input is a parameter, you need to find its input recursively.
298       auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters);
299       (void)input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end());
300     }
301   }
302   return input_nodes;
303 }
304 
305 // Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
306 // called by the call node.
FetchFuncGraphOutput(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * call_nodes)307 std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *call_nodes) {
308   std::vector<AnfNodePtr> outputs;
309   const auto &output = func_graph->output();
310   const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0, false, {prim::kPrimTupleGetItem});
311   if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
312     return outputs;
313   }
314   if (!IsCallNode(real_output.first)) {
315     outputs.push_back(real_output.first);
316     return outputs;
317   }
318 
319   (*call_nodes).push_back(real_output.first);
320   std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(real_output.first);
321   for (const auto &graph : func_graphs) {
322     auto single_outputs = FetchFuncGraphOutput(graph, call_nodes);
323     (void)outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end());
324   }
325   return outputs;
326 }
327 std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set<AnfNodePtr> *call_nodes,
328                                                 std::set<AnfNodePtr> *switch_nodes);
329 
330 // Recursive interface, get all possible output nodes of call node.
FetchOutputByCallNode(const AnfNodePtr & call_node,std::set<AnfNodePtr> * call_nodes,std::set<AnfNodePtr> * switch_nodes)331 std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::set<AnfNodePtr> *call_nodes,
332                                               std::set<AnfNodePtr> *switch_nodes) {
333   std::vector<AnfNodePtr> outputs;
334   if ((*call_nodes).find(call_node) != (*call_nodes).end()) {
335     return outputs;
336   }
337   (void)((*call_nodes).insert(call_node));
338 
339   const auto func_graphs = FetchFuncGraphbyCallNode(call_node);
340 
341   for (const auto func_graph : func_graphs) {
342     std::vector<AnfNodePtr> sub_call_nodes;
343     const std::vector<AnfNodePtr> graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes);
344     for (const auto &graph_output : graph_outputs) {
345       if (graph_output->isa<Parameter>()) {
346         outputs.push_back(graph_output);
347       } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
348         const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
349         (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
350       } else if (IsCallNode(graph_output)) {
351         const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
352         (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
353       } else if (graph_output->isa<CNode>()) {
354         (void)outputs.emplace_back(graph_output);
355       } else if (graph_output->isa<ValueNode>()) {
356         outputs.push_back(graph_output);
357       } else {
358         MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output);
359       }
360     }
361   }
362 
363   return outputs;
364 }
365 
366 // Recursive interface, get all possible output nodes of switch node.
FetchOutputBySwitchNode(const AnfNodePtr & switch_node,std::set<AnfNodePtr> * call_nodes,std::set<AnfNodePtr> * switch_nodes)367 std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set<AnfNodePtr> *call_nodes,
368                                                 std::set<AnfNodePtr> *switch_nodes) {
369   std::vector<AnfNodePtr> outputs;
370   if ((*switch_nodes).find(switch_node) != (*switch_nodes).end()) {
371     return outputs;
372   }
373   (void)((*switch_nodes).insert(switch_node));
374 
375   if (!switch_node->isa<CNode>()) {
376     MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node);
377   }
378   const auto &inputs = switch_node->cast<CNodePtr>()->inputs();
379   if (inputs.size() != kSwitchInputNum) {
380     MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node);
381   }
382 
383   for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
384     if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimPartial)) {
385       continue;
386     } else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
387       const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes);
388       (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
389     } else if (IsCallNode(inputs[i])) {
390       const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes);
391       (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
392     } else {
393       (void)outputs.emplace_back(inputs[i]);
394     }
395   }
396 
397   return outputs;
398 }
399 
400 // Recursive interface, get the real kernel that UpdateState node depends on.
FetchSourceNodeByAutoMonad(const AnfNodePtr & node)401 AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) {
402   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
403     const auto &cnode = node->cast<CNodePtr>();
404     const auto &inputs = cnode->inputs();
405     if (inputs.size() <= kUpdateStateRealInput) {
406       MS_LOG(EXCEPTION) << "Invalid updatestate node:" << AnfAlgo::GetNodeDebugString(node);
407     }
408 
409     return FetchSourceNodeByAutoMonad(inputs[kUpdateStateRealInput]);
410   }
411   return node;
412 }
413 
414 // Fetch all parameters in control node of root funcgraph.
FetchParameterByControlNode(const std::vector<AnfNodePtr> & control_nodes)415 std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr> &control_nodes) {
416   std::vector<AnfNodePtr> parameters;
417 
418   for (const auto &control_node : control_nodes) {
419     CNodePtr cnode = control_node->cast<CNodePtr>();
420     const auto &inputs = cnode->inputs();
421     if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
422       break;
423     } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
424       for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
425         if (inputs[i]->isa<Parameter>()) {
426           (void)parameters.emplace_back(inputs[i]);
427         }
428       }
429     } else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
430       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
431         if (inputs[i]->isa<Parameter>()) {
432           (void)parameters.emplace_back(inputs[i]);
433         }
434       }
435     } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
436       if (inputs.size() != kSwitchInputNum) {
437         MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
438       }
439       if (inputs[kSwitchCondPos]->isa<Parameter>()) {
440         (void)parameters.emplace_back(inputs[kSwitchCondPos]);
441       }
442     } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
443       if (inputs.size() != kSwitchLayerInputNum) {
444         MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
445       }
446       if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
447         (void)parameters.emplace_back(inputs[kSwitchLayerCondPos]);
448       }
449     }
450   }
451   return parameters;
452 }
453 
454 // Get funcgraph from node, the interface only accepts partial node and funcgraph value node.
FetchFuncGraphInNode(const auto & node)455 FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
456   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
457     const auto &func_graph = GetFuncGraphFromPartial(node);
458 
459     if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
460       return FetchFuncGraphInNode(func_graph->output());
461     } else if (IsValueNode<FuncGraph>(func_graph->output())) {
462       // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
463       return FetchFuncGraphInNode(func_graph->output());
464     }
465 
466     return func_graph;
467   } else if (IsValueNode<FuncGraph>(node)) {
468     const auto &func_graph = GetValueNode<FuncGraphPtr>(node);
469 
470     if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
471       // When the output of funcgraph is a funcgraph, it needs to return the funcgraph that is finally called.
472       return FetchFuncGraphInNode(func_graph->output());
473     } else if (IsValueNode<FuncGraph>(func_graph->output())) {
474       // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
475       return FetchFuncGraphInNode(func_graph->output());
476     }
477 
478     return func_graph;
479   }
480 
481   return nullptr;
482 }
483 }  // namespace
484 
FetchRealOutputByCallNode(const AnfNodePtr & node,std::set<AnfNodePtr> * call_nodes)485 AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
486   const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
487   if (!IsCallNode(real_node)) {
488     return real_node;
489   }
490   if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
491     return nullptr;
492   }
493   (void)((*call_nodes).insert(real_node));
494 
495   const auto &func_graphs = FetchFuncGraphbyCallNode(real_node);
496   for (const auto &func_graph : func_graphs) {
497     const auto &output = FetchRealOutputByCallNode(func_graph->output(), call_nodes);
498     if (output != nullptr) {
499       return output;
500     }
501   }
502   return nullptr;
503 }
504 
505 // Return true if the node has Ref abstract.
HasAbstractRef(const AnfNodePtr & node)506 bool HasAbstractRef(const AnfNodePtr &node) {
507   if (node == nullptr) {
508     return false;
509   }
510   auto &abs = node->abstract();
511   return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
512 }
513 
IsCallNode(const AnfNodePtr & node)514 bool IsCallNode(const AnfNodePtr &node) {
515   if (!node->isa<CNode>()) {
516     return false;
517   }
518   const auto &cnode = node->cast<CNodePtr>();
519   const auto &inputs = cnode->inputs();
520   return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
521 }
522 
IsSubCallNode(const AnfNodePtr & node)523 bool IsSubCallNode(const AnfNodePtr &node) {
524   if (!node->isa<CNode>()) {
525     return false;
526   }
527 
528   const auto inputs = node->cast<CNodePtr>()->inputs();
529   if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
530     return false;
531   }
532 
533   const auto &switch_layer_inputs = inputs[0]->cast<CNodePtr>()->inputs();
534   const auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
535   if (tuple_inputs.size() <= kMakeTupleInputStartPos) {
536     return false;
537   }
538 
539   // Check whether the funcgraph called by the call node returns funcgraph or partial node.
540   FuncGraphPtr func_graph = nullptr;
541   if (AnfAlgo::CheckPrimitiveType(tuple_inputs[kMakeTupleInputStartPos], prim::kPrimPartial)) {
542     const auto &func_graph_node = tuple_inputs[kMakeTupleInputStartPos]->cast<CNodePtr>()->input(kPartialFuncGraphPos);
543     func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
544   } else if (tuple_inputs[kMakeTupleInputStartPos]->isa<ValueNode>() &&
545              IsValueNode<FuncGraph>(tuple_inputs[kMakeTupleInputStartPos])) {
546     func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[kMakeTupleInputStartPos]);
547   }
548 
549   const auto &output = func_graph->output();
550   return AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial) ||
551          (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output));
552 }
553 
FetchAllRealInputNodeByParameter(const KernelWithIndex & node)554 std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) {
555   std::vector<KernelWithIndex> parameters;
556   const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
557   const auto &real_node = real_node_with_index.first;
558   if (real_node->isa<Parameter>()) {
559     if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
560       (void)parameters.emplace_back(real_node_with_index);
561     }
562   } else if (HasAbstractMonad(real_node)) {
563     return parameters;
564   } else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
565     const auto &inputs = real_node->cast<CNodePtr>()->inputs();
566     for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
567       const auto &sub_parameters = FetchAllRealInputNodeByParameter({inputs[i], 0});
568       (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
569     }
570   } else {
571     (void)parameters.emplace_back(real_node_with_index);
572   }
573   return parameters;
574 }
575 
FetchFuncGraphbyCallNode(const AnfNodePtr & node)576 std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
577   std::vector<FuncGraphPtr> func_graphs;
578   if (!node->isa<CNode>()) {
579     return func_graphs;
580   }
581 
582   const auto &call_inputs = node->cast<CNodePtr>()->inputs();
583   if (call_inputs[0]->isa<CNode>()) {
584     const auto &cnode = call_inputs[0]->cast<CNodePtr>();
585     const auto &cnode_inputs = cnode->inputs();
586     if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
587       for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) {
588         if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) {
589           (void)func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
590         } else if (IsValueNode<FuncGraph>(cnode_inputs[i])) {
591           (void)func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(cnode_inputs[i]));
592         }
593       }
594     } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) &&
595                AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) {
596       const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
597 
598       // Fetch all funcgraphs in make tuple node.
599       for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
600         const auto func_graph = FetchFuncGraphInNode(tuple_inputs[i]);
601         if (func_graph != nullptr) {
602           func_graphs.emplace_back(func_graph);
603         }
604       }
605     } else if (IsCallNode(cnode)) {
606       return FetchFuncGraphbyCallNode(cnode);
607     } else {
608       MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
609     }
610   } else if (call_inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(call_inputs[0])) {
611     (void)func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(call_inputs[0]));
612   } else {
613     MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
614   }
615   return func_graphs;
616 }
617 
FetchOutputSizebyCallNode(const AnfNodePtr & node,std::vector<AnfNodePtr> * call_nodes)618 size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) {
619   if (!IsCallNode(node)) {
620     MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node);
621   }
622   if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) {
623     return 0;
624   }
625   (void)((*call_nodes).emplace_back(node));
626 
627   const auto &func_graphs = FetchFuncGraphbyCallNode(node);
628   for (const auto &func_graph : func_graphs) {
629     const auto &output = func_graph->output();
630     const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
631 
632     if (IsCallNode(real_output.first)) {
633       size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes);
634       if (output_num > 0) {
635         return output_num;
636       }
637     } else if (AnfAlgo::CheckPrimitiveType(real_output.first, prim::kPrimMakeTuple)) {
638       size_t total_num = 0;
639       const auto &tuple_cnode = real_output.first->cast<CNodePtr>();
640       const auto &inputs = tuple_cnode->inputs();
641       size_t i = 1;
642       for (; i < inputs.size(); ++i) {
643         if (IsCallNode(inputs[i])) {
644           size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes);
645           if (call_output_num == 0) {
646             break;
647           }
648           total_num += call_output_num;
649         } else if (inputs[i]->isa<ValueNode>() && inputs[i]->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
650           auto value_tuple = inputs[i]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>();
651           MS_EXCEPTION_IF_NULL(value_tuple);
652           auto tuple_value = value_tuple->value();
653           total_num += tuple_value.size();
654         } else if (!HasAbstractMonad(inputs[i])) {
655           ++total_num;
656         }
657       }
658       if (i == inputs.size()) {
659         return total_num;
660       }
661     } else {
662       return 1;
663     }
664   }
665   return 0;
666 }
667 
FetchFuncGraphByNode(const AnfNodePtr & node)668 FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) {
669   auto front_node = GetFrontNodeByBackendNode(node);
670   // If the front node is nullptr, we can check its inputs.
671   if (front_node == nullptr) {
672     if (node->isa<CNode>()) {
673       const auto &cnode = node->cast<CNodePtr>();
674       const auto &inputs = cnode->inputs();
675 
676       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
677         const auto &func_graph = FetchFuncGraphByNode(inputs[i]);
678         if (func_graph != nullptr) {
679           return func_graph;
680         }
681       }
682     } else {
683       return nullptr;
684     }
685   }
686 
687   const auto &func_graph = front_node->func_graph();
688   return func_graph;
689 }
690 
GetFrontNodeByBackendNode(const AnfNodePtr & backend_node)691 AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
692   if (backend_node->func_graph() == nullptr) {
693     return nullptr;
694   }
695   auto kernel_graph = dynamic_cast<KernelGraph *>(backend_node->func_graph().get());
696   if (kernel_graph == nullptr) {
697     return nullptr;
698   }
699   return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
700 }
701 
GetFrontNodeByKernelGraph(const AnfNodePtr & backend_node,const KernelGraphPtr & graph)702 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
703   const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
704   if (front_node != nullptr) {
705     return {front_node, 0};
706   }
707   const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
708   if (front_node_with_index.first == nullptr) {
709     MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node);
710   }
711   return front_node_with_index;
712 }
713 
GetFuncgraphByBackendNode(const AnfNodePtr & backend_node)714 FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
715   auto front_node = GetFrontNodeByBackendNode(backend_node);
716   if (front_node == nullptr) {
717     return nullptr;
718   }
719   return front_node->func_graph();
720 }
721 
Parse(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const FuncGraphPtr & root_graph)722 void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
723                               const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
724   if (graphs.size() != device_contexts.size()) {
725     MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size()
726                       << " device context num:" << device_contexts.size();
727   }
728   if (graphs.empty()) {
729     return;
730   }
731 
732   root_func_graph_ = root_graph;
733 
734   root_graph_parameters_ = root_graph->parameters();
735 
736   CreateBranchIDForFuncGraph(control_nodes);
737 
738   RealToFormalNode real_to_formal_front_parameters;
739   FetchFrontToFrontParameter(control_nodes, &real_to_formal_front_parameters);
740 
741   RealToFormalNode formal_to_real_front_parameters;
742   for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) {
743     for (const auto formal_parameter : real_to_formal_front_parameter.second) {
744       (void)formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first);
745     }
746   }
747 
748   FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters,
749                                formal_to_real_front_parameters);
750 
751   FetchFuncGraphToParameter(control_nodes);
752 
753   FetchHostParameterToWeight(real_to_formal_front_parameters);
754 
755   FetchCallInputKernelGraph(graphs, device_contexts);
756 
757   FetchFrontValueNode(control_nodes, graphs, device_contexts);
758 
759   FetchFrontToBackendKernel(graphs, device_contexts);
760 
761   FetchCallInputKernelGraph(graphs, device_contexts);
762 
763   control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]);
764 
765   FetchFuncGraphCallNum(control_nodes);
766 
767   FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
768 
769   FetchAutoMonadNode(control_nodes);
770 }
771 
GetBackendInputByParameter(const AnfNodePtr & parameter)772 std::vector<KernelWithIndex> ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr &parameter) {
773   return formal_to_real_parameters_[parameter];
774 }
775 
FetchBackendInputNodeByFrontNode(const AnfNodePtr & front_output)776 std::set<KernelWithIndex> ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) {
777   std::set<AnfNodePtr> call_nodes;
778   std::set<AnfNodePtr> switch_nodes;
779   std::set<KernelWithIndex> results;
780   FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes, &results);
781   return results;
782 }
783 
GetBranchIDByFuncGraph(const FuncGraphPtr & func_graph)784 int ControlNodeParser::GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph) {
785   MS_EXCEPTION_IF_NULL(func_graph);
786 
787   if (func_graph_to_branch_id_.find(func_graph) == func_graph_to_branch_id_.end()) {
788     MS_LOG(EXCEPTION) << "Invalid branch id for funcgraph:" << func_graph->ToString();
789   }
790   return func_graph_to_branch_id_[func_graph];
791 }
792 
IsCallInputKernelGraph(const KernelGraphPtr & graph)793 bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) {
794   if (call_input_kernel_graphs_.find(graph) == call_input_kernel_graphs_.end()) {
795     return false;
796   }
797   return true;
798 }
799 
IsKernelInRootFuncGraph(const AnfNodePtr & kernel)800 bool ControlNodeParser::IsKernelInRootFuncGraph(const AnfNodePtr &kernel) {
801   if (kernel == nullptr) {
802     return true;
803   }
804 
805   const auto &graph = kernel->func_graph();
806   if (kernel != nullptr && graph != nullptr) {
807     const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
808     if (kernel_graph == nullptr) {
809       return true;
810     }
811 
812     const auto func_graph = kernel_graph->GetFuncGraph();
813     if (func_graph != nullptr && func_graph != root_func_graph_) {
814       return false;
815     }
816   }
817 
818   return true;
819 }
820 
GetCallNumByFuncGraph(const FuncGraphPtr & func_graph)821 size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) {
822   if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
823     MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString();
824   }
825 
826   return func_graph_to_call_num_[func_graph];
827 }
828 
FetchAllBranchOutputs(const FuncGraphPtr & func_graph)829 std::vector<AnfNodePtr> ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) {
830   std::vector<AnfNodePtr> call_nodes;
831   return FetchFuncGraphOutput(func_graph, &call_nodes);
832 }
833 
GetFrontValueNodeDeviceContext(const AnfNodePtr & value_node)834 DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node) {
835   auto iter = std::find_if(
836     front_value_nodes_.begin(), front_value_nodes_.end(),
837     [value_node](const auto &front_node_with_context) { return front_node_with_context.first == value_node; });
838   if (iter != front_value_nodes_.end()) {
839     return iter->second;
840   }
841   return nullptr;
842 }
843 
FetchBackendNodebyWeightNode(const AnfNodePtr & node)844 AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &node) {
845   for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
846     for (const auto &front_weight : host_parameter_to_weight.second) {
847       if (front_weight == node) {
848         const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
849         if (iter != front_to_backend_parameters_.end()) {
850           return iter->second.first;
851         }
852       }
853     }
854   }
855 
856   return nullptr;
857 }
858 
FetchValueNodeBySwitchNode(const AnfNodePtr & switch_node,std::vector<AnfNodePtr> * value_nodes)859 void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node,
860                                                    std::vector<AnfNodePtr> *value_nodes) {
861   const auto &cnode = switch_node->cast<CNodePtr>();
862   const auto &inputs = cnode->inputs();
863   if (inputs.size() != kSwitchInputNum) {
864     MS_LOG(EXCEPTION) << "Invalid switch node input num:" << inputs.size();
865   }
866 
867   for (const auto &input : inputs) {
868     if (input->isa<ValueNode>()) {
869       const auto &node_value = input->cast<ValueNodePtr>()->value();
870       if (node_value->isa<tensor::Tensor>()) {
871         (void)((*value_nodes).emplace_back(input));
872       }
873     } else if (IsCallNode(input)) {
874       // If input is a call not, should check the switch node in its input.
875       const auto &call_node = input->cast<CNodePtr>();
876       const auto &call_inputs = call_node->inputs();
877       if (call_inputs.empty() || (!AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch))) {
878         continue;
879       }
880       FetchValueNodeBySwitchNode(call_inputs[0], value_nodes);
881     } else if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) {
882       const auto &partial_node = input->cast<CNodePtr>();
883       const auto &partial_inputs = partial_node->inputs();
884       if (partial_inputs.size() <= kPartialFuncGraphPos) {
885         MS_LOG(EXCEPTION) << "Invalid partial node input num:" << partial_inputs.size();
886       }
887 
888       // if input is a partial node, get the value node in its funcgraph.
889       const auto &func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
890       if (func_graph->output()->isa<ValueNode>()) {
891         (void)((*value_nodes).emplace_back(func_graph->output()));
892       }
893     }
894   }
895 }
896 
FetchFrontValueNode(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)897 void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
898                                             const std::vector<KernelGraphPtr> &graphs,
899                                             const std::vector<DeviceContext *> &device_contexts) {
900   for (const auto &control_node : control_nodes) {
901     CNodePtr cnode = control_node->cast<CNodePtr>();
902     auto inputs = cnode->inputs();
903     if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
904       auto func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
905       const auto parameters = func_graph->parameters();
906       if (parameters.size() != inputs.size() - kCallInputStartPos) {
907         MS_LOG(EXCEPTION) << "Invalid parameters num, need:" << parameters.size()
908                           << " has:" << inputs.size() - kCallInputStartPos;
909       }
910       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
911         if (inputs[i]->isa<ValueNode>()) {
912           const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
913           if (!node_value->isa<tensor::Tensor>()) {
914             continue;
915           }
916           if (front_to_backend_parameters_.find(parameters[i - kCallInputStartPos]) ==
917               front_to_backend_parameters_.end()) {
918             MS_LOG(INFO) << "Cannot find backend parameter for front parameter:"
919                          << AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos])
920                          << ", used the default format";
921             CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]);
922             (void)front_value_nodes_.emplace_back(inputs[i], device_contexts[0]);
923             continue;
924           }
925 
926           const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first;
927           const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second;
928           CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context);
929           (void)front_value_nodes_.emplace_back(inputs[i], device_context);
930         }
931       }
932     }
933   }
934 
935   for (size_t index = 0; index < graphs.size(); ++index) {
936     const auto &graph = graphs[index];
937     MS_EXCEPTION_IF_NULL(graph);
938 
939     for (const auto &parameter : graph->input_nodes()) {
940       MS_EXCEPTION_IF_NULL(parameter);
941 
942       if (IsInternalParameter(parameter, graph)) {
943         auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
944         MS_EXCEPTION_IF_NULL(front_node_with_index.first);
945         const auto &front_output_with_index =
946           AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
947         auto front_output_node = front_output_with_index.first;
948         MS_EXCEPTION_IF_NULL(front_output_node);
949         if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) {
950           std::vector<AnfNodePtr> value_nodes;
951           FetchValueNodeBySwitchNode(front_output_node, &value_nodes);
952           for (const auto value_node : value_nodes) {
953             CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]);
954             (void)front_value_nodes_.emplace_back(value_node, device_contexts[index]);
955           }
956         }
957       }
958     }
959   }
960 
961   // When funcgraph called by call node returns to the value node, device addresses should be created for these
962   // value nodes.
963   for (const auto &call_node_to_backend_parameter : call_node_to_backend_parameters_) {
964     const auto func_graphs = FetchFuncGraphbyCallNode(call_node_to_backend_parameter.first.first);
965     for (const auto &func_graph : func_graphs) {
966       const auto &output = func_graph->output();
967       if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) {
968         const auto &device_context = call_node_to_backend_parameter.second.second;
969         CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context);
970         (void)front_value_nodes_.emplace_back(output, device_context);
971       }
972     }
973   }
974 }
975 
FetchFrontToFrontParameter(const std::vector<AnfNodePtr> & control_nodes,std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> * front_to_front_parameter)976 void ControlNodeParser::FetchFrontToFrontParameter(
977   const std::vector<AnfNodePtr> &control_nodes,
978   std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter) {
979   for (const auto &node : control_nodes) {
980     CNodePtr cnode = node->cast<CNodePtr>();
981     const auto &inputs = cnode->inputs();
982     if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
983       // Call node which the first input node is a valuenode of funcgraph.
984       const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
985       const auto &parameters = func_graph->parameters();
986       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
987         if (inputs[i]->isa<Parameter>()) {
988           (*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kCallInputStartPos]);
989         }
990       }
991     }
992   }
993 }
994 
FetchControlNodeParameter(const std::vector<AnfNodePtr> & control_nodes,DeviceContext * device_context)995 std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
996                                                                      DeviceContext *device_context) {
997   std::vector<AnfNodePtr> parameters = FetchParameterByControlNode(control_nodes);
998 
999   for (const auto &graph_with_device_context : call_input_kernel_graphs_) {
1000     const auto &graph = graph_with_device_context.first;
1001     const auto &func_graph = graph->GetFuncGraph();
1002     if (func_graph == nullptr) {
1003       MS_LOG(WARNING) << "Cannot get funcgraph by kernel graph:" << graph->ToString();
1004       continue;
1005     }
1006     if (func_graph != root_func_graph_) {
1007       continue;
1008     }
1009 
1010     const auto &inputs = graph->input_nodes();
1011     for (const auto &input : inputs) {
1012       const auto &front_node = graph->GetFrontAnfByBackendAnf(input);
1013       if (front_node != nullptr && front_node->isa<Parameter>() && (!HasAbstractRef(front_node))) {
1014         (void)parameters.emplace_back(front_node);
1015       }
1016     }
1017   }
1018 
1019   for (const auto &parameter : parameters) {
1020     auto backend_iter = front_to_backend_parameters_.find(parameter);
1021     if (backend_iter == front_to_backend_parameters_.end()) {
1022       CreateDeviceTensorForFrontParameter(parameter, device_context);
1023       front_to_backend_parameters_[parameter] = {parameter, device_context};
1024       (void)front_parameters_.emplace_back(parameter, device_context);
1025     }
1026   }
1027 
1028   return parameters;
1029 }
1030 
FetchFuncGraphCallNum(const std::vector<AnfNodePtr> & control_nodes)1031 void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) {
1032   for (const auto &control_node : control_nodes) {
1033     if (IsCallNode(control_node)) {
1034       const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
1035 
1036       for (const auto &func_graph : func_graphs) {
1037         MS_EXCEPTION_IF_NULL(func_graph);
1038 
1039         if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
1040           func_graph_to_call_num_[func_graph] = 1;
1041         } else {
1042           func_graph_to_call_num_[func_graph]++;
1043         }
1044       }
1045     }
1046   }
1047 }
1048 
FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)1049 void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
1050                                                   const std::vector<DeviceContext *> &device_contexts) {
1051   for (size_t i = 0; i < graphs.size(); ++i) {
1052     const auto &graph = graphs[i];
1053     const auto &device_context = device_contexts[i];
1054 
1055     const auto inputs = graph->input_nodes();
1056     for (const auto &input : inputs) {
1057       const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
1058       if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
1059         call_input_kernel_graphs_[graph] = device_context;
1060         call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
1061       }
1062     }
1063   }
1064 }
1065 
CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> & control_nodes)1066 void ControlNodeParser::CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
1067   int branch_id = 0;
1068 
1069   for (const auto &control_node : control_nodes) {
1070     // Root funcgraph does not need to create a gather actor.
1071     if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
1072       auto func_graph = control_node->func_graph();
1073       func_graph_to_branch_id_[func_graph] = branch_id++;
1074     }
1075   }
1076 }
1077 
FetchInputParameterbyControlNode(const AnfNodePtr & node,std::set<AnfNodePtr> * switch_nodes,std::set<AnfNodePtr> * call_nodes)1078 std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *switch_nodes,
1079                                                          std::set<AnfNodePtr> *call_nodes) {
1080   std::vector<AnfNodePtr> parameters;
1081 
1082   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
1083     if ((*switch_nodes).find(node) != (*switch_nodes).end()) {
1084       return parameters;
1085     }
1086     (void)(*switch_nodes).insert(node);
1087 
1088     const auto &cnode = node->cast<CNodePtr>();
1089     const auto &inputs = cnode->inputs();
1090     if (inputs.size() != kSwitchInputNum) {
1091       MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(node);
1092     }
1093 
1094     for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
1095       if (inputs[i]->isa<Parameter>()) {
1096         (void)parameters.emplace_back(inputs[i]);
1097       } else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
1098         const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes);
1099         (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
1100       }
1101     }
1102   } else if (IsCallNode(node)) {
1103     if ((*call_nodes).find(node) != (*call_nodes).end()) {
1104       return parameters;
1105     }
1106     (void)(*call_nodes).insert(node);
1107 
1108     const auto &func_graphs = FetchFuncGraphbyCallNode(node);
1109     for (const auto &func_graph : func_graphs) {
1110       if (func_graph->output()->isa<Parameter>()) {
1111         (void)parameters.emplace_back(func_graph->output());
1112       }
1113     }
1114   }
1115   return parameters;
1116 }
1117 
FetchParameterbyKernelGraph(const KernelGraphPtr & graph)1118 std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
1119   std::vector<KernelWithIndex> parameters;
1120   const auto &graph_parameters = graph->input_nodes();
1121 
1122   for (const auto &graph_parameter : graph_parameters) {
1123     const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
1124     const auto &internal_front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
1125     const auto &internal_front_node = internal_front_node_with_index.first;
1126 
1127     if (external_front_node == nullptr && internal_front_node == nullptr) {
1128       MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
1129                       << AnfAlgo::GetNodeDebugString(graph_parameter);
1130       continue;
1131     }
1132 
1133     const auto &front_node_with_index =
1134       ((external_front_node != nullptr) ? KernelWithIndex(external_front_node, 0) : internal_front_node_with_index);
1135     const auto &sub_parameters = FetchAllRealInputNodeByParameter(front_node_with_index);
1136     (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
1137   }
1138 
1139   return parameters;
1140 }
1141 
FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const RealToFormalNode & real_to_formal_front_parameters,const RealToFormalNode & formal_to_real_front_parameters)1142 void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
1143                                                      const std::vector<DeviceContext *> &device_contexts,
1144                                                      const RealToFormalNode &real_to_formal_front_parameters,
1145                                                      const RealToFormalNode &formal_to_real_front_parameters) {
1146   if (graphs.size() != device_contexts.size()) {
1147     MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
1148   }
1149 
1150   // Fetch the mapping relationship between front parameters and backend parameters in the kernel graphs.
1151   for (size_t i = 0; i < graphs.size(); ++i) {
1152     const auto &graph = graphs[i];
1153     auto device_context = device_contexts[i];
1154     for (const auto &parameter : graph->input_nodes()) {
1155       auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
1156       if (front_node != nullptr && front_node->isa<Parameter>() &&
1157           front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) {
1158         front_to_backend_parameters_[front_node] = {parameter, device_context};
1159       }
1160     }
1161   }
1162 
1163   // This for loop cannot be combined with the for loop above, because the relationship between front
1164   // and backend needs to be consistent with HostDataSource.
1165   for (size_t i = 0; i < graphs.size(); ++i) {
1166     const auto &graph = graphs[i];
1167     auto device_context = device_contexts[i];
1168     for (const auto &parameter : graph->input_nodes()) {
1169       const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(parameter);
1170 
1171       if (internal_front_node.first != nullptr) {
1172         std::set<AnfNodePtr> call_nodes;
1173         std::set<AnfNodePtr> switch_nodes;
1174         const auto &front_paramters =
1175           FetchInputParameterbyControlNode(internal_front_node.first, &switch_nodes, &call_nodes);
1176         for (const auto &front_paramter : front_paramters) {
1177           if (front_to_backend_parameters_.find(front_paramter) == front_to_backend_parameters_.end()) {
1178             front_to_backend_parameters_[front_paramter] = {parameter, device_context};
1179           }
1180         }
1181       }
1182     }
1183   }
1184 
1185   for (const auto &front_pair : real_to_formal_front_parameters) {
1186     std::set<AnfNodePtr> invalid_node;
1187     const auto &backend_node =
1188       FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters,
1189                                   front_to_backend_parameters_, &invalid_node);
1190     if (backend_node.first != nullptr) {
1191       if (front_to_backend_parameters_.find(front_pair.first) == front_to_backend_parameters_.end()) {
1192         front_to_backend_parameters_[front_pair.first] = backend_node;
1193       }
1194     }
1195   }
1196 }
1197 
FetchHostParameterToWeight(const RealToFormalNode & front_to_front_parameters)1198 void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front_to_front_parameters) {
1199   for (const auto &pair : front_to_front_parameters) {
1200     std::vector<AnfNodePtr> dest_nodes;
1201     FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters);
1202     host_parameter_to_weights_[pair.first] = dest_nodes;
1203 
1204     if (std::find(root_graph_parameters_.begin(), root_graph_parameters_.end(), pair.first) !=
1205         root_graph_parameters_.end()) {
1206       for (auto &sub_front_node : dest_nodes) {
1207         sub_front_node_to_root_front_node_[sub_front_node] = pair.first;
1208       }
1209     }
1210   }
1211 }
1212 
FetchFuncGraphToParameter(const std::vector<AnfNodePtr> & control_nodes)1213 void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes) {
1214   for (const auto &control_node : control_nodes) {
1215     const auto &cnode = control_node->cast<CNodePtr>();
1216     const auto &inputs = cnode->inputs();
1217     if (inputs.empty()) {
1218       MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node);
1219     }
1220 
1221     // Call node which the first input is a cnode.
1222     if (inputs[0]->isa<CNode>()) {
1223       const auto &switch_cnode = inputs[0]->cast<CNodePtr>();
1224 
1225       if (AnfAlgo::CheckPrimitiveType(switch_cnode, prim::kPrimSwitch)) {
1226         // Switch node.
1227         FetchParameterBySwitchNode(inputs[0], &func_graph_to_parameters_);
1228       } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
1229         // Switchlayer node.
1230         FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
1231       } else if (IsCallNode(inputs[0])) {
1232         continue;
1233       } else {
1234         MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
1235       }
1236     } else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
1237       // Call node which the first input is a value node of funcgraph.
1238       const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
1239       std::vector<AnfNodePtr> parameters;
1240       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1241         if (CheckValidFuncGraphInput(inputs[i])) {
1242           (void)parameters.emplace_back(inputs[i]);
1243         }
1244       }
1245       (void)func_graph_to_parameters_[func_graph].emplace_back(parameters);
1246     }
1247   }
1248 }
1249 
FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)1250 void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
1251                                                   const std::vector<DeviceContext *> &device_contexts) {
1252   for (size_t i = 0; i < graphs.size(); ++i) {
1253     const auto &graph = graphs[i];
1254     const auto &device_context = device_contexts[i];
1255     MS_EXCEPTION_IF_NULL(graph);
1256     auto execution_order = graph->execution_order();
1257     for (auto &kernel : execution_order) {
1258       if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
1259         auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
1260         if (front_node != nullptr) {
1261           for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) {
1262             front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context};
1263           }
1264         }
1265       }
1266     }
1267 
1268     const auto graph_output_map = graph->graph_output_map();
1269     for (const auto &output_pair : graph_output_map) {
1270       front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
1271     }
1272   }
1273 }
1274 
FetchBackendOutputByFrontOutput(const AnfNodePtr & front_output,std::set<AnfNodePtr> * call_nodes,std::set<AnfNodePtr> * switch_nodes,std::set<KernelWithIndex> * results)1275 void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
1276                                                         std::set<AnfNodePtr> *call_nodes,
1277                                                         std::set<AnfNodePtr> *switch_nodes,
1278                                                         std::set<KernelWithIndex> *results) {
1279   if (front_output->isa<ValueNode>()) {
1280     (void)(*results).emplace(front_output, 0);
1281 
1282     const auto &iter = formal_to_real_parameters_.find(front_output);
1283     if (iter != formal_to_real_parameters_.end()) {
1284       for (const auto &node : iter->second) {
1285         (void)(*results).emplace(node);
1286       }
1287     }
1288   } else if (front_output->isa<Parameter>()) {
1289     // Output is a parameter.
1290     const auto iter = formal_to_real_parameters_.find(front_output);
1291     if (iter != formal_to_real_parameters_.end()) {
1292       for (const auto &node : iter->second) {
1293         (void)(*results).emplace(node);
1294       }
1295     } else {
1296       MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output);
1297     }
1298   } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimSwitch)) {
1299     // Output is a switch.
1300     const auto &switch_outputs = FetchOutputBySwitchNode(front_output, call_nodes, switch_nodes);
1301 
1302     for (const auto &switch_output : switch_outputs) {
1303       FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
1304     }
1305   } else if (IsCallNode(front_output)) {
1306     // Output is a call.
1307     const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
1308 
1309     for (const auto &call_output : call_outputs) {
1310       FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes, results);
1311     }
1312   } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimMakeTuple)) {
1313     // Output is a make tuple.
1314     const auto &cnode = front_output->cast<CNodePtr>();
1315     const auto &inputs = cnode->inputs();
1316 
1317     for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
1318       FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes, results);
1319     }
1320   } else if (front_output->isa<CNode>()) {
1321     // Output is a kernel.
1322     const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0));
1323     if (iter != front_to_backend_kernels_.end()) {
1324       (void)(*results).emplace(iter->second.first);
1325     } else {
1326       MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
1327     }
1328   } else {
1329     MS_LOG(EXCEPTION) << "Invalid front node:" << AnfAlgo::GetNodeDebugString(front_output);
1330   }
1331 }
1332 
FetchBackendInputNodebyFrontNode(const AnfNodePtr & real_parameter,const AnfNodePtr & formal_parameter,const FrontToBackendNodeWithContext & front_to_backend_parameters)1333 void ControlNodeParser::FetchBackendInputNodebyFrontNode(
1334   const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
1335   const FrontToBackendNodeWithContext &front_to_backend_parameters) {
1336   if (real_parameter->isa<Parameter>()) {
1337     // Input node is a parameter from host data source actor.
1338     std::set<AnfNodePtr> invalid_inputs;
1339     std::vector<AnfNodePtr> front_inputs =
1340       FetchInputNodeByParameter(real_parameter, root_graph_parameters_, &invalid_inputs, func_graph_to_parameters_);
1341 
1342     for (const auto &front_input : front_inputs) {
1343       const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);
1344       if (node_with_index.first->isa<Parameter>()) {
1345         const auto &iter = front_to_backend_parameters.find(real_parameter);
1346         if (iter == front_to_backend_parameters.end()) {
1347           MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
1348           continue;
1349         }
1350         (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0);
1351       } else {
1352         const auto iter = front_to_backend_kernels_.find(node_with_index);
1353         if (iter == front_to_backend_kernels_.end()) {
1354           MS_LOG(EXCEPTION) << "Cannot find actor of front node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
1355         }
1356         (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
1357       }
1358     }
1359   } else if (real_parameter->isa<ValueNode>()) {
1360     (void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
1361   } else if (IsCallNode(real_parameter)) {
1362     const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
1363     for (const auto func_graph : func_graphs) {
1364       FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);
1365     }
1366   } else {
1367     // Input node is a cnode.
1368     const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(real_parameter, 0);
1369     const auto iter = front_to_backend_kernels_.find(node_with_index);
1370     if (iter == front_to_backend_kernels_.end()) {
1371       MS_LOG(EXCEPTION) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
1372     }
1373     (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
1374   }
1375 }
1376 
FetchBackendParameterNode(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const RealToFormalNode & real_to_formal_front_parameters,const RealToFormalNode & formal_to_real_front_parameters,FrontToBackendNodeWithContext * front_to_backend_parameters)1377 void ControlNodeParser::FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
1378                                                   const std::vector<DeviceContext *> &device_contexts,
1379                                                   const RealToFormalNode &real_to_formal_front_parameters,
1380                                                   const RealToFormalNode &formal_to_real_front_parameters,
1381                                                   FrontToBackendNodeWithContext *front_to_backend_parameters) {
1382   for (size_t i = 0; i < graphs.size(); ++i) {
1383     const auto &graph = graphs[i];
1384     const auto &device_context = device_contexts[i];
1385     if (graph->GetFuncGraph() != root_func_graph_) {
1386       continue;
1387     }
1388     for (const auto &parameter : graph->input_nodes()) {
1389       auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
1390       if (front_node != nullptr && front_node->isa<Parameter>() &&
1391           (*front_to_backend_parameters).find(front_node) == (*front_to_backend_parameters).end()) {
1392         (*front_to_backend_parameters)[front_node] = {parameter, device_context};
1393       }
1394     }
1395   }
1396   for (const auto &control_node_parameter : control_node_parameters_) {
1397     const auto &iter = front_to_backend_parameters_.find(control_node_parameter);
1398     if (iter == front_to_backend_parameters_.end()) {
1399       MS_LOG(EXCEPTION) << "Cannot find backend node for control node parameter:"
1400                         << AnfAlgo::GetNodeDebugString(control_node_parameter);
1401     }
1402     (*front_to_backend_parameters)[control_node_parameter] = iter->second;
1403   }
1404 
1405   for (const auto &front_pair : formal_to_real_front_parameters) {
1406     std::set<AnfNodePtr> invalid_node;
1407     const auto &backend_node =
1408       FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters,
1409                                   (*front_to_backend_parameters), &invalid_node);
1410     if (backend_node.first != nullptr) {
1411       if ((*front_to_backend_parameters).find(front_pair.first) == (*front_to_backend_parameters).end()) {
1412         (*front_to_backend_parameters)[front_pair.first] = backend_node;
1413       }
1414     }
1415   }
1416 }
1417 
FetchBackendInputNode(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const RealToFormalNode & real_to_formal_front_parameters,const RealToFormalNode & formal_to_real_front_parameters)1418 void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
1419                                               const std::vector<DeviceContext *> &device_contexts,
1420                                               const RealToFormalNode &real_to_formal_front_parameters,
1421                                               const RealToFormalNode &formal_to_real_front_parameters) {
1422   FrontToBackendNodeWithContext front_to_backend_parameters;
1423   FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters,
1424                             &front_to_backend_parameters);
1425 
1426   for (size_t i = 0; i < graphs.size(); ++i) {
1427     const auto &graph = graphs[i];
1428     for (const auto &value_node : graph->graph_value_nodes()) {
1429       auto front_node = graph->GetFrontAnfByBackendAnf(value_node);
1430       if (front_node != nullptr) {
1431         (void)formal_to_real_parameters_[front_node].emplace_back(value_node, 0);
1432       }
1433     }
1434   }
1435 
1436   for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
1437     for (const auto &front_weight : host_parameter_to_weight.second) {
1438       const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
1439       if (iter != front_to_backend_parameters_.end()) {
1440         (void)formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0);
1441       }
1442     }
1443   }
1444 
1445   for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
1446     const auto &func_graph = func_graph_to_parameters.first;
1447     std::vector<AnfNodePtr> graph_inputs;
1448     for (const auto &input : func_graph->get_inputs()) {
1449       // Monad input would not send to gather actor.
1450       if (HasAbstractMonad(input) || (input->isa<Parameter>() && HasAbstractRef(input))) {
1451         continue;
1452       }
1453       (void)graph_inputs.emplace_back(input);
1454     }
1455 
1456     // Collect all backend input node to gather, There are two situations:
1457     // 1. The parameter from the host data source.
1458     // 2. Output the kernel actor.
1459     for (const auto parameters : func_graph_to_parameters.second) {
1460       if (parameters.size() != graph_inputs.size()) {
1461         MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size()
1462                           << " need:" << graph_inputs.size() << " func_graph:" << func_graph->ToString();
1463       }
1464 
1465       for (size_t i = 0; i < parameters.size(); ++i) {
1466         FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i], front_to_backend_parameters);
1467       }
1468     }
1469   }
1470   for (const auto parameter_pair : front_to_backend_parameters) {
1471     (void)formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
1472   }
1473   for (const auto parameter_pair : front_to_backend_parameters_) {
1474     (void)formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
1475   }
1476 }
1477 
FetchAutoMonadNode(const std::vector<AnfNodePtr> & control_nodes)1478 void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes) {
1479   for (const auto &control_node : control_nodes) {
1480     const auto &cnode = control_node->cast<CNodePtr>();
1481     const auto &inputs = cnode->inputs();
1482     if (inputs.empty()) {
1483       MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node);
1484     }
1485 
1486     if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
1487       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1488         if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) {
1489           const auto &node = FetchSourceNodeByAutoMonad(inputs[i]);
1490           const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0));
1491           if (iter != front_to_backend_kernels_.end()) {
1492             kernel_to_call_nodes_[iter->second.first.first] = control_node;
1493           }
1494         }
1495       }
1496     }
1497   }
1498 }
1499 
FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr & sub_front_node)1500 AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node) {
1501   if (sub_front_node_to_root_front_node_.count(sub_front_node) == 0) {
1502     return sub_front_node;
1503   }
1504   return sub_front_node_to_root_front_node_[sub_front_node];
1505 }
1506 }  // namespace runtime
1507 }  // namespace mindspore
1508