• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <stack>
24 #include <queue>
25 #include <set>
26 #include <algorithm>
27 #include "utils/hash_map.h"
28 #include "runtime/graph_scheduler/actor/actor_common.h"
29 #include "runtime/graph_scheduler/actor/control_flow/control_actor.h"
30 
31 namespace mindspore {
32 namespace runtime {
33 // Entrance actor is used in the control flow to receive a set of result arrow and a branch id and then send
34 // the data to the corresponding actor. It is the entry point for subgraph execution.
35 class EntranceActor : public ControlActor {
36  public:
EntranceActor(const std::string & name,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters,const std::set<KernelWithIndex> & call_nodes,const AnfNodePtr & node)37   EntranceActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters,
38                 const std::set<KernelWithIndex> &call_nodes, const AnfNodePtr &node)
39       : ControlActor(name, KernelTransformType::kEntranceActor, memory_manager_aid, parameters, node),
40         call_nodes_(call_nodes) {
41     device_contexts_.resize(parameters.size());
42     input_device_tensors_.resize(parameters.size());
43   }
44   ~EntranceActor() override = default;
45 
46   void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
47 
48   void RunOpRealParameterWithBranchID(const OpRealParameterWithBranchID &real_parameter_with_branch_id,
49                                       OpContext<DeviceTensor> *const context);
50 
51   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
52 
53   // Clear the data which are generated in the loop body execution.
54   void ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context);
55 
loop_body_input_control_arrow_aids()56   const std::vector<AID> &loop_body_input_control_arrow_aids() const { return loop_body_input_control_arrow_aids_; }
57 
58  protected:
59   void Run(OpContext<DeviceTensor> *const context) override;
60   void FetchInput(OpContext<DeviceTensor> *const context) override;
61   bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
62   void EraseInput(const OpContext<DeviceTensor> *const context) override;
63 
64  private:
65   friend class ControlNodeScheduler;
66   friend class MemorySwapNodeScheduler;
67   friend class SchedulerHelper;
68 
69   // Indicate whether the entrance actor is the execution of loop body. In the control flow, the subgraph can be
70   // triggered to execute in two ways: one is the begin execution of step, another is the execution of loop body.
71   // The input controls are different in the two ways.
72   bool is_loop_body_execution_{false};
73   // The dependent of loop body input actors.
74   mindspore::HashMap<int, std::vector<AID *>> loop_body_input_op_controls_;
75   std::vector<AID> loop_body_input_control_arrow_aids_;
76   size_t loop_body_input_controls_nums_{0};
77 
78   // Input data with branch id.
79   mindspore::HashMap<int, std::queue<OpRealParameterWithBranchID>> real_parameters_with_branch_id_;
80 
81   // Call nodes are used to record the caller of the subgraph, and are used to connect the data arrow
82   // and branch id arrow in the link process.
83   std::set<KernelWithIndex> call_nodes_;
84 };
85 
86 using EntranceActorPtr = std::shared_ptr<EntranceActor>;
87 }  // namespace runtime
88 }  // namespace mindspore
89 
90 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_ENTRANCE_ACTOR_H_
91