• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <stack>
24 #include <set>
25 #include <algorithm>
26 #include "utils/hash_map.h"
27 #include "runtime/graph_scheduler/actor/actor_common.h"
28 #include "runtime/graph_scheduler/actor/control_flow/control_actor.h"
29 
30 namespace mindspore {
31 namespace runtime {
32 // Stack actor is used to record those device actors that need additional storage in recursive scenes.
33 // The execution steps of the stack actor:
34 // 1. Accept a copy of all direct parameters and push them to the stack
35 // 2. Notify gather actor can be executed
36 // 3. Receive the output of exit actor
37 // 4. Send output.
38 class StackActor : public ControlActor {
39  public:
40   StackActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters);
41   ~StackActor() override = default;
42 
43   // The input data and partial of the stack actor needs to be pushed into the stack according to the input index,
44   // so it is implemented separately.
45   void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
46   void RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) override;
47   void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
48 
49   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
50 
input_stack_data_num()51   size_t input_stack_data_num() const { return input_stack_data_num_; }
input_stack_partials_num()52   size_t input_stack_partials_num() const { return input_stack_partials_num_; }
input_stack_controls_num()53   size_t input_stack_controls_num() const { return input_stack_controls_num_; }
54 
55  protected:
56   void Init() override;
57   void FetchInput(OpContext<DeviceTensor> *const context) override;
58   bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
59   void EraseInput(const OpContext<DeviceTensor> *const context) override;
60 
61  private:
62   friend class ControlNodeScheduler;
63 
64   // Check running condition functions.
65   bool CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const;
66   bool CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const;
67   bool CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const;
68 
69   // The input data and partials records that the stack actor is copied from the input nodes and needs to be
70   // stored in the device tensor in the stack.
71   mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_stack_data_;
72   mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<OpPartialPtr>>> input_stack_partials_;
73   // When the input has side effects, some control arrows need to be pushed to the stack, which needs to be
74   // recorded according to the from aids, but if the input node is a call node, the input control arrows may
75   // come from different exit actors, so the relationship between from actor and index needs to be recorded
76   // during the schedule, and the number of control arrows is recorded according to the index at runtime.
77   mindspore::HashMap<AID, size_t> control_aid_to_indexs_;
78   mindspore::HashMap<int, mindspore::HashMap<size_t, size_t>> input_stack_controls_;
79   std::set<AID> stack_control_aids_;
80   // Input parameter num represents the number of actor's input come from funcgraph itself, these inputs will
81   // be ranked at the front of input.
82   size_t input_stack_data_num_{0};
83   size_t input_stack_partials_num_{0};
84   size_t input_stack_controls_num_{0};
85   bool is_branch_id_enable_{true};
86 };
87 
88 using StackActorPtr = std::shared_ptr<StackActor>;
89 }  // namespace runtime
90 }  // namespace mindspore
91 
92 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_
93