1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <unordered_map> 24 #include <stack> 25 #include <utility> 26 #include <algorithm> 27 #include "runtime/framework/device_tensor_store.h" 28 #include "runtime/framework/actor/actor_common.h" 29 #include "runtime/framework/control_node_parser.h" 30 #include "runtime/hardware/device_context.h" 31 #include "backend/session/anf_runtime_algorithm.h" 32 #include "backend/session/kernel_graph.h" 33 #include "ir/tensor.h" 34 35 namespace mindspore { 36 namespace runtime { 37 38 constexpr size_t kReturnInputPos = 1; 39 40 // Gather actor is used in three places: 41 // 1. Entrance of sub funcgraph 42 // 2. call node which input0 is a funcgraph 43 // 3. There is some call nodes in the inputs of kernel graph. 44 // Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put 45 // together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data. 46 // Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be 47 // collected at the entrance of the kernel graph. 48 class GatherActor : public OpActor<DeviceTensor> { 49 public: GatherActor(const std::string & name,const std::vector<KernelWithIndex> & parameters,const bool need_branch_id_input,const AID switch_aid,const AID gather_aid,const int branch_id)50 GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const bool need_branch_id_input, 51 const AID switch_aid, const AID gather_aid, const int branch_id) 52 : OpActor(name), 53 data_nodes_(parameters), 54 need_branch_id_input_(need_branch_id_input), 55 switch_aid_(switch_aid), 56 gather_aid_(gather_aid), 57 local_branch_id_(branch_id) { 58 device_contexts_.resize(parameters.size()); 59 } 60 ~GatherActor() override = default; 61 62 // Get the index of the parameter, the data_node needs to be the front node. 63 size_t FetchDataNodePosition(const KernelWithIndex &data_node) const; 64 65 // The gather actor run when receive the input data. 66 void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override; 67 // The gather actor run when receive the input control. 68 void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override; 69 // The gather actor run when receive the input branch id. 70 void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context); 71 void Init() override; 72 73 private: 74 friend class GraphScheduler; 75 76 // Collect the inputs of gather actor. 77 void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser); 78 void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context); 79 // Check whether satisfy the condition for launch. 80 bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const; 81 void SendOutput(OpContext<DeviceTensor> *const context) const; 82 // Erase input data and input controls when finish gather launch. 83 void EraseInput(OpContext<DeviceTensor> *const context); 84 85 // The device tensors for launch. 86 std::vector<DeviceTensor *> input_device_tensors_; 87 // The branch if for current step. 88 int input_branch_id_{kInvalidBranchID}; 89 90 // Input data. 91 std::unordered_map<int, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_; 92 // Input branch ids is used to record the id corresponding receive from gather actor. 93 // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different 94 // places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send 95 // its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the 96 // switch actor connected to the output. 97 std::unordered_map<int, int> input_branch_ids_; 98 99 // Output data. 100 // Cache unique output data by output index to modify the output data effectively. 101 std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_; 102 // The output_data_ corresponds to the output_data_arrows_ one by one. 103 std::vector<OpData<DeviceTensor> *> output_data_; 104 105 // Output arrows. 106 std::vector<DataArrowPtr> output_result_arrows_; 107 std::vector<AID> output_branch_arrows_; 108 109 // Parameters of sub funcgraph, which is the front node. 110 std::vector<KernelWithIndex> data_nodes_; 111 std::vector<DeviceContext *> device_contexts_; 112 // Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store. 113 std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_; 114 115 // When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor, 116 // so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent 117 // to the output actor, the corresponding backend node will be found from the map. 118 std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> front_to_backend_parameter_; 119 120 // The dependent input data number. 121 size_t input_datas_num_{0}; 122 // The dependent input controls number. 123 size_t input_controls_num_{0}; 124 // Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive 125 // branch id sent by the subgraph caller, which will be true at this time. 126 bool need_branch_id_input_; 127 128 // Actor id that needs to send the branch id to it. 129 // When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output 130 // switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is 131 // empty, just need to send branch id to its output switch actor. 132 const AID switch_aid_; 133 const AID gather_aid_; 134 135 // The branch id corresponding to the funcgraph to which the gather actor belongs. 136 int local_branch_id_; 137 }; 138 139 using GatherActorPtr = std::shared_ptr<GatherActor>; 140 } // namespace runtime 141 } // namespace mindspore 142 143 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_ 144