1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONTROL_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONTROL_ACTOR_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <map> 24 #include <set> 25 #include <unordered_map> 26 #include <stack> 27 #include <queue> 28 #include <utility> 29 #include <algorithm> 30 #include "runtime/graph_scheduler/actor/actor_common.h" 31 #include "runtime/graph_scheduler/actor/abstract_actor.h" 32 #include "runtime/graph_scheduler/actor/memory_aware_actor.h" 33 #include "runtime/graph_scheduler/actor/memory_manager_actor.h" 34 35 namespace mindspore { 36 namespace runtime { 37 // Op partial represents the partial structure, including a funcgraph and its real parameters, maybe device tensors 38 // or partials. 39 struct OpPartial; 40 using OpPartialPtr = std::shared_ptr<OpPartial>; 41 struct OpPartial { 42 FuncGraph *func_graph_{nullptr}; 43 std::vector<std::pair<size_t, DeviceTensor *>> device_tensors_; 44 std::vector<std::pair<size_t, OpPartialPtr>> partials_; 45 }; 46 47 // Op real parameters with branch ID represents the data sent by gather actor to entrance actor, including all real 48 // parameters and the id of the caller. 49 struct OpRealParameterWithBranchID { 50 std::vector<std::pair<size_t, DeviceTensor *>> device_tensors_; 51 std::vector<std::pair<size_t, OpPartialPtr>> partials_; 52 int branch_id_; 53 }; 54 // The control actor is the base class of control flow actor. 55 class ControlActor : public MemoryAwareActor { 56 public: 57 ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid, 58 const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node); 59 ~ControlActor() override = default; 60 61 // Receive partial. 62 virtual void RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context); 63 64 // Receive branch id. 65 virtual void RunBranchID(int branch_id, OpContext<DeviceTensor> *const context); 66 output_partial_arrows()67 const std::vector<DataArrowPtr> &output_partial_arrows() const { return output_partial_arrows_; } output_branch_id_arrows()68 const std::vector<AID> &output_branch_id_arrows() const { return output_branch_id_arrows_; } local_partials()69 const std::unordered_map<size_t, OpPartialPtr> &local_partials() const { return local_partials_; } input_partial_arrow_aids()70 const std::vector<AID> &input_partial_arrow_aids() const { return input_partial_arrow_aids_; } input_branch_id_arrow_aids()71 const std::vector<AID> &input_branch_id_arrow_aids() const { return input_branch_id_arrow_aids_; } ref_formal_parameter_device_tensors()72 const std::map<size_t, std::set<DeviceTensorPtr>> &ref_formal_parameter_device_tensors() const { 73 return ref_formal_parameter_device_tensors_; 74 } ref_node_formal_parameter_device_tensors()75 const std::map<size_t, std::set<DeviceTensorPtr>> &ref_node_formal_parameter_device_tensors() const { 76 return ref_node_formal_parameter_device_tensors_; 77 } branch_id()78 int branch_id() const { return output_branch_id_; } 79 // Free memory by the dynamic ref count decremented. It corresponds to the EraseInput. 80 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; 81 set_start_time(double start_time)82 void set_start_time(double start_time) { start_time_ = start_time; } node()83 const AnfNodePtr &node() const { return node_; } 84 85 protected: 86 friend class ControlNodeScheduler; 87 friend class SchedulerHelper; 88 friend class MemSwapScheduler; 89 90 void Init() override; 91 92 // The basic interfaces for op partial and op real parameter. 93 void GetAllDeviceTensors(const OpPartialPtr &op_partial, std::vector<DeviceTensor *> *device_tensors); 94 void GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter, 95 std::vector<DeviceTensor *> *device_tensors); 96 void IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data) const; 97 void IncreaseDynamicRefCount(const OpPartialPtr &op_partial); 98 void IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter); 99 100 // Get the position of node in the input. 101 size_t FetchNodePosition(const KernelWithIndex &node) const; 102 103 // Get all input, including data, partial, branchid. 104 virtual void FetchInput(OpContext<DeviceTensor> *const context); 105 void Run(OpContext<DeviceTensor> *const context) override; 106 bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; 107 void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow, 108 const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override; 109 110 void SendOutput(OpContext<DeviceTensor> *const context) override; 111 void EraseInput(const OpContext<DeviceTensor> *context) override; 112 113 // Increase the dynamic ref count by the outputs. It corresponds to the SendOutput. 114 virtual void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context); 115 void MergeDeviceAddress(OpContext<DeviceTensor> *const context, const std::vector<DeviceTensor *> &addr_list, 116 DeviceTensor **deivce_tensor); 117 void MergeEmptyAddressDeviceAddress(OpContext<DeviceTensor> *const context, 118 const std::vector<DeviceTensor *> &addr_list, DeviceTensor **device_tensor); 119 120 // Input data. 121 // 1.Input partial. 122 // Record the partial received by each step, the key of the pair indicates the location of the partial. 123 std::unordered_map<int, std::vector<std::pair<size_t, OpPartialPtr>>> input_op_partials_; 124 // 2. Branch ids is used to record the id corresponding to the output branch. 125 // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different 126 // places. Therefore, the output of each subgraph will be connected to a exit actor, and the caller will send 127 // its branch id to the entrance actor of the subgraph. Then branch id will be sent by the entrance actor to 128 // the exit actor connected to the output. 129 // In a recursive scenario, the exit will sequentially receive the branch ids sent by the caller, and the exit 130 // actor needs to store the branch ids in the stack, and pop up in turn when returning. 131 std::unordered_map<int, std::stack<int>> input_branch_ids_; 132 133 // Fetch data. After fetch input, all the input collected is saved here. 134 std::vector<OpPartialPtr> input_partials_; 135 std::vector<DeviceTensor *> input_device_tensors_; 136 137 // The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step. 138 std::queue<std::vector<DeviceTensor *>> memory_free_lists_; 139 140 // The exit actor needs to create a new device address and take out the ptr from the device tensor come from 141 // the kernel actor. These new created device tensors are stored in the created device tensors. 142 std::vector<DeviceTensorPtr> created_device_tensors_; 143 std::vector<DeviceTensorPtr> last_step_created_device_tensors_; 144 // In control flow, when the argument is not a dynamic len tuple but the parameter is, need create a new 145 // real make tuple node for it. 146 std::vector<FuncGraphPtr> created_new_graphs_; 147 std::vector<AnfNodePtr> created_new_nodes_; 148 // Input num. 149 size_t input_partials_num_{0}; 150 size_t input_branch_ids_num_{0}; 151 152 // The dependent input actors. 153 std::vector<AID> input_partial_arrow_aids_; 154 std::vector<AID> input_branch_id_arrow_aids_; 155 156 // Output Arrows. 157 std::vector<DataArrowPtr> output_partial_arrows_; 158 159 std::vector<AID> output_branch_id_arrows_; 160 // The branch id is the unique identifier of the control actor. In the control flow, there are multiple control 161 // actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned 162 // to the calling place according to the branch id. 163 int output_branch_id_{0}; 164 165 // Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor. 166 std::unordered_map<size_t, OpPartialPtr> local_partials_; 167 // Device tensor in control node, but not in kernel graph. 168 std::unordered_map<size_t, DeviceTensor *> local_device_tensors_; 169 170 // Cache output data by output index to modify the output data effectively. 171 std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_; 172 173 // Formal parameters for control actor. 174 std::vector<KernelWithIndex> formal_parameters_; 175 // The device tensors of backend input nodes corresponding to ref formal parameters, the key is the position index of 176 // formal parameter. Used to update the ptr of device tensors when receive the real parameters for ref nodes. 177 std::map<size_t, std::set<DeviceTensorPtr>> ref_formal_parameter_device_tensors_; 178 std::map<size_t, std::set<DeviceTensorPtr>> ref_node_formal_parameter_device_tensors_; 179 180 // Backend parameters in the kernel graph.In the dynamic shape, when parameters are passed between the kernel 181 // graphs, the shape in the backend parameters needs to be updated. 182 std::vector<std::vector<AnfNodePtr>> backend_parameters_; 183 184 // Count the time cost bewtween this actor to the end actors, when this actor is executed, set current time to the 185 // start_time_ of the end actors and then when the end actors are executed, it will count the time cost between its 186 // start_time_ and its current time, for example, set exit actor of kernel graph to its entrance actor to count the 187 // execution time of the kernel graph. 188 std::set<ControlActor *> end_actors_; 189 double start_time_{0}; 190 191 // local node for control actor, such as return node for exit actor, switch node for switch actor. 192 AnfNodePtr node_; 193 }; 194 195 using ControlActorPtr = std::shared_ptr<ControlActor>; 196 } // namespace runtime 197 } // namespace mindspore 198 199 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONTROL_ACTOR_H_ 200