• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONTROL_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONTROL_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <map>
24 #include <set>
25 #include <unordered_map>
26 #include <stack>
27 #include <queue>
28 #include <utility>
29 #include <algorithm>
30 #include "runtime/graph_scheduler/actor/actor_common.h"
31 #include "runtime/graph_scheduler/actor/abstract_actor.h"
32 #include "runtime/graph_scheduler/actor/memory_aware_actor.h"
33 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
34 
35 namespace mindspore {
36 namespace runtime {
37 // Op partial represents the partial structure, including a funcgraph and its real parameters, maybe device tensors
38 // or partials.
39 struct OpPartial;
40 using OpPartialPtr = std::shared_ptr<OpPartial>;
41 struct OpPartial {
42   FuncGraph *func_graph_{nullptr};
43   std::vector<std::pair<size_t, DeviceTensor *>> device_tensors_;
44   std::vector<std::pair<size_t, OpPartialPtr>> partials_;
45 };
46 
47 // Op real parameters with branch ID represents the data sent by gather actor to entrance actor, including all real
48 // parameters and the id of the caller.
49 struct OpRealParameterWithBranchID {
50   std::vector<std::pair<size_t, DeviceTensor *>> device_tensors_;
51   std::vector<std::pair<size_t, OpPartialPtr>> partials_;
52   int branch_id_;
53 };
54 // The control actor is the base class of control flow actor.
55 class ControlActor : public MemoryAwareActor {
56  public:
57   ControlActor(const std::string &name, KernelTransformType type, const AID &memory_manager_aid,
58                const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node);
59   ~ControlActor() override = default;
60 
61   // Receive partial.
62   virtual void RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context);
63 
64   // Receive branch id.
65   virtual void RunBranchID(int branch_id, OpContext<DeviceTensor> *const context);
66 
output_partial_arrows()67   const std::vector<DataArrowPtr> &output_partial_arrows() const { return output_partial_arrows_; }
output_branch_id_arrows()68   const std::vector<AID> &output_branch_id_arrows() const { return output_branch_id_arrows_; }
local_partials()69   const std::unordered_map<size_t, OpPartialPtr> &local_partials() const { return local_partials_; }
input_partial_arrow_aids()70   const std::vector<AID> &input_partial_arrow_aids() const { return input_partial_arrow_aids_; }
input_branch_id_arrow_aids()71   const std::vector<AID> &input_branch_id_arrow_aids() const { return input_branch_id_arrow_aids_; }
ref_formal_parameter_device_tensors()72   const std::map<size_t, std::set<DeviceTensorPtr>> &ref_formal_parameter_device_tensors() const {
73     return ref_formal_parameter_device_tensors_;
74   }
ref_node_formal_parameter_device_tensors()75   const std::map<size_t, std::set<DeviceTensorPtr>> &ref_node_formal_parameter_device_tensors() const {
76     return ref_node_formal_parameter_device_tensors_;
77   }
branch_id()78   int branch_id() const { return output_branch_id_; }
79   // Free memory by the dynamic ref count decremented. It corresponds to the EraseInput.
80   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
81 
set_start_time(double start_time)82   void set_start_time(double start_time) { start_time_ = start_time; }
node()83   const AnfNodePtr &node() const { return node_; }
84 
85  protected:
86   friend class ControlNodeScheduler;
87   friend class SchedulerHelper;
88   friend class MemSwapScheduler;
89 
90   void Init() override;
91 
92   // The basic interfaces for op partial and op real parameter.
93   void GetAllDeviceTensors(const OpPartialPtr &op_partial, std::vector<DeviceTensor *> *device_tensors);
94   void GetAllDeviceTensors(const OpRealParameterWithBranchID &op_real_parameter,
95                            std::vector<DeviceTensor *> *device_tensors);
96   void IncreaseDynamicRefCount(const OpData<DeviceTensor> *op_data) const;
97   void IncreaseDynamicRefCount(const OpPartialPtr &op_partial);
98   void IncreaseDynamicRefCount(const OpRealParameterWithBranchID &op_real_parameter);
99 
100   // Get the position of node in the input.
101   size_t FetchNodePosition(const KernelWithIndex &node) const;
102 
103   // Get all input, including data, partial, branchid.
104   virtual void FetchInput(OpContext<DeviceTensor> *const context);
105   void Run(OpContext<DeviceTensor> *const context) override;
106   bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
107   void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
108                         const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override;
109 
110   void SendOutput(OpContext<DeviceTensor> *const context) override;
111   void EraseInput(const OpContext<DeviceTensor> *context) override;
112 
113   // Increase the dynamic ref count by the outputs. It corresponds to the SendOutput.
114   virtual void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context);
115   void MergeDeviceAddress(OpContext<DeviceTensor> *const context, const std::vector<DeviceTensor *> &addr_list,
116                           DeviceTensor **deivce_tensor);
117   void MergeEmptyAddressDeviceAddress(OpContext<DeviceTensor> *const context,
118                                       const std::vector<DeviceTensor *> &addr_list, DeviceTensor **device_tensor);
119 
120   // Input data.
121   // 1.Input partial.
122   // Record the partial received by each step, the key of the pair indicates the location of the partial.
123   std::unordered_map<int, std::vector<std::pair<size_t, OpPartialPtr>>> input_op_partials_;
124   // 2. Branch ids is used to record the id corresponding to the output branch.
125   // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
126   // places. Therefore, the output of each subgraph will be connected to a exit actor, and the caller will send
127   // its branch id to the entrance actor of the subgraph. Then branch id will be sent by the entrance actor to
128   // the exit actor connected to the output.
129   // In a recursive scenario, the exit will sequentially receive the branch ids sent by the caller, and the exit
130   // actor needs to store the branch ids in the stack, and pop up in turn when returning.
131   std::unordered_map<int, std::stack<int>> input_branch_ids_;
132 
133   // Fetch data. After fetch input, all the input collected is saved here.
134   std::vector<OpPartialPtr> input_partials_;
135   std::vector<DeviceTensor *> input_device_tensors_;
136 
137   // The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step.
138   std::queue<std::vector<DeviceTensor *>> memory_free_lists_;
139 
140   // The exit actor needs to create a new device address and take out the ptr from the device tensor come from
141   // the kernel actor. These new created device tensors are stored in the created device tensors.
142   std::vector<DeviceTensorPtr> created_device_tensors_;
143   std::vector<DeviceTensorPtr> last_step_created_device_tensors_;
144   // In control flow, when the argument is not a dynamic len tuple but the parameter is, need create a new
145   // real make tuple node for it.
146   std::vector<FuncGraphPtr> created_new_graphs_;
147   std::vector<AnfNodePtr> created_new_nodes_;
148   // Input num.
149   size_t input_partials_num_{0};
150   size_t input_branch_ids_num_{0};
151 
152   // The dependent input actors.
153   std::vector<AID> input_partial_arrow_aids_;
154   std::vector<AID> input_branch_id_arrow_aids_;
155 
156   // Output Arrows.
157   std::vector<DataArrowPtr> output_partial_arrows_;
158 
159   std::vector<AID> output_branch_id_arrows_;
160   // The branch id is the unique identifier of the control actor. In the control flow, there are multiple control
161   // actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned
162   // to the calling place according to the branch id.
163   int output_branch_id_{0};
164 
165   // Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor.
166   std::unordered_map<size_t, OpPartialPtr> local_partials_;
167   // Device tensor in control node, but not in kernel graph.
168   std::unordered_map<size_t, DeviceTensor *> local_device_tensors_;
169 
170   // Cache output data by output index to modify the output data effectively.
171   std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
172 
173   // Formal parameters for control actor.
174   std::vector<KernelWithIndex> formal_parameters_;
175   // The device tensors of backend input nodes corresponding to ref formal parameters, the key is the position index of
176   // formal parameter. Used to update the ptr of device tensors when receive the real parameters for ref nodes.
177   std::map<size_t, std::set<DeviceTensorPtr>> ref_formal_parameter_device_tensors_;
178   std::map<size_t, std::set<DeviceTensorPtr>> ref_node_formal_parameter_device_tensors_;
179 
180   // Backend parameters in the kernel graph.In the dynamic shape, when parameters are passed between the kernel
181   // graphs, the shape in the backend parameters needs to be updated.
182   std::vector<std::vector<AnfNodePtr>> backend_parameters_;
183 
184   // Count the time cost bewtween this actor to the end actors, when this actor is executed, set current time to the
185   // start_time_ of the end actors and then when the end actors are executed, it will count the time cost between its
186   // start_time_ and its current time, for example, set exit actor of kernel graph to its entrance actor to count the
187   // execution time of the kernel graph.
188   std::set<ControlActor *> end_actors_;
189   double start_time_{0};
190 
191   // local node for control actor, such as return node for exit actor, switch node for switch actor.
192   AnfNodePtr node_;
193 };
194 
195 using ControlActorPtr = std::shared_ptr<ControlActor>;
196 }  // namespace runtime
197 }  // namespace mindspore
198 
199 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONTROL_ACTOR_H_
200