1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_SCHEDULER_HELPER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_SCHEDULER_HELPER_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <utility> 24 #include <map> 25 #include <set> 26 #include <algorithm> 27 #include "utils/hash_map.h" 28 #include "utils/hash_set.h" 29 #include "runtime/graph_scheduler/actor/actor_set.h" 30 31 namespace mindspore { 32 namespace runtime { 33 class SchedulerHelper { 34 public: 35 // Convert the actors vector by the actor set. 36 static std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set); 37 38 // Judge the input node whether need the control arrow. 39 static bool HasMonadControl(const AnfNodePtr &input_node, const KernelGraphPtr &graph); 40 41 static void AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor); 42 43 static void AddMonadDeviceTensorStore(AbstractActor *const to_actor, const CNodePtr &kernel, 44 const KernelGraphPtr &graph); 45 46 // Judge whether need ignore the input address that is not used in the kernel launch. 47 static bool IsIgnoredInputAddress(AbstractActor *const to_actor, size_t to_input_index); 48 static size_t GetIgnoredInputAddressCount(const AnfNodePtr &node); 49 50 // Add the arrow between from actor and to actor. 51 static void AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, size_t from_output_index, 52 size_t to_input_index, const AnfNodePtr &from_kernel = nullptr); 53 static void AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, 54 const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position); 55 static void AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor); 56 57 // Add the arrow for control actor. 58 static void AddPartialArrow(ControlActor *const from_actor, ControlActor *const to_actor, size_t from_index, 59 size_t to_index); 60 static void AddBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor); 61 // Body control arrow is only exists to entrance actor.. 62 static void AddLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor); 63 // Data arrow with branch id is only exists from gather actor to entrance actor. 64 static void AddDataWithBranchIDArrow(GatherActor *const gather_actor, const EntranceActor *entrance_actor, 65 const FuncGraphPtr &func_graph); 66 // Since the output of exit actor has branches, it needs to be based on a dedicated interface. 67 static void AddDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index, 68 size_t to_index, int branch_id); 69 static void AddPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index, 70 size_t to_index, int branch_id); 71 static void AddControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id); 72 73 // Fill the device tensors of backend input nodes corresponding to ref formal parameters. 74 static void AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index, 75 const AnfNodePtr &input_node, const KernelGraphPtr &graph); 76 77 // Convert the invalid data arrow to control arrow. 78 static void ConvertDataArrowToControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, 79 const DataArrowPtr &data_arrow, size_t data_arrow_index); 80 81 // Fuse the data arrows to batch data arrow for the same destination actor. 82 static void FuseDataArrowsToBatchDataArrow(AbstractActor *const actor); 83 84 // The interface of fusing the actors to a fusion actor. 85 static void AddDependency(AbstractActor *const actor, const AbstractActor *dependent_actor); 86 static bool CheckDependency(const std::vector<AbstractActorPtr> &output_actors); 87 static FusionActorPtr BuildFusionActor(const std::vector<AbstractActorPtr> &actors); 88 static void AddArrowForFusionActor(FusionActor *fusion_actor); 89 90 // The interface of integration of dynamic and static memory. 91 static void AddMemorySign(AbstractActor *const from_actor, AbstractActor *const to_actor); 92 static KernelGraphPtr FetchKernelGraphByActor(AbstractActor *const actor); 93 // Add the memory alloc sign for the head kernel actor of graph. 94 static void AddMemoryAllocSign(AbstractActor *const from_actor, AbstractActor *const to_actor, 95 const KernelGraphPtr &to_graph); 96 // Add the memory free sign for the tail kernel actor of graph. 97 static void AddMemoryFreeSign(AbstractActor *const from_actor, AbstractActor *const to_actor, 98 const KernelGraphPtr &from_graph); 99 static void AddSomasInfo(AbstractActor *const actor); 100 static void AddSomasInfoForGraphOutput(AbstractActor *const output_actor, const AnfNodePtr &output_kernel, 101 size_t output_index, size_t graph_id); 102 103 // Check whether the actor set is valid. 104 static void CheckActorValid(const ActorSet *actor_set); 105 106 static void DumpActorSet(const ActorSet *actor_set, std::ofstream &ofs); 107 static void DumpFormatActorSet(const ActorSet *actor_set, std::ofstream &ofs); 108 109 static size_t fusion_actor_index_; 110 }; 111 } // namespace runtime 112 } // namespace mindspore 113 114 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_ 115