1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/control_node_parser.h"
18 #include "runtime/framework/actor/switch_actor.h"
19 #include "runtime/framework/actor/gather_actor.h"
20 #include "abstract/utils.h"
21 #include "ir/tensor.h"
22
23 namespace mindspore {
24 namespace runtime {
25 namespace {
26 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
27 // Fetch all the weight parameters related to node. It runs like this:
28 // if we have a map like {{a, {b, c}}, {b, {d, e}}}, final we will get {{a, {b, c, d, e}}, {b, {c, d}}}.
FetchWeightbyHostParameter(const AnfNodePtr & node,std::vector<AnfNodePtr> * dest_nodes,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & front_to_front_weight)29 void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *dest_nodes,
30 const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &front_to_front_weight) {
31 if (find((*dest_nodes).begin(), (*dest_nodes).end(), node) != (*dest_nodes).end()) {
32 return;
33 }
34 (void)((*dest_nodes).emplace_back(node));
35 if (front_to_front_weight.find(node) == front_to_front_weight.end()) {
36 return;
37 }
38
39 const auto weight_nodes = front_to_front_weight.at(node);
40 for (const auto weight_node : weight_nodes) {
41 FetchWeightbyHostParameter(weight_node, dest_nodes, front_to_front_weight);
42 }
43 }
44
45 // Check whether the input is a valid parameter.
CheckValidFuncGraphInput(const AnfNodePtr & node)46 bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
47 if (HasAbstractMonad(node)) {
48 return false;
49 } else if (node->isa<Parameter>()) {
50 return !HasAbstractRef(node);
51 }
52 return true;
53 }
54
55 // Get the funcgraph in partial node.
GetFuncGraphFromPartial(const AnfNodePtr & node)56 FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
57 const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
58 return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
59 }
60
61 // Get the relationship between funcgraph and parameters in the switch node.
FetchParameterBySwitchNode(const AnfNodePtr & switch_node,FuncGraphToParameter * graph_to_real_parameters)62 void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParameter *graph_to_real_parameters) {
63 const auto &switch_cnode = switch_node->cast<CNodePtr>();
64 const auto &switch_inputs = switch_cnode->inputs();
65 if (switch_inputs.size() != kSwitchInputNum) {
66 MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(switch_node);
67 }
68
69 for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
70 const auto &partial_node = switch_inputs[i];
71 if (IsValueNode<FuncGraph>(partial_node)) {
72 continue;
73 }
74 const auto &func_graph = GetFuncGraphFromPartial(partial_node);
75 std::vector<AnfNodePtr> parameters;
76 const auto &partial_inputs = partial_node->cast<CNodePtr>()->inputs();
77 for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
78 if (CheckValidFuncGraphInput(partial_inputs[j])) {
79 (void)parameters.emplace_back(partial_inputs[j]);
80 }
81 }
82 (void)((*graph_to_real_parameters)[func_graph].emplace_back(parameters));
83 }
84 }
85
86 // Get the corresponding relationship between funcgraph and parameters in the switch layer node.
FetchParameterBySwitchLayerNode(const AnfNodePtr & switch_layer_node,const std::vector<AnfNodePtr> & call_inputs,FuncGraphToParameter * graph_to_real_parameters)87 void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const std::vector<AnfNodePtr> &call_inputs,
88 FuncGraphToParameter *graph_to_real_parameters) {
89 const auto &switch_layer_cnode = switch_layer_node->cast<CNodePtr>();
90 const auto &switch_layer_inputs = switch_layer_cnode->inputs();
91
92 if (switch_layer_inputs.size() != kSwitchLayerInputNum) {
93 MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(switch_layer_node);
94 }
95
96 auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
97
98 // Get the parameter corresponding to each funcgraph in make tuple.
99 for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
100 if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
101 // Tuple branch is a partial node.
102 const auto &func_graph = GetFuncGraphFromPartial(tuple_inputs[i]);
103 std::vector<AnfNodePtr> parameters;
104 const auto &partial_inputs = tuple_inputs[i]->cast<CNodePtr>()->inputs();
105
106 // Get inputs in partial node.
107 for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
108 if (CheckValidFuncGraphInput(partial_inputs[j])) {
109 (void)parameters.emplace_back(partial_inputs[j]);
110 }
111 }
112
113 // Get inputs in call node.
114 for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) {
115 if (CheckValidFuncGraphInput(call_inputs[j])) {
116 (void)parameters.emplace_back(call_inputs[j]);
117 }
118 }
119 (void)((*graph_to_real_parameters)[func_graph].emplace_back(parameters));
120 } else if (tuple_inputs[i]->isa<ValueNode>() && IsValueNode<FuncGraph>(tuple_inputs[i])) {
121 // Tuple branch is a call node.
122 const auto &func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
123 std::vector<AnfNodePtr> parameters;
124
125 // Get inputs in call node.
126 for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) {
127 if (CheckValidFuncGraphInput(call_inputs[j])) {
128 (void)parameters.emplace_back(call_inputs[j]);
129 }
130 }
131
132 (void)(*graph_to_real_parameters)[func_graph].emplace_back(parameters);
133 }
134 }
135 }
136
137 // Create a device tensor for the front node.
138 // Get the output format and select kernel build info from the backend node corresponding to the front node to
139 // create the device address.
CreateDeviceTensorForValueNode(const AnfNodePtr & front_node,const AnfNodePtr & backend_node,const DeviceContext * device_context)140 void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
141 const DeviceContext *device_context) {
142 MS_EXCEPTION_IF_NULL(device_context);
143
144 const auto &node_value = front_node->cast<ValueNodePtr>()->value();
145 if (!node_value->isa<tensor::Tensor>()) {
146 return;
147 }
148
149 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0);
150 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0);
151 if (output_type_id == kTypeUnknown) {
152 output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0);
153 }
154
155 if (front_node->kernel_info() == nullptr) {
156 front_node->set_kernel_info(std::make_shared<device::KernelInfo>());
157 }
158
159 // Get the select kernel build info.
160 auto kernel_info = dynamic_cast<device::KernelInfo *>(backend_node->kernel_info());
161 MS_EXCEPTION_IF_NULL(kernel_info);
162 auto build_info = kernel_info->GetMutableSelectKernelBuildInfo();
163 MS_EXCEPTION_IF_NULL(build_info);
164 AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get());
165
166 // Create device tensor.
167 std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0);
168 device::DeviceAddressPtr address =
169 device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
170 MS_EXCEPTION_IF_NULL(address);
171 AnfAlgo::SetOutputAddr(address, 0, front_node.get());
172 }
173
174 // Create a device tensor for front parameter.
175 // When the condition input of the switch and switchlayer or the output of a subgraph is a parameter, there is no
176 // corresponding backend node for this parameter, so a device tensor needs to be created for it.
CreateDeviceTensorForFrontParameter(const AnfNodePtr & node,const DeviceContext * device_context)177 void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceContext *device_context) {
178 MS_EXCEPTION_IF_NULL(device_context);
179
180 TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0);
181
182 if (node->kernel_info() == nullptr) {
183 auto kernel_info = std::make_shared<device::KernelInfo>();
184 std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
185 builder->SetOutputsFormat({kOpFormat_DEFAULT});
186 builder->SetOutputsDeviceType({type_id});
187 kernel_info->set_select_kernel_build_info(builder->Build());
188 node->set_kernel_info(kernel_info);
189 }
190 size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0);
191
192 // Create device tensor.
193 device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id);
194 MS_EXCEPTION_IF_NULL(address);
195 AnfAlgo::SetOutputAddr(address, 0, node.get());
196 }
197
198 // Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding
199 // backend parameter, then recursively find the backend parameters of other front parameters corresponding to the
200 // front_node.
FetchBackendNodeByFrontNode(const AnfNodePtr & front_node,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & real_to_formal_front_parameters,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & formal_to_real_front_parameters,const std::unordered_map<AnfNodePtr,std::pair<AnfNodePtr,DeviceContext * >> & front_to_backend_parameter,std::set<AnfNodePtr> * invalid_node)201 std::pair<AnfNodePtr, DeviceContext *> FetchBackendNodeByFrontNode(
202 const AnfNodePtr &front_node,
203 const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &real_to_formal_front_parameters,
204 const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &formal_to_real_front_parameters,
205 const std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>> &front_to_backend_parameter,
206 std::set<AnfNodePtr> *invalid_node) {
207 // Check whether the front_node has been looked for.
208 if ((*invalid_node).find(front_node) != (*invalid_node).end()) {
209 return std::pair<AnfNodePtr, DeviceContext *>();
210 }
211 (void)(*invalid_node).insert(front_node);
212
213 const auto front_to_backend_iter = front_to_backend_parameter.find(front_node);
214 if (front_to_backend_iter != front_to_backend_parameter.end()) {
215 return front_to_backend_iter->second;
216 }
217
218 const auto &real_to_formal_iter = real_to_formal_front_parameters.find(front_node);
219 if (real_to_formal_iter == real_to_formal_front_parameters.end()) {
220 return std::pair<AnfNodePtr, DeviceContext *>();
221 }
222 for (const auto &next_node : real_to_formal_iter->second) {
223 auto banckend_node =
224 FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters,
225 front_to_backend_parameter, invalid_node);
226 if (banckend_node.first != nullptr) {
227 return banckend_node;
228 }
229 }
230
231 const auto &formal_to_real_iter = formal_to_real_front_parameters.find(front_node);
232 if (formal_to_real_iter == formal_to_real_front_parameters.end()) {
233 return std::pair<AnfNodePtr, DeviceContext *>();
234 }
235 for (const auto &next_node : formal_to_real_iter->second) {
236 auto banckend_node =
237 FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters,
238 front_to_backend_parameter, invalid_node);
239 if (banckend_node.first != nullptr) {
240 return banckend_node;
241 }
242 }
243 return std::pair<AnfNodePtr, DeviceContext *>();
244 }
245
246 // Fetch all backend input nodes by parameter for gather actor.
FetchInputNodeByParameter(const AnfNodePtr & parameter,const std::vector<AnfNodePtr> & host_ds_parameters,std::set<AnfNodePtr> * invalid_inputs,const FuncGraphToParameter & graph_to_real_parameters)247 std::vector<AnfNodePtr> FetchInputNodeByParameter(const AnfNodePtr ¶meter,
248 const std::vector<AnfNodePtr> &host_ds_parameters,
249 std::set<AnfNodePtr> *invalid_inputs,
250 const FuncGraphToParameter &graph_to_real_parameters) {
251 std::vector<AnfNodePtr> input_nodes;
252
253 // If the node has been collected, skip it.
254 if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) {
255 return input_nodes;
256 }
257
258 // Record the node which has been collected.
259 (void)(*invalid_inputs).insert(parameter);
260
261 // If the parameter node is a parameter of host data source actor, return it.
262 if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) {
263 (void)input_nodes.emplace_back(parameter);
264 return input_nodes;
265 }
266
267 // Check the parameter which send to its funcgraph.
268 const auto &func_graph = parameter->func_graph();
269 if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) {
270 return input_nodes;
271 }
272
273 std::vector<AnfNodePtr> self_inputs;
274 for (const auto &input : func_graph->get_inputs()) {
275 // Monad input need not send to funcgraph.
276 if (HasAbstractMonad(input) || HasAbstractRef(input)) {
277 continue;
278 }
279 (void)self_inputs.emplace_back(input);
280 }
281
282 const auto iter = find(self_inputs.begin(), self_inputs.end(), parameter);
283 if (iter == self_inputs.end()) {
284 MS_LOG(EXCEPTION) << "Cannot find parameter node:" << AnfAlgo::GetNodeDebugString(parameter);
285 }
286 size_t pos = iter - self_inputs.begin();
287
288 for (const auto parameters : graph_to_real_parameters.at(func_graph)) {
289 if (parameters.size() != self_inputs.size()) {
290 MS_LOG(EXCEPTION) << "Invalid input num:" << parameters.size() << " and:" << self_inputs.size()
291 << " for func_graph:" << func_graph->ToString();
292 }
293 const auto input = parameters[pos];
294 if (input->isa<CNode>()) {
295 (void)input_nodes.emplace_back(input);
296 } else if (input->isa<Parameter>()) {
297 // If input is a parameter, you need to find its input recursively.
298 auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters);
299 (void)input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end());
300 }
301 }
302 return input_nodes;
303 }
304
305 // Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph
306 // called by the call node.
FetchFuncGraphOutput(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * call_nodes)307 std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *call_nodes) {
308 std::vector<AnfNodePtr> outputs;
309 const auto &output = func_graph->output();
310 const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0, false, {prim::kPrimTupleGetItem});
311 if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
312 return outputs;
313 }
314 if (!IsCallNode(real_output.first)) {
315 outputs.push_back(real_output.first);
316 return outputs;
317 }
318
319 (*call_nodes).push_back(real_output.first);
320 std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(real_output.first);
321 for (const auto &graph : func_graphs) {
322 auto single_outputs = FetchFuncGraphOutput(graph, call_nodes);
323 (void)outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end());
324 }
325 return outputs;
326 }
327 std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set<AnfNodePtr> *call_nodes,
328 std::set<AnfNodePtr> *switch_nodes);
329
330 // Recursive interface, get all possible output nodes of call node.
FetchOutputByCallNode(const AnfNodePtr & call_node,std::set<AnfNodePtr> * call_nodes,std::set<AnfNodePtr> * switch_nodes)331 std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::set<AnfNodePtr> *call_nodes,
332 std::set<AnfNodePtr> *switch_nodes) {
333 std::vector<AnfNodePtr> outputs;
334 if ((*call_nodes).find(call_node) != (*call_nodes).end()) {
335 return outputs;
336 }
337 (void)((*call_nodes).insert(call_node));
338
339 const auto func_graphs = FetchFuncGraphbyCallNode(call_node);
340
341 for (const auto func_graph : func_graphs) {
342 std::vector<AnfNodePtr> sub_call_nodes;
343 const std::vector<AnfNodePtr> graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes);
344 for (const auto &graph_output : graph_outputs) {
345 if (graph_output->isa<Parameter>()) {
346 outputs.push_back(graph_output);
347 } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
348 const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
349 (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
350 } else if (IsCallNode(graph_output)) {
351 const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
352 (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
353 } else if (graph_output->isa<CNode>()) {
354 (void)outputs.emplace_back(graph_output);
355 } else if (graph_output->isa<ValueNode>()) {
356 outputs.push_back(graph_output);
357 } else {
358 MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output);
359 }
360 }
361 }
362
363 return outputs;
364 }
365
366 // Recursive interface, get all possible output nodes of switch node.
FetchOutputBySwitchNode(const AnfNodePtr & switch_node,std::set<AnfNodePtr> * call_nodes,std::set<AnfNodePtr> * switch_nodes)367 std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set<AnfNodePtr> *call_nodes,
368 std::set<AnfNodePtr> *switch_nodes) {
369 std::vector<AnfNodePtr> outputs;
370 if ((*switch_nodes).find(switch_node) != (*switch_nodes).end()) {
371 return outputs;
372 }
373 (void)((*switch_nodes).insert(switch_node));
374
375 if (!switch_node->isa<CNode>()) {
376 MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node);
377 }
378 const auto &inputs = switch_node->cast<CNodePtr>()->inputs();
379 if (inputs.size() != kSwitchInputNum) {
380 MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node);
381 }
382
383 for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
384 if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimPartial)) {
385 continue;
386 } else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
387 const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes);
388 (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
389 } else if (IsCallNode(inputs[i])) {
390 const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes);
391 (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
392 } else {
393 (void)outputs.emplace_back(inputs[i]);
394 }
395 }
396
397 return outputs;
398 }
399
400 // Recursive interface, get the real kernel that UpdateState node depends on.
FetchSourceNodeByAutoMonad(const AnfNodePtr & node)401 AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) {
402 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
403 const auto &cnode = node->cast<CNodePtr>();
404 const auto &inputs = cnode->inputs();
405 if (inputs.size() <= kUpdateStateRealInput) {
406 MS_LOG(EXCEPTION) << "Invalid updatestate node:" << AnfAlgo::GetNodeDebugString(node);
407 }
408
409 return FetchSourceNodeByAutoMonad(inputs[kUpdateStateRealInput]);
410 }
411 return node;
412 }
413
414 // Fetch all parameters in control node of root funcgraph.
FetchParameterByControlNode(const std::vector<AnfNodePtr> & control_nodes)415 std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr> &control_nodes) {
416 std::vector<AnfNodePtr> parameters;
417
418 for (const auto &control_node : control_nodes) {
419 CNodePtr cnode = control_node->cast<CNodePtr>();
420 const auto &inputs = cnode->inputs();
421 if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
422 break;
423 } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) {
424 for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) {
425 if (inputs[i]->isa<Parameter>()) {
426 (void)parameters.emplace_back(inputs[i]);
427 }
428 }
429 } else if (cnode->input(0)->isa<CNode>() || IsValueNode<FuncGraph>(cnode->input(0))) {
430 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
431 if (inputs[i]->isa<Parameter>()) {
432 (void)parameters.emplace_back(inputs[i]);
433 }
434 }
435 } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) {
436 if (inputs.size() != kSwitchInputNum) {
437 MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
438 }
439 if (inputs[kSwitchCondPos]->isa<Parameter>()) {
440 (void)parameters.emplace_back(inputs[kSwitchCondPos]);
441 }
442 } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
443 if (inputs.size() != kSwitchLayerInputNum) {
444 MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node);
445 }
446 if (inputs[kSwitchLayerCondPos]->isa<Parameter>()) {
447 (void)parameters.emplace_back(inputs[kSwitchLayerCondPos]);
448 }
449 }
450 }
451 return parameters;
452 }
453
454 // Get funcgraph from node, the interface only accepts partial node and funcgraph value node.
FetchFuncGraphInNode(const auto & node)455 FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
456 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
457 const auto &func_graph = GetFuncGraphFromPartial(node);
458
459 if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
460 return FetchFuncGraphInNode(func_graph->output());
461 } else if (IsValueNode<FuncGraph>(func_graph->output())) {
462 // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
463 return FetchFuncGraphInNode(func_graph->output());
464 }
465
466 return func_graph;
467 } else if (IsValueNode<FuncGraph>(node)) {
468 const auto &func_graph = GetValueNode<FuncGraphPtr>(node);
469
470 if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
471 // When the output of funcgraph is a funcgraph, it needs to return the funcgraph that is finally called.
472 return FetchFuncGraphInNode(func_graph->output());
473 } else if (IsValueNode<FuncGraph>(func_graph->output())) {
474 // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
475 return FetchFuncGraphInNode(func_graph->output());
476 }
477
478 return func_graph;
479 }
480
481 return nullptr;
482 }
483 } // namespace
484
FetchRealOutputByCallNode(const AnfNodePtr & node,std::set<AnfNodePtr> * call_nodes)485 AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
486 const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
487 if (!IsCallNode(real_node)) {
488 return real_node;
489 }
490 if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
491 return nullptr;
492 }
493 (void)((*call_nodes).insert(real_node));
494
495 const auto &func_graphs = FetchFuncGraphbyCallNode(real_node);
496 for (const auto &func_graph : func_graphs) {
497 const auto &output = FetchRealOutputByCallNode(func_graph->output(), call_nodes);
498 if (output != nullptr) {
499 return output;
500 }
501 }
502 return nullptr;
503 }
504
505 // Return true if the node has Ref abstract.
HasAbstractRef(const AnfNodePtr & node)506 bool HasAbstractRef(const AnfNodePtr &node) {
507 if (node == nullptr) {
508 return false;
509 }
510 auto &abs = node->abstract();
511 return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
512 }
513
IsCallNode(const AnfNodePtr & node)514 bool IsCallNode(const AnfNodePtr &node) {
515 if (!node->isa<CNode>()) {
516 return false;
517 }
518 const auto &cnode = node->cast<CNodePtr>();
519 const auto &inputs = cnode->inputs();
520 return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
521 }
522
IsSubCallNode(const AnfNodePtr & node)523 bool IsSubCallNode(const AnfNodePtr &node) {
524 if (!node->isa<CNode>()) {
525 return false;
526 }
527
528 const auto inputs = node->cast<CNodePtr>()->inputs();
529 if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
530 return false;
531 }
532
533 const auto &switch_layer_inputs = inputs[0]->cast<CNodePtr>()->inputs();
534 const auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
535 if (tuple_inputs.size() <= kMakeTupleInputStartPos) {
536 return false;
537 }
538
539 // Check whether the funcgraph called by the call node returns funcgraph or partial node.
540 FuncGraphPtr func_graph = nullptr;
541 if (AnfAlgo::CheckPrimitiveType(tuple_inputs[kMakeTupleInputStartPos], prim::kPrimPartial)) {
542 const auto &func_graph_node = tuple_inputs[kMakeTupleInputStartPos]->cast<CNodePtr>()->input(kPartialFuncGraphPos);
543 func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
544 } else if (tuple_inputs[kMakeTupleInputStartPos]->isa<ValueNode>() &&
545 IsValueNode<FuncGraph>(tuple_inputs[kMakeTupleInputStartPos])) {
546 func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[kMakeTupleInputStartPos]);
547 }
548
549 const auto &output = func_graph->output();
550 return AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial) ||
551 (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output));
552 }
553
FetchAllRealInputNodeByParameter(const KernelWithIndex & node)554 std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) {
555 std::vector<KernelWithIndex> parameters;
556 const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
557 const auto &real_node = real_node_with_index.first;
558 if (real_node->isa<Parameter>()) {
559 if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
560 (void)parameters.emplace_back(real_node_with_index);
561 }
562 } else if (HasAbstractMonad(real_node)) {
563 return parameters;
564 } else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
565 const auto &inputs = real_node->cast<CNodePtr>()->inputs();
566 for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
567 const auto &sub_parameters = FetchAllRealInputNodeByParameter({inputs[i], 0});
568 (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
569 }
570 } else {
571 (void)parameters.emplace_back(real_node_with_index);
572 }
573 return parameters;
574 }
575
FetchFuncGraphbyCallNode(const AnfNodePtr & node)576 std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
577 std::vector<FuncGraphPtr> func_graphs;
578 if (!node->isa<CNode>()) {
579 return func_graphs;
580 }
581
582 const auto &call_inputs = node->cast<CNodePtr>()->inputs();
583 if (call_inputs[0]->isa<CNode>()) {
584 const auto &cnode = call_inputs[0]->cast<CNodePtr>();
585 const auto &cnode_inputs = cnode->inputs();
586 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
587 for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) {
588 if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) {
589 (void)func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
590 } else if (IsValueNode<FuncGraph>(cnode_inputs[i])) {
591 (void)func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(cnode_inputs[i]));
592 }
593 }
594 } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) &&
595 AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) {
596 const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
597
598 // Fetch all funcgraphs in make tuple node.
599 for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
600 const auto func_graph = FetchFuncGraphInNode(tuple_inputs[i]);
601 if (func_graph != nullptr) {
602 func_graphs.emplace_back(func_graph);
603 }
604 }
605 } else if (IsCallNode(cnode)) {
606 return FetchFuncGraphbyCallNode(cnode);
607 } else {
608 MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
609 }
610 } else if (call_inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(call_inputs[0])) {
611 (void)func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(call_inputs[0]));
612 } else {
613 MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
614 }
615 return func_graphs;
616 }
617
FetchOutputSizebyCallNode(const AnfNodePtr & node,std::vector<AnfNodePtr> * call_nodes)618 size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) {
619 if (!IsCallNode(node)) {
620 MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node);
621 }
622 if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) {
623 return 0;
624 }
625 (void)((*call_nodes).emplace_back(node));
626
627 const auto &func_graphs = FetchFuncGraphbyCallNode(node);
628 for (const auto &func_graph : func_graphs) {
629 const auto &output = func_graph->output();
630 const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
631
632 if (IsCallNode(real_output.first)) {
633 size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes);
634 if (output_num > 0) {
635 return output_num;
636 }
637 } else if (AnfAlgo::CheckPrimitiveType(real_output.first, prim::kPrimMakeTuple)) {
638 size_t total_num = 0;
639 const auto &tuple_cnode = real_output.first->cast<CNodePtr>();
640 const auto &inputs = tuple_cnode->inputs();
641 size_t i = 1;
642 for (; i < inputs.size(); ++i) {
643 if (IsCallNode(inputs[i])) {
644 size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes);
645 if (call_output_num == 0) {
646 break;
647 }
648 total_num += call_output_num;
649 } else if (inputs[i]->isa<ValueNode>() && inputs[i]->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
650 auto value_tuple = inputs[i]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>();
651 MS_EXCEPTION_IF_NULL(value_tuple);
652 auto tuple_value = value_tuple->value();
653 total_num += tuple_value.size();
654 } else if (!HasAbstractMonad(inputs[i])) {
655 ++total_num;
656 }
657 }
658 if (i == inputs.size()) {
659 return total_num;
660 }
661 } else {
662 return 1;
663 }
664 }
665 return 0;
666 }
667
FetchFuncGraphByNode(const AnfNodePtr & node)668 FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) {
669 auto front_node = GetFrontNodeByBackendNode(node);
670 // If the front node is nullptr, we can check its inputs.
671 if (front_node == nullptr) {
672 if (node->isa<CNode>()) {
673 const auto &cnode = node->cast<CNodePtr>();
674 const auto &inputs = cnode->inputs();
675
676 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
677 const auto &func_graph = FetchFuncGraphByNode(inputs[i]);
678 if (func_graph != nullptr) {
679 return func_graph;
680 }
681 }
682 } else {
683 return nullptr;
684 }
685 }
686
687 const auto &func_graph = front_node->func_graph();
688 return func_graph;
689 }
690
GetFrontNodeByBackendNode(const AnfNodePtr & backend_node)691 AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
692 if (backend_node->func_graph() == nullptr) {
693 return nullptr;
694 }
695 auto kernel_graph = dynamic_cast<KernelGraph *>(backend_node->func_graph().get());
696 if (kernel_graph == nullptr) {
697 return nullptr;
698 }
699 return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
700 }
701
GetFrontNodeByKernelGraph(const AnfNodePtr & backend_node,const KernelGraphPtr & graph)702 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
703 const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
704 if (front_node != nullptr) {
705 return {front_node, 0};
706 }
707 const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
708 if (front_node_with_index.first == nullptr) {
709 MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node);
710 }
711 return front_node_with_index;
712 }
713
GetFuncgraphByBackendNode(const AnfNodePtr & backend_node)714 FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
715 auto front_node = GetFrontNodeByBackendNode(backend_node);
716 if (front_node == nullptr) {
717 return nullptr;
718 }
719 return front_node->func_graph();
720 }
721
Parse(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const FuncGraphPtr & root_graph)722 void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
723 const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph) {
724 if (graphs.size() != device_contexts.size()) {
725 MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size()
726 << " device context num:" << device_contexts.size();
727 }
728 if (graphs.empty()) {
729 return;
730 }
731
732 root_func_graph_ = root_graph;
733
734 root_graph_parameters_ = root_graph->parameters();
735
736 CreateBranchIDForFuncGraph(control_nodes);
737
738 RealToFormalNode real_to_formal_front_parameters;
739 FetchFrontToFrontParameter(control_nodes, &real_to_formal_front_parameters);
740
741 RealToFormalNode formal_to_real_front_parameters;
742 for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) {
743 for (const auto formal_parameter : real_to_formal_front_parameter.second) {
744 (void)formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first);
745 }
746 }
747
748 FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters,
749 formal_to_real_front_parameters);
750
751 FetchFuncGraphToParameter(control_nodes);
752
753 FetchHostParameterToWeight(real_to_formal_front_parameters);
754
755 FetchCallInputKernelGraph(graphs, device_contexts);
756
757 FetchFrontValueNode(control_nodes, graphs, device_contexts);
758
759 FetchFrontToBackendKernel(graphs, device_contexts);
760
761 FetchCallInputKernelGraph(graphs, device_contexts);
762
763 control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]);
764
765 FetchFuncGraphCallNum(control_nodes);
766
767 FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters);
768
769 FetchAutoMonadNode(control_nodes);
770 }
771
GetBackendInputByParameter(const AnfNodePtr & parameter)772 std::vector<KernelWithIndex> ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr ¶meter) {
773 return formal_to_real_parameters_[parameter];
774 }
775
FetchBackendInputNodeByFrontNode(const AnfNodePtr & front_output)776 std::set<KernelWithIndex> ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) {
777 std::set<AnfNodePtr> call_nodes;
778 std::set<AnfNodePtr> switch_nodes;
779 std::set<KernelWithIndex> results;
780 FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes, &results);
781 return results;
782 }
783
GetBranchIDByFuncGraph(const FuncGraphPtr & func_graph)784 int ControlNodeParser::GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph) {
785 MS_EXCEPTION_IF_NULL(func_graph);
786
787 if (func_graph_to_branch_id_.find(func_graph) == func_graph_to_branch_id_.end()) {
788 MS_LOG(EXCEPTION) << "Invalid branch id for funcgraph:" << func_graph->ToString();
789 }
790 return func_graph_to_branch_id_[func_graph];
791 }
792
IsCallInputKernelGraph(const KernelGraphPtr & graph)793 bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) {
794 if (call_input_kernel_graphs_.find(graph) == call_input_kernel_graphs_.end()) {
795 return false;
796 }
797 return true;
798 }
799
IsKernelInRootFuncGraph(const AnfNodePtr & kernel)800 bool ControlNodeParser::IsKernelInRootFuncGraph(const AnfNodePtr &kernel) {
801 if (kernel == nullptr) {
802 return true;
803 }
804
805 const auto &graph = kernel->func_graph();
806 if (kernel != nullptr && graph != nullptr) {
807 const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
808 if (kernel_graph == nullptr) {
809 return true;
810 }
811
812 const auto func_graph = kernel_graph->GetFuncGraph();
813 if (func_graph != nullptr && func_graph != root_func_graph_) {
814 return false;
815 }
816 }
817
818 return true;
819 }
820
GetCallNumByFuncGraph(const FuncGraphPtr & func_graph)821 size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) {
822 if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
823 MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString();
824 }
825
826 return func_graph_to_call_num_[func_graph];
827 }
828
FetchAllBranchOutputs(const FuncGraphPtr & func_graph)829 std::vector<AnfNodePtr> ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) {
830 std::vector<AnfNodePtr> call_nodes;
831 return FetchFuncGraphOutput(func_graph, &call_nodes);
832 }
833
GetFrontValueNodeDeviceContext(const AnfNodePtr & value_node)834 DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node) {
835 auto iter = std::find_if(
836 front_value_nodes_.begin(), front_value_nodes_.end(),
837 [value_node](const auto &front_node_with_context) { return front_node_with_context.first == value_node; });
838 if (iter != front_value_nodes_.end()) {
839 return iter->second;
840 }
841 return nullptr;
842 }
843
FetchBackendNodebyWeightNode(const AnfNodePtr & node)844 AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &node) {
845 for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
846 for (const auto &front_weight : host_parameter_to_weight.second) {
847 if (front_weight == node) {
848 const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
849 if (iter != front_to_backend_parameters_.end()) {
850 return iter->second.first;
851 }
852 }
853 }
854 }
855
856 return nullptr;
857 }
858
FetchValueNodeBySwitchNode(const AnfNodePtr & switch_node,std::vector<AnfNodePtr> * value_nodes)859 void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node,
860 std::vector<AnfNodePtr> *value_nodes) {
861 const auto &cnode = switch_node->cast<CNodePtr>();
862 const auto &inputs = cnode->inputs();
863 if (inputs.size() != kSwitchInputNum) {
864 MS_LOG(EXCEPTION) << "Invalid switch node input num:" << inputs.size();
865 }
866
867 for (const auto &input : inputs) {
868 if (input->isa<ValueNode>()) {
869 const auto &node_value = input->cast<ValueNodePtr>()->value();
870 if (node_value->isa<tensor::Tensor>()) {
871 (void)((*value_nodes).emplace_back(input));
872 }
873 } else if (IsCallNode(input)) {
874 // If input is a call not, should check the switch node in its input.
875 const auto &call_node = input->cast<CNodePtr>();
876 const auto &call_inputs = call_node->inputs();
877 if (call_inputs.empty() || (!AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch))) {
878 continue;
879 }
880 FetchValueNodeBySwitchNode(call_inputs[0], value_nodes);
881 } else if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) {
882 const auto &partial_node = input->cast<CNodePtr>();
883 const auto &partial_inputs = partial_node->inputs();
884 if (partial_inputs.size() <= kPartialFuncGraphPos) {
885 MS_LOG(EXCEPTION) << "Invalid partial node input num:" << partial_inputs.size();
886 }
887
888 // if input is a partial node, get the value node in its funcgraph.
889 const auto &func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
890 if (func_graph->output()->isa<ValueNode>()) {
891 (void)((*value_nodes).emplace_back(func_graph->output()));
892 }
893 }
894 }
895 }
896
FetchFrontValueNode(const std::vector<AnfNodePtr> & control_nodes,const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)897 void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
898 const std::vector<KernelGraphPtr> &graphs,
899 const std::vector<DeviceContext *> &device_contexts) {
900 for (const auto &control_node : control_nodes) {
901 CNodePtr cnode = control_node->cast<CNodePtr>();
902 auto inputs = cnode->inputs();
903 if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
904 auto func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
905 const auto parameters = func_graph->parameters();
906 if (parameters.size() != inputs.size() - kCallInputStartPos) {
907 MS_LOG(EXCEPTION) << "Invalid parameters num, need:" << parameters.size()
908 << " has:" << inputs.size() - kCallInputStartPos;
909 }
910 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
911 if (inputs[i]->isa<ValueNode>()) {
912 const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
913 if (!node_value->isa<tensor::Tensor>()) {
914 continue;
915 }
916 if (front_to_backend_parameters_.find(parameters[i - kCallInputStartPos]) ==
917 front_to_backend_parameters_.end()) {
918 MS_LOG(INFO) << "Cannot find backend parameter for front parameter:"
919 << AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos])
920 << ", used the default format";
921 CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]);
922 (void)front_value_nodes_.emplace_back(inputs[i], device_contexts[0]);
923 continue;
924 }
925
926 const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first;
927 const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second;
928 CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context);
929 (void)front_value_nodes_.emplace_back(inputs[i], device_context);
930 }
931 }
932 }
933 }
934
935 for (size_t index = 0; index < graphs.size(); ++index) {
936 const auto &graph = graphs[index];
937 MS_EXCEPTION_IF_NULL(graph);
938
939 for (const auto ¶meter : graph->input_nodes()) {
940 MS_EXCEPTION_IF_NULL(parameter);
941
942 if (IsInternalParameter(parameter, graph)) {
943 auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter);
944 MS_EXCEPTION_IF_NULL(front_node_with_index.first);
945 const auto &front_output_with_index =
946 AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
947 auto front_output_node = front_output_with_index.first;
948 MS_EXCEPTION_IF_NULL(front_output_node);
949 if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) {
950 std::vector<AnfNodePtr> value_nodes;
951 FetchValueNodeBySwitchNode(front_output_node, &value_nodes);
952 for (const auto value_node : value_nodes) {
953 CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]);
954 (void)front_value_nodes_.emplace_back(value_node, device_contexts[index]);
955 }
956 }
957 }
958 }
959 }
960
961 // When funcgraph called by call node returns to the value node, device addresses should be created for these
962 // value nodes.
963 for (const auto &call_node_to_backend_parameter : call_node_to_backend_parameters_) {
964 const auto func_graphs = FetchFuncGraphbyCallNode(call_node_to_backend_parameter.first.first);
965 for (const auto &func_graph : func_graphs) {
966 const auto &output = func_graph->output();
967 if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) {
968 const auto &device_context = call_node_to_backend_parameter.second.second;
969 CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context);
970 (void)front_value_nodes_.emplace_back(output, device_context);
971 }
972 }
973 }
974 }
975
FetchFrontToFrontParameter(const std::vector<AnfNodePtr> & control_nodes,std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> * front_to_front_parameter)976 void ControlNodeParser::FetchFrontToFrontParameter(
977 const std::vector<AnfNodePtr> &control_nodes,
978 std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter) {
979 for (const auto &node : control_nodes) {
980 CNodePtr cnode = node->cast<CNodePtr>();
981 const auto &inputs = cnode->inputs();
982 if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
983 // Call node which the first input node is a valuenode of funcgraph.
984 const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
985 const auto ¶meters = func_graph->parameters();
986 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
987 if (inputs[i]->isa<Parameter>()) {
988 (*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kCallInputStartPos]);
989 }
990 }
991 }
992 }
993 }
994
FetchControlNodeParameter(const std::vector<AnfNodePtr> & control_nodes,DeviceContext * device_context)995 std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
996 DeviceContext *device_context) {
997 std::vector<AnfNodePtr> parameters = FetchParameterByControlNode(control_nodes);
998
999 for (const auto &graph_with_device_context : call_input_kernel_graphs_) {
1000 const auto &graph = graph_with_device_context.first;
1001 const auto &func_graph = graph->GetFuncGraph();
1002 if (func_graph == nullptr) {
1003 MS_LOG(WARNING) << "Cannot get funcgraph by kernel graph:" << graph->ToString();
1004 continue;
1005 }
1006 if (func_graph != root_func_graph_) {
1007 continue;
1008 }
1009
1010 const auto &inputs = graph->input_nodes();
1011 for (const auto &input : inputs) {
1012 const auto &front_node = graph->GetFrontAnfByBackendAnf(input);
1013 if (front_node != nullptr && front_node->isa<Parameter>() && (!HasAbstractRef(front_node))) {
1014 (void)parameters.emplace_back(front_node);
1015 }
1016 }
1017 }
1018
1019 for (const auto ¶meter : parameters) {
1020 auto backend_iter = front_to_backend_parameters_.find(parameter);
1021 if (backend_iter == front_to_backend_parameters_.end()) {
1022 CreateDeviceTensorForFrontParameter(parameter, device_context);
1023 front_to_backend_parameters_[parameter] = {parameter, device_context};
1024 (void)front_parameters_.emplace_back(parameter, device_context);
1025 }
1026 }
1027
1028 return parameters;
1029 }
1030
FetchFuncGraphCallNum(const std::vector<AnfNodePtr> & control_nodes)1031 void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) {
1032 for (const auto &control_node : control_nodes) {
1033 if (IsCallNode(control_node)) {
1034 const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
1035
1036 for (const auto &func_graph : func_graphs) {
1037 MS_EXCEPTION_IF_NULL(func_graph);
1038
1039 if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
1040 func_graph_to_call_num_[func_graph] = 1;
1041 } else {
1042 func_graph_to_call_num_[func_graph]++;
1043 }
1044 }
1045 }
1046 }
1047 }
1048
FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)1049 void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
1050 const std::vector<DeviceContext *> &device_contexts) {
1051 for (size_t i = 0; i < graphs.size(); ++i) {
1052 const auto &graph = graphs[i];
1053 const auto &device_context = device_contexts[i];
1054
1055 const auto inputs = graph->input_nodes();
1056 for (const auto &input : inputs) {
1057 const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
1058 if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
1059 call_input_kernel_graphs_[graph] = device_context;
1060 call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
1061 }
1062 }
1063 }
1064 }
1065
CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> & control_nodes)1066 void ControlNodeParser::CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
1067 int branch_id = 0;
1068
1069 for (const auto &control_node : control_nodes) {
1070 // Root funcgraph does not need to create a gather actor.
1071 if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
1072 auto func_graph = control_node->func_graph();
1073 func_graph_to_branch_id_[func_graph] = branch_id++;
1074 }
1075 }
1076 }
1077
FetchInputParameterbyControlNode(const AnfNodePtr & node,std::set<AnfNodePtr> * switch_nodes,std::set<AnfNodePtr> * call_nodes)1078 std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *switch_nodes,
1079 std::set<AnfNodePtr> *call_nodes) {
1080 std::vector<AnfNodePtr> parameters;
1081
1082 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
1083 if ((*switch_nodes).find(node) != (*switch_nodes).end()) {
1084 return parameters;
1085 }
1086 (void)(*switch_nodes).insert(node);
1087
1088 const auto &cnode = node->cast<CNodePtr>();
1089 const auto &inputs = cnode->inputs();
1090 if (inputs.size() != kSwitchInputNum) {
1091 MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(node);
1092 }
1093
1094 for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
1095 if (inputs[i]->isa<Parameter>()) {
1096 (void)parameters.emplace_back(inputs[i]);
1097 } else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
1098 const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes);
1099 (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
1100 }
1101 }
1102 } else if (IsCallNode(node)) {
1103 if ((*call_nodes).find(node) != (*call_nodes).end()) {
1104 return parameters;
1105 }
1106 (void)(*call_nodes).insert(node);
1107
1108 const auto &func_graphs = FetchFuncGraphbyCallNode(node);
1109 for (const auto &func_graph : func_graphs) {
1110 if (func_graph->output()->isa<Parameter>()) {
1111 (void)parameters.emplace_back(func_graph->output());
1112 }
1113 }
1114 }
1115 return parameters;
1116 }
1117
FetchParameterbyKernelGraph(const KernelGraphPtr & graph)1118 std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
1119 std::vector<KernelWithIndex> parameters;
1120 const auto &graph_parameters = graph->input_nodes();
1121
1122 for (const auto &graph_parameter : graph_parameters) {
1123 const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
1124 const auto &internal_front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
1125 const auto &internal_front_node = internal_front_node_with_index.first;
1126
1127 if (external_front_node == nullptr && internal_front_node == nullptr) {
1128 MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
1129 << AnfAlgo::GetNodeDebugString(graph_parameter);
1130 continue;
1131 }
1132
1133 const auto &front_node_with_index =
1134 ((external_front_node != nullptr) ? KernelWithIndex(external_front_node, 0) : internal_front_node_with_index);
1135 const auto &sub_parameters = FetchAllRealInputNodeByParameter(front_node_with_index);
1136 (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
1137 }
1138
1139 return parameters;
1140 }
1141
FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const RealToFormalNode & real_to_formal_front_parameters,const RealToFormalNode & formal_to_real_front_parameters)1142 void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
1143 const std::vector<DeviceContext *> &device_contexts,
1144 const RealToFormalNode &real_to_formal_front_parameters,
1145 const RealToFormalNode &formal_to_real_front_parameters) {
1146 if (graphs.size() != device_contexts.size()) {
1147 MS_LOG(EXCEPTION) << "Graph num is not equal to device context num.";
1148 }
1149
1150 // Fetch the mapping relationship between front parameters and backend parameters in the kernel graphs.
1151 for (size_t i = 0; i < graphs.size(); ++i) {
1152 const auto &graph = graphs[i];
1153 auto device_context = device_contexts[i];
1154 for (const auto ¶meter : graph->input_nodes()) {
1155 auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
1156 if (front_node != nullptr && front_node->isa<Parameter>() &&
1157 front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) {
1158 front_to_backend_parameters_[front_node] = {parameter, device_context};
1159 }
1160 }
1161 }
1162
1163 // This for loop cannot be combined with the for loop above, because the relationship between front
1164 // and backend needs to be consistent with HostDataSource.
1165 for (size_t i = 0; i < graphs.size(); ++i) {
1166 const auto &graph = graphs[i];
1167 auto device_context = device_contexts[i];
1168 for (const auto ¶meter : graph->input_nodes()) {
1169 const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(parameter);
1170
1171 if (internal_front_node.first != nullptr) {
1172 std::set<AnfNodePtr> call_nodes;
1173 std::set<AnfNodePtr> switch_nodes;
1174 const auto &front_paramters =
1175 FetchInputParameterbyControlNode(internal_front_node.first, &switch_nodes, &call_nodes);
1176 for (const auto &front_paramter : front_paramters) {
1177 if (front_to_backend_parameters_.find(front_paramter) == front_to_backend_parameters_.end()) {
1178 front_to_backend_parameters_[front_paramter] = {parameter, device_context};
1179 }
1180 }
1181 }
1182 }
1183 }
1184
1185 for (const auto &front_pair : real_to_formal_front_parameters) {
1186 std::set<AnfNodePtr> invalid_node;
1187 const auto &backend_node =
1188 FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters,
1189 front_to_backend_parameters_, &invalid_node);
1190 if (backend_node.first != nullptr) {
1191 if (front_to_backend_parameters_.find(front_pair.first) == front_to_backend_parameters_.end()) {
1192 front_to_backend_parameters_[front_pair.first] = backend_node;
1193 }
1194 }
1195 }
1196 }
1197
FetchHostParameterToWeight(const RealToFormalNode & front_to_front_parameters)1198 void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front_to_front_parameters) {
1199 for (const auto &pair : front_to_front_parameters) {
1200 std::vector<AnfNodePtr> dest_nodes;
1201 FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters);
1202 host_parameter_to_weights_[pair.first] = dest_nodes;
1203
1204 if (std::find(root_graph_parameters_.begin(), root_graph_parameters_.end(), pair.first) !=
1205 root_graph_parameters_.end()) {
1206 for (auto &sub_front_node : dest_nodes) {
1207 sub_front_node_to_root_front_node_[sub_front_node] = pair.first;
1208 }
1209 }
1210 }
1211 }
1212
FetchFuncGraphToParameter(const std::vector<AnfNodePtr> & control_nodes)1213 void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes) {
1214 for (const auto &control_node : control_nodes) {
1215 const auto &cnode = control_node->cast<CNodePtr>();
1216 const auto &inputs = cnode->inputs();
1217 if (inputs.empty()) {
1218 MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node);
1219 }
1220
1221 // Call node which the first input is a cnode.
1222 if (inputs[0]->isa<CNode>()) {
1223 const auto &switch_cnode = inputs[0]->cast<CNodePtr>();
1224
1225 if (AnfAlgo::CheckPrimitiveType(switch_cnode, prim::kPrimSwitch)) {
1226 // Switch node.
1227 FetchParameterBySwitchNode(inputs[0], &func_graph_to_parameters_);
1228 } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
1229 // Switchlayer node.
1230 FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
1231 } else if (IsCallNode(inputs[0])) {
1232 continue;
1233 } else {
1234 MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
1235 }
1236 } else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
1237 // Call node which the first input is a value node of funcgraph.
1238 const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
1239 std::vector<AnfNodePtr> parameters;
1240 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1241 if (CheckValidFuncGraphInput(inputs[i])) {
1242 (void)parameters.emplace_back(inputs[i]);
1243 }
1244 }
1245 (void)func_graph_to_parameters_[func_graph].emplace_back(parameters);
1246 }
1247 }
1248 }
1249
FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts)1250 void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
1251 const std::vector<DeviceContext *> &device_contexts) {
1252 for (size_t i = 0; i < graphs.size(); ++i) {
1253 const auto &graph = graphs[i];
1254 const auto &device_context = device_contexts[i];
1255 MS_EXCEPTION_IF_NULL(graph);
1256 auto execution_order = graph->execution_order();
1257 for (auto &kernel : execution_order) {
1258 if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
1259 auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
1260 if (front_node != nullptr) {
1261 for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) {
1262 front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context};
1263 }
1264 }
1265 }
1266 }
1267
1268 const auto graph_output_map = graph->graph_output_map();
1269 for (const auto &output_pair : graph_output_map) {
1270 front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
1271 }
1272 }
1273 }
1274
FetchBackendOutputByFrontOutput(const AnfNodePtr & front_output,std::set<AnfNodePtr> * call_nodes,std::set<AnfNodePtr> * switch_nodes,std::set<KernelWithIndex> * results)1275 void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output,
1276 std::set<AnfNodePtr> *call_nodes,
1277 std::set<AnfNodePtr> *switch_nodes,
1278 std::set<KernelWithIndex> *results) {
1279 if (front_output->isa<ValueNode>()) {
1280 (void)(*results).emplace(front_output, 0);
1281
1282 const auto &iter = formal_to_real_parameters_.find(front_output);
1283 if (iter != formal_to_real_parameters_.end()) {
1284 for (const auto &node : iter->second) {
1285 (void)(*results).emplace(node);
1286 }
1287 }
1288 } else if (front_output->isa<Parameter>()) {
1289 // Output is a parameter.
1290 const auto iter = formal_to_real_parameters_.find(front_output);
1291 if (iter != formal_to_real_parameters_.end()) {
1292 for (const auto &node : iter->second) {
1293 (void)(*results).emplace(node);
1294 }
1295 } else {
1296 MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output);
1297 }
1298 } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimSwitch)) {
1299 // Output is a switch.
1300 const auto &switch_outputs = FetchOutputBySwitchNode(front_output, call_nodes, switch_nodes);
1301
1302 for (const auto &switch_output : switch_outputs) {
1303 FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
1304 }
1305 } else if (IsCallNode(front_output)) {
1306 // Output is a call.
1307 const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
1308
1309 for (const auto &call_output : call_outputs) {
1310 FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes, results);
1311 }
1312 } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimMakeTuple)) {
1313 // Output is a make tuple.
1314 const auto &cnode = front_output->cast<CNodePtr>();
1315 const auto &inputs = cnode->inputs();
1316
1317 for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
1318 FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes, results);
1319 }
1320 } else if (front_output->isa<CNode>()) {
1321 // Output is a kernel.
1322 const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0));
1323 if (iter != front_to_backend_kernels_.end()) {
1324 (void)(*results).emplace(iter->second.first);
1325 } else {
1326 MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output);
1327 }
1328 } else {
1329 MS_LOG(EXCEPTION) << "Invalid front node:" << AnfAlgo::GetNodeDebugString(front_output);
1330 }
1331 }
1332
FetchBackendInputNodebyFrontNode(const AnfNodePtr & real_parameter,const AnfNodePtr & formal_parameter,const FrontToBackendNodeWithContext & front_to_backend_parameters)1333 void ControlNodeParser::FetchBackendInputNodebyFrontNode(
1334 const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
1335 const FrontToBackendNodeWithContext &front_to_backend_parameters) {
1336 if (real_parameter->isa<Parameter>()) {
1337 // Input node is a parameter from host data source actor.
1338 std::set<AnfNodePtr> invalid_inputs;
1339 std::vector<AnfNodePtr> front_inputs =
1340 FetchInputNodeByParameter(real_parameter, root_graph_parameters_, &invalid_inputs, func_graph_to_parameters_);
1341
1342 for (const auto &front_input : front_inputs) {
1343 const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0);
1344 if (node_with_index.first->isa<Parameter>()) {
1345 const auto &iter = front_to_backend_parameters.find(real_parameter);
1346 if (iter == front_to_backend_parameters.end()) {
1347 MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
1348 continue;
1349 }
1350 (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0);
1351 } else {
1352 const auto iter = front_to_backend_kernels_.find(node_with_index);
1353 if (iter == front_to_backend_kernels_.end()) {
1354 MS_LOG(EXCEPTION) << "Cannot find actor of front node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
1355 }
1356 (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
1357 }
1358 }
1359 } else if (real_parameter->isa<ValueNode>()) {
1360 (void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
1361 } else if (IsCallNode(real_parameter)) {
1362 const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
1363 for (const auto func_graph : func_graphs) {
1364 FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);
1365 }
1366 } else {
1367 // Input node is a cnode.
1368 const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(real_parameter, 0);
1369 const auto iter = front_to_backend_kernels_.find(node_with_index);
1370 if (iter == front_to_backend_kernels_.end()) {
1371 MS_LOG(EXCEPTION) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first);
1372 }
1373 (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first);
1374 }
1375 }
1376
FetchBackendParameterNode(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const RealToFormalNode & real_to_formal_front_parameters,const RealToFormalNode & formal_to_real_front_parameters,FrontToBackendNodeWithContext * front_to_backend_parameters)1377 void ControlNodeParser::FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
1378 const std::vector<DeviceContext *> &device_contexts,
1379 const RealToFormalNode &real_to_formal_front_parameters,
1380 const RealToFormalNode &formal_to_real_front_parameters,
1381 FrontToBackendNodeWithContext *front_to_backend_parameters) {
1382 for (size_t i = 0; i < graphs.size(); ++i) {
1383 const auto &graph = graphs[i];
1384 const auto &device_context = device_contexts[i];
1385 if (graph->GetFuncGraph() != root_func_graph_) {
1386 continue;
1387 }
1388 for (const auto ¶meter : graph->input_nodes()) {
1389 auto front_node = graph->GetFrontAnfByBackendAnf(parameter);
1390 if (front_node != nullptr && front_node->isa<Parameter>() &&
1391 (*front_to_backend_parameters).find(front_node) == (*front_to_backend_parameters).end()) {
1392 (*front_to_backend_parameters)[front_node] = {parameter, device_context};
1393 }
1394 }
1395 }
1396 for (const auto &control_node_parameter : control_node_parameters_) {
1397 const auto &iter = front_to_backend_parameters_.find(control_node_parameter);
1398 if (iter == front_to_backend_parameters_.end()) {
1399 MS_LOG(EXCEPTION) << "Cannot find backend node for control node parameter:"
1400 << AnfAlgo::GetNodeDebugString(control_node_parameter);
1401 }
1402 (*front_to_backend_parameters)[control_node_parameter] = iter->second;
1403 }
1404
1405 for (const auto &front_pair : formal_to_real_front_parameters) {
1406 std::set<AnfNodePtr> invalid_node;
1407 const auto &backend_node =
1408 FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters,
1409 (*front_to_backend_parameters), &invalid_node);
1410 if (backend_node.first != nullptr) {
1411 if ((*front_to_backend_parameters).find(front_pair.first) == (*front_to_backend_parameters).end()) {
1412 (*front_to_backend_parameters)[front_pair.first] = backend_node;
1413 }
1414 }
1415 }
1416 }
1417
FetchBackendInputNode(const std::vector<KernelGraphPtr> & graphs,const std::vector<DeviceContext * > & device_contexts,const RealToFormalNode & real_to_formal_front_parameters,const RealToFormalNode & formal_to_real_front_parameters)1418 void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
1419 const std::vector<DeviceContext *> &device_contexts,
1420 const RealToFormalNode &real_to_formal_front_parameters,
1421 const RealToFormalNode &formal_to_real_front_parameters) {
1422 FrontToBackendNodeWithContext front_to_backend_parameters;
1423 FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters,
1424 &front_to_backend_parameters);
1425
1426 for (size_t i = 0; i < graphs.size(); ++i) {
1427 const auto &graph = graphs[i];
1428 for (const auto &value_node : graph->graph_value_nodes()) {
1429 auto front_node = graph->GetFrontAnfByBackendAnf(value_node);
1430 if (front_node != nullptr) {
1431 (void)formal_to_real_parameters_[front_node].emplace_back(value_node, 0);
1432 }
1433 }
1434 }
1435
1436 for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
1437 for (const auto &front_weight : host_parameter_to_weight.second) {
1438 const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
1439 if (iter != front_to_backend_parameters_.end()) {
1440 (void)formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0);
1441 }
1442 }
1443 }
1444
1445 for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
1446 const auto &func_graph = func_graph_to_parameters.first;
1447 std::vector<AnfNodePtr> graph_inputs;
1448 for (const auto &input : func_graph->get_inputs()) {
1449 // Monad input would not send to gather actor.
1450 if (HasAbstractMonad(input) || (input->isa<Parameter>() && HasAbstractRef(input))) {
1451 continue;
1452 }
1453 (void)graph_inputs.emplace_back(input);
1454 }
1455
1456 // Collect all backend input node to gather, There are two situations:
1457 // 1. The parameter from the host data source.
1458 // 2. Output the kernel actor.
1459 for (const auto parameters : func_graph_to_parameters.second) {
1460 if (parameters.size() != graph_inputs.size()) {
1461 MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size()
1462 << " need:" << graph_inputs.size() << " func_graph:" << func_graph->ToString();
1463 }
1464
1465 for (size_t i = 0; i < parameters.size(); ++i) {
1466 FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i], front_to_backend_parameters);
1467 }
1468 }
1469 }
1470 for (const auto parameter_pair : front_to_backend_parameters) {
1471 (void)formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
1472 }
1473 for (const auto parameter_pair : front_to_backend_parameters_) {
1474 (void)formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0);
1475 }
1476 }
1477
FetchAutoMonadNode(const std::vector<AnfNodePtr> & control_nodes)1478 void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes) {
1479 for (const auto &control_node : control_nodes) {
1480 const auto &cnode = control_node->cast<CNodePtr>();
1481 const auto &inputs = cnode->inputs();
1482 if (inputs.empty()) {
1483 MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node);
1484 }
1485
1486 if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
1487 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1488 if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) {
1489 const auto &node = FetchSourceNodeByAutoMonad(inputs[i]);
1490 const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0));
1491 if (iter != front_to_backend_kernels_.end()) {
1492 kernel_to_call_nodes_[iter->second.first.first] = control_node;
1493 }
1494 }
1495 }
1496 }
1497 }
1498 }
1499
FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr & sub_front_node)1500 AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node) {
1501 if (sub_front_node_to_root_front_node_.count(sub_front_node) == 0) {
1502 return sub_front_node;
1503 }
1504 return sub_front_node_to_root_front_node_[sub_front_node];
1505 }
1506 } // namespace runtime
1507 } // namespace mindspore
1508