• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <utility>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <map>
27 #include <set>
28 #include <algorithm>
29 #include <fstream>
30 #include "runtime/framework/graph_compiler.h"
31 #include "runtime/framework/actor/data_prepare_actor.h"
32 #include "runtime/framework/actor/data_source_actor.h"
33 #include "runtime/framework/actor/loop_count_actor.h"
34 #include "runtime/framework/actor/kernel_actor.h"
35 #include "runtime/framework/actor/output_actor.h"
36 #include "runtime/framework/actor/switch_actor.h"
37 #include "runtime/framework/actor/gather_actor.h"
38 #include "runtime/framework/actor/copy_actor.h"
39 #include "thread/actor_threadpool.h"
40 
41 namespace mindspore {
42 namespace runtime {
43 using mindspore::device::DeviceContext;
44 using mindspore::session::KernelGraph;
45 using mindspore::session::KernelWithIndex;
46 using ActorInfo = std::string;
47 
48 // The second element of pair represents the output index of abstract actor corresponding to the graph output node.
49 using GraphOutputPair = std::pair<AbstractActor *, size_t>;
50 
51 // The actor set generated by graph transformer is the execution unit of actor runtime.
52 // It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
53 // The data prepare actor is used to prepare data for device tensor store and host tensor queue to represent the begin
54 // of one step.
55 // The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
56 // The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
57 // kernel actor, it means that this actor has no input device tensor, need be triggered externally.
58 // The switch actor is used to run different branches in the control flow scenario.
59 // The gather actor is used to collect the inputs of graph and send branch id to loop count actor in multi-branch
60 // output scenario.
61 // The copy actor is used to convert the device tensor between the different device kernel.
62 // The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
63 // and decide whether to loop execution by loop count.
64 // The output actor is used to receive the output result of actor which represents the graph output.
65 struct ActorSet {
ActorSetActorSet66   explicit ActorSet(const ActorInfo &name) : name_(name) {}
67   DataPrepareActorPtr data_prepare_actor_{nullptr};
68   std::vector<DataSourceActorPtr> data_source_actors_;
69   std::vector<KernelActorPtr> kernel_actors_;
70   // No input kernel actors need be triggered specifically.
71   std::vector<KernelActorPtr> no_input_kernel_actors_;
72   std::vector<SwitchActorPtr> switch_actors_;
73   std::vector<GatherActorPtr> gather_actors_;
74   std::vector<CopyActorPtr> copy_actors_;
75   LoopCountActorPtr loop_count_actor_{nullptr};
76   OutputActorPtr output_actor_{nullptr};
77   ActorInfo name_;
78 };
79 using ActorSetPtr = std::shared_ptr<ActorSet>;
80 
81 class GraphScheduler {
82  public:
GetInstance()83   static GraphScheduler &GetInstance() noexcept {
84     static GraphScheduler instance;
85     return instance;
86   }
87 
88   // 1. Thread pool creating.
89   // 2. The global actors creating and scheduling.
90   void Initialize();
91 
92   // Clear the members.
93   void Clear();
94   void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept;
95 
96   // Transform graph to actor DAG, contains build and link.
97   ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
98 
99   // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
100   // will be supported in the future.
101   void Schedule(const ActorSet *actor_set);
102 
103   // The processing entry of actors running. The third parameter is used only in the step execution strategy.
104   bool Run(const ActorSet *actor_set, const std::vector<std::vector<TensorPtr>> &input_tensors,
105            const std::vector<TensorPtr> &input_tensors_with_value_node = {},
106            GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
107 
108   // Fetch the actor set by actor info.
109   ActorSet *Fetch(const ActorInfo &actor_info) const;
110 
111  private:
112   GraphScheduler() = default;
113   ~GraphScheduler() = default;
114   DISABLE_COPY_AND_ASSIGN(GraphScheduler);
115 
116   // The Global actors contain memory manager actor, recorder actor and debug actor.
117   void BuildAndScheduleGlobalActor();
118 
119   // Transform the nodes of graph to actors.
120   ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
121   // Link actors to DAG through the edge connection of graph and graph execution strategy.
122   void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
123 
124   // The processing of actors build.
125   std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
126                                                        const HostTensorQueuePtr &host_queue);
127   std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
128   LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
129   OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
130   DataPrepareActorPtr BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
131                                             const std::vector<DataSourceActorPtr> &data_source_actors,
132                                             const HostTensorQueuePtr &host_queue);
133   std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
134   std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
135   std::vector<GatherActorPtr> BuildGatherActor(const GraphCompilerInfo &graph_compiler_info);
136 
137   // Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
138   // previous graph and the head of next graph.
139   void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
140 
141   // The processing of actors linking.
142   // 1. The processing of linking data arrows.
143   // The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
144   void LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
145                      const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
146                      const KernelWithIndex &to_kernel_with_input_idx);
147   void LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
148                                  const KernelWithIndex &from_kernel_with_output_idx,
149                                  const KernelWithIndex &to_kernel_with_input_idx);
150   // Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
151   void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, KernelActor *const to_actor,
152                                          const KernelWithIndex &from_kernel_with_output_idx,
153                                          const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
154   void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, KernelActor *const to_actor,
155                                          const KernelWithIndex &from_kernel_with_output_idx,
156                                          const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
157   void LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
158                                      const KernelWithIndex &from_kernel_with_output_idx,
159                                      const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
160   void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
161                                    const KernelWithIndex &from_kernel_with_output_idx,
162                                    const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
163   void LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
164                                    const KernelWithIndex &from_kernel_with_output_idx,
165                                    const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
166   // Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
167   void LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
168                                  const KernelWithIndex &from_kernel_with_output_idx,
169                                  const KernelWithIndex &to_kernel_with_input_idx);
170 
171   // 2. The processing of linking control arrows.
172   void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph);
173   // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
174   void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
175   // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
176   void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
177 
178   // The gather of linking the global control arrows, it will call following functions:
179   void LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
180                               const std::vector<KernelActor *> &auto_monad_actors,
181                               const GraphCompilerInfo &graph_compiler_info);
182   // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
183   void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
184                                            const GraphCompilerInfo &graph_compiler_info);
185   void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
186   void LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor, const ActorSet *actor_set);
187   void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
188                                          const ControlNodeParserPtr &parser);
189 
190   // 3. The processing of linking output result arrows.
191   void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
192 
193   // 4. The processing of control flow linking.
194   void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set);
195   void LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
196                                    const KernelWithIndex &front_node_with_index,
197                                    const KernelWithIndex &to_node_with_index);
198   void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *const actor);
199   // Connect the input of the actor.
200   void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
201                                   const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
202                                   const size_t to_index);
203   // When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be
204   // connected.
205   void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser,
206                                 const FuncGraphPtr &from_func_graph, OpActor<DeviceTensor> *const to_actor,
207                                 const size_t to_index);
208   void LinkDataArrowForSwitchActor(SwitchActor *const from_actor, const size_t from_index,
209                                    OpActor<DeviceTensor> *const to_actor, const size_t to_index,
210                                    const size_t branch_index = SIZE_MAX);
211 
212   void LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
213                                       const std::vector<KernelGraphPtr> &graphs, const ControlNodeParserPtr &parser);
214 
215   void LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors, LoopCountActor *const to_actor,
216                                       const KernelMapPosition &origin_outputs_order);
217   // In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to
218   // send the branch id to the loop count actor.
219   void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info);
220   void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info);
221   void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
222   // Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added
223   // to switch actor.
224   void PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes);
225 
226   // Check whether the actor set is valid.
227   bool CheckActorValid(const ActorSet *actor_set,
228                        GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) const;
229 
230   // Persist device tensors of graph's some nodes(such as weights and value nodes).
231   void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
232 
233   // The fetch results are kernel_type and kernel_name.
234   void FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
235                                        const GraphCompilerInfo &graph_compiler_info,
236                                        KernelTransformType *const kernel_type, std::string *const kernel_name);
237 
238   // The operation of the map of actor_name_to_actor_.
239   void InsertActor(OpActor<DeviceTensor> *actor);
240   OpActor<DeviceTensor> *FetchActor(const std::string &actor_name) const;
241 
242   // Display the actor information of corresponding kernel graph.
243   void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
244   void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const;
245   void DumpDataPrepareActor(const DataPrepareActor *actor, std::ofstream &ofs) const;
246   void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
247   void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
248   void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
249   void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
250   void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const;
251   void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const;
252   void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const;
253   void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;
254 
255   // The global maps, only be cleared in the deconstruction.
256   std::unordered_map<ActorInfo, ActorSetPtr> actors_;
257   std::unordered_map<std::string, OpActor<DeviceTensor> *> actor_name_to_actor_;
258 
259   // The local maps and vectors, will be cleared at the end of each graph transform:
260   // 1.The second element of pair represents the output index of op actor corresponding to the graph output front node.
261   std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
262   // 2.Since the control node does not have a backend node, it can only be connected through the relationship between
263   // the front node, so the mapping relationship between the front node and the actor needs to be recorded.
264   std::unordered_map<AnfNodePtr, KernelActorPtr> front_node_to_actor_;
265   // 3.Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push
266   // into the actor set after link.
267   std::vector<CopyActorPtr> copy_actors_;
268 
269   // The id of global actor.
270   AID memory_manager_aid_;
271   const AID *recorder_aid_{nullptr};
272   const AID *debug_aid_{nullptr};
273 
274   bool init_{false};
275 };
276 }  // namespace runtime
277 }  // namespace mindspore
278 
279 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
280