• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
19 
20 #include <string>
21 #include <stack>
22 #include "runtime/graph_scheduler/actor/actor_set.h"
23 
24 namespace mindspore {
25 namespace runtime {
26 bool IsInlineKernelActor(const AbstractActorPtr &actor);
27 class InlineControlFlowScheduler {
28  public:
29   InlineControlFlowScheduler() = default;
30   ~InlineControlFlowScheduler() = default;
31   DISABLE_COPY_AND_ASSIGN(InlineControlFlowScheduler);
32 
33   // Link control arrows and fix the member variables for condition actors.
34   void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info, bool execution_order_running);
35   void LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
36                                         const GraphCompilerInfo &graph_compiler_info) const;
37 
38  private:
39   // Fix the member variables for condition actors.
40   void HandleConditionSwitchActor(const KernelActorPtr &kernel_actor);
41   void HandleConditionGatherActor(const KernelActorPtr &kernel_actor);
42 
43   // Init the output branch info for condition actor.
44   // For condition switch actor, the output arrow include all the output branch and should be distinguished.
45   void InitOutputBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
46                                                    const KernelGraphPtr &kernel_graph);
47   void InitOutputControlBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
48                                                           const KernelGraphPtr &kernel_graph);
49   void InitOutputDataBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
50                                                        const KernelGraphPtr &kernel_graph);
51   void InitInputBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
52                                                   const KernelGraphPtr &kernel_graph);
53   void InitInputDataBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
54                                                       const KernelGraphPtr &kernel_graph);
55   void InitInputControlBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
56                                                          const KernelGraphPtr &kernel_graph);
57 
58   // Fix ref count for condition actors.
59   // In condition switch actor, the ref count of actor should be change to total num for both branch.
60   // In condition gather actor, the ref count of gather input should add the ref count of gather output.
61   // The ref count of ref node should be add to the input of condition actor.
62   void FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
63                                          const KernelGraphPtr &kernel_graph);
64   void FixRefCountForRefNode(const KernelWithIndex &input_with_index, size_t ref_count, const std::string &branch_name,
65                              const KernelGraph *const kernel_graph);
66   void FixRefCountForInputNode(const KernelWithIndex &input_with_index, size_t ref_count,
67                                const std::string &branch_name);
68   std::string GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
69                                                   KernelActor *condition_gather_actor, DataArrow *data_arrow,
70                                                   const KernelGraphPtr &kernel_graph);
71   void FixRefCountRecursively(const KernelWithIndex &output_pair, const KernelWithIndex &input_pair,
72                               const KernelGraphPtr &kernel_graph, size_t ref_count = 0);
73   void AddRefCountForConditionSwitchActor(ConditionSwitchActor *const switch_actor, const std::string &branch_name,
74                                           size_t output_index, size_t ref_count);
75   void LinkControlArrowForNoInputOrOutputActor(
76     ActorSet *actor_set, const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_switch_actor,
77     const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_gather_actor);
78 };
79 }  // namespace runtime
80 }  // namespace mindspore
81 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
82