1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_ 19 20 #include <string> 21 #include <stack> 22 #include "runtime/graph_scheduler/actor/actor_set.h" 23 24 namespace mindspore { 25 namespace runtime { 26 bool IsInlineKernelActor(const AbstractActorPtr &actor); 27 class InlineControlFlowScheduler { 28 public: 29 InlineControlFlowScheduler() = default; 30 ~InlineControlFlowScheduler() = default; 31 DISABLE_COPY_AND_ASSIGN(InlineControlFlowScheduler); 32 33 // Link control arrows and fix the member variables for condition actors. 34 void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info, bool execution_order_running); 35 void LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph, 36 const GraphCompilerInfo &graph_compiler_info) const; 37 38 private: 39 // Fix the member variables for condition actors. 40 void HandleConditionSwitchActor(const KernelActorPtr &kernel_actor); 41 void HandleConditionGatherActor(const KernelActorPtr &kernel_actor); 42 43 // Init the output branch info for condition actor. 44 // For condition switch actor, the output arrow include all the output branch and should be distinguished. 45 void InitOutputBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor, 46 const KernelGraphPtr &kernel_graph); 47 void InitOutputControlBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor, 48 const KernelGraphPtr &kernel_graph); 49 void InitOutputDataBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor, 50 const KernelGraphPtr &kernel_graph); 51 void InitInputBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor, 52 const KernelGraphPtr &kernel_graph); 53 void InitInputDataBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor, 54 const KernelGraphPtr &kernel_graph); 55 void InitInputControlBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor, 56 const KernelGraphPtr &kernel_graph); 57 58 // Fix ref count for condition actors. 59 // In condition switch actor, the ref count of actor should be change to total num for both branch. 60 // In condition gather actor, the ref count of gather input should add the ref count of gather output. 61 // The ref count of ref node should be add to the input of condition actor. 62 void FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor, 63 const KernelGraphPtr &kernel_graph); 64 void FixRefCountForRefNode(const KernelWithIndex &input_with_index, size_t ref_count, const std::string &branch_name, 65 const KernelGraph *const kernel_graph); 66 void FixRefCountForInputNode(const KernelWithIndex &input_with_index, size_t ref_count, 67 const std::string &branch_name); 68 std::string GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor, 69 KernelActor *condition_gather_actor, DataArrow *data_arrow, 70 const KernelGraphPtr &kernel_graph); 71 void FixRefCountRecursively(const KernelWithIndex &output_pair, const KernelWithIndex &input_pair, 72 const KernelGraphPtr &kernel_graph, size_t ref_count = 0); 73 void AddRefCountForConditionSwitchActor(ConditionSwitchActor *const switch_actor, const std::string &branch_name, 74 size_t output_index, size_t ref_count); 75 void LinkControlArrowForNoInputOrOutputActor( 76 ActorSet *actor_set, const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_switch_actor, 77 const mindspore::HashMap<std::string, AbstractActor *> &branch_name_to_gather_actor); 78 }; 79 } // namespace runtime 80 } // namespace mindspore 81 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_ 82