1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_ 19 20 #include <vector> 21 #include <string> 22 #include <set> 23 #include <memory> 24 #include <utility> 25 #include <stack> 26 #include <unordered_map> 27 #include "runtime/framework/actor/actor_common.h" 28 #include "runtime/framework/device_tensor_store.h" 29 #include "runtime/framework/control_node_parser.h" 30 #include "mindrt/include/actor/switch_actor.h" 31 #include "runtime/hardware/device_context.h" 32 33 namespace mindspore { 34 namespace runtime { 35 using mindspore::device::DeviceContext; 36 using mindspore::session::KernelWithIndex; 37 38 constexpr size_t kSwitchInputNum = 4; 39 constexpr size_t kSwitchCondPos = 1; 40 constexpr size_t kSwitchPartialNum = 2; 41 constexpr size_t kSwitchLayerCondPos = 1; 42 constexpr size_t kSwitchLayerBranchPos = 2; 43 constexpr size_t kSwitchLayerInputNum = 3; 44 constexpr size_t kMaxSwitchCondSize = 8; 45 constexpr size_t kSwitchTrueBranchPos = 2; 46 constexpr size_t kSwitchFalseBranchPos = 3; 47 constexpr size_t kPartialFuncGraphPos = 1; 48 constexpr size_t kPartialInputStartPos = 2; 49 constexpr size_t kCallInputStartPos = 1; 50 constexpr size_t kMakeTupleInputStartPos = 1; 51 52 // Switch actor is used to execute the branch according to the input condition. 53 // Switch and SwitchLayer node will be converted to switch actor. 54 // The execution process is divided into: 55 // 1. Put input into the vector. 56 // 2. Check whether the input condition has been received. 57 // 3. Check whether all input from the branch corresponding to the index has been received. 58 // 4. Send the data to the corresponding branch. 59 // 5. Free Memory 60 class SwitchActor : public SwitchActorBase<DeviceTensor> { 61 public: SwitchActor(const std::string & name,DeviceContext * device_context,const CNodePtr & node,const int branch_id,const bool need_branch_id_input)62 SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id, 63 const bool need_branch_id_input) 64 : SwitchActorBase(name), 65 device_context_(device_context), 66 node_(node), 67 local_branch_id_(branch_id), 68 need_branch_id_input_(need_branch_id_input) {} 69 ~SwitchActor() override = default; 70 71 void Init() override; 72 73 // The switch actor run when receive the input data. 74 void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context); 75 // The switch actor run when receive the input control. 76 void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context); 77 // The switch actor run when receive the input branch id. 78 void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context); 79 // Parse the input node information of the switch actor according to node_. 80 void ParseInput(const ControlNodeParserPtr &parser); 81 // Add input for all branches. 82 void AddCommonInput(const AnfNodePtr &node); AddSingleInput(const AnfNodePtr & node,size_t branch)83 void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); } 84 // Fetch the input position of the data node. 85 size_t FetchDataNodePosition(const AnfNodePtr &data_node) const; 86 87 private: 88 friend class GraphScheduler; 89 90 void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id); 91 void ParseSwitchInput(); 92 void ParseSwitchLayerInput(); 93 // In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is 94 // initialized with the return node of the subgraph. 95 void ParseReturnInput(const ControlNodeParserPtr &parser); 96 // Initialize the size of the vector members. 97 void InitVectorSize(const size_t num); 98 // Get index from DeviceTensor. 99 size_t GetIndex(const OpContext<DeviceTensor> *const context); 100 // Add input for the branch. 101 void AddInput(const AnfNodePtr &node, size_t branch); 102 void AddInput(const KernelWithIndex node_with_index, const size_t branch); 103 104 // Check whether satisfy the condition for send outputs. 105 bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const; 106 // Fetch the args of switch branch. 107 void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context); 108 void SendOutput(OpContext<DeviceTensor> *const context); 109 // Erase input data and input controls when finish switch launch. 110 void EraseInput(OpContext<DeviceTensor> *const context); 111 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context); 112 113 // Collect all the backend inputs of switch actor. 114 void FetchInputNode(const ControlNodeParserPtr &parser); 115 // All inputs of the switch actor, include weight and tensor. 116 // Used to receive input data, the first input is the condition of switch. 117 std::vector<KernelWithIndex> input_nodes_; 118 // The position of the branch output in the input_nodes_. 119 std::vector<std::vector<size_t>> branch_inputs_pos_; 120 121 std::unordered_map<int, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_; 122 123 std::unordered_map<int, std::unordered_map<AID *, size_t>> input_controls_; 124 125 // Branch ids is used to record the id corresponding to the switch output branch. 126 // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different 127 // places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send 128 // its branch id to the gather of the subgraph. Then branch id will be sent by the gather actor to the switch 129 // actor connected to the output. 130 // In a recursive scenario, the switch will sequentially receive the branch ids sent by the caller, and the switch 131 // actor needs to store the branch ids in the stack, and pop up in turn when returning. 132 std::unordered_map<int, std::stack<int>> input_branch_ids_; 133 134 // Control arrows of different branches. 135 std::vector<std::vector<AID>> output_branch_control_arrows_; 136 // Branch id arrows of different branches. 137 std::vector<std::vector<AID>> output_branch_branch_arrows_; 138 // Result arrows of different branches. 139 std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_; 140 141 // When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor, 142 // so all the nodes that may send the device tensor to switch actor are recorded. 143 std::vector<std::set<KernelWithIndex>> backend_parameters_; 144 std::vector<std::vector<AnfNodePtr>> branch_total_inputs_; 145 146 std::vector<FuncGraphPtr> branch_func_graph_; 147 148 std::unordered_map<int, size_t> branch_id_to_index_; 149 150 // Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store. 151 std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_; 152 153 std::vector<DeviceTensor *> input_device_tensors_; 154 155 // Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor. 156 const DeviceContext *device_context_; 157 158 // The id of memory manager actor. Send message to it for alloc and free memory. 159 const AID memory_manager_aid_; 160 // The dependent input data number. 161 size_t input_datas_num_{0}; 162 // The dependent input controls number. 163 size_t input_controls_num_{0}; 164 CNodePtr node_; 165 166 // The branch id corresponding to the funcgraph to which the gather actor belongs. 167 int local_branch_id_; 168 // Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive 169 // branch id sent by the gather actor of subgraph, which will be true at this time. 170 bool need_branch_id_input_; 171 172 // The output_data_ corresponds to the output_data_arrows_ one by one. 173 std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_; 174 }; 175 176 using SwitchActorPtr = std::shared_ptr<SwitchActor>; 177 } // namespace runtime 178 } // namespace mindspore 179 180 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_ 181