• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <set>
24 #include <map>
25 #include <utility>
26 #include <unordered_map>
27 #include <algorithm>
28 #include "runtime/hardware/device_context.h"
29 #include "backend/session/kernel_graph.h"
30 
31 namespace mindspore {
32 namespace runtime {
33 using mindspore::device::DeviceContext;
34 using mindspore::session::KernelGraph;
35 using mindspore::session::KernelWithIndex;
36 
37 constexpr int kInvalidBranchID = -1;
38 constexpr int kMainBranchID = 0;
39 constexpr int kSubBranchStartID = 1;
40 
41 using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>;
42 using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>;
43 using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>;
44 using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
45 using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
46 using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
47 
48 // Check whether node is a call node, there are two types of call nodes:
49 // 1. First input of node is a cnode.
50 // 2. First input of node is a funcgraph value node.
51 bool IsCallNode(const AnfNodePtr &node);
52 
53 // Check if the call node is the input of another call node.
54 bool IsSubCallNode(const AnfNodePtr &node);
55 
56 // Recursive interface, find the real output of funcgraph called by call node.
57 AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes);
58 
59 // Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
60 // it is determined whether it is a weight.
61 bool HasAbstractRef(const AnfNodePtr &node);
62 
63 // Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph
64 // which the front node belongs, if not, find the funcgraph which the input of the node belongs.
65 FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node);
66 
67 // Recursive interface, get the number of output nodes of funcgraph called by call node.
68 size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes);
69 
70 // Get front node by backend node.
71 AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
72 
73 // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
74 // corresponding cnode.
75 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
76 
77 // Get the funcgraph to which the node belongs.
78 FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
79 
80 // Find all funcgraphs that the call node will call.
81 std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
82 
83 // Get parameters in kernel graph.
84 std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
85 
86 // ControlNodeParser is used to parse control nodes, and get the edges between nodes.
87 class ControlNodeParser {
88  public:
89   // Parse the control node and put the results of the parsing into member variables.
90   void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
91              const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph);
92 
control_node_parameters()93   const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
front_to_backend_parameters()94   const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
host_parameter_to_weights()95   const HostParameterToWeight &host_parameter_to_weights() const { return host_parameter_to_weights_; }
front_value_nodes()96   const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; }
97 
98   // Get the output of funcgraph, usually there is only one output node, In the control flow, there are
99   // multiple branch outputs, there will be multiple output nodes.
100   std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph);
101 
102   // Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node
103   // which device address belongs, so switch actor need to get all the possible nodes.
104   std::set<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output);
105 
106   // Get the device context corresponding to the value node.
107   DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node);
108 
109   // Get the branch id corresponding to funcgraph.
110   int GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph);
111 
112   // Get the number of calls to funcgraph
113   size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph);
114 
115   // Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node
116   // which device address belongs, so gather actor need to get all the possible nodes.
117   std::vector<KernelWithIndex> GetBackendInputByParameter(const AnfNodePtr &parameter);
118 
119   // Check whether there is a call node in the front input nodes of the kernel graph.
120   bool IsCallInputKernelGraph(const KernelGraphPtr &graph);
121 
122   // Check whether the kernel actor belongs to the root graph.
123   // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should
124   // be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the
125   // output switch actor.
126   bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel);
127 
128   // Get the backend node corresponding to the weight node in the subgraph.
129   AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node);
130 
GetBackendKernelByFrontKernel(const KernelWithIndex & front_node_with_index)131   KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) {
132     return front_to_backend_kernels_[front_node_with_index].first;
133   }
134 
135   AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
136 
137  private:
138   friend class GraphScheduler;
139 
140   // Collect all front value nodes. In the control flow, when the input of the switch actor is the value node, these
141   // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
142   // them separately during initialization.
143   // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
144   void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
145                            const std::vector<DeviceContext *> &device_contexts);
146   // Create branch id for all subgraphs in the control flow.
147   void CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
148   // Find all value nodes in the switch recursively.
149   void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes);
150   // Fetch all the relationships between front parameters and backend parameters.The front parameters
151   // include two parts:
152   // 1. The parameter from kernel graph.
153   // 2. The parameter from control nodes.
154   void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs,
155                                     const std::vector<DeviceContext *> &device_contexts,
156                                     const RealToFormalNode &real_to_formal_front_parameters,
157                                     const RealToFormalNode &formal_to_real_front_parameters);
158   // Get the relationship between the front and backend of the executable kernel in all kernel graphs.
159   void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs,
160                                  const std::vector<DeviceContext *> &device_contexts);
161   // Get inputs of control node which come from the host actor. These inputs generally come from the partial
162   // nodes and call nodes of the root funcgraph.
163   std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes,
164                                                     DeviceContext *device_context);
165   // Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node,
166   // and the input of the call node is the input parameter of the corresponding funcgraph.
167   void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes);
168   // Get all the front weight parameters related to the weight in the host parameter.
169   void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters);
170   // The relationship between front parameters indicates that the parameter is directly used as the input of the
171   // funcgraph. There are two situations:
172   // 1. The parameter is used as the input of the call node,
173   // 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the
174   //    subsequent call node.
175   void FetchFrontToFrontParameter(const std::vector<AnfNodePtr> &control_nodes,
176                                   std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter);
177   // Get the number of calls to all subgraphs in the whole funcgraph.
178   void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes);
179   // Get all the kernel graphs where the input node has a call node.
180   void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs,
181                                  const std::vector<DeviceContext *> &device_contexts);
182   // Get the relationship of all real and formal nodes in the whole funcgraph.
183   void FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs,
184                              const std::vector<DeviceContext *> &device_contexts,
185                              const RealToFormalNode &real_to_formal_front_parameters,
186                              const RealToFormalNode &formal_to_real_front_parameters);
187   // Get the relationship of all real and formal parameters in the whole funcgraph.
188   void FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs,
189                                  const std::vector<DeviceContext *> &device_contexts,
190                                  const RealToFormalNode &real_to_formal_front_parameters,
191                                  const RealToFormalNode &formal_to_real_front_parameters,
192                                  FrontToBackendNodeWithContext *front_to_backend_parameters);
193   // Get all possible input node of real parameter.
194   void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter,
195                                         const FrontToBackendNodeWithContext &front_to_backend_parameters);
196   // Recursive interface, get all Backend node by front_output.
197   void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes,
198                                        std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results);
199 
200   // Get the dependency between kernel and call node in auto monad.
201   void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
202   // The front to backend parameters is used to build and link the host data source actor in the control flow scenario.
203   FrontToBackendNodeWithContext front_to_backend_parameters_;
204 
205   // The relationship between all real parameters and formal parameters in the entire func_graph.
206   // In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output
207   // actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor.
208   // When the control actor receives the device address, it can find the corresponding input node.
209   std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_;
210 
211   // Relationship between the front and backend of the executable kernel in all kernel graphs.
212   FrontToBackendKernelWithContext front_to_backend_kernels_;
213 
214   // The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize
215   // the input node of gather.
216   FuncGraphToParameter func_graph_to_parameters_;
217 
218   // The relationship between the valuenode inputs of the call node and the backend parameter
219   std::map<KernelWithIndex, std::pair<AnfNodePtr, DeviceContext *>> call_node_to_backend_parameters_;
220 
221   // Branch id of funcgraph.
222   // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
223   // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
224   // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id
225   // to its output switch actor.
226   std::unordered_map<FuncGraphPtr, int> func_graph_to_branch_id_;
227 
228   // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph.
229   // When initializing the weights, all related weights need to be recorded as the same device tensor.
230   HostParameterToWeight host_parameter_to_weights_;
231   std::unordered_map<AnfNodePtr, AnfNodePtr> sub_front_node_to_root_front_node_;
232 
233   // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the
234   // input of the control node.
235   NodeWithDeviceContext front_value_nodes_;
236   // The front value node saves all parameters that are not in the kernel graph. These nodes are generally the
237   // output of subgraph, or the switch condition node.
238   NodeWithDeviceContext front_parameters_;
239 
240   // Parameters of control node which come from the host actor.
241   std::vector<AnfNodePtr> control_node_parameters_;
242   // The number of calls to func_graph.
243   std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_;
244   // The kernel graph of call exists in the front input node.
245   // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor
246   // is created for kernel graph which has a call input.
247   std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_;
248   // Root funcgraph and its parameters.
249   FuncGraphPtr root_func_graph_;
250   std::vector<AnfNodePtr> root_graph_parameters_;
251 
252   // The dependency between kernel and call node in auto monad.
253   std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
254 };
255 
256 using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>;
257 }  // namespace runtime
258 }  // namespace mindspore
259 
260 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_
261