• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <set>
23 #include <memory>
24 #include <utility>
25 #include <stack>
26 #include <unordered_map>
27 #include "runtime/framework/actor/actor_common.h"
28 #include "runtime/framework/device_tensor_store.h"
29 #include "runtime/framework/control_node_parser.h"
30 #include "mindrt/include/actor/switch_actor.h"
31 #include "runtime/hardware/device_context.h"
32 
33 namespace mindspore {
34 namespace runtime {
35 using mindspore::device::DeviceContext;
36 using mindspore::session::KernelWithIndex;
37 
38 constexpr size_t kSwitchInputNum = 4;
39 constexpr size_t kSwitchCondPos = 1;
40 constexpr size_t kSwitchPartialNum = 2;
41 constexpr size_t kSwitchLayerCondPos = 1;
42 constexpr size_t kSwitchLayerBranchPos = 2;
43 constexpr size_t kSwitchLayerInputNum = 3;
44 constexpr size_t kMaxSwitchCondSize = 8;
45 constexpr size_t kSwitchTrueBranchPos = 2;
46 constexpr size_t kSwitchFalseBranchPos = 3;
47 constexpr size_t kPartialFuncGraphPos = 1;
48 constexpr size_t kPartialInputStartPos = 2;
49 constexpr size_t kCallInputStartPos = 1;
50 constexpr size_t kMakeTupleInputStartPos = 1;
51 
52 // Switch actor is used to execute the branch according to the input condition.
53 // Switch and SwitchLayer node will be converted to switch actor.
54 // The execution process is divided into:
55 // 1. Put input into the vector.
56 // 2. Check whether the input condition has been received.
57 // 3. Check whether all input from the branch corresponding to the index has been received.
58 // 4. Send the data to the corresponding branch.
59 // 5. Free Memory
60 class SwitchActor : public SwitchActorBase<DeviceTensor> {
61  public:
SwitchActor(const std::string & name,DeviceContext * device_context,const CNodePtr & node,const int branch_id,const bool need_branch_id_input)62   SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id,
63               const bool need_branch_id_input)
64       : SwitchActorBase(name),
65         device_context_(device_context),
66         node_(node),
67         local_branch_id_(branch_id),
68         need_branch_id_input_(need_branch_id_input) {}
69   ~SwitchActor() override = default;
70 
71   void Init() override;
72 
73   // The switch actor run when receive the input data.
74   void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
75   // The switch actor run when receive the input control.
76   void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
77   // The switch actor run when receive the input branch id.
78   void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context);
79   // Parse the input node information of the switch actor according to node_.
80   void ParseInput(const ControlNodeParserPtr &parser);
81   // Add input for all branches.
82   void AddCommonInput(const AnfNodePtr &node);
AddSingleInput(const AnfNodePtr & node,size_t branch)83   void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); }
84   // Fetch the input position of the data node.
85   size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
86 
87  private:
88   friend class GraphScheduler;
89 
90   void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id);
91   void ParseSwitchInput();
92   void ParseSwitchLayerInput();
93   // In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
94   // initialized with the return node of the subgraph.
95   void ParseReturnInput(const ControlNodeParserPtr &parser);
96   // Initialize the size of the vector members.
97   void InitVectorSize(const size_t num);
98   // Get index from DeviceTensor.
99   size_t GetIndex(const OpContext<DeviceTensor> *const context);
100   // Add input for the branch.
101   void AddInput(const AnfNodePtr &node, size_t branch);
102   void AddInput(const KernelWithIndex node_with_index, const size_t branch);
103 
104   // Check whether satisfy the condition for send outputs.
105   bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
106   // Fetch the args of switch branch.
107   void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
108   void SendOutput(OpContext<DeviceTensor> *const context);
109   // Erase input data and input controls when finish switch launch.
110   void EraseInput(OpContext<DeviceTensor> *const context);
111   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context);
112 
113   // Collect all the backend inputs of switch actor.
114   void FetchInputNode(const ControlNodeParserPtr &parser);
115   // All inputs of the switch actor, include weight and tensor.
116   // Used to receive input data, the first input is the condition of switch.
117   std::vector<KernelWithIndex> input_nodes_;
118   // The position of the branch output in the input_nodes_.
119   std::vector<std::vector<size_t>> branch_inputs_pos_;
120 
121   std::unordered_map<int, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
122 
123   std::unordered_map<int, std::unordered_map<AID *, size_t>> input_controls_;
124 
125   // Branch ids is used to record the id corresponding to the switch output branch.
126   // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
127   // places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
128   // its branch id to the gather of the subgraph. Then branch id will be sent by the gather actor to the switch
129   // actor connected to the output.
130   // In a recursive scenario, the switch will sequentially receive the branch ids sent by the caller, and the switch
131   // actor needs to store the branch ids in the stack, and pop up in turn when returning.
132   std::unordered_map<int, std::stack<int>> input_branch_ids_;
133 
134   // Control arrows of different branches.
135   std::vector<std::vector<AID>> output_branch_control_arrows_;
136   // Branch id arrows of different branches.
137   std::vector<std::vector<AID>> output_branch_branch_arrows_;
138   // Result arrows of different branches.
139   std::vector<std::vector<DataArrowPtr>> output_branch_result_arrows_;
140 
141   // When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor,
142   // so all the nodes that may send the device tensor to switch actor are recorded.
143   std::vector<std::set<KernelWithIndex>> backend_parameters_;
144   std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
145 
146   std::vector<FuncGraphPtr> branch_func_graph_;
147 
148   std::unordered_map<int, size_t> branch_id_to_index_;
149 
150   // Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
151   std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
152 
153   std::vector<DeviceTensor *> input_device_tensors_;
154 
155   // Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
156   const DeviceContext *device_context_;
157 
158   // The id of memory manager actor. Send message to it for alloc and free memory.
159   const AID memory_manager_aid_;
160   // The dependent input data number.
161   size_t input_datas_num_{0};
162   // The dependent input controls number.
163   size_t input_controls_num_{0};
164   CNodePtr node_;
165 
166   // The branch id corresponding to the funcgraph to which the gather actor belongs.
167   int local_branch_id_;
168   // Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive
169   // branch id sent by the gather actor of subgraph, which will be true at this time.
170   bool need_branch_id_input_;
171 
172   //  The output_data_ corresponds to the output_data_arrows_ one by one.
173   std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
174 };
175 
176 using SwitchActorPtr = std::shared_ptr<SwitchActor>;
177 }  // namespace runtime
178 }  // namespace mindspore
179 
180 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SWITCH_ACTOR_H_
181