• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_SCHEDULER_HELPER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_SCHEDULER_HELPER_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <utility>
24 #include <map>
25 #include <set>
26 #include <algorithm>
27 #include "utils/hash_map.h"
28 #include "utils/hash_set.h"
29 #include "runtime/graph_scheduler/actor/actor_set.h"
30 
31 namespace mindspore {
32 namespace runtime {
33 class SchedulerHelper {
34  public:
35   // Convert the actors vector by the actor set.
36   static std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set);
37 
38   // Judge the input node whether need the control arrow.
39   static bool HasMonadControl(const AnfNodePtr &input_node, const KernelGraphPtr &graph);
40 
41   static void AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor);
42 
43   static void AddMonadDeviceTensorStore(AbstractActor *const to_actor, const CNodePtr &kernel,
44                                         const KernelGraphPtr &graph);
45 
46   // Judge whether need ignore the input address that is not used in the kernel launch.
47   static bool IsIgnoredInputAddress(AbstractActor *const to_actor, size_t to_input_index);
48   static size_t GetIgnoredInputAddressCount(const AnfNodePtr &node);
49 
50   // Add the arrow between from actor and to actor.
51   static void AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor, size_t from_output_index,
52                            size_t to_input_index, const AnfNodePtr &from_kernel = nullptr);
53   static void AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
54                              const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position);
55   static void AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor);
56 
57   // Add the arrow for control actor.
58   static void AddPartialArrow(ControlActor *const from_actor, ControlActor *const to_actor, size_t from_index,
59                               size_t to_index);
60   static void AddBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor);
61   // Body control arrow is only exists to entrance actor..
62   static void AddLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor);
63   // Data arrow with branch id is only exists from gather actor to entrance actor.
64   static void AddDataWithBranchIDArrow(GatherActor *const gather_actor, const EntranceActor *entrance_actor,
65                                        const FuncGraphPtr &func_graph);
66   // Since the output of exit actor has branches, it needs to be based on a dedicated interface.
67   static void AddDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index,
68                                        size_t to_index, int branch_id);
69   static void AddPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index,
70                                           size_t to_index, int branch_id);
71   static void AddControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id);
72 
73   // Fill the device tensors of backend input nodes corresponding to ref formal parameters.
74   static void AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index,
75                                              const AnfNodePtr &input_node, const KernelGraphPtr &graph);
76 
77   // Convert the invalid data arrow to control arrow.
78   static void ConvertDataArrowToControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
79                                              const DataArrowPtr &data_arrow, size_t data_arrow_index);
80 
81   // Fuse the data arrows to batch data arrow for the same destination actor.
82   static void FuseDataArrowsToBatchDataArrow(AbstractActor *const actor);
83 
84   // The interface of fusing the actors to a fusion actor.
85   static void AddDependency(AbstractActor *const actor, const AbstractActor *dependent_actor);
86   static bool CheckDependency(const std::vector<AbstractActorPtr> &output_actors);
87   static FusionActorPtr BuildFusionActor(const std::vector<AbstractActorPtr> &actors);
88   static void AddArrowForFusionActor(FusionActor *fusion_actor);
89 
90   // The interface of integration of dynamic and static memory.
91   static void AddMemorySign(AbstractActor *const from_actor, AbstractActor *const to_actor);
92   static KernelGraphPtr FetchKernelGraphByActor(AbstractActor *const actor);
93   // Add the memory alloc sign for the head kernel actor of graph.
94   static void AddMemoryAllocSign(AbstractActor *const from_actor, AbstractActor *const to_actor,
95                                  const KernelGraphPtr &to_graph);
96   // Add the memory free sign for the tail kernel actor of graph.
97   static void AddMemoryFreeSign(AbstractActor *const from_actor, AbstractActor *const to_actor,
98                                 const KernelGraphPtr &from_graph);
99   static void AddSomasInfo(AbstractActor *const actor);
100   static void AddSomasInfoForGraphOutput(AbstractActor *const output_actor, const AnfNodePtr &output_kernel,
101                                          size_t output_index, size_t graph_id);
102 
103   // Check whether the actor set is valid.
104   static void CheckActorValid(const ActorSet *actor_set);
105 
106   static void DumpActorSet(const ActorSet *actor_set, std::ofstream &ofs);
107   static void DumpFormatActorSet(const ActorSet *actor_set, std::ofstream &ofs);
108 
109   static size_t fusion_actor_index_;
110 };
111 }  // namespace runtime
112 }  // namespace mindspore
113 
114 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
115