• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <set>
24 #include <queue>
25 #include <map>
26 #include <stack>
27 #include <utility>
28 #include <algorithm>
29 #include "utils/hash_map.h"
30 #include "runtime/hardware/device_context.h"
31 #include "include/backend/kernel_graph.h"
32 
33 namespace mindspore {
34 namespace runtime {
35 using mindspore::device::DeviceContext;
36 using mindspore::session::KernelGraph;
37 using mindspore::session::KernelWithIndex;
38 
39 constexpr int kInvalidBranchID = -1;
40 constexpr int kMainBranchID = 0;
41 constexpr int kSubBranchStartID = 1;
42 constexpr size_t kSwitchInputNum = 4;
43 constexpr size_t kSwitchCondPos = 1;
44 constexpr size_t kSwitchPartialNum = 2;
45 constexpr size_t kSwitchLayerCondPos = 1;
46 constexpr size_t kSwitchLayerBranchPos = 2;
47 constexpr size_t kSwitchLayerInputNum = 3;
48 constexpr size_t kSwitchTrueBranchPos = 2;
49 constexpr size_t kSwitchFalseBranchPos = 3;
50 constexpr size_t kPartialFuncGraphPos = 1;
51 constexpr size_t kPartialInputStartPos = 2;
52 constexpr size_t kCallInputStartPos = 1;
53 constexpr size_t kMakeTupleInputStartPos = 1;
54 constexpr size_t kCNodeInputStartPos = 1;
55 constexpr size_t kReturnInputPos = 1;
56 constexpr size_t kSingleControlNode = 1;
57 
58 constexpr size_t kCsrTensorIndPtrIndex = 0;
59 constexpr size_t kCsrTensorIndicesIndex = 1;
60 constexpr size_t kCsrTensorValuesIndex = 2;
61 constexpr size_t kCsrTensorDenseShapeIndex = 3;
62 constexpr size_t kCsrParamOutputSize = 3;
63 constexpr size_t kCooTensorIndicesIndex = 0;
64 constexpr size_t kCooTensorValuesIndex = 1;
65 constexpr size_t kCooTensorDenseShapeIndex = 2;
66 constexpr size_t kMakeCSRTensorInputStartPos = 1;
67 constexpr size_t kMakeTensorInputStartPos = 1;
68 constexpr size_t kMakeCSRTensorInputNum = 4;
69 constexpr size_t kMakeCOOTensorInputNum = 3;
70 
71 using NodeWithIndexToContext = std::pair<KernelWithIndex, DeviceContext *>;
72 struct NodeWithContextCmp {
operatorNodeWithContextCmp73   bool operator()(const NodeWithIndexToContext &node1, const NodeWithIndexToContext &node2) const {
74     return node1.second->GetDeviceType() < node2.second->GetDeviceType();
75   }
76 };
77 
78 using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<NodeWithIndexToContext, NodeWithContextCmp>>;
79 using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
80 using FuncGraphToKernelGraphGroup = mindspore::HashMap<FuncGraphPtr, std::vector<std::vector<KernelGraphPtr>>>;
81 using HostParameterToWeight = std::map<AnfNodePtr, std::set<AnfNodePtr>>;
82 using NodeWithDeviceContext = std::set<std::pair<KernelWithIndex, const DeviceContext *>>;
83 using RealToFormalNode = mindspore::HashMap<AnfNodePtr, std::vector<AnfNodePtr>>;
84 using FormalToRealParameter = std::map<KernelWithIndex, std::set<KernelWithIndex>>;
85 using RealToFormalParameter = std::map<KernelWithIndex, std::set<KernelWithIndex>>;
86 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
87 using FrontNodeToKernelGraph = mindspore::HashMap<AnfNodePtr, KernelGraphPtr>;
88 using FuncGraphCallRelation = mindspore::HashMap<FuncGraphPtr, std::vector<std::set<FuncGraphPtr>>>;
89 using FuncGraphToCallNode = mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>>;
90 using CallNodeToFuncGraph = mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr>>;
91 using KernelGraphToDeviceContext = mindspore::HashMap<KernelGraphPtr, DeviceContext *>;
92 using GroupNameToCommuNodes =
93   mindspore::HashMap<std::string, std::pair<std::vector<CNodePtr>, std::vector<KernelGraphPtr>>>;
94 using ReturnDynamicLenArgIndex =
95   mindspore::HashMap<AnfNodePtr, mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, size_t>>>;
96 using ControlNodeDynamicLenArgIndex =
97   mindspore::HashMap<AnfNodePtr, mindspore::HashMap<FuncGraph *, mindspore::HashMap<size_t, size_t>>>;
98 // In the control flow, heterogeneous kernel graphs need to be reconnected in the same group, and the kernel graph
99 // group info is used to store the inputs and outputs of the group.
100 // Need stack indicates whether a stack actor needs to be created for the group.
101 // Level indicates the level of the output of the graph in the group.
102 struct KernelGraphGroupInfo {
103   bool need_stack_{0};
104   size_t level_;
105   std::string group_name_;
106   std::set<KernelGraphPtr> graphs_;
107   std::set<AnfNodePtr> monad_inputs_;
108   std::set<KernelWithIndex> monad_outputs_;
109   std::map<KernelWithIndex, const DeviceContext *> front_input_nodes_;
110   FrontToBackendKernelWithContext front_output_nodes_;
111 };
112 using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
113 
114 // Check whether the node is a csr node.
115 bool IsCsrNode(const AnfNodePtr &node);
116 bool IsCooNode(const AnfNodePtr &node);
117 // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
118 // corresponding cnode.
119 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph);
120 // Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
121 std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
122 // Fetch the real input of tuple get item node.
123 KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
124 // Check if the partial node is valid.
125 // Invalid partial nodes are those partial cnodes whose funcgraph is deadnode.
126 bool IsInvalidPartial(const AnfNodePtr &node);
127 // Check whether the switch node abstract is functional.
128 bool IsPartialInput(const AnfNodePtr &node);
129 // Fetch the depend nodes according to the monad node.
130 void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *const depend_nodes);
131 // Get all the depend nodes of node in side effect.
132 std::vector<AnfNodePtr> FetchAllMonadNodeByNode(const AnfNodePtr &node);
133 void CreateBuildInfoForFrontNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node);
134 // ControlNodeParser is used to parse control nodes, and get the edges between nodes.
135 class ControlNodeParser {
136  public:
ControlNodeParser()137   ControlNodeParser() : is_inited_(false), root_func_graph_(nullptr) {}
138 
139   // Parse the control node and put the results of the parsing into member variables.
140   void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
141              const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
142              const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs);
143 
IsInited()144   bool IsInited() const { return is_inited_; }
145   // Check whether there is a call node in the front input nodes of the kernel graph.
146   bool IsCallInputKernelGraph(KernelGraph *const graph);
147   // Check whether there is a call node in the front input nodes of the kernel graph group.
148   bool IsCallInputKernelGraphGroup(const std::string &group_name);
149   // Check whether the data arrow of the kernel actor needs to be connected to the control actor.
150   // There are two situations:
151   // 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph.
152   // 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor.
153   bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node);
154   // Only the parameters of root graph are persistent and fetched from the store, the parameters of sub graphs are not
155   // persistent and real parameters passed.
156   bool IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node);
157   bool IsRecursionCallNode(const AnfNodePtr &node);
158   bool IsNeedStackControlNode(const AnfNodePtr &node);
159   // If there is a recursive call node in the input of the kernel graph, the graph is recursive.
160   bool IsRecursionKernelGraph(const KernelGraphPtr &graph);
161   bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph);
162   bool IsInputInSameLevel(const AnfNodePtr &node);
163   // If the two input call nodes will call the same recursion graph in same time.
164   bool IsParallelCallRecursionGraph(const AnfNodePtr &call_node1, const AnfNodePtr &call_node2,
165                                     const FuncGraphToCallNode &func_graph_to_call_node);
control_node_parameters()166   const std::vector<KernelWithIndex> &control_node_parameters() const { return control_node_parameters_; }
front_to_backend_parameters()167   const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
front_value_nodes()168   const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; }
169 
170   // Fetch all funcgraphs that the call node may call.
171   const std::set<FuncGraphPtr> &FetchFuncGraphbyCallNode(const AnfNodePtr &control_node);
172   // Fetch the branch id corresponding to funcgraph.
173   int FetchBranchIDByCallNode(const AnfNodePtr &call_node);
174   // Fetch the kernel graph which the kernel belongs.
175   KernelGraphPtr FetchKernelGraphByFrontNode(const AnfNodePtr &kernel);
176   KernelWithIndex FetchBackendOutputByKernelGraph(const KernelGraphPtr &graph,
177                                                   const KernelWithIndex &front_node_with_index);
178   // Fetch the backend kernel of front node.
179   KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index);
180   FuncGraphPtr FetchFuncGraphByKernelGraph(const KernelGraph *const graph);
181   std::string FetchGroupNameByKernelGraph(const KernelGraphPtr &graph);
182   NodeWithIndexToContext FetchBackendParameterWithContextByFrontParameter(
183     const KernelWithIndex &front_parameter_with_index);
184   // Create tensor for value like scalar or monad U.
185   tensor::TensorPtr CreateTensorForValue(const ValuePtr &value);
AddControlNodeTensor(const tensor::TensorPtr & tensor)186   void AddControlNodeTensor(const tensor::TensorPtr &tensor) { (void)control_node_tensors_.emplace_back(tensor); }
187   void PrintParseInfo();
188 
189  private:
190   friend class GraphScheduler;
191   friend class ControlNodeScheduler;
192   // Collect all front value nodes. In the control flow, when the input of the switch actor is the value node, these
193   // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
194   // them separately during initialization.
195   // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
196   void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const DeviceContext *const default_context);
197   void CreateDeviceTensors(const std::vector<AnfNodePtr> &control_nodes, const DeviceContext *const default_context);
198   // Create branch id for all call node in the control flow.
199   void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);
200 
201   // Parse all the relationships between front parameters and backend parameters.The front parameters
202   // include two parts:
203   // 1. The parameter from kernel graph.
204   // 2. The parameter from control nodes.
205   void ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
206                                     const std::vector<DeviceContext *> &device_contexts);
207   // The relationship between front parameters indicates that the parameter is directly used as the input of the
208   // funcgraph. There are two situations:
209   // 1. The parameter is used as the input of the call node,
210   // 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
211   //    subsequent call node.
212   void ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes);
213   // Recursively get all the real parameters corresponding to the formal parameters.
214   void ParseAllRealParameterByFormalParameter(const KernelWithIndex &formal_parameter,
215                                               const FormalToRealParameter &formal_to_real_parameters,
216                                               std::set<KernelWithIndex> *const total_real_parameters,
217                                               std::set<KernelWithIndex> *invalid_real_parameter);
218   // Get all the call nodes without a recursion call relation.
219   void ParseUnRecursionCallNode();
220 
221   // Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be
222   // copied between different device memories. The analysis steps:
223   // 1. Get the device context of the funcgraph parameter according to the device type of the kernel in the funcgraph.
224   // 2. Determine the type of device context output by funcgraph according to the call relationship of funcgrpah.
225   // 3. Determine the type of device context output for the real parameters on the partial nodes and call nodes.
226   void ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
227                           const std::vector<KernelGraphPtr> &kernel_graphs,
228                           const std::vector<DeviceContext *> &device_contexts, DeviceContext *default_context,
229                           const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs);
230   void ParseDeviceContextForFuncGraph(const std::vector<KernelGraphPtr> &kernel_graphs,
231                                       const std::vector<DeviceContext *> &device_contexts,
232                                       DeviceContext *default_context,
233                                       const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs);
234   void ParseDeviceContextForReturnNode(const DeviceContext *default_context);
235   void ParseDeviceContextForCallNode(const std::vector<AnfNodePtr> &control_nodes);
236   void ParseDeviceContextForPartialNode(const std::vector<AnfNodePtr> &control_nodes);
237   void FetchDeviceContextByNode(const std::vector<KernelWithIndex> &output_nodes,
238                                 std::vector<const DeviceContext *> *return_device_contexts,
239                                 const FuncGraphPtr &func_graph, const DeviceContext *default_context);
240   // In the actor model, when the funcgraph comes to an end temporarily, the exit of the funcgraph needs to notify
241   // the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all
242   // actors in the funcgraph that need to send control messages to the entrance.
243   // These node are control nodes without control node input in the topological sort of the funcgraph.
244   void ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
245   // Parse all funcgraphs that call nodes may call.
246   void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
247 
248   // Get the relationship between the front and backend of the executable kernel in all kernel graphs.
249   void ParseFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
250                                  const std::vector<DeviceContext *> &device_contexts);
251   void ParseFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs);
252   // nodes and call nodes of the root funcgraph.
253   void ParseControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes);
254   // Get the control nodes and kernel graphs which need to add a stack actor for them.
255   // When a control node or kernel graph has input that is a call node, you need to add a stack actor for it.
256   void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
257   bool IsCallNodeNeedStack(const AnfNodePtr &node);
258   void ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts);
259   // Parse the level of inputs and outputs of graphs and all control nodes.
260   void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes);
261   // Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level
262   // among them.
263   size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes);
264   // When there is the possibility of calling the same funcgraph at multiple places in the graph, the graph cannot
265   // be executed in parallel, and all call nodes need to be executed serially.
266   void InsertDependForParallelCall(const std::vector<AnfNodePtr> &control_nodes);
267   // When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
268   // tensor needs to be created for it.
269   void CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context);
270   void ParseDynamicLenFormalParameter(const std::vector<AnfNodePtr> &control_nodes);
271   void ParseDynamicLenFormalParameterByCallNode(const AnfNodePtr &node);
272   void ParseDynamicLenFormalParameterByPartial(const AnfNodePtr &node);
273   // In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded
274   // in this map.
275   FuncGraphToKernelGraphGroup func_graph_to_kernel_graph_groups_;
276   // The kernel graph to which the front node belongs after the funcgraph is cut.
277   FrontNodeToKernelGraph front_node_to_kernel_graph_;
278 
279   // The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
280   FrontToBackendNodeWithContext front_to_backend_parameters_;
281   // Relationship between the front and backend of the executable kernel in all kernel graphs.
282   FrontToBackendKernelWithContext front_to_backend_kernels_;
283 
284   // Relationship between formal parameters and real parameters.
285   FormalToRealParameter formal_to_real_parameters_;
286   RealToFormalParameter real_to_formal_parameters_;
287 
288   // Branch id of funcgraph.
289   // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
290   // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
291   // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
292   // to its output switch actor.
293   mindspore::HashMap<AnfNodePtr, int> call_node_to_branch_id_;
294   // Level indicates that the input of the node depends on the number of the recursive call node in the funcgraph.
295   // During graph scheduler, the input needs to be graded according to the input's dependence on the recursive call
296   // node, and according to this level, the lower-level inputs are pushed in the stack actor. When arranging, first
297   // sort the call nodes in the funcgraph according to their topological relationships, and then confirm the
298   // dependencies of other nodes on these call nodes in turn.
299   // For example, the dependencies are a -> b, b -> d, c -> d, where b is a call node, then the level of a and c is 0,
300   // and the level of bd is 1, then since d has inputs with different levels of b and c, it is necessary to add a
301   // stack to d.
302   mindspore::HashMap<AnfNodePtr, size_t> node_to_level_;
303   CallNodeToFuncGraph call_node_to_func_graphs_;
304   FuncGraphToCallNode func_graph_to_call_nodes_;
305   // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
306   // input of the control node.
307   NodeWithDeviceContext front_value_nodes_;
308 
309   // Parameters of control node which come from the host actor.
310   std::vector<KernelWithIndex> control_node_parameters_;
311   // The kernel graph of call exists in the front input node.
312   // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a stack actor
313   // is created for kernel graph which has a call input.
314   std::set<KernelGraph *> call_input_kernel_graphs_;
315   std::set<KernelGraphGroupInfoPtr> kernel_graph_group_infos_;
316   // Control nodes without a control node input in the topological sorting of funcgraph.
317   mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
318   // Kernel graphs need to link a control arrow to its entrance actor.
319   // In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
320   // entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
321   // to the entrance actor.
322   mindspore::HashMap<FuncGraphPtr, std::set<KernelGraphGroupInfoPtr>> func_graph_to_first_kernel_graphs_;
323   // Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node
324   // belong.
325   std::set<AnfNodePtr> unrecursion_call_nodes_;
326   // Those control nodes that need to create the corresponding stack actor, when there is a call node in the inputs
327   // of the control node, the stack actor is needed to collect these inputs.
328   std::set<AnfNodePtr> need_stack_control_nodes_;
329 
330   // In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding
331   // to the type needs to be parsed in advance so that it can add some copy operation in the scheduler.
332   // 1. The device context type of the formal parameters of funcgraph.
333   mindspore::HashMap<FuncGraphPtr, std::vector<const DeviceContext *>> func_graph_to_device_contexts_;
334   // 2. The device context type of the control node inputs.
335   mindspore::HashMap<AnfNodePtr, std::vector<const DeviceContext *>> control_node_to_device_contexts_;
336 
337   // Kernel graph to the group info it belongs.
338   mindspore::HashMap<KernelGraphPtr, KernelGraphGroupInfoPtr> kernel_graphs_to_group_info_;
339   // Scalar value will be convert to tensor in control flow, these tensors are placed in the vector.
340   std::vector<tensor::TensorPtr> control_node_tensors_;
341   // The index of the argument that needs to be converted into a dynamic len sequence.
342   ReturnDynamicLenArgIndex return_to_call_with_dynamic_sequence_index_;
343   ControlNodeDynamicLenArgIndex control_node_to_funcgraph_with_dynamic_sequence_index_;
344   // Is control flow enable.
345   bool is_inited_;
346 
347   // Root funcgraph and its parameters.
348   FuncGraphPtr root_func_graph_;
349   std::vector<AnfNodePtr> root_graph_parameters_;
350 };
351 
352 using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
353 }  // namespace runtime
354 }  // namespace mindspore
355 
356 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_
357