1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <set> 24 #include <map> 25 #include <utility> 26 #include <unordered_map> 27 #include <algorithm> 28 #include "runtime/hardware/device_context.h" 29 #include "backend/session/kernel_graph.h" 30 31 namespace mindspore { 32 namespace runtime { 33 using mindspore::device::DeviceContext; 34 using mindspore::session::KernelGraph; 35 using mindspore::session::KernelWithIndex; 36 37 constexpr int kInvalidBranchID = -1; 38 constexpr int kMainBranchID = 0; 39 constexpr int kSubBranchStartID = 1; 40 41 using FrontToBackendNodeWithContext = std::unordered_map<AnfNodePtr, std::pair<AnfNodePtr, DeviceContext *>>; 42 using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>; 43 using FuncGraphToParameter = std::unordered_map<FuncGraphPtr, std::vector<std::vector<AnfNodePtr>>>; 44 using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>; 45 using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>; 46 using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>; 47 48 // Check whether node is a call node, there are two types of call nodes: 49 // 1. First input of node is a cnode. 50 // 2. First input of node is a funcgraph value node. 51 bool IsCallNode(const AnfNodePtr &node); 52 53 // Check if the call node is the input of another call node. 54 bool IsSubCallNode(const AnfNodePtr &node); 55 56 // Recursive interface, find the real output of funcgraph called by call node. 57 AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes); 58 59 // Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph, 60 // it is determined whether it is a weight. 61 bool HasAbstractRef(const AnfNodePtr &node); 62 63 // Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph 64 // which the front node belongs, if not, find the funcgraph which the input of the node belongs. 65 FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node); 66 67 // Recursive interface, get the number of output nodes of funcgraph called by call node. 68 size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes); 69 70 // Get front node by backend node. 71 AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node); 72 73 // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the 74 // corresponding cnode. 75 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph); 76 77 // Get the funcgraph to which the node belongs. 78 FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node); 79 80 // Find all funcgraphs that the call node will call. 81 std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node); 82 83 // Get parameters in kernel graph. 84 std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph); 85 86 // ControlNodeParser is used to parse control nodes, and get the edges between nodes. 87 class ControlNodeParser { 88 public: 89 // Parse the control node and put the results of the parsing into member variables. 90 void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs, 91 const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph); 92 control_node_parameters()93 const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; } front_to_backend_parameters()94 const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; } host_parameter_to_weights()95 const HostParameterToWeight &host_parameter_to_weights() const { return host_parameter_to_weights_; } front_value_nodes()96 const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; } 97 98 // Get the output of funcgraph, usually there is only one output node, In the control flow, there are 99 // multiple branch outputs, there will be multiple output nodes. 100 std::vector<AnfNodePtr> FetchAllBranchOutputs(const FuncGraphPtr &func_graph); 101 102 // Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node 103 // which device address belongs, so switch actor need to get all the possible nodes. 104 std::set<KernelWithIndex> FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output); 105 106 // Get the device context corresponding to the value node. 107 DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node); 108 109 // Get the branch id corresponding to funcgraph. 110 int GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph); 111 112 // Get the number of calls to funcgraph 113 size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph); 114 115 // Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node 116 // which device address belongs, so gather actor need to get all the possible nodes. 117 std::vector<KernelWithIndex> GetBackendInputByParameter(const AnfNodePtr ¶meter); 118 119 // Check whether there is a call node in the front input nodes of the kernel graph. 120 bool IsCallInputKernelGraph(const KernelGraphPtr &graph); 121 122 // Check whether the kernel actor belongs to the root graph. 123 // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should 124 // be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the 125 // output switch actor. 126 bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel); 127 128 // Get the backend node corresponding to the weight node in the subgraph. 129 AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node); 130 GetBackendKernelByFrontKernel(const KernelWithIndex & front_node_with_index)131 KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) { 132 return front_to_backend_kernels_[front_node_with_index].first; 133 } 134 135 AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node); 136 137 private: 138 friend class GraphScheduler; 139 140 // Collect all front value nodes. In the control flow, when the input of the switch actor is the value node, these 141 // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for 142 // them separately during initialization. 143 // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to. 144 void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs, 145 const std::vector<DeviceContext *> &device_contexts); 146 // Create branch id for all subgraphs in the control flow. 147 void CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr> &control_nodes); 148 // Find all value nodes in the switch recursively. 149 void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector<AnfNodePtr> *value_nodes); 150 // Fetch all the relationships between front parameters and backend parameters.The front parameters 151 // include two parts: 152 // 1. The parameter from kernel graph. 153 // 2. The parameter from control nodes. 154 void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs, 155 const std::vector<DeviceContext *> &device_contexts, 156 const RealToFormalNode &real_to_formal_front_parameters, 157 const RealToFormalNode &formal_to_real_front_parameters); 158 // Get the relationship between the front and backend of the executable kernel in all kernel graphs. 159 void FetchFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs, 160 const std::vector<DeviceContext *> &device_contexts); 161 // Get inputs of control node which come from the host actor. These inputs generally come from the partial 162 // nodes and call nodes of the root funcgraph. 163 std::vector<AnfNodePtr> FetchControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes, 164 DeviceContext *device_context); 165 // Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node, 166 // and the input of the call node is the input parameter of the corresponding funcgraph. 167 void FetchFuncGraphToParameter(const std::vector<AnfNodePtr> &control_nodes); 168 // Get all the front weight parameters related to the weight in the host parameter. 169 void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters); 170 // The relationship between front parameters indicates that the parameter is directly used as the input of the 171 // funcgraph. There are two situations: 172 // 1. The parameter is used as the input of the call node, 173 // 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the 174 // subsequent call node. 175 void FetchFrontToFrontParameter(const std::vector<AnfNodePtr> &control_nodes, 176 std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> *front_to_front_parameter); 177 // Get the number of calls to all subgraphs in the whole funcgraph. 178 void FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes); 179 // Get all the kernel graphs where the input node has a call node. 180 void FetchCallInputKernelGraph(const std::vector<KernelGraphPtr> &graphs, 181 const std::vector<DeviceContext *> &device_contexts); 182 // Get the relationship of all real and formal nodes in the whole funcgraph. 183 void FetchBackendInputNode(const std::vector<KernelGraphPtr> &graphs, 184 const std::vector<DeviceContext *> &device_contexts, 185 const RealToFormalNode &real_to_formal_front_parameters, 186 const RealToFormalNode &formal_to_real_front_parameters); 187 // Get the relationship of all real and formal parameters in the whole funcgraph. 188 void FetchBackendParameterNode(const std::vector<KernelGraphPtr> &graphs, 189 const std::vector<DeviceContext *> &device_contexts, 190 const RealToFormalNode &real_to_formal_front_parameters, 191 const RealToFormalNode &formal_to_real_front_parameters, 192 FrontToBackendNodeWithContext *front_to_backend_parameters); 193 // Get all possible input node of real parameter. 194 void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter, 195 const FrontToBackendNodeWithContext &front_to_backend_parameters); 196 // Recursive interface, get all Backend node by front_output. 197 void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set<AnfNodePtr> *call_nodes, 198 std::set<AnfNodePtr> *switch_nodes, std::set<KernelWithIndex> *results); 199 200 // Get the dependency between kernel and call node in auto monad. 201 void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes); 202 // The front to backend parameters is used to build and link the host data source actor in the control flow scenario. 203 FrontToBackendNodeWithContext front_to_backend_parameters_; 204 205 // The relationship between all real parameters and formal parameters in the entire func_graph. 206 // In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output 207 // actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor. 208 // When the control actor receives the device address, it can find the corresponding input node. 209 std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> formal_to_real_parameters_; 210 211 // Relationship between the front and backend of the executable kernel in all kernel graphs. 212 FrontToBackendKernelWithContext front_to_backend_kernels_; 213 214 // The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize 215 // the input node of gather. 216 FuncGraphToParameter func_graph_to_parameters_; 217 218 // The relationship between the valuenode inputs of the call node and the backend parameter 219 std::map<KernelWithIndex, std::pair<AnfNodePtr, DeviceContext *>> call_node_to_backend_parameters_; 220 221 // Branch id of funcgraph. 222 // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to 223 // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch 224 // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id 225 // to its output switch actor. 226 std::unordered_map<FuncGraphPtr, int> func_graph_to_branch_id_; 227 228 // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph. 229 // When initializing the weights, all related weights need to be recorded as the same device tensor. 230 HostParameterToWeight host_parameter_to_weights_; 231 std::unordered_map<AnfNodePtr, AnfNodePtr> sub_front_node_to_root_front_node_; 232 233 // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the 234 // input of the control node. 235 NodeWithDeviceContext front_value_nodes_; 236 // The front value node saves all parameters that are not in the kernel graph. These nodes are generally the 237 // output of subgraph, or the switch condition node. 238 NodeWithDeviceContext front_parameters_; 239 240 // Parameters of control node which come from the host actor. 241 std::vector<AnfNodePtr> control_node_parameters_; 242 // The number of calls to func_graph. 243 std::unordered_map<FuncGraphPtr, size_t> func_graph_to_call_num_; 244 // The kernel graph of call exists in the front input node. 245 // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor 246 // is created for kernel graph which has a call input. 247 std::unordered_map<KernelGraphPtr, DeviceContext *> call_input_kernel_graphs_; 248 // Root funcgraph and its parameters. 249 FuncGraphPtr root_func_graph_; 250 std::vector<AnfNodePtr> root_graph_parameters_; 251 252 // The dependency between kernel and call node in auto monad. 253 std::unordered_map<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_; 254 }; 255 256 using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>; 257 } // namespace runtime 258 } // namespace mindspore 259 260 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_ 261