1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
18 #include <vector>
19 #include "runtime/graph_scheduler/scheduler_helper.h"
20 #include "ops/framework_ops.h"
21
22 namespace mindspore {
23 namespace runtime {
IsInlineKernelActor(const AbstractActorPtr & actor)24 bool IsInlineKernelActor(const AbstractActorPtr &actor) {
25 MS_EXCEPTION_IF_NULL(actor);
26 if (actor->type() != KernelTransformType::kKernelActor &&
27 actor->type() != KernelTransformType::kConditionGatherActor &&
28 actor->type() != KernelTransformType::kConditionSwitchActor) {
29 return false;
30 }
31 const auto &kernel_actor = dynamic_cast<KernelActor *>(actor.get());
32 MS_EXCEPTION_IF_NULL(kernel_actor);
33 MS_EXCEPTION_IF_NULL(kernel_actor->kernel());
34 const auto &func_graph = kernel_actor->kernel()->func_graph();
35 if (func_graph == nullptr || (!func_graph->isa<KernelGraph>())) {
36 return false;
37 }
38 const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
39 MS_EXCEPTION_IF_NULL(kernel_graph);
40 return kernel_graph->inline_sub_graph_kernels().find(kernel_actor->kernel()) !=
41 kernel_graph->inline_sub_graph_kernels().end();
42 }
43
44 namespace {
GetBranchNameByKernelActor(const KernelActor * const kernel_actor)45 std::string GetBranchNameByKernelActor(const KernelActor *const kernel_actor) {
46 MS_EXCEPTION_IF_NULL(kernel_actor);
47 MS_EXCEPTION_IF_NULL(kernel_actor->kernel());
48 const auto &func_graph = kernel_actor->kernel()->func_graph();
49 if (func_graph == nullptr || (!func_graph->isa<KernelGraph>())) {
50 MS_LOG(EXCEPTION) << "Invalid funcgraph in kernel:" << kernel_actor->kernel()->fullname_with_scope();
51 }
52 const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
53 MS_EXCEPTION_IF_NULL(kernel_graph);
54 return kernel_graph->inline_sub_graph_kernels().at(kernel_actor->kernel());
55 }
56
GetBranchNameToCondtionActor(const KernelGraphPtr & graph,mindspore::HashMap<std::string,AbstractActor * > * branch_name_to_switch_actor,mindspore::HashMap<std::string,AbstractActor * > * branch_name_to_gather_actor)57 void GetBranchNameToCondtionActor(const KernelGraphPtr &graph,
58 mindspore::HashMap<std::string, AbstractActor *> *branch_name_to_switch_actor,
59 mindspore::HashMap<std::string, AbstractActor *> *branch_name_to_gather_actor) {
60 MS_EXCEPTION_IF_NULL(branch_name_to_gather_actor);
61 MS_EXCEPTION_IF_NULL(branch_name_to_switch_actor);
62 for (const auto &gather_to_switch : graph->condition_gather_to_switch()) {
63 MS_EXCEPTION_IF_NULL(gather_to_switch.first);
64 if (!common::AnfAlgo::CheckPrimitiveType(gather_to_switch.first, prim::kPrimConditionGather) ||
65 !common::AnfAlgo::CheckPrimitiveType(gather_to_switch.second, prim::kPrimConditionSwitch)) {
66 MS_LOG_WITH_NODE(EXCEPTION, gather_to_switch.first)
67 << "Invalid condition gather node:" << gather_to_switch.first->DebugString()
68 << " or condition switch node:" << gather_to_switch.second->DebugString();
69 }
70 const auto &gather_cnode = gather_to_switch.first->cast<CNodePtr>();
71 const auto &switch_cnode = gather_to_switch.second->cast<CNodePtr>();
72 MS_EXCEPTION_IF_NULL(gather_cnode);
73 MS_EXCEPTION_IF_NULL(switch_cnode);
74 if (!gather_cnode->HasAttr(kAttrBranchGraphName)) {
75 MS_LOG_WITH_NODE(EXCEPTION, gather_cnode)
76 << "Failed to get inline graph name by node:" << gather_cnode->fullname_with_scope();
77 }
78 const auto &branch_graph_names = gather_cnode->GetAttr(kAttrBranchGraphName);
79 MS_EXCEPTION_IF_NULL(branch_graph_names);
80 MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
81 << " for node:" << gather_cnode->fullname_with_scope();
82 if (!branch_graph_names->isa<ValueTuple>()) {
83 MS_LOG_WITH_NODE(EXCEPTION, gather_cnode) << "Invalid branch group name:" << branch_graph_names->ToString()
84 << " for node:" << gather_cnode->fullname_with_scope();
85 }
86 const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
87 MS_EXCEPTION_IF_NULL(tuple_name);
88 const auto &gather_actor = FetchActor(GetActorIdByKernel(gather_cnode));
89 const auto &switch_actor = FetchActor(GetActorIdByKernel(switch_cnode));
90 MS_EXCEPTION_IF_NULL(gather_actor);
91 MS_EXCEPTION_IF_NULL(switch_actor);
92 for (const auto &value : tuple_name->value()) {
93 const auto &branch_name = GetValue<std::string>(value);
94 (*branch_name_to_gather_actor)[branch_name] = gather_actor;
95 (*branch_name_to_switch_actor)[branch_name] = switch_actor;
96 }
97 }
98 }
99 } // namespace
100
LinkControlArrowByExecutionOrder(const KernelGraphPtr & graph,const GraphCompilerInfo & graph_compiler_info) const101 void InlineControlFlowScheduler::LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
102 const GraphCompilerInfo &graph_compiler_info) const {
103 MS_EXCEPTION_IF_NULL(graph);
104 const auto &inline_sub_graph_kernels = graph->inline_sub_graph_kernels();
105 if (graph->is_graph_run_mode() || graph->is_any_type_input() || inline_sub_graph_kernels.empty()) {
106 return;
107 }
108
109 mindspore::HashMap<std::string, AbstractActor *> branch_name_to_switch_actor;
110 mindspore::HashMap<std::string, AbstractActor *> branch_name_to_gather_actor;
111 GetBranchNameToCondtionActor(graph, &branch_name_to_switch_actor, &branch_name_to_gather_actor);
112
113 MS_LOG(DEBUG) << "Link control arrow for graph:" << graph->ToString();
114 // Only link control arrow between kernels in the same graph.
115 mindspore::HashMap<std::string, AbstractActor *> branch_last_actor;
116 for (size_t i = 0; i < graph->execution_order().size(); ++i) {
117 const auto &to_kernel = graph->execution_order()[i];
118 if (IsRpcActor(to_kernel)) {
119 MS_LOG(INFO) << "Rpc op is not available in the execution order, from kernel: "
120 << graph->execution_order()[i - 1]->fullname_with_scope()
121 << ", to kernel:" << graph->execution_order()[i]->fullname_with_scope();
122 continue;
123 }
124 const auto &iter = inline_sub_graph_kernels.find(to_kernel);
125 std::string current_branch = graph->ToString();
126 if (iter != inline_sub_graph_kernels.end()) {
127 current_branch = iter->second;
128 MS_LOG(DEBUG) << "Kernel:" << to_kernel->fullname_with_scope() << " branch:" << current_branch;
129 }
130
131 const auto to_kernel_type = FetchKernelTransformType(to_kernel, graph, {}, GraphExecutionStrategy::kPipeline);
132 auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_kernel, graph);
133 const auto &actor_iter = branch_last_actor.find(current_branch);
134 if (actor_iter == branch_last_actor.end()) {
135 if (!common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
136 branch_last_actor[current_branch] = to_actor;
137 MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
138 }
139 continue;
140 }
141 MS_LOG(DEBUG) << "Add control arrow between " << actor_iter->second->GetAID() << " and " << to_actor->GetAID();
142 SchedulerHelper::AddControlArrow(actor_iter->second, to_actor);
143 if (common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
144 // The control relation end after the condition switch node in graph.
145 branch_last_actor.erase(current_branch);
146 MS_LOG(DEBUG) << "For branch:" << current_branch << " end actor:" << to_actor->GetAID();
147 } else {
148 // The control relation start first kernel in graph.
149 branch_last_actor[current_branch] = to_actor;
150 MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
151 }
152 }
153
154 for (const auto &pair : branch_last_actor) {
155 const auto &branch_name = pair.first;
156 if (pair.second == nullptr || pair.second->type() != KernelTransformType::kKernelActor) {
157 continue;
158 }
159 const auto &iter = branch_name_to_gather_actor.find(branch_name);
160 if (iter == branch_name_to_gather_actor.end()) {
161 MS_LOG(INFO) << "Invalid branch name:" << branch_name << " in graph:" << graph->ToString();
162 continue;
163 }
164 SchedulerHelper::AddControlArrow(pair.second, iter->second);
165 MS_LOG(DEBUG) << "Add control arrow between:" << pair.second->GetAID() << " to:" << iter->second->GetAID();
166 }
167 }
168
169 // Get the branch name by input data arrow.
GetBranchNameByConditionGatherActor(KernelActor * condition_switch_actor,KernelActor * condition_gather_actor,DataArrow * data_arrow,const KernelGraphPtr & kernel_graph)170 std::string InlineControlFlowScheduler::GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
171 KernelActor *condition_gather_actor,
172 DataArrow *data_arrow,
173 const KernelGraphPtr &kernel_graph) {
174 MS_EXCEPTION_IF_NULL(condition_switch_actor);
175 MS_EXCEPTION_IF_NULL(condition_gather_actor);
176 MS_EXCEPTION_IF_NULL(data_arrow);
177 MS_EXCEPTION_IF_NULL(kernel_graph);
178 const auto &condition_gather_kernel = condition_gather_actor->kernel();
179 MS_EXCEPTION_IF_NULL(condition_gather_kernel);
180 auto gather_to_switch = kernel_graph->condition_gather_to_switch();
181 const auto &condition_pair_iter = gather_to_switch.find(condition_gather_kernel);
182 if (condition_pair_iter == gather_to_switch.end() ||
183 condition_pair_iter->second != condition_switch_actor->kernel()) {
184 MS_LOG(EXCEPTION) << "Condition switch actor:" << condition_switch_actor->GetAID()
185 << " and gather actor:" << condition_gather_actor << " is not match.";
186 }
187 if (!condition_gather_kernel->HasAttr(kAttrBranchOutputNum)) {
188 MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
189 }
190 // Get the output branch index in condition gather actor.
191 const auto &output_value = condition_gather_kernel->GetAttr(kAttrBranchOutputNum);
192 MS_EXCEPTION_IF_NULL(output_value);
193 size_t branch_index = IntToSize(data_arrow->to_input_index_) / GetValue<size_t>(output_value);
194 if (!condition_gather_kernel->HasAttr(kAttrBranchGraphName)) {
195 MS_LOG(EXCEPTION) << "Failed to get branch graph name by actor:" << condition_gather_actor->GetAID();
196 }
197
198 // Get output branch name by branch index.
199 const auto &branch_graph_names = condition_gather_kernel->GetAttr(kAttrBranchGraphName);
200 MS_EXCEPTION_IF_NULL(branch_graph_names);
201 MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
202 << " for actor:" << condition_gather_actor->GetAID();
203 if (!branch_graph_names->isa<ValueTuple>()) {
204 MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
205 << " for actor:" << condition_gather_actor->GetAID();
206 }
207 const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
208 MS_EXCEPTION_IF_NULL(tuple_name);
209 if (branch_index >= tuple_name->size()) {
210 MS_LOG(EXCEPTION) << "Invalid to index:" << data_arrow->to_input_index_
211 << " output num:" << GetValue<size_t>(output_value)
212 << " branch graph name:" << tuple_name->ToString()
213 << " from actor:" << condition_switch_actor->GetAID()
214 << " to actor:" << condition_gather_actor->GetAID();
215 }
216 MS_EXCEPTION_IF_NULL(tuple_name->value()[branch_index]);
217 return GetValue<std::string>(tuple_name->value()[branch_index]);
218 }
219
InitOutputDataBranchInfoForConditionSwitchActor(ConditionSwitchActor * const condition_switch_actor,const KernelGraphPtr & kernel_graph)220 void InlineControlFlowScheduler::InitOutputDataBranchInfoForConditionSwitchActor(
221 ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
222 const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
223 size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
224 condition_switch_actor->output_data_branch_indexes_.resize(condition_switch_actor->output_data_arrows().size());
225 // Get the index for each output data arrow.
226 for (size_t i = 0; i < condition_switch_actor->output_data_arrows().size(); ++i) {
227 const auto &output_node = condition_switch_actor->output_data_nodes()[i];
228 const auto &data_arrow = condition_switch_actor->output_data_arrows()[i];
229 MS_EXCEPTION_IF_NULL(output_node);
230 MS_EXCEPTION_IF_NULL(data_arrow);
231 const auto &to_actor = FetchActor(data_arrow->to_op_id_.Name());
232 if (to_actor == nullptr) {
233 MS_LOG(EXCEPTION) << "Failed to get actor:" << data_arrow->to_op_id_.Name()
234 << " from actor:" << condition_switch_actor->GetAID();
235 }
236 if (to_actor->type() != KernelTransformType::kConditionSwitchActor &&
237 to_actor->type() != KernelTransformType::kConditionGatherActor &&
238 to_actor->type() != KernelTransformType::kKernelActor) {
239 MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
240 << " from actor:" << condition_switch_actor->GetAID();
241 }
242
243 const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
244 MS_EXCEPTION_IF_NULL(to_kernel_actor);
245 MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
246 std::string current_branch_name;
247 if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
248 current_branch_name =
249 GetBranchNameByConditionGatherActor(condition_switch_actor, to_kernel_actor, data_arrow.get(), kernel_graph);
250 } else {
251 if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
252 MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by data user node:"
253 << to_kernel_actor->kernel()->fullname_with_scope()
254 << " in actor:" << condition_switch_actor->GetAID();
255 }
256 MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
257 << " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
258 << " in actor:" << condition_switch_actor->GetAID()
259 << " from index:" << data_arrow->from_output_index_ << " to actor:" << data_arrow->to_op_id_
260 << " to index:" << data_arrow->to_input_index_;
261 current_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
262 }
263 // Get branch index for output data arrow.
264 const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
265 condition_switch_actor->branch_names_.end(), current_branch_name);
266 if (iter == condition_switch_actor->branch_names_.end()) {
267 MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
268 << " total branch name:" << condition_switch_actor->branch_names_
269 << " from actor:" << condition_switch_actor->GetAID() << " to actor:" << to_actor->GetAID();
270 }
271 size_t branch_index = LongToSize(iter - condition_switch_actor->branch_names_.begin());
272 if (IntToSize(data_arrow->from_output_index_) >= output_num ||
273 branch_index >= condition_switch_actor->branch_names_.size()) {
274 MS_LOG(EXCEPTION) << "Invalid output index:" << data_arrow->from_output_index_ << " total:" << output_num
275 << " and branch index:" << branch_index
276 << " total:" << condition_switch_actor->branch_names_.size()
277 << " for actor:" << condition_switch_actor->GetAID();
278 }
279 condition_switch_actor->output_data_branch_indexes_[i] = branch_index;
280 condition_switch_actor->branch_origin_ref_count_[branch_index][data_arrow->from_output_index_]++;
281 }
282 }
283
InitOutputControlBranchInfoForConditionSwitchActor(ConditionSwitchActor * const condition_switch_actor,const KernelGraphPtr & kernel_graph)284 void InlineControlFlowScheduler::InitOutputControlBranchInfoForConditionSwitchActor(
285 ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
286 const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
287 condition_switch_actor->output_control_branch_indexes_.resize(condition_switch_actor->output_control_arrows().size());
288 // Get the index for each output control arrow.
289 for (size_t i = 0; i < condition_switch_actor->output_control_arrows().size(); ++i) {
290 const auto &arrow = condition_switch_actor->output_control_arrows()[i];
291 MS_EXCEPTION_IF_NULL(arrow);
292 const auto &to_actor = FetchActor(arrow->to_op_id_.Name());
293 if (to_actor == nullptr) {
294 MS_LOG(EXCEPTION) << "Failed to get actor:" << arrow->to_op_id_.Name()
295 << " from actor:" << condition_switch_actor->GetAID();
296 }
297 if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
298 condition_switch_actor->output_control_branch_indexes_[i] = SIZE_MAX;
299 continue;
300 }
301 if (to_actor->type() != KernelTransformType::kKernelActor &&
302 to_actor->type() != KernelTransformType::kConditionSwitchActor) {
303 MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
304 << " from actor:" << condition_switch_actor->GetAID();
305 }
306 const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
307 MS_EXCEPTION_IF_NULL(to_kernel_actor);
308 MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
309 if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
310 MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by control user node:"
311 << to_kernel_actor->kernel()->fullname_with_scope()
312 << " in actor:" << condition_switch_actor->GetAID();
313 }
314 MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
315 << " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
316 << " in actor:" << condition_switch_actor->GetAID() << " to actor:" << arrow->to_op_id_;
317 const auto ¤t_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
318 const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
319 condition_switch_actor->branch_names_.end(), current_branch_name);
320 if (iter == condition_switch_actor->branch_names_.end()) {
321 MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
322 << " total branch name:" << condition_switch_actor->branch_names_
323 << " for actor:" << condition_switch_actor->GetAID();
324 }
325 size_t branch_index = LongToSize(iter - condition_switch_actor->branch_names_.begin());
326 condition_switch_actor->output_control_branch_indexes_[i] = branch_index;
327 }
328 }
329
InitOutputBranchInfoForConditionSwitchActor(ConditionSwitchActor * const condition_switch_actor,const KernelGraphPtr & kernel_graph)330 void InlineControlFlowScheduler::InitOutputBranchInfoForConditionSwitchActor(
331 ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
332 if (condition_switch_actor->output_data_nodes().size() != condition_switch_actor->output_data_arrows().size()) {
333 MS_LOG(EXCEPTION) << "Invalid data node size:" << condition_switch_actor->output_data_nodes().size()
334 << " and arrow size:" << condition_switch_actor->output_data_arrows().size()
335 << " for actor:" << condition_switch_actor->GetAID();
336 }
337 InitOutputDataBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
338 InitOutputControlBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
339 MS_LOG(DEBUG) << "Branch origin ref count:" << condition_switch_actor->branch_origin_ref_count_
340 << " output data branch index:" << condition_switch_actor->output_data_branch_indexes_
341 << " output control branch index:" << condition_switch_actor->output_control_branch_indexes_
342 << " for actor:" << condition_switch_actor->GetAID();
343 }
344
HandleConditionSwitchActor(const KernelActorPtr & kernel_actor)345 void InlineControlFlowScheduler::HandleConditionSwitchActor(const KernelActorPtr &kernel_actor) {
346 MS_EXCEPTION_IF_NULL(kernel_actor);
347 const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(kernel_actor.get());
348 MS_EXCEPTION_IF_NULL(condition_switch_actor);
349 MS_EXCEPTION_IF_NULL(condition_switch_actor->kernel());
350 const auto &graph = condition_switch_actor->kernel()->func_graph();
351 if (graph == nullptr || !graph->isa<KernelGraph>()) {
352 MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_switch_actor->GetAID();
353 }
354 const auto &kernel_graph = graph->cast<KernelGraphPtr>();
355 MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
356 << " by actor:" << condition_switch_actor->GetAID();
357 if (!condition_switch_actor->kernel()->HasAttr(kInlineSubGraphName)) {
358 MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_switch_actor->GetAID();
359 }
360 const auto &inline_sub_graph_names = condition_switch_actor->kernel()->GetAttr(kInlineSubGraphName);
361 MS_EXCEPTION_IF_NULL(inline_sub_graph_names);
362 MS_LOG(DEBUG) << "inline sub graph name:" << inline_sub_graph_names->ToString()
363 << " for actor:" << condition_switch_actor->GetAID();
364 if (!inline_sub_graph_names->isa<ValueTuple>()) {
365 MS_LOG(EXCEPTION) << "Invalid input subgraph name:" << inline_sub_graph_names->ToString()
366 << " for actor:" << condition_switch_actor->GetAID();
367 }
368 const auto &tuple_name = inline_sub_graph_names->cast<ValueTuplePtr>();
369 MS_EXCEPTION_IF_NULL(tuple_name);
370 std::vector<std::string> branch_names;
371 for_each(tuple_name->value().begin(), tuple_name->value().end(),
372 [&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
373 condition_switch_actor->branch_names_ = branch_names;
374 // Fix ref count.
375 size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
376 condition_switch_actor->branch_origin_ref_count_ =
377 std::vector<std::vector<size_t>>(tuple_name->size(), vector<size_t>(output_num, 0));
378
379 InitOutputBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
380 }
381
AddRefCountForConditionSwitchActor(ConditionSwitchActor * const switch_actor,const std::string & branch_name,size_t output_index,size_t ref_count)382 void InlineControlFlowScheduler::AddRefCountForConditionSwitchActor(ConditionSwitchActor *const switch_actor,
383 const std::string &branch_name, size_t output_index,
384 size_t ref_count) {
385 const auto &iter = std::find(switch_actor->branch_names_.begin(), switch_actor->branch_names_.end(), branch_name);
386 if (iter == switch_actor->branch_names_.end()) {
387 MS_LOG(EXCEPTION) << "Failed to get branch name:" << branch_name << " total:" << switch_actor->branch_names_
388 << " in actor:" << switch_actor->GetAID();
389 }
390 size_t index = LongToSize(iter - switch_actor->branch_names_.begin());
391 if (index >= switch_actor->branch_origin_ref_count_.size()) {
392 MS_LOG(EXCEPTION) << " Invalid index:" << index
393 << " for branch origin ref count:" << switch_actor->branch_origin_ref_count_
394 << " for actor:" << switch_actor->GetAID();
395 }
396 if (output_index >= switch_actor->branch_origin_ref_count_[index].size()) {
397 MS_LOG(EXCEPTION) << " Invalid output index:" << output_index << " branch index:" << index
398 << " for branch origin ref count:" << switch_actor->branch_origin_ref_count_
399 << " for actor:" << switch_actor->GetAID();
400 }
401 MS_LOG(DEBUG) << "Add ref count:" << ref_count << " for branch index:" << index << " index:" << output_index
402 << " origin ref count:" << switch_actor->branch_origin_ref_count_
403 << " for actor:" << switch_actor->GetAID();
404 switch_actor->branch_origin_ref_count_[index][output_index] += ref_count;
405 }
406
FixRefCountForRefNode(const KernelWithIndex & input_with_index,size_t ref_count,const std::string & branch_name,const KernelGraph * const kernel_graph)407 void InlineControlFlowScheduler::FixRefCountForRefNode(const KernelWithIndex &input_with_index, size_t ref_count,
408 const std::string &branch_name,
409 const KernelGraph *const kernel_graph) {
410 MS_EXCEPTION_IF_NULL(kernel_graph);
411 MS_EXCEPTION_IF_NULL(input_with_index.first);
412 auto new_branch_name = branch_name;
413 if (common::AnfAlgo::CheckPrimitiveType(input_with_index.first, prim::kPrimConditionSwitch)) {
414 MS_LOG(DEBUG) << "Check switch node:" << input_with_index.first->fullname_with_scope()
415 << " index:" << input_with_index.second << " ref count:" << ref_count
416 << " branch name:" << branch_name;
417 const auto &actor = FetchActor(GetActorIdByKernel(input_with_index.first));
418 MS_EXCEPTION_IF_NULL(actor);
419 const auto &switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
420 MS_EXCEPTION_IF_NULL(switch_actor);
421 AddRefCountForConditionSwitchActor(switch_actor, branch_name, input_with_index.second, ref_count);
422 const auto &iter = kernel_graph->inline_sub_graph_kernels().find(input_with_index.first);
423 new_branch_name =
424 (iter == kernel_graph->inline_sub_graph_kernels().end() ? kernel_graph->ToString() : iter->second);
425 MS_LOG(DEBUG) << "Switch branch name from:" << branch_name << " to:" << new_branch_name
426 << " by switch node:" << input_with_index.first->fullname_with_scope()
427 << " in kernel graph:" << kernel_graph->ToString() << " ref count:" << ref_count;
428 } else if (common::AnfAlgo::CheckPrimitiveType(input_with_index.first, prim::kPrimConditionGather)) {
429 const auto &actor = FetchActor(GetActorIdByKernel(input_with_index.first));
430 MS_EXCEPTION_IF_NULL(actor);
431 const auto &gather_actor = dynamic_cast<ConditionGatherActor *>(actor);
432 MS_EXCEPTION_IF_NULL(gather_actor);
433 const auto &gather_cnode = input_with_index.first->cast<CNodePtr>();
434 size_t input_num = common::AnfAlgo::GetInputNum(gather_cnode);
435 if (input_num == 0 || input_num != gather_actor->branch_names_.size() * gather_actor->branch_output_num_) {
436 MS_LOG_WITH_NODE(EXCEPTION, gather_cnode)
437 << "Invalid input num:" << input_num << " branch output num:" << gather_actor->branch_output_num_
438 << " branch num:" << gather_actor->branch_names_.size() << " for node:" << gather_cnode->fullname_with_scope();
439 }
440 for (size_t i = input_with_index.second; i < input_num; i = i + gather_actor->branch_output_num_) {
441 FixRefCountForInputNode(common::AnfAlgo::VisitKernelWithReturnType(gather_cnode->input(i + 1), 0), ref_count,
442 gather_actor->branch_names_[i / gather_actor->branch_output_num_]);
443 }
444 return;
445 }
446
447 if (kernel_graph->IsInRefOutputMap(input_with_index)) {
448 const auto &ref_value = kernel_graph->GetRefCorrespondOutput(input_with_index);
449 if (ref_value.first == nullptr) {
450 return;
451 }
452 MS_LOG(DEBUG) << "Check input node:" << ref_value.first->fullname_with_scope() << " index:" << ref_value.second
453 << " output node:" << input_with_index.first->fullname_with_scope()
454 << " index:" << input_with_index.second;
455 FixRefCountForRefNode(ref_value, ref_count, new_branch_name, kernel_graph);
456 }
457 }
458
FixRefCountForInputNode(const KernelWithIndex & input_with_index,size_t ref_count,const std::string & branch_name)459 void InlineControlFlowScheduler::FixRefCountForInputNode(const KernelWithIndex &input_with_index, size_t ref_count,
460 const std::string &branch_name) {
461 const auto &node = input_with_index.first;
462 MS_EXCEPTION_IF_NULL(node);
463 const auto &device_address = AnfAlgo::GetMutableOutputAddr(node, input_with_index.second, false);
464 MS_EXCEPTION_IF_NULL(device_address);
465 if (ref_count == SIZE_MAX) {
466 MS_LOG(DEBUG) << "set ref count to max for device address:" << device_address;
467 device_address->set_original_ref_count(ref_count);
468 } else {
469 MS_LOG(DEBUG) << "set ref count from:" << device_address->original_ref_count()
470 << " to:" << device_address->original_ref_count() + ref_count
471 << " for device address:" << device_address;
472 device_address->set_original_ref_count(device_address->original_ref_count() + ref_count);
473 }
474 device_address->ResetRefCount();
475 if (node->isa<CNode>()) {
476 const auto &cnode = node->cast<CNodePtr>();
477 MS_EXCEPTION_IF_NULL(cnode);
478 const auto &graph = cnode->func_graph();
479 if (graph != nullptr && graph->isa<KernelGraph>()) {
480 const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
481 MS_EXCEPTION_IF_NULL(kernel_graph);
482 if (kernel_graph->IsInRefOutputMap(input_with_index)) {
483 FixRefCountForRefNode(input_with_index, ref_count, branch_name, kernel_graph);
484 return;
485 }
486 }
487 }
488
489 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimConditionGather)) {
490 const auto &gather_cnode = node->cast<CNodePtr>();
491 MS_EXCEPTION_IF_NULL(gather_cnode);
492 const auto &actor = FetchActor(GetActorIdByKernel(gather_cnode));
493 MS_EXCEPTION_IF_NULL(actor);
494 const auto &gather_actor = dynamic_cast<ConditionGatherActor *>(actor);
495 MS_EXCEPTION_IF_NULL(gather_actor);
496 size_t input_num = common::AnfAlgo::GetInputNum(gather_cnode);
497 if (input_num == 0 || input_num != gather_actor->branch_names_.size() * gather_actor->branch_output_num_) {
498 MS_LOG_WITH_NODE(EXCEPTION, gather_cnode)
499 << "Invalid input num:" << input_num << " branch output num:" << gather_actor->branch_output_num_
500 << " branch num:" << gather_actor->branch_names_.size() << " for node:" << gather_cnode->fullname_with_scope();
501 }
502 for (size_t i = input_with_index.second; i < input_num; i = i + gather_actor->branch_output_num_) {
503 FixRefCountForInputNode(common::AnfAlgo::VisitKernelWithReturnType(gather_cnode->input(i + 1), 0), ref_count,
504 gather_actor->branch_names_[i / gather_actor->branch_output_num_]);
505 }
506 }
507 }
508
FixRefCountByConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)509 void InlineControlFlowScheduler::FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
510 const KernelGraphPtr &kernel_graph) {
511 std::vector<size_t> need_add_ref_count;
512 size_t output_num = AnfAlgo::GetOutputTensorNum(condition_gather_actor->kernel());
513 for (size_t i = 0; i < output_num; ++i) {
514 const auto &device_address = AnfAlgo::GetMutableOutputAddr(condition_gather_actor->kernel(), i, false);
515 MS_EXCEPTION_IF_NULL(device_address);
516 need_add_ref_count.emplace_back(
517 device_address->original_ref_count() == SIZE_MAX ? SIZE_MAX : device_address->original_ref_count() - 1);
518 MS_LOG(DEBUG) << "For actor:" << condition_gather_actor->GetAID() << " output device address:" << device_address
519 << " output index:" << i << " ref_count:" << device_address->original_ref_count()
520 << " need add:" << need_add_ref_count.back();
521 }
522 size_t input_num = common::AnfAlgo::GetInputNum(condition_gather_actor->kernel());
523 if (input_num == 0 ||
524 input_num != condition_gather_actor->branch_output_num_ * condition_gather_actor->branch_names_.size()) {
525 MS_LOG(EXCEPTION) << "Invalid input num:" << input_num
526 << " branch output num:" << condition_gather_actor->branch_output_num_
527 << " for actor:" << condition_gather_actor->GetAID();
528 }
529 for (size_t i = 0; i < input_num; ++i) {
530 const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_gather_actor->kernel(), i, false);
531 MS_EXCEPTION_IF_NULL(device_address);
532 MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
533 << " input index:" << i << " ref_count:" << device_address->original_ref_count();
534 if (device_address->original_ref_count() == SIZE_MAX) {
535 continue;
536 }
537 const auto &input_with_index =
538 common::AnfAlgo::VisitKernelWithReturnType(condition_gather_actor->kernel()->input(i + 1), 0);
539 FixRefCountForInputNode(input_with_index, need_add_ref_count[i % condition_gather_actor->branch_output_num_],
540 condition_gather_actor->branch_names_[i / condition_gather_actor->branch_output_num_]);
541 MS_LOG(DEBUG) << "Condition gather actor:" << condition_gather_actor->GetAID() << " input index:" << i
542 << " input node:" << input_with_index.first->DebugString()
543 << " with index:" << input_with_index.second
544 << " need add ref count:" << need_add_ref_count[i % condition_gather_actor->branch_output_num_];
545 }
546 }
547
InitInputDataBranchInfoForConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)548 void InlineControlFlowScheduler::InitInputDataBranchInfoForConditionGatherActor(
549 ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
550 const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
551 MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
552 << " by actor:" << condition_gather_actor->GetAID();
553 for (const auto &pair : condition_gather_actor->input_data_arrow_aids_) {
554 const auto &from_aid = pair.first;
555 const auto &data_arrow = pair.second;
556 MS_EXCEPTION_IF_NULL(data_arrow);
557 const auto &from_actor = FetchActor(from_aid.Name());
558 if (from_actor == nullptr) {
559 MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
560 }
561 if (from_actor->type() != KernelTransformType::kKernelActor &&
562 from_actor->type() != KernelTransformType::kConditionSwitchActor &&
563 from_actor->type() != KernelTransformType::kConditionGatherActor) {
564 MS_LOG(EXCEPTION) << "Invalid to actor:" << from_actor->GetAID()
565 << " from actor:" << condition_gather_actor->GetAID();
566 }
567 const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
568 MS_EXCEPTION_IF_NULL(from_kernel_actor);
569 MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
570 std::string current_branch_name;
571 if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
572 current_branch_name =
573 GetBranchNameByConditionGatherActor(from_kernel_actor, condition_gather_actor, data_arrow, kernel_graph);
574 } else {
575 if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
576 MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by data user node:"
577 << from_kernel_actor->kernel()->fullname_with_scope()
578 << " in actor:" << condition_gather_actor->GetAID();
579 }
580 MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
581 << " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
582 << " in actor:" << condition_gather_actor->GetAID();
583 current_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
584 }
585 const auto &iter = condition_gather_actor->branch_name_to_id_.find(current_branch_name);
586 if (iter == condition_gather_actor->branch_name_to_id_.end()) {
587 condition_gather_actor->branch_name_to_id_[current_branch_name] =
588 condition_gather_actor->branch_name_to_id_.size();
589 MS_LOG(DEBUG) << "Add branch index:" << condition_gather_actor->branch_name_to_id_[current_branch_name]
590 << " branch name:" << current_branch_name << " for actor:" << condition_gather_actor->GetAID();
591 }
592 // Get the input data num of each branch.
593 if (condition_gather_actor->branch_name_to_input_data_num_.find(current_branch_name) ==
594 condition_gather_actor->branch_name_to_input_data_num_.end()) {
595 condition_gather_actor->branch_name_to_input_data_num_[current_branch_name] = 1;
596 } else {
597 condition_gather_actor->branch_name_to_input_data_num_[current_branch_name]++;
598 }
599 }
600 }
601
InitInputControlBranchInfoForConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)602 void InlineControlFlowScheduler::InitInputControlBranchInfoForConditionGatherActor(
603 ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
604 const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
605 MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
606 << " by actor:" << condition_gather_actor->GetAID();
607
608 for (const auto &pair : condition_gather_actor->input_control_arrow_aids_) {
609 const auto &from_aid = pair.first;
610 const auto &from_actor = FetchActor(from_aid.Name());
611 if (from_actor == nullptr) {
612 MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
613 }
614 if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
615 continue;
616 }
617 if (from_actor->type() != KernelTransformType::kKernelActor &&
618 from_actor->type() != KernelTransformType::kConditionGatherActor) {
619 MS_LOG(EXCEPTION) << "Invalid from actor:" << from_actor->GetAID()
620 << " to actor:" << condition_gather_actor->GetAID();
621 }
622 const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
623 MS_EXCEPTION_IF_NULL(from_kernel_actor);
624 MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
625 if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
626 MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by control user node:"
627 << from_kernel_actor->kernel()->fullname_with_scope()
628 << " in actor:" << condition_gather_actor->GetAID();
629 }
630 MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
631 << " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
632 << " in actor:" << condition_gather_actor->GetAID();
633 const auto ¤t_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
634 // Get input op control num of each branch.
635 if (condition_gather_actor->branch_name_to_input_control_num_.find(current_branch_name) ==
636 condition_gather_actor->branch_name_to_input_control_num_.end()) {
637 condition_gather_actor->branch_name_to_input_control_num_[current_branch_name] = 1;
638 } else {
639 condition_gather_actor->branch_name_to_input_control_num_[current_branch_name]++;
640 }
641 }
642 }
643
InitInputBranchInfoForConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)644 void InlineControlFlowScheduler::InitInputBranchInfoForConditionGatherActor(
645 ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
646 InitInputDataBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
647 InitInputControlBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
648 }
649
HandleConditionGatherActor(const KernelActorPtr & kernel_actor)650 void InlineControlFlowScheduler::HandleConditionGatherActor(const KernelActorPtr &kernel_actor) {
651 const auto &condition_gather_actor = dynamic_cast<ConditionGatherActor *>(kernel_actor.get());
652 MS_EXCEPTION_IF_NULL(condition_gather_actor);
653 const auto &gather_node = condition_gather_actor->kernel();
654 MS_EXCEPTION_IF_NULL(gather_node);
655 const auto &graph = gather_node->func_graph();
656 if (graph == nullptr || !graph->isa<KernelGraph>()) {
657 MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_gather_actor->GetAID();
658 }
659 const auto &kernel_graph = graph->cast<KernelGraphPtr>();
660 MS_EXCEPTION_IF_NULL(kernel_graph);
661 const auto &gather_switch_map = kernel_graph->condition_gather_to_switch();
662 const auto &gather_switch_iter = gather_switch_map.find(gather_node);
663 if (gather_switch_iter == gather_switch_map.end()) {
664 MS_LOG_WITH_NODE(EXCEPTION, gather_node)
665 << "Failed to get switch node by gather node:" << gather_node->fullname_with_scope();
666 }
667 if (gather_switch_iter->second == nullptr) {
668 MS_LOG_WITH_NODE(EXCEPTION, gather_node)
669 << "Failed to get switch node by gather node:" << gather_node->fullname_with_scope()
670 << " in kernel graph:" << kernel_graph->ToString();
671 }
672 const auto &actor = FetchActor(GetActorIdByKernel(gather_switch_iter->second));
673 MS_EXCEPTION_IF_NULL(actor);
674 const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
675 MS_EXCEPTION_IF_NULL(condition_switch_actor);
676 condition_switch_actor->gather_aid_ = const_cast<AID *>(&condition_gather_actor->GetAID());
677
678 if (!gather_node->HasAttr(kAttrBranchOutputNum)) {
679 MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
680 }
681 const auto &output_value = gather_node->GetAttr(kAttrBranchOutputNum);
682 MS_EXCEPTION_IF_NULL(output_value);
683 condition_gather_actor->branch_output_num_ = GetValue<size_t>(output_value);
684
685 if (!gather_node->HasAttr(kAttrBranchGraphName)) {
686 MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_gather_actor->GetAID();
687 }
688 const auto &branch_graph_names = gather_node->GetAttr(kAttrBranchGraphName);
689 MS_EXCEPTION_IF_NULL(branch_graph_names);
690 MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
691 << " for actor:" << condition_gather_actor->GetAID();
692 if (!branch_graph_names->isa<ValueTuple>()) {
693 MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
694 << " for actor:" << condition_gather_actor->GetAID();
695 }
696 const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
697 MS_EXCEPTION_IF_NULL(tuple_name);
698 std::vector<std::string> branch_names;
699 std::for_each(tuple_name->value().begin(), tuple_name->value().end(),
700 [&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
701 condition_gather_actor->branch_names_ = branch_names;
702 // Fix ref count.
703 FixRefCountByConditionGatherActor(condition_gather_actor, kernel_graph);
704 InitInputBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
705 }
706
LinkControlArrowForNoInputOrOutputActor(ActorSet * actor_set,const mindspore::HashMap<std::string,AbstractActor * > & branch_name_to_switch_actor,const mindspore::HashMap<std::string,AbstractActor * > & branch_name_to_gather_actor)707 void InlineControlFlowScheduler::LinkControlArrowForNoInputOrOutputActor(
708 ActorSet *actor_set, const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_switch_actor,
709 const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_gather_actor) {
710 MS_EXCEPTION_IF_NULL(actor_set);
711 for (const auto &kernel_actor : actor_set->kernel_actors_) {
712 MS_EXCEPTION_IF_NULL(kernel_actor);
713 if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0) &&
714 IsInlineKernelActor(kernel_actor)) {
715 const auto &branch_name = GetBranchNameByKernelActor(kernel_actor.get());
716 const auto &iter = branch_name_to_switch_actor.find(branch_name);
717 if (iter == branch_name_to_switch_actor.end()) {
718 MS_LOG(EXCEPTION) << "Failed to get condition switch actor by branch name:" << branch_name;
719 }
720 MS_LOG(DEBUG) << "Inline control flow scheduler add control flow from switch actor:" << iter->second->GetAID()
721 << " to kernel actor:" << kernel_actor->GetAID();
722 SchedulerHelper::AddControlArrow(iter->second, kernel_actor.get());
723 }
724 if (kernel_actor->output_data_arrows_.size() == 0 && kernel_actor->output_control_arrows_.size() == 0 &&
725 IsInlineKernelActor(kernel_actor)) {
726 const auto &branch_name = GetBranchNameByKernelActor(kernel_actor.get());
727 const auto &iter = branch_name_to_gather_actor.find(branch_name);
728 if (iter == branch_name_to_gather_actor.end()) {
729 MS_LOG(EXCEPTION) << "Failed to get condition gather actor by branch name:" << branch_name;
730 }
731 MS_LOG(DEBUG) << "Inline control flow scheduler add control flow from kernel actor:" << kernel_actor->GetAID()
732 << " to gather actor:" << iter->second->GetAID();
733 SchedulerHelper::AddControlArrow(kernel_actor.get(), iter->second);
734 }
735 }
736 }
737
Link(ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info,bool execution_order_running)738 void InlineControlFlowScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
739 bool execution_order_running) {
740 MS_EXCEPTION_IF_NULL(actor_set);
741 auto context_ptr = MsContext::GetInstance();
742 MS_EXCEPTION_IF_NULL(context_ptr);
743 mindspore::HashMap<std::string, AbstractActor *> branch_name_to_switch_actor;
744 mindspore::HashMap<std::string, AbstractActor *> branch_name_to_gather_actor;
745 for (const auto &graph : graph_compiler_info.graphs_) {
746 MS_EXCEPTION_IF_NULL(graph);
747 GetBranchNameToCondtionActor(graph, &branch_name_to_switch_actor, &branch_name_to_gather_actor);
748 }
749 LinkControlArrowForNoInputOrOutputActor(actor_set, branch_name_to_switch_actor, branch_name_to_gather_actor);
750 for (const auto &kernel_actor : actor_set->kernel_actors_) {
751 if (kernel_actor->type() == KernelTransformType::kConditionSwitchActor) {
752 HandleConditionSwitchActor(kernel_actor);
753 } else if (kernel_actor->type() == KernelTransformType::kConditionGatherActor) {
754 HandleConditionGatherActor(kernel_actor);
755 }
756 }
757 for (const auto &kernel_graph : graph_compiler_info.graphs_) {
758 MS_EXCEPTION_IF_NULL(kernel_graph);
759 if (kernel_graph->inline_sub_graph_kernels().empty()) {
760 continue;
761 }
762 for (const auto &ref_pair : kernel_graph->GetRefMap()) {
763 const auto &output_pair = ref_pair.first;
764 const auto &input_pair = ref_pair.second;
765 MS_EXCEPTION_IF_NULL(output_pair.first);
766 MS_EXCEPTION_IF_NULL(input_pair.first);
767 MS_LOG(DEBUG) << "output node:" << output_pair.first->fullname_with_scope()
768 << " input node:" << input_pair.first->fullname_with_scope();
769 const auto &actor = FetchActor(GetActorIdByKernel(output_pair.first));
770 if (actor == nullptr) {
771 MS_LOG_WITH_NODE(EXCEPTION, output_pair.first)
772 << "Failed to get actor by ref node:" << output_pair.first->fullname_with_scope()
773 << " index:" << output_pair.second << " origin node:" << input_pair.first->fullname_with_scope()
774 << " index:" << input_pair.second << " in graph:" << kernel_graph->ToString();
775 }
776 size_t ref_count = 1;
777 std::for_each(actor->output_data_arrows().begin(), actor->output_data_arrows().end(),
778 [&ref_count, &output_pair](const auto &data_arrow) {
779 MS_EXCEPTION_IF_NULL(data_arrow);
780 if (IntToSize(data_arrow->from_output_index_) == output_pair.second) {
781 ++ref_count;
782 }
783 });
784 FixRefCountRecursively(output_pair, input_pair, kernel_graph, ref_count);
785 }
786 }
787 }
788
FixRefCountRecursively(const KernelWithIndex & output_pair,const KernelWithIndex & input_pair,const KernelGraphPtr & kernel_graph,size_t ref_count)789 void InlineControlFlowScheduler::FixRefCountRecursively(const KernelWithIndex &output_pair,
790 const KernelWithIndex &input_pair,
791 const KernelGraphPtr &kernel_graph, size_t ref_count) {
792 MS_EXCEPTION_IF_NULL(output_pair.first);
793 MS_EXCEPTION_IF_NULL(input_pair.first);
794 if (common::AnfAlgo::CheckPrimitiveType(input_pair.first, prim::kPrimConditionGather)) {
795 return;
796 }
797 if (common::AnfAlgo::CheckPrimitiveType(input_pair.first, prim::kPrimConditionSwitch)) {
798 const auto &iter = kernel_graph->inline_sub_graph_kernels().find(output_pair.first);
799 if (iter == kernel_graph->inline_sub_graph_kernels().end()) {
800 MS_LOG_WITH_NODE(EXCEPTION, input_pair.first)
801 << "Invalid ref node pair, input node:" << input_pair.first->fullname_with_scope()
802 << " index:" << input_pair.second << " output node:" << output_pair.first->fullname_with_scope()
803 << " index:" << output_pair.second << " in kernel graph:" << kernel_graph->ToString();
804 }
805 const auto &branch_name = iter->second;
806 const auto &actor = FetchActor(GetActorIdByKernel(input_pair.first));
807 MS_EXCEPTION_IF_NULL(actor);
808 const auto &switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
809 MS_EXCEPTION_IF_NULL(switch_actor);
810 AddRefCountForConditionSwitchActor(switch_actor, branch_name, input_pair.second, ref_count);
811 }
812 if (kernel_graph->IsInRefOutputMap(input_pair)) {
813 FixRefCountRecursively(input_pair, kernel_graph->GetRefCorrespondOutput(input_pair), kernel_graph, ref_count);
814 }
815 }
816 } // namespace runtime
817 } // namespace mindspore
818