• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
18 #include <vector>
19 #include "runtime/graph_scheduler/scheduler_helper.h"
20 #include "ops/framework_ops.h"
21 
22 namespace mindspore {
23 namespace runtime {
IsInlineKernelActor(const AbstractActorPtr & actor)24 bool IsInlineKernelActor(const AbstractActorPtr &actor) {
25   MS_EXCEPTION_IF_NULL(actor);
26   if (actor->type() != KernelTransformType::kKernelActor &&
27       actor->type() != KernelTransformType::kConditionGatherActor &&
28       actor->type() != KernelTransformType::kConditionSwitchActor) {
29     return false;
30   }
31   const auto &kernel_actor = dynamic_cast<KernelActor *>(actor.get());
32   MS_EXCEPTION_IF_NULL(kernel_actor);
33   MS_EXCEPTION_IF_NULL(kernel_actor->kernel());
34   const auto &func_graph = kernel_actor->kernel()->func_graph();
35   if (func_graph == nullptr || (!func_graph->isa<KernelGraph>())) {
36     return false;
37   }
38   const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
39   MS_EXCEPTION_IF_NULL(kernel_graph);
40   return kernel_graph->inline_sub_graph_kernels().find(kernel_actor->kernel()) !=
41          kernel_graph->inline_sub_graph_kernels().end();
42 }
43 
44 namespace {
GetBranchNameByKernelActor(const KernelActor * const kernel_actor)45 std::string GetBranchNameByKernelActor(const KernelActor *const kernel_actor) {
46   MS_EXCEPTION_IF_NULL(kernel_actor);
47   MS_EXCEPTION_IF_NULL(kernel_actor->kernel());
48   const auto &func_graph = kernel_actor->kernel()->func_graph();
49   if (func_graph == nullptr || (!func_graph->isa<KernelGraph>())) {
50     MS_LOG(EXCEPTION) << "Invalid funcgraph in kernel:" << kernel_actor->kernel()->fullname_with_scope();
51   }
52   const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
53   MS_EXCEPTION_IF_NULL(kernel_graph);
54   return kernel_graph->inline_sub_graph_kernels().at(kernel_actor->kernel());
55 }
56 
GetBranchNameToCondtionActor(const KernelGraphPtr & graph,mindspore::HashMap<std::string,AbstractActor * > * branch_name_to_switch_actor,mindspore::HashMap<std::string,AbstractActor * > * branch_name_to_gather_actor)57 void GetBranchNameToCondtionActor(const KernelGraphPtr &graph,
58                                   mindspore::HashMap<std::string, AbstractActor *> *branch_name_to_switch_actor,
59                                   mindspore::HashMap<std::string, AbstractActor *> *branch_name_to_gather_actor) {
60   MS_EXCEPTION_IF_NULL(branch_name_to_gather_actor);
61   MS_EXCEPTION_IF_NULL(branch_name_to_switch_actor);
62   for (const auto &gather_to_switch : graph->condition_gather_to_switch()) {
63     MS_EXCEPTION_IF_NULL(gather_to_switch.first);
64     if (!common::AnfAlgo::CheckPrimitiveType(gather_to_switch.first, prim::kPrimConditionGather) ||
65         !common::AnfAlgo::CheckPrimitiveType(gather_to_switch.second, prim::kPrimConditionSwitch)) {
66       MS_LOG_WITH_NODE(EXCEPTION, gather_to_switch.first)
67         << "Invalid condition gather node:" << gather_to_switch.first->DebugString()
68         << " or condition switch node:" << gather_to_switch.second->DebugString();
69     }
70     const auto &gather_cnode = gather_to_switch.first->cast<CNodePtr>();
71     const auto &switch_cnode = gather_to_switch.second->cast<CNodePtr>();
72     MS_EXCEPTION_IF_NULL(gather_cnode);
73     MS_EXCEPTION_IF_NULL(switch_cnode);
74     if (!gather_cnode->HasAttr(kAttrBranchGraphName)) {
75       MS_LOG_WITH_NODE(EXCEPTION, gather_cnode)
76         << "Failed to get inline graph name by node:" << gather_cnode->fullname_with_scope();
77     }
78     const auto &branch_graph_names = gather_cnode->GetAttr(kAttrBranchGraphName);
79     MS_EXCEPTION_IF_NULL(branch_graph_names);
80     MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
81                   << " for node:" << gather_cnode->fullname_with_scope();
82     if (!branch_graph_names->isa<ValueTuple>()) {
83       MS_LOG_WITH_NODE(EXCEPTION, gather_cnode) << "Invalid branch group name:" << branch_graph_names->ToString()
84                                                 << " for node:" << gather_cnode->fullname_with_scope();
85     }
86     const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
87     MS_EXCEPTION_IF_NULL(tuple_name);
88     const auto &gather_actor = FetchActor(GetActorIdByKernel(gather_cnode));
89     const auto &switch_actor = FetchActor(GetActorIdByKernel(switch_cnode));
90     MS_EXCEPTION_IF_NULL(gather_actor);
91     MS_EXCEPTION_IF_NULL(switch_actor);
92     for (const auto &value : tuple_name->value()) {
93       const auto &branch_name = GetValue<std::string>(value);
94       (*branch_name_to_gather_actor)[branch_name] = gather_actor;
95       (*branch_name_to_switch_actor)[branch_name] = switch_actor;
96     }
97   }
98 }
99 }  // namespace
100 
LinkControlArrowByExecutionOrder(const KernelGraphPtr & graph,const GraphCompilerInfo & graph_compiler_info) const101 void InlineControlFlowScheduler::LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
102                                                                   const GraphCompilerInfo &graph_compiler_info) const {
103   MS_EXCEPTION_IF_NULL(graph);
104   const auto &inline_sub_graph_kernels = graph->inline_sub_graph_kernels();
105   if (graph->is_graph_run_mode() || graph->is_any_type_input() || inline_sub_graph_kernels.empty()) {
106     return;
107   }
108 
109   mindspore::HashMap<std::string, AbstractActor *> branch_name_to_switch_actor;
110   mindspore::HashMap<std::string, AbstractActor *> branch_name_to_gather_actor;
111   GetBranchNameToCondtionActor(graph, &branch_name_to_switch_actor, &branch_name_to_gather_actor);
112 
113   MS_LOG(DEBUG) << "Link control arrow for graph:" << graph->ToString();
114   // Only link control arrow between kernels in the same graph.
115   mindspore::HashMap<std::string, AbstractActor *> branch_last_actor;
116   for (size_t i = 0; i < graph->execution_order().size(); ++i) {
117     const auto &to_kernel = graph->execution_order()[i];
118     if (IsRpcActor(to_kernel)) {
119       MS_LOG(INFO) << "Rpc op is not available in the execution order, from kernel: "
120                    << graph->execution_order()[i - 1]->fullname_with_scope()
121                    << ", to kernel:" << graph->execution_order()[i]->fullname_with_scope();
122       continue;
123     }
124     const auto &iter = inline_sub_graph_kernels.find(to_kernel);
125     std::string current_branch = graph->ToString();
126     if (iter != inline_sub_graph_kernels.end()) {
127       current_branch = iter->second;
128       MS_LOG(DEBUG) << "Kernel:" << to_kernel->fullname_with_scope() << " branch:" << current_branch;
129     }
130 
131     const auto to_kernel_type = FetchKernelTransformType(to_kernel, graph, {}, GraphExecutionStrategy::kPipeline);
132     auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_kernel, graph);
133     const auto &actor_iter = branch_last_actor.find(current_branch);
134     if (actor_iter == branch_last_actor.end()) {
135       if (!common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
136         branch_last_actor[current_branch] = to_actor;
137         MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
138       }
139       continue;
140     }
141     MS_LOG(DEBUG) << "Add control arrow between " << actor_iter->second->GetAID() << " and " << to_actor->GetAID();
142     SchedulerHelper::AddControlArrow(actor_iter->second, to_actor);
143     if (common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
144       // The control relation end after the condition switch node in graph.
145       branch_last_actor.erase(current_branch);
146       MS_LOG(DEBUG) << "For branch:" << current_branch << " end actor:" << to_actor->GetAID();
147     } else {
148       // The control relation start first kernel in graph.
149       branch_last_actor[current_branch] = to_actor;
150       MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
151     }
152   }
153 
154   for (const auto &pair : branch_last_actor) {
155     const auto &branch_name = pair.first;
156     if (pair.second == nullptr || pair.second->type() != KernelTransformType::kKernelActor) {
157       continue;
158     }
159     const auto &iter = branch_name_to_gather_actor.find(branch_name);
160     if (iter == branch_name_to_gather_actor.end()) {
161       MS_LOG(INFO) << "Invalid branch name:" << branch_name << " in graph:" << graph->ToString();
162       continue;
163     }
164     SchedulerHelper::AddControlArrow(pair.second, iter->second);
165     MS_LOG(DEBUG) << "Add control arrow between:" << pair.second->GetAID() << " to:" << iter->second->GetAID();
166   }
167 }
168 
169 // Get the branch name by input data arrow.
GetBranchNameByConditionGatherActor(KernelActor * condition_switch_actor,KernelActor * condition_gather_actor,DataArrow * data_arrow,const KernelGraphPtr & kernel_graph)170 std::string InlineControlFlowScheduler::GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
171                                                                             KernelActor *condition_gather_actor,
172                                                                             DataArrow *data_arrow,
173                                                                             const KernelGraphPtr &kernel_graph) {
174   MS_EXCEPTION_IF_NULL(condition_switch_actor);
175   MS_EXCEPTION_IF_NULL(condition_gather_actor);
176   MS_EXCEPTION_IF_NULL(data_arrow);
177   MS_EXCEPTION_IF_NULL(kernel_graph);
178   const auto &condition_gather_kernel = condition_gather_actor->kernel();
179   MS_EXCEPTION_IF_NULL(condition_gather_kernel);
180   auto gather_to_switch = kernel_graph->condition_gather_to_switch();
181   const auto &condition_pair_iter = gather_to_switch.find(condition_gather_kernel);
182   if (condition_pair_iter == gather_to_switch.end() ||
183       condition_pair_iter->second != condition_switch_actor->kernel()) {
184     MS_LOG(EXCEPTION) << "Condition switch actor:" << condition_switch_actor->GetAID()
185                       << " and gather actor:" << condition_gather_actor << " is not match.";
186   }
187   if (!condition_gather_kernel->HasAttr(kAttrBranchOutputNum)) {
188     MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
189   }
190   // Get the output branch index in condition gather actor.
191   const auto &output_value = condition_gather_kernel->GetAttr(kAttrBranchOutputNum);
192   MS_EXCEPTION_IF_NULL(output_value);
193   size_t branch_index = IntToSize(data_arrow->to_input_index_) / GetValue<size_t>(output_value);
194   if (!condition_gather_kernel->HasAttr(kAttrBranchGraphName)) {
195     MS_LOG(EXCEPTION) << "Failed to get branch graph name by actor:" << condition_gather_actor->GetAID();
196   }
197 
198   // Get output branch name by branch index.
199   const auto &branch_graph_names = condition_gather_kernel->GetAttr(kAttrBranchGraphName);
200   MS_EXCEPTION_IF_NULL(branch_graph_names);
201   MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
202                 << " for actor:" << condition_gather_actor->GetAID();
203   if (!branch_graph_names->isa<ValueTuple>()) {
204     MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
205                       << " for actor:" << condition_gather_actor->GetAID();
206   }
207   const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
208   MS_EXCEPTION_IF_NULL(tuple_name);
209   if (branch_index >= tuple_name->size()) {
210     MS_LOG(EXCEPTION) << "Invalid to index:" << data_arrow->to_input_index_
211                       << " output num:" << GetValue<size_t>(output_value)
212                       << " branch graph name:" << tuple_name->ToString()
213                       << " from actor:" << condition_switch_actor->GetAID()
214                       << " to actor:" << condition_gather_actor->GetAID();
215   }
216   MS_EXCEPTION_IF_NULL(tuple_name->value()[branch_index]);
217   return GetValue<std::string>(tuple_name->value()[branch_index]);
218 }
219 
InitOutputDataBranchInfoForConditionSwitchActor(ConditionSwitchActor * const condition_switch_actor,const KernelGraphPtr & kernel_graph)220 void InlineControlFlowScheduler::InitOutputDataBranchInfoForConditionSwitchActor(
221   ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
222   const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
223   size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
224   condition_switch_actor->output_data_branch_indexes_.resize(condition_switch_actor->output_data_arrows().size());
225   // Get the index for each output data arrow.
226   for (size_t i = 0; i < condition_switch_actor->output_data_arrows().size(); ++i) {
227     const auto &output_node = condition_switch_actor->output_data_nodes()[i];
228     const auto &data_arrow = condition_switch_actor->output_data_arrows()[i];
229     MS_EXCEPTION_IF_NULL(output_node);
230     MS_EXCEPTION_IF_NULL(data_arrow);
231     const auto &to_actor = FetchActor(data_arrow->to_op_id_.Name());
232     if (to_actor == nullptr) {
233       MS_LOG(EXCEPTION) << "Failed to get actor:" << data_arrow->to_op_id_.Name()
234                         << " from actor:" << condition_switch_actor->GetAID();
235     }
236     if (to_actor->type() != KernelTransformType::kConditionSwitchActor &&
237         to_actor->type() != KernelTransformType::kConditionGatherActor &&
238         to_actor->type() != KernelTransformType::kKernelActor) {
239       MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
240                         << " from actor:" << condition_switch_actor->GetAID();
241     }
242 
243     const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
244     MS_EXCEPTION_IF_NULL(to_kernel_actor);
245     MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
246     std::string current_branch_name;
247     if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
248       current_branch_name =
249         GetBranchNameByConditionGatherActor(condition_switch_actor, to_kernel_actor, data_arrow.get(), kernel_graph);
250     } else {
251       if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
252         MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by data user node:"
253                           << to_kernel_actor->kernel()->fullname_with_scope()
254                           << " in actor:" << condition_switch_actor->GetAID();
255       }
256       MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
257                     << " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
258                     << " in actor:" << condition_switch_actor->GetAID()
259                     << " from index:" << data_arrow->from_output_index_ << " to actor:" << data_arrow->to_op_id_
260                     << " to index:" << data_arrow->to_input_index_;
261       current_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
262     }
263     // Get branch index for output data arrow.
264     const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
265                                  condition_switch_actor->branch_names_.end(), current_branch_name);
266     if (iter == condition_switch_actor->branch_names_.end()) {
267       MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
268                         << " total branch name:" << condition_switch_actor->branch_names_
269                         << " from actor:" << condition_switch_actor->GetAID() << " to actor:" << to_actor->GetAID();
270     }
271     size_t branch_index = LongToSize(iter - condition_switch_actor->branch_names_.begin());
272     if (IntToSize(data_arrow->from_output_index_) >= output_num ||
273         branch_index >= condition_switch_actor->branch_names_.size()) {
274       MS_LOG(EXCEPTION) << "Invalid output index:" << data_arrow->from_output_index_ << " total:" << output_num
275                         << " and branch index:" << branch_index
276                         << " total:" << condition_switch_actor->branch_names_.size()
277                         << " for actor:" << condition_switch_actor->GetAID();
278     }
279     condition_switch_actor->output_data_branch_indexes_[i] = branch_index;
280     condition_switch_actor->branch_origin_ref_count_[branch_index][data_arrow->from_output_index_]++;
281   }
282 }
283 
InitOutputControlBranchInfoForConditionSwitchActor(ConditionSwitchActor * const condition_switch_actor,const KernelGraphPtr & kernel_graph)284 void InlineControlFlowScheduler::InitOutputControlBranchInfoForConditionSwitchActor(
285   ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
286   const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
287   condition_switch_actor->output_control_branch_indexes_.resize(condition_switch_actor->output_control_arrows().size());
288   // Get the index for each output control arrow.
289   for (size_t i = 0; i < condition_switch_actor->output_control_arrows().size(); ++i) {
290     const auto &arrow = condition_switch_actor->output_control_arrows()[i];
291     MS_EXCEPTION_IF_NULL(arrow);
292     const auto &to_actor = FetchActor(arrow->to_op_id_.Name());
293     if (to_actor == nullptr) {
294       MS_LOG(EXCEPTION) << "Failed to get actor:" << arrow->to_op_id_.Name()
295                         << " from actor:" << condition_switch_actor->GetAID();
296     }
297     if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
298       condition_switch_actor->output_control_branch_indexes_[i] = SIZE_MAX;
299       continue;
300     }
301     if (to_actor->type() != KernelTransformType::kKernelActor &&
302         to_actor->type() != KernelTransformType::kConditionSwitchActor) {
303       MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
304                         << " from actor:" << condition_switch_actor->GetAID();
305     }
306     const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
307     MS_EXCEPTION_IF_NULL(to_kernel_actor);
308     MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
309     if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
310       MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by control user node:"
311                         << to_kernel_actor->kernel()->fullname_with_scope()
312                         << " in actor:" << condition_switch_actor->GetAID();
313     }
314     MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
315                   << " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
316                   << " in actor:" << condition_switch_actor->GetAID() << " to actor:" << arrow->to_op_id_;
317     const auto &current_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
318     const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
319                                  condition_switch_actor->branch_names_.end(), current_branch_name);
320     if (iter == condition_switch_actor->branch_names_.end()) {
321       MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
322                         << " total branch name:" << condition_switch_actor->branch_names_
323                         << " for actor:" << condition_switch_actor->GetAID();
324     }
325     size_t branch_index = LongToSize(iter - condition_switch_actor->branch_names_.begin());
326     condition_switch_actor->output_control_branch_indexes_[i] = branch_index;
327   }
328 }
329 
InitOutputBranchInfoForConditionSwitchActor(ConditionSwitchActor * const condition_switch_actor,const KernelGraphPtr & kernel_graph)330 void InlineControlFlowScheduler::InitOutputBranchInfoForConditionSwitchActor(
331   ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
332   if (condition_switch_actor->output_data_nodes().size() != condition_switch_actor->output_data_arrows().size()) {
333     MS_LOG(EXCEPTION) << "Invalid data node size:" << condition_switch_actor->output_data_nodes().size()
334                       << " and arrow size:" << condition_switch_actor->output_data_arrows().size()
335                       << " for actor:" << condition_switch_actor->GetAID();
336   }
337   InitOutputDataBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
338   InitOutputControlBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
339   MS_LOG(DEBUG) << "Branch origin ref count:" << condition_switch_actor->branch_origin_ref_count_
340                 << " output data branch index:" << condition_switch_actor->output_data_branch_indexes_
341                 << " output control branch index:" << condition_switch_actor->output_control_branch_indexes_
342                 << " for actor:" << condition_switch_actor->GetAID();
343 }
344 
HandleConditionSwitchActor(const KernelActorPtr & kernel_actor)345 void InlineControlFlowScheduler::HandleConditionSwitchActor(const KernelActorPtr &kernel_actor) {
346   MS_EXCEPTION_IF_NULL(kernel_actor);
347   const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(kernel_actor.get());
348   MS_EXCEPTION_IF_NULL(condition_switch_actor);
349   MS_EXCEPTION_IF_NULL(condition_switch_actor->kernel());
350   const auto &graph = condition_switch_actor->kernel()->func_graph();
351   if (graph == nullptr || !graph->isa<KernelGraph>()) {
352     MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_switch_actor->GetAID();
353   }
354   const auto &kernel_graph = graph->cast<KernelGraphPtr>();
355   MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
356                 << " by actor:" << condition_switch_actor->GetAID();
357   if (!condition_switch_actor->kernel()->HasAttr(kInlineSubGraphName)) {
358     MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_switch_actor->GetAID();
359   }
360   const auto &inline_sub_graph_names = condition_switch_actor->kernel()->GetAttr(kInlineSubGraphName);
361   MS_EXCEPTION_IF_NULL(inline_sub_graph_names);
362   MS_LOG(DEBUG) << "inline sub graph name:" << inline_sub_graph_names->ToString()
363                 << " for actor:" << condition_switch_actor->GetAID();
364   if (!inline_sub_graph_names->isa<ValueTuple>()) {
365     MS_LOG(EXCEPTION) << "Invalid input subgraph name:" << inline_sub_graph_names->ToString()
366                       << " for actor:" << condition_switch_actor->GetAID();
367   }
368   const auto &tuple_name = inline_sub_graph_names->cast<ValueTuplePtr>();
369   MS_EXCEPTION_IF_NULL(tuple_name);
370   std::vector<std::string> branch_names;
371   for_each(tuple_name->value().begin(), tuple_name->value().end(),
372            [&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
373   condition_switch_actor->branch_names_ = branch_names;
374   // Fix ref count.
375   size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
376   condition_switch_actor->branch_origin_ref_count_ =
377     std::vector<std::vector<size_t>>(tuple_name->size(), vector<size_t>(output_num, 0));
378 
379   InitOutputBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
380 }
381 
AddRefCountForConditionSwitchActor(ConditionSwitchActor * const switch_actor,const std::string & branch_name,size_t output_index,size_t ref_count)382 void InlineControlFlowScheduler::AddRefCountForConditionSwitchActor(ConditionSwitchActor *const switch_actor,
383                                                                     const std::string &branch_name, size_t output_index,
384                                                                     size_t ref_count) {
385   const auto &iter = std::find(switch_actor->branch_names_.begin(), switch_actor->branch_names_.end(), branch_name);
386   if (iter == switch_actor->branch_names_.end()) {
387     MS_LOG(EXCEPTION) << "Failed to get branch name:" << branch_name << " total:" << switch_actor->branch_names_
388                       << " in actor:" << switch_actor->GetAID();
389   }
390   size_t index = LongToSize(iter - switch_actor->branch_names_.begin());
391   if (index >= switch_actor->branch_origin_ref_count_.size()) {
392     MS_LOG(EXCEPTION) << " Invalid index:" << index
393                       << " for branch origin ref count:" << switch_actor->branch_origin_ref_count_
394                       << " for actor:" << switch_actor->GetAID();
395   }
396   if (output_index >= switch_actor->branch_origin_ref_count_[index].size()) {
397     MS_LOG(EXCEPTION) << " Invalid output index:" << output_index << " branch index:" << index
398                       << " for branch origin ref count:" << switch_actor->branch_origin_ref_count_
399                       << " for actor:" << switch_actor->GetAID();
400   }
401   MS_LOG(DEBUG) << "Add ref count:" << ref_count << " for branch index:" << index << " index:" << output_index
402                 << " origin ref count:" << switch_actor->branch_origin_ref_count_
403                 << " for actor:" << switch_actor->GetAID();
404   switch_actor->branch_origin_ref_count_[index][output_index] += ref_count;
405 }
406 
FixRefCountForRefNode(const KernelWithIndex & input_with_index,size_t ref_count,const std::string & branch_name,const KernelGraph * const kernel_graph)407 void InlineControlFlowScheduler::FixRefCountForRefNode(const KernelWithIndex &input_with_index, size_t ref_count,
408                                                        const std::string &branch_name,
409                                                        const KernelGraph *const kernel_graph) {
410   MS_EXCEPTION_IF_NULL(kernel_graph);
411   MS_EXCEPTION_IF_NULL(input_with_index.first);
412   auto new_branch_name = branch_name;
413   if (common::AnfAlgo::CheckPrimitiveType(input_with_index.first, prim::kPrimConditionSwitch)) {
414     MS_LOG(DEBUG) << "Check switch node:" << input_with_index.first->fullname_with_scope()
415                   << " index:" << input_with_index.second << " ref count:" << ref_count
416                   << " branch name:" << branch_name;
417     const auto &actor = FetchActor(GetActorIdByKernel(input_with_index.first));
418     MS_EXCEPTION_IF_NULL(actor);
419     const auto &switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
420     MS_EXCEPTION_IF_NULL(switch_actor);
421     AddRefCountForConditionSwitchActor(switch_actor, branch_name, input_with_index.second, ref_count);
422     const auto &iter = kernel_graph->inline_sub_graph_kernels().find(input_with_index.first);
423     new_branch_name =
424       (iter == kernel_graph->inline_sub_graph_kernels().end() ? kernel_graph->ToString() : iter->second);
425     MS_LOG(DEBUG) << "Switch branch name from:" << branch_name << " to:" << new_branch_name
426                   << " by switch node:" << input_with_index.first->fullname_with_scope()
427                   << " in kernel graph:" << kernel_graph->ToString() << " ref count:" << ref_count;
428   } else if (common::AnfAlgo::CheckPrimitiveType(input_with_index.first, prim::kPrimConditionGather)) {
429     const auto &actor = FetchActor(GetActorIdByKernel(input_with_index.first));
430     MS_EXCEPTION_IF_NULL(actor);
431     const auto &gather_actor = dynamic_cast<ConditionGatherActor *>(actor);
432     MS_EXCEPTION_IF_NULL(gather_actor);
433     const auto &gather_cnode = input_with_index.first->cast<CNodePtr>();
434     size_t input_num = common::AnfAlgo::GetInputNum(gather_cnode);
435     if (input_num == 0 || input_num != gather_actor->branch_names_.size() * gather_actor->branch_output_num_) {
436       MS_LOG_WITH_NODE(EXCEPTION, gather_cnode)
437         << "Invalid input num:" << input_num << " branch output num:" << gather_actor->branch_output_num_
438         << " branch num:" << gather_actor->branch_names_.size() << " for node:" << gather_cnode->fullname_with_scope();
439     }
440     for (size_t i = input_with_index.second; i < input_num; i = i + gather_actor->branch_output_num_) {
441       FixRefCountForInputNode(common::AnfAlgo::VisitKernelWithReturnType(gather_cnode->input(i + 1), 0), ref_count,
442                               gather_actor->branch_names_[i / gather_actor->branch_output_num_]);
443     }
444     return;
445   }
446 
447   if (kernel_graph->IsInRefOutputMap(input_with_index)) {
448     const auto &ref_value = kernel_graph->GetRefCorrespondOutput(input_with_index);
449     if (ref_value.first == nullptr) {
450       return;
451     }
452     MS_LOG(DEBUG) << "Check input node:" << ref_value.first->fullname_with_scope() << " index:" << ref_value.second
453                   << " output node:" << input_with_index.first->fullname_with_scope()
454                   << " index:" << input_with_index.second;
455     FixRefCountForRefNode(ref_value, ref_count, new_branch_name, kernel_graph);
456   }
457 }
458 
FixRefCountForInputNode(const KernelWithIndex & input_with_index,size_t ref_count,const std::string & branch_name)459 void InlineControlFlowScheduler::FixRefCountForInputNode(const KernelWithIndex &input_with_index, size_t ref_count,
460                                                          const std::string &branch_name) {
461   const auto &node = input_with_index.first;
462   MS_EXCEPTION_IF_NULL(node);
463   const auto &device_address = AnfAlgo::GetMutableOutputAddr(node, input_with_index.second, false);
464   MS_EXCEPTION_IF_NULL(device_address);
465   if (ref_count == SIZE_MAX) {
466     MS_LOG(DEBUG) << "set ref count to max for device address:" << device_address;
467     device_address->set_original_ref_count(ref_count);
468   } else {
469     MS_LOG(DEBUG) << "set ref count from:" << device_address->original_ref_count()
470                   << " to:" << device_address->original_ref_count() + ref_count
471                   << " for device address:" << device_address;
472     device_address->set_original_ref_count(device_address->original_ref_count() + ref_count);
473   }
474   device_address->ResetRefCount();
475   if (node->isa<CNode>()) {
476     const auto &cnode = node->cast<CNodePtr>();
477     MS_EXCEPTION_IF_NULL(cnode);
478     const auto &graph = cnode->func_graph();
479     if (graph != nullptr && graph->isa<KernelGraph>()) {
480       const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
481       MS_EXCEPTION_IF_NULL(kernel_graph);
482       if (kernel_graph->IsInRefOutputMap(input_with_index)) {
483         FixRefCountForRefNode(input_with_index, ref_count, branch_name, kernel_graph);
484         return;
485       }
486     }
487   }
488 
489   if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimConditionGather)) {
490     const auto &gather_cnode = node->cast<CNodePtr>();
491     MS_EXCEPTION_IF_NULL(gather_cnode);
492     const auto &actor = FetchActor(GetActorIdByKernel(gather_cnode));
493     MS_EXCEPTION_IF_NULL(actor);
494     const auto &gather_actor = dynamic_cast<ConditionGatherActor *>(actor);
495     MS_EXCEPTION_IF_NULL(gather_actor);
496     size_t input_num = common::AnfAlgo::GetInputNum(gather_cnode);
497     if (input_num == 0 || input_num != gather_actor->branch_names_.size() * gather_actor->branch_output_num_) {
498       MS_LOG_WITH_NODE(EXCEPTION, gather_cnode)
499         << "Invalid input num:" << input_num << " branch output num:" << gather_actor->branch_output_num_
500         << " branch num:" << gather_actor->branch_names_.size() << " for node:" << gather_cnode->fullname_with_scope();
501     }
502     for (size_t i = input_with_index.second; i < input_num; i = i + gather_actor->branch_output_num_) {
503       FixRefCountForInputNode(common::AnfAlgo::VisitKernelWithReturnType(gather_cnode->input(i + 1), 0), ref_count,
504                               gather_actor->branch_names_[i / gather_actor->branch_output_num_]);
505     }
506   }
507 }
508 
FixRefCountByConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)509 void InlineControlFlowScheduler::FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
510                                                                    const KernelGraphPtr &kernel_graph) {
511   std::vector<size_t> need_add_ref_count;
512   size_t output_num = AnfAlgo::GetOutputTensorNum(condition_gather_actor->kernel());
513   for (size_t i = 0; i < output_num; ++i) {
514     const auto &device_address = AnfAlgo::GetMutableOutputAddr(condition_gather_actor->kernel(), i, false);
515     MS_EXCEPTION_IF_NULL(device_address);
516     need_add_ref_count.emplace_back(
517       device_address->original_ref_count() == SIZE_MAX ? SIZE_MAX : device_address->original_ref_count() - 1);
518     MS_LOG(DEBUG) << "For actor:" << condition_gather_actor->GetAID() << " output device address:" << device_address
519                   << " output index:" << i << " ref_count:" << device_address->original_ref_count()
520                   << " need add:" << need_add_ref_count.back();
521   }
522   size_t input_num = common::AnfAlgo::GetInputNum(condition_gather_actor->kernel());
523   if (input_num == 0 ||
524       input_num != condition_gather_actor->branch_output_num_ * condition_gather_actor->branch_names_.size()) {
525     MS_LOG(EXCEPTION) << "Invalid input num:" << input_num
526                       << " branch output num:" << condition_gather_actor->branch_output_num_
527                       << " for actor:" << condition_gather_actor->GetAID();
528   }
529   for (size_t i = 0; i < input_num; ++i) {
530     const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_gather_actor->kernel(), i, false);
531     MS_EXCEPTION_IF_NULL(device_address);
532     MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
533                   << " input index:" << i << " ref_count:" << device_address->original_ref_count();
534     if (device_address->original_ref_count() == SIZE_MAX) {
535       continue;
536     }
537     const auto &input_with_index =
538       common::AnfAlgo::VisitKernelWithReturnType(condition_gather_actor->kernel()->input(i + 1), 0);
539     FixRefCountForInputNode(input_with_index, need_add_ref_count[i % condition_gather_actor->branch_output_num_],
540                             condition_gather_actor->branch_names_[i / condition_gather_actor->branch_output_num_]);
541     MS_LOG(DEBUG) << "Condition gather actor:" << condition_gather_actor->GetAID() << " input index:" << i
542                   << " input node:" << input_with_index.first->DebugString()
543                   << " with index:" << input_with_index.second
544                   << " need add ref count:" << need_add_ref_count[i % condition_gather_actor->branch_output_num_];
545   }
546 }
547 
InitInputDataBranchInfoForConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)548 void InlineControlFlowScheduler::InitInputDataBranchInfoForConditionGatherActor(
549   ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
550   const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
551   MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
552                 << " by actor:" << condition_gather_actor->GetAID();
553   for (const auto &pair : condition_gather_actor->input_data_arrow_aids_) {
554     const auto &from_aid = pair.first;
555     const auto &data_arrow = pair.second;
556     MS_EXCEPTION_IF_NULL(data_arrow);
557     const auto &from_actor = FetchActor(from_aid.Name());
558     if (from_actor == nullptr) {
559       MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
560     }
561     if (from_actor->type() != KernelTransformType::kKernelActor &&
562         from_actor->type() != KernelTransformType::kConditionSwitchActor &&
563         from_actor->type() != KernelTransformType::kConditionGatherActor) {
564       MS_LOG(EXCEPTION) << "Invalid to actor:" << from_actor->GetAID()
565                         << " from actor:" << condition_gather_actor->GetAID();
566     }
567     const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
568     MS_EXCEPTION_IF_NULL(from_kernel_actor);
569     MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
570     std::string current_branch_name;
571     if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
572       current_branch_name =
573         GetBranchNameByConditionGatherActor(from_kernel_actor, condition_gather_actor, data_arrow, kernel_graph);
574     } else {
575       if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
576         MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by data user node:"
577                           << from_kernel_actor->kernel()->fullname_with_scope()
578                           << " in actor:" << condition_gather_actor->GetAID();
579       }
580       MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
581                     << " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
582                     << " in actor:" << condition_gather_actor->GetAID();
583       current_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
584     }
585     const auto &iter = condition_gather_actor->branch_name_to_id_.find(current_branch_name);
586     if (iter == condition_gather_actor->branch_name_to_id_.end()) {
587       condition_gather_actor->branch_name_to_id_[current_branch_name] =
588         condition_gather_actor->branch_name_to_id_.size();
589       MS_LOG(DEBUG) << "Add branch index:" << condition_gather_actor->branch_name_to_id_[current_branch_name]
590                     << " branch name:" << current_branch_name << " for actor:" << condition_gather_actor->GetAID();
591     }
592     // Get the input data num of each branch.
593     if (condition_gather_actor->branch_name_to_input_data_num_.find(current_branch_name) ==
594         condition_gather_actor->branch_name_to_input_data_num_.end()) {
595       condition_gather_actor->branch_name_to_input_data_num_[current_branch_name] = 1;
596     } else {
597       condition_gather_actor->branch_name_to_input_data_num_[current_branch_name]++;
598     }
599   }
600 }
601 
InitInputControlBranchInfoForConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)602 void InlineControlFlowScheduler::InitInputControlBranchInfoForConditionGatherActor(
603   ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
604   const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
605   MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
606                 << " by actor:" << condition_gather_actor->GetAID();
607 
608   for (const auto &pair : condition_gather_actor->input_control_arrow_aids_) {
609     const auto &from_aid = pair.first;
610     const auto &from_actor = FetchActor(from_aid.Name());
611     if (from_actor == nullptr) {
612       MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
613     }
614     if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
615       continue;
616     }
617     if (from_actor->type() != KernelTransformType::kKernelActor &&
618         from_actor->type() != KernelTransformType::kConditionGatherActor) {
619       MS_LOG(EXCEPTION) << "Invalid from actor:" << from_actor->GetAID()
620                         << " to actor:" << condition_gather_actor->GetAID();
621     }
622     const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
623     MS_EXCEPTION_IF_NULL(from_kernel_actor);
624     MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
625     if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
626       MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by control user node:"
627                         << from_kernel_actor->kernel()->fullname_with_scope()
628                         << " in actor:" << condition_gather_actor->GetAID();
629     }
630     MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
631                   << " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
632                   << " in actor:" << condition_gather_actor->GetAID();
633     const auto &current_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
634     // Get input op control num of each branch.
635     if (condition_gather_actor->branch_name_to_input_control_num_.find(current_branch_name) ==
636         condition_gather_actor->branch_name_to_input_control_num_.end()) {
637       condition_gather_actor->branch_name_to_input_control_num_[current_branch_name] = 1;
638     } else {
639       condition_gather_actor->branch_name_to_input_control_num_[current_branch_name]++;
640     }
641   }
642 }
643 
InitInputBranchInfoForConditionGatherActor(ConditionGatherActor * const condition_gather_actor,const KernelGraphPtr & kernel_graph)644 void InlineControlFlowScheduler::InitInputBranchInfoForConditionGatherActor(
645   ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
646   InitInputDataBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
647   InitInputControlBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
648 }
649 
HandleConditionGatherActor(const KernelActorPtr & kernel_actor)650 void InlineControlFlowScheduler::HandleConditionGatherActor(const KernelActorPtr &kernel_actor) {
651   const auto &condition_gather_actor = dynamic_cast<ConditionGatherActor *>(kernel_actor.get());
652   MS_EXCEPTION_IF_NULL(condition_gather_actor);
653   const auto &gather_node = condition_gather_actor->kernel();
654   MS_EXCEPTION_IF_NULL(gather_node);
655   const auto &graph = gather_node->func_graph();
656   if (graph == nullptr || !graph->isa<KernelGraph>()) {
657     MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_gather_actor->GetAID();
658   }
659   const auto &kernel_graph = graph->cast<KernelGraphPtr>();
660   MS_EXCEPTION_IF_NULL(kernel_graph);
661   const auto &gather_switch_map = kernel_graph->condition_gather_to_switch();
662   const auto &gather_switch_iter = gather_switch_map.find(gather_node);
663   if (gather_switch_iter == gather_switch_map.end()) {
664     MS_LOG_WITH_NODE(EXCEPTION, gather_node)
665       << "Failed to get switch node by gather node:" << gather_node->fullname_with_scope();
666   }
667   if (gather_switch_iter->second == nullptr) {
668     MS_LOG_WITH_NODE(EXCEPTION, gather_node)
669       << "Failed to get switch node by gather node:" << gather_node->fullname_with_scope()
670       << " in kernel graph:" << kernel_graph->ToString();
671   }
672   const auto &actor = FetchActor(GetActorIdByKernel(gather_switch_iter->second));
673   MS_EXCEPTION_IF_NULL(actor);
674   const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
675   MS_EXCEPTION_IF_NULL(condition_switch_actor);
676   condition_switch_actor->gather_aid_ = const_cast<AID *>(&condition_gather_actor->GetAID());
677 
678   if (!gather_node->HasAttr(kAttrBranchOutputNum)) {
679     MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
680   }
681   const auto &output_value = gather_node->GetAttr(kAttrBranchOutputNum);
682   MS_EXCEPTION_IF_NULL(output_value);
683   condition_gather_actor->branch_output_num_ = GetValue<size_t>(output_value);
684 
685   if (!gather_node->HasAttr(kAttrBranchGraphName)) {
686     MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_gather_actor->GetAID();
687   }
688   const auto &branch_graph_names = gather_node->GetAttr(kAttrBranchGraphName);
689   MS_EXCEPTION_IF_NULL(branch_graph_names);
690   MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
691                 << " for actor:" << condition_gather_actor->GetAID();
692   if (!branch_graph_names->isa<ValueTuple>()) {
693     MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
694                       << " for actor:" << condition_gather_actor->GetAID();
695   }
696   const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
697   MS_EXCEPTION_IF_NULL(tuple_name);
698   std::vector<std::string> branch_names;
699   std::for_each(tuple_name->value().begin(), tuple_name->value().end(),
700                 [&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
701   condition_gather_actor->branch_names_ = branch_names;
702   // Fix ref count.
703   FixRefCountByConditionGatherActor(condition_gather_actor, kernel_graph);
704   InitInputBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
705 }
706 
LinkControlArrowForNoInputOrOutputActor(ActorSet * actor_set,const mindspore::HashMap<std::string,AbstractActor * > & branch_name_to_switch_actor,const mindspore::HashMap<std::string,AbstractActor * > & branch_name_to_gather_actor)707 void InlineControlFlowScheduler::LinkControlArrowForNoInputOrOutputActor(
708   ActorSet *actor_set, const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_switch_actor,
709   const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_gather_actor) {
710   MS_EXCEPTION_IF_NULL(actor_set);
711   for (const auto &kernel_actor : actor_set->kernel_actors_) {
712     MS_EXCEPTION_IF_NULL(kernel_actor);
713     if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0) &&
714         IsInlineKernelActor(kernel_actor)) {
715       const auto &branch_name = GetBranchNameByKernelActor(kernel_actor.get());
716       const auto &iter = branch_name_to_switch_actor.find(branch_name);
717       if (iter == branch_name_to_switch_actor.end()) {
718         MS_LOG(EXCEPTION) << "Failed to get condition switch actor by branch name:" << branch_name;
719       }
720       MS_LOG(DEBUG) << "Inline control flow scheduler add control flow from switch actor:" << iter->second->GetAID()
721                     << " to kernel actor:" << kernel_actor->GetAID();
722       SchedulerHelper::AddControlArrow(iter->second, kernel_actor.get());
723     }
724     if (kernel_actor->output_data_arrows_.size() == 0 && kernel_actor->output_control_arrows_.size() == 0 &&
725         IsInlineKernelActor(kernel_actor)) {
726       const auto &branch_name = GetBranchNameByKernelActor(kernel_actor.get());
727       const auto &iter = branch_name_to_gather_actor.find(branch_name);
728       if (iter == branch_name_to_gather_actor.end()) {
729         MS_LOG(EXCEPTION) << "Failed to get condition gather actor by branch name:" << branch_name;
730       }
731       MS_LOG(DEBUG) << "Inline control flow scheduler add control flow from kernel actor:" << kernel_actor->GetAID()
732                     << " to gather actor:" << iter->second->GetAID();
733       SchedulerHelper::AddControlArrow(kernel_actor.get(), iter->second);
734     }
735   }
736 }
737 
Link(ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info,bool execution_order_running)738 void InlineControlFlowScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
739                                       bool execution_order_running) {
740   MS_EXCEPTION_IF_NULL(actor_set);
741   auto context_ptr = MsContext::GetInstance();
742   MS_EXCEPTION_IF_NULL(context_ptr);
743   mindspore::HashMap<std::string, AbstractActor *> branch_name_to_switch_actor;
744   mindspore::HashMap<std::string, AbstractActor *> branch_name_to_gather_actor;
745   for (const auto &graph : graph_compiler_info.graphs_) {
746     MS_EXCEPTION_IF_NULL(graph);
747     GetBranchNameToCondtionActor(graph, &branch_name_to_switch_actor, &branch_name_to_gather_actor);
748   }
749   LinkControlArrowForNoInputOrOutputActor(actor_set, branch_name_to_switch_actor, branch_name_to_gather_actor);
750   for (const auto &kernel_actor : actor_set->kernel_actors_) {
751     if (kernel_actor->type() == KernelTransformType::kConditionSwitchActor) {
752       HandleConditionSwitchActor(kernel_actor);
753     } else if (kernel_actor->type() == KernelTransformType::kConditionGatherActor) {
754       HandleConditionGatherActor(kernel_actor);
755     }
756   }
757   for (const auto &kernel_graph : graph_compiler_info.graphs_) {
758     MS_EXCEPTION_IF_NULL(kernel_graph);
759     if (kernel_graph->inline_sub_graph_kernels().empty()) {
760       continue;
761     }
762     for (const auto &ref_pair : kernel_graph->GetRefMap()) {
763       const auto &output_pair = ref_pair.first;
764       const auto &input_pair = ref_pair.second;
765       MS_EXCEPTION_IF_NULL(output_pair.first);
766       MS_EXCEPTION_IF_NULL(input_pair.first);
767       MS_LOG(DEBUG) << "output node:" << output_pair.first->fullname_with_scope()
768                     << " input node:" << input_pair.first->fullname_with_scope();
769       const auto &actor = FetchActor(GetActorIdByKernel(output_pair.first));
770       if (actor == nullptr) {
771         MS_LOG_WITH_NODE(EXCEPTION, output_pair.first)
772           << "Failed to get actor by ref node:" << output_pair.first->fullname_with_scope()
773           << " index:" << output_pair.second << " origin node:" << input_pair.first->fullname_with_scope()
774           << " index:" << input_pair.second << " in graph:" << kernel_graph->ToString();
775       }
776       size_t ref_count = 1;
777       std::for_each(actor->output_data_arrows().begin(), actor->output_data_arrows().end(),
778                     [&ref_count, &output_pair](const auto &data_arrow) {
779                       MS_EXCEPTION_IF_NULL(data_arrow);
780                       if (IntToSize(data_arrow->from_output_index_) == output_pair.second) {
781                         ++ref_count;
782                       }
783                     });
784       FixRefCountRecursively(output_pair, input_pair, kernel_graph, ref_count);
785     }
786   }
787 }
788 
FixRefCountRecursively(const KernelWithIndex & output_pair,const KernelWithIndex & input_pair,const KernelGraphPtr & kernel_graph,size_t ref_count)789 void InlineControlFlowScheduler::FixRefCountRecursively(const KernelWithIndex &output_pair,
790                                                         const KernelWithIndex &input_pair,
791                                                         const KernelGraphPtr &kernel_graph, size_t ref_count) {
792   MS_EXCEPTION_IF_NULL(output_pair.first);
793   MS_EXCEPTION_IF_NULL(input_pair.first);
794   if (common::AnfAlgo::CheckPrimitiveType(input_pair.first, prim::kPrimConditionGather)) {
795     return;
796   }
797   if (common::AnfAlgo::CheckPrimitiveType(input_pair.first, prim::kPrimConditionSwitch)) {
798     const auto &iter = kernel_graph->inline_sub_graph_kernels().find(output_pair.first);
799     if (iter == kernel_graph->inline_sub_graph_kernels().end()) {
800       MS_LOG_WITH_NODE(EXCEPTION, input_pair.first)
801         << "Invalid ref node pair, input node:" << input_pair.first->fullname_with_scope()
802         << " index:" << input_pair.second << " output node:" << output_pair.first->fullname_with_scope()
803         << " index:" << output_pair.second << " in kernel graph:" << kernel_graph->ToString();
804     }
805     const auto &branch_name = iter->second;
806     const auto &actor = FetchActor(GetActorIdByKernel(input_pair.first));
807     MS_EXCEPTION_IF_NULL(actor);
808     const auto &switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
809     MS_EXCEPTION_IF_NULL(switch_actor);
810     AddRefCountForConditionSwitchActor(switch_actor, branch_name, input_pair.second, ref_count);
811   }
812   if (kernel_graph->IsInRefOutputMap(input_pair)) {
813     FixRefCountRecursively(input_pair, kernel_graph->GetRefCorrespondOutput(input_pair), kernel_graph, ref_count);
814   }
815 }
816 }  // namespace runtime
817 }  // namespace mindspore
818