• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <unordered_map>
24 #include <stack>
25 #include <utility>
26 #include <algorithm>
27 #include "runtime/framework/device_tensor_store.h"
28 #include "runtime/framework/actor/actor_common.h"
29 #include "runtime/framework/control_node_parser.h"
30 #include "runtime/hardware/device_context.h"
31 #include "backend/session/anf_runtime_algorithm.h"
32 #include "backend/session/kernel_graph.h"
33 #include "ir/tensor.h"
34 
35 namespace mindspore {
36 namespace runtime {
37 
38 constexpr size_t kReturnInputPos = 1;
39 
40 // Gather actor is used in three places:
41 // 1. Entrance of sub funcgraph
42 // 2. call node which input0 is a funcgraph
43 // 3. There is some call nodes in the inputs of kernel graph.
44 // Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put
45 // together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data.
46 // Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be
47 // collected at the entrance of the kernel graph.
48 class GatherActor : public OpActor<DeviceTensor> {
49  public:
GatherActor(const std::string & name,const std::vector<KernelWithIndex> & parameters,const bool need_branch_id_input,const AID switch_aid,const AID gather_aid,const int branch_id)50   GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const bool need_branch_id_input,
51               const AID switch_aid, const AID gather_aid, const int branch_id)
52       : OpActor(name),
53         data_nodes_(parameters),
54         need_branch_id_input_(need_branch_id_input),
55         switch_aid_(switch_aid),
56         gather_aid_(gather_aid),
57         local_branch_id_(branch_id) {
58     device_contexts_.resize(parameters.size());
59   }
60   ~GatherActor() override = default;
61 
62   // Get the index of the parameter, the data_node needs to be the front node.
63   size_t FetchDataNodePosition(const KernelWithIndex &data_node) const;
64 
65   // The gather actor run when receive the input data.
66   void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
67   // The gather actor run when receive the input control.
68   void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
69   // The gather actor run when receive the input branch id.
70   void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *const context);
71   void Init() override;
72 
73  private:
74   friend class GraphScheduler;
75 
76   // Collect the inputs of gather actor.
77   void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser);
78   void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
79   // Check whether satisfy the condition for launch.
80   bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
81   void SendOutput(OpContext<DeviceTensor> *const context) const;
82   // Erase input data and input controls when finish gather launch.
83   void EraseInput(OpContext<DeviceTensor> *const context);
84 
85   // The device tensors for launch.
86   std::vector<DeviceTensor *> input_device_tensors_;
87   // The branch if for current step.
88   int input_branch_id_{kInvalidBranchID};
89 
90   // Input data.
91   std::unordered_map<int, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_;
92   // Input branch ids is used to record the id corresponding receive from gather actor.
93   // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different
94   // places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send
95   // its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the
96   // switch actor connected to the output.
97   std::unordered_map<int, int> input_branch_ids_;
98 
99   // Output data.
100   // Cache unique output data by output index to modify the output data effectively.
101   std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
102   //  The output_data_ corresponds to the output_data_arrows_ one by one.
103   std::vector<OpData<DeviceTensor> *> output_data_;
104 
105   // Output arrows.
106   std::vector<DataArrowPtr> output_result_arrows_;
107   std::vector<AID> output_branch_arrows_;
108 
109   // Parameters of sub funcgraph, which is the front node.
110   std::vector<KernelWithIndex> data_nodes_;
111   std::vector<DeviceContext *> device_contexts_;
112   // Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
113   std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
114 
115   // When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor,
116   // so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent
117   // to the output actor, the corresponding backend node will be found from the map.
118   std::unordered_map<AnfNodePtr, std::vector<KernelWithIndex>> front_to_backend_parameter_;
119 
120   // The dependent input data number.
121   size_t input_datas_num_{0};
122   // The dependent input controls number.
123   size_t input_controls_num_{0};
124   // Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive
125   // branch id sent by the subgraph caller, which will be true at this time.
126   bool need_branch_id_input_;
127 
128   // Actor id that needs to send the branch id to it.
129   // When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output
130   // switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is
131   // empty, just need to send branch id to its output switch actor.
132   const AID switch_aid_;
133   const AID gather_aid_;
134 
135   // The branch id corresponding to the funcgraph to which the gather actor belongs.
136   int local_branch_id_;
137 };
138 
139 using GatherActorPtr = std::shared_ptr<GatherActor>;
140 }  // namespace runtime
141 }  // namespace mindspore
142 
143 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_GATHER_ACTOR_H_
144