1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <set> 24 #include <queue> 25 #include <map> 26 #include <stack> 27 #include <utility> 28 #include <algorithm> 29 #include "utils/hash_map.h" 30 #include "runtime/hardware/device_context.h" 31 #include "include/backend/kernel_graph.h" 32 33 namespace mindspore { 34 namespace runtime { 35 using mindspore::device::DeviceContext; 36 using mindspore::session::KernelGraph; 37 using mindspore::session::KernelWithIndex; 38 39 constexpr int kInvalidBranchID = -1; 40 constexpr int kMainBranchID = 0; 41 constexpr int kSubBranchStartID = 1; 42 constexpr size_t kSwitchInputNum = 4; 43 constexpr size_t kSwitchCondPos = 1; 44 constexpr size_t kSwitchPartialNum = 2; 45 constexpr size_t kSwitchLayerCondPos = 1; 46 constexpr size_t kSwitchLayerBranchPos = 2; 47 constexpr size_t kSwitchLayerInputNum = 3; 48 constexpr size_t kSwitchTrueBranchPos = 2; 49 constexpr size_t kSwitchFalseBranchPos = 3; 50 constexpr size_t kPartialFuncGraphPos = 1; 51 constexpr size_t kPartialInputStartPos = 2; 52 constexpr size_t kCallInputStartPos = 1; 53 constexpr size_t kMakeTupleInputStartPos = 1; 54 constexpr size_t kCNodeInputStartPos = 1; 55 constexpr size_t kReturnInputPos = 1; 56 constexpr size_t kSingleControlNode = 1; 57 58 constexpr size_t kCsrTensorIndPtrIndex = 0; 59 constexpr size_t kCsrTensorIndicesIndex = 1; 60 constexpr size_t kCsrTensorValuesIndex = 2; 61 constexpr size_t kCsrTensorDenseShapeIndex = 3; 62 constexpr size_t kCsrParamOutputSize = 3; 63 constexpr size_t kCooTensorIndicesIndex = 0; 64 constexpr size_t kCooTensorValuesIndex = 1; 65 constexpr size_t kCooTensorDenseShapeIndex = 2; 66 constexpr size_t kMakeCSRTensorInputStartPos = 1; 67 constexpr size_t kMakeTensorInputStartPos = 1; 68 constexpr size_t kMakeCSRTensorInputNum = 4; 69 constexpr size_t kMakeCOOTensorInputNum = 3; 70 71 using NodeWithIndexToContext = std::pair<KernelWithIndex, DeviceContext *>; 72 struct NodeWithContextCmp { operatorNodeWithContextCmp73 bool operator()(const NodeWithIndexToContext &node1, const NodeWithIndexToContext &node2) const { 74 return node1.second->GetDeviceType() < node2.second->GetDeviceType(); 75 } 76 }; 77 78 using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<NodeWithIndexToContext, NodeWithContextCmp>>; 79 using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>; 80 using FuncGraphToKernelGraphGroup = mindspore::HashMap<FuncGraphPtr, std::vector<std::vector<KernelGraphPtr>>>; 81 using HostParameterToWeight = std::map<AnfNodePtr, std::set<AnfNodePtr>>; 82 using NodeWithDeviceContext = std::set<std::pair<KernelWithIndex, const DeviceContext *>>; 83 using RealToFormalNode = mindspore::HashMap<AnfNodePtr, std::vector<AnfNodePtr>>; 84 using FormalToRealParameter = std::map<KernelWithIndex, std::set<KernelWithIndex>>; 85 using RealToFormalParameter = std::map<KernelWithIndex, std::set<KernelWithIndex>>; 86 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; 87 using FrontNodeToKernelGraph = mindspore::HashMap<AnfNodePtr, KernelGraphPtr>; 88 using FuncGraphCallRelation = mindspore::HashMap<FuncGraphPtr, std::vector<std::set<FuncGraphPtr>>>; 89 using FuncGraphToCallNode = mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>>; 90 using CallNodeToFuncGraph = mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr>>; 91 using KernelGraphToDeviceContext = mindspore::HashMap<KernelGraphPtr, DeviceContext *>; 92 using GroupNameToCommuNodes = 93 mindspore::HashMap<std::string, std::pair<std::vector<CNodePtr>, std::vector<KernelGraphPtr>>>; 94 using ReturnDynamicLenArgIndex = 95 mindspore::HashMap<AnfNodePtr, mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, size_t>>>; 96 using ControlNodeDynamicLenArgIndex = 97 mindspore::HashMap<AnfNodePtr, mindspore::HashMap<FuncGraph *, mindspore::HashMap<size_t, size_t>>>; 98 // In the control flow, heterogeneous kernel graphs need to be reconnected in the same group, and the kernel graph 99 // group info is used to store the inputs and outputs of the group. 100 // Need stack indicates whether a stack actor needs to be created for the group. 101 // Level indicates the level of the output of the graph in the group. 102 struct KernelGraphGroupInfo { 103 bool need_stack_{0}; 104 size_t level_; 105 std::string group_name_; 106 std::set<KernelGraphPtr> graphs_; 107 std::set<AnfNodePtr> monad_inputs_; 108 std::set<KernelWithIndex> monad_outputs_; 109 std::map<KernelWithIndex, const DeviceContext *> front_input_nodes_; 110 FrontToBackendKernelWithContext front_output_nodes_; 111 }; 112 using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>; 113 114 // Check whether the node is a csr node. 115 bool IsCsrNode(const AnfNodePtr &node); 116 bool IsCooNode(const AnfNodePtr &node); 117 // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the 118 // corresponding cnode. 119 KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph); 120 // Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem. 121 std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node); 122 // Fetch the real input of tuple get item node. 123 KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index); 124 // Check if the partial node is valid. 125 // Invalid partial nodes are those partial cnodes whose funcgraph is deadnode. 126 bool IsInvalidPartial(const AnfNodePtr &node); 127 // Check whether the switch node abstract is functional. 128 bool IsPartialInput(const AnfNodePtr &node); 129 // Fetch the depend nodes according to the monad node. 130 void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *const depend_nodes); 131 // Get all the depend nodes of node in side effect. 132 std::vector<AnfNodePtr> FetchAllMonadNodeByNode(const AnfNodePtr &node); 133 void CreateBuildInfoForFrontNode(const KernelWithIndex &front_node_with_index, const AnfNodePtr &backend_node); 134 // ControlNodeParser is used to parse control nodes, and get the edges between nodes. 135 class ControlNodeParser { 136 public: ControlNodeParser()137 ControlNodeParser() : is_inited_(false), root_func_graph_(nullptr) {} 138 139 // Parse the control node and put the results of the parsing into member variables. 140 void Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs, 141 const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph, 142 const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs); 143 IsInited()144 bool IsInited() const { return is_inited_; } 145 // Check whether there is a call node in the front input nodes of the kernel graph. 146 bool IsCallInputKernelGraph(KernelGraph *const graph); 147 // Check whether there is a call node in the front input nodes of the kernel graph group. 148 bool IsCallInputKernelGraphGroup(const std::string &group_name); 149 // Check whether the data arrow of the kernel actor needs to be connected to the control actor. 150 // There are two situations: 151 // 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph. 152 // 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor. 153 bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node); 154 // Only the parameters of root graph are persistent and fetched from the store, the parameters of sub graphs are not 155 // persistent and real parameters passed. 156 bool IsRootGraphPersistentDeviceTensor(const AnfNodePtr &node); 157 bool IsRecursionCallNode(const AnfNodePtr &node); 158 bool IsNeedStackControlNode(const AnfNodePtr &node); 159 // If there is a recursive call node in the input of the kernel graph, the graph is recursive. 160 bool IsRecursionKernelGraph(const KernelGraphPtr &graph); 161 bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph); 162 bool IsInputInSameLevel(const AnfNodePtr &node); 163 // If the two input call nodes will call the same recursion graph in same time. 164 bool IsParallelCallRecursionGraph(const AnfNodePtr &call_node1, const AnfNodePtr &call_node2, 165 const FuncGraphToCallNode &func_graph_to_call_node); control_node_parameters()166 const std::vector<KernelWithIndex> &control_node_parameters() const { return control_node_parameters_; } front_to_backend_parameters()167 const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; } front_value_nodes()168 const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; } 169 170 // Fetch all funcgraphs that the call node may call. 171 const std::set<FuncGraphPtr> &FetchFuncGraphbyCallNode(const AnfNodePtr &control_node); 172 // Fetch the branch id corresponding to funcgraph. 173 int FetchBranchIDByCallNode(const AnfNodePtr &call_node); 174 // Fetch the kernel graph which the kernel belongs. 175 KernelGraphPtr FetchKernelGraphByFrontNode(const AnfNodePtr &kernel); 176 KernelWithIndex FetchBackendOutputByKernelGraph(const KernelGraphPtr &graph, 177 const KernelWithIndex &front_node_with_index); 178 // Fetch the backend kernel of front node. 179 KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index); 180 FuncGraphPtr FetchFuncGraphByKernelGraph(const KernelGraph *const graph); 181 std::string FetchGroupNameByKernelGraph(const KernelGraphPtr &graph); 182 NodeWithIndexToContext FetchBackendParameterWithContextByFrontParameter( 183 const KernelWithIndex &front_parameter_with_index); 184 // Create tensor for value like scalar or monad U. 185 tensor::TensorPtr CreateTensorForValue(const ValuePtr &value); AddControlNodeTensor(const tensor::TensorPtr & tensor)186 void AddControlNodeTensor(const tensor::TensorPtr &tensor) { (void)control_node_tensors_.emplace_back(tensor); } 187 void PrintParseInfo(); 188 189 private: 190 friend class GraphScheduler; 191 friend class ControlNodeScheduler; 192 // Collect all front value nodes. In the control flow, when the input of the switch actor is the value node, these 193 // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for 194 // them separately during initialization. 195 // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to. 196 void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, const DeviceContext *const default_context); 197 void CreateDeviceTensors(const std::vector<AnfNodePtr> &control_nodes, const DeviceContext *const default_context); 198 // Create branch id for all call node in the control flow. 199 void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes); 200 201 // Parse all the relationships between front parameters and backend parameters.The front parameters 202 // include two parts: 203 // 1. The parameter from kernel graph. 204 // 2. The parameter from control nodes. 205 void ParseFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs, 206 const std::vector<DeviceContext *> &device_contexts); 207 // The relationship between front parameters indicates that the parameter is directly used as the input of the 208 // funcgraph. There are two situations: 209 // 1. The parameter is used as the input of the call node, 210 // 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the 211 // subsequent call node. 212 void ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes); 213 // Recursively get all the real parameters corresponding to the formal parameters. 214 void ParseAllRealParameterByFormalParameter(const KernelWithIndex &formal_parameter, 215 const FormalToRealParameter &formal_to_real_parameters, 216 std::set<KernelWithIndex> *const total_real_parameters, 217 std::set<KernelWithIndex> *invalid_real_parameter); 218 // Get all the call nodes without a recursion call relation. 219 void ParseUnRecursionCallNode(); 220 221 // Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be 222 // copied between different device memories. The analysis steps: 223 // 1. Get the device context of the funcgraph parameter according to the device type of the kernel in the funcgraph. 224 // 2. Determine the type of device context output by funcgraph according to the call relationship of funcgrpah. 225 // 3. Determine the type of device context output for the real parameters on the partial nodes and call nodes. 226 void ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes, 227 const std::vector<KernelGraphPtr> &kernel_graphs, 228 const std::vector<DeviceContext *> &device_contexts, DeviceContext *default_context, 229 const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs); 230 void ParseDeviceContextForFuncGraph(const std::vector<KernelGraphPtr> &kernel_graphs, 231 const std::vector<DeviceContext *> &device_contexts, 232 DeviceContext *default_context, 233 const FuncGraphToKernelGraphGroup &func_graph_to_kernel_graphs); 234 void ParseDeviceContextForReturnNode(const DeviceContext *default_context); 235 void ParseDeviceContextForCallNode(const std::vector<AnfNodePtr> &control_nodes); 236 void ParseDeviceContextForPartialNode(const std::vector<AnfNodePtr> &control_nodes); 237 void FetchDeviceContextByNode(const std::vector<KernelWithIndex> &output_nodes, 238 std::vector<const DeviceContext *> *return_device_contexts, 239 const FuncGraphPtr &func_graph, const DeviceContext *default_context); 240 // In the actor model, when the funcgraph comes to an end temporarily, the exit of the funcgraph needs to notify 241 // the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all 242 // actors in the funcgraph that need to send control messages to the entrance. 243 // These node are control nodes without control node input in the topological sort of the funcgraph. 244 void ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes); 245 // Parse all funcgraphs that call nodes may call. 246 void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes); 247 248 // Get the relationship between the front and backend of the executable kernel in all kernel graphs. 249 void ParseFrontToBackendKernel(const std::vector<KernelGraphPtr> &graphs, 250 const std::vector<DeviceContext *> &device_contexts); 251 void ParseFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs); 252 // nodes and call nodes of the root funcgraph. 253 void ParseControlNodeParameter(const std::vector<AnfNodePtr> &control_nodes); 254 // Get the control nodes and kernel graphs which need to add a stack actor for them. 255 // When a control node or kernel graph has input that is a call node, you need to add a stack actor for it. 256 void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes); 257 bool IsCallNodeNeedStack(const AnfNodePtr &node); 258 void ParseKernelGraphGroup(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts); 259 // Parse the level of inputs and outputs of graphs and all control nodes. 260 void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes); 261 // Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level 262 // among them. 263 size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes); 264 // When there is the possibility of calling the same funcgraph at multiple places in the graph, the graph cannot 265 // be executed in parallel, and all call nodes need to be executed serially. 266 void InsertDependForParallelCall(const std::vector<AnfNodePtr> &control_nodes); 267 // When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device 268 // tensor needs to be created for it. 269 void CreateDeviceTensorForRootGraphParameter(DeviceContext *const default_context); 270 void ParseDynamicLenFormalParameter(const std::vector<AnfNodePtr> &control_nodes); 271 void ParseDynamicLenFormalParameterByCallNode(const AnfNodePtr &node); 272 void ParseDynamicLenFormalParameterByPartial(const AnfNodePtr &node); 273 // In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded 274 // in this map. 275 FuncGraphToKernelGraphGroup func_graph_to_kernel_graph_groups_; 276 // The kernel graph to which the front node belongs after the funcgraph is cut. 277 FrontNodeToKernelGraph front_node_to_kernel_graph_; 278 279 // The front to backend parameters is used to build and link the host data source actor in the control flow scenario. 280 FrontToBackendNodeWithContext front_to_backend_parameters_; 281 // Relationship between the front and backend of the executable kernel in all kernel graphs. 282 FrontToBackendKernelWithContext front_to_backend_kernels_; 283 284 // Relationship between formal parameters and real parameters. 285 FormalToRealParameter formal_to_real_parameters_; 286 RealToFormalParameter real_to_formal_parameters_; 287 288 // Branch id of funcgraph. 289 // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to 290 // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch 291 // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id 292 // to its output switch actor. 293 mindspore::HashMap<AnfNodePtr, int> call_node_to_branch_id_; 294 // Level indicates that the input of the node depends on the number of the recursive call node in the funcgraph. 295 // During graph scheduler, the input needs to be graded according to the input's dependence on the recursive call 296 // node, and according to this level, the lower-level inputs are pushed in the stack actor. When arranging, first 297 // sort the call nodes in the funcgraph according to their topological relationships, and then confirm the 298 // dependencies of other nodes on these call nodes in turn. 299 // For example, the dependencies are a -> b, b -> d, c -> d, where b is a call node, then the level of a and c is 0, 300 // and the level of bd is 1, then since d has inputs with different levels of b and c, it is necessary to add a 301 // stack to d. 302 mindspore::HashMap<AnfNodePtr, size_t> node_to_level_; 303 CallNodeToFuncGraph call_node_to_func_graphs_; 304 FuncGraphToCallNode func_graph_to_call_nodes_; 305 // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the 306 // input of the control node. 307 NodeWithDeviceContext front_value_nodes_; 308 309 // Parameters of control node which come from the host actor. 310 std::vector<KernelWithIndex> control_node_parameters_; 311 // The kernel graph of call exists in the front input node. 312 // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a stack actor 313 // is created for kernel graph which has a call input. 314 std::set<KernelGraph *> call_input_kernel_graphs_; 315 std::set<KernelGraphGroupInfoPtr> kernel_graph_group_infos_; 316 // Control nodes without a control node input in the topological sorting of funcgraph. 317 mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_; 318 // Kernel graphs need to link a control arrow to its entrance actor. 319 // In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the 320 // entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph 321 // to the entrance actor. 322 mindspore::HashMap<FuncGraphPtr, std::set<KernelGraphGroupInfoPtr>> func_graph_to_first_kernel_graphs_; 323 // Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node 324 // belong. 325 std::set<AnfNodePtr> unrecursion_call_nodes_; 326 // Those control nodes that need to create the corresponding stack actor, when there is a call node in the inputs 327 // of the control node, the stack actor is needed to collect these inputs. 328 std::set<AnfNodePtr> need_stack_control_nodes_; 329 330 // In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding 331 // to the type needs to be parsed in advance so that it can add some copy operation in the scheduler. 332 // 1. The device context type of the formal parameters of funcgraph. 333 mindspore::HashMap<FuncGraphPtr, std::vector<const DeviceContext *>> func_graph_to_device_contexts_; 334 // 2. The device context type of the control node inputs. 335 mindspore::HashMap<AnfNodePtr, std::vector<const DeviceContext *>> control_node_to_device_contexts_; 336 337 // Kernel graph to the group info it belongs. 338 mindspore::HashMap<KernelGraphPtr, KernelGraphGroupInfoPtr> kernel_graphs_to_group_info_; 339 // Scalar value will be convert to tensor in control flow, these tensors are placed in the vector. 340 std::vector<tensor::TensorPtr> control_node_tensors_; 341 // The index of the argument that needs to be converted into a dynamic len sequence. 342 ReturnDynamicLenArgIndex return_to_call_with_dynamic_sequence_index_; 343 ControlNodeDynamicLenArgIndex control_node_to_funcgraph_with_dynamic_sequence_index_; 344 // Is control flow enable. 345 bool is_inited_; 346 347 // Root funcgraph and its parameters. 348 FuncGraphPtr root_func_graph_; 349 std::vector<AnfNodePtr> root_graph_parameters_; 350 }; 351 352 using ControlNodeParserPtr = std::shared_ptr<ControlNodeParser>; 353 } // namespace runtime 354 } // namespace mindspore 355 356 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_CONTROL_NODE_PARSER_H_ 357