1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <stack> 24 #include <set> 25 #include <algorithm> 26 #include "utils/hash_map.h" 27 #include "runtime/graph_scheduler/actor/actor_common.h" 28 #include "runtime/graph_scheduler/actor/control_flow/control_actor.h" 29 30 namespace mindspore { 31 namespace runtime { 32 // Stack actor is used to record those device actors that need additional storage in recursive scenes. 33 // The execution steps of the stack actor: 34 // 1. Accept a copy of all direct parameters and push them to the stack 35 // 2. Notify gather actor can be executed 36 // 3. Receive the output of exit actor 37 // 4. Send output. 38 class StackActor : public ControlActor { 39 public: 40 StackActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> ¶meters); 41 ~StackActor() override = default; 42 43 // The input data and partial of the stack actor needs to be pushed into the stack according to the input index, 44 // so it is implemented separately. 45 void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; 46 void RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) override; 47 void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; 48 49 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; 50 input_stack_data_num()51 size_t input_stack_data_num() const { return input_stack_data_num_; } input_stack_partials_num()52 size_t input_stack_partials_num() const { return input_stack_partials_num_; } input_stack_controls_num()53 size_t input_stack_controls_num() const { return input_stack_controls_num_; } 54 55 protected: 56 void Init() override; 57 void FetchInput(OpContext<DeviceTensor> *const context) override; 58 bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; 59 void EraseInput(const OpContext<DeviceTensor> *const context) override; 60 61 private: 62 friend class ControlNodeScheduler; 63 64 // Check running condition functions. 65 bool CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const; 66 bool CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const; 67 bool CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const; 68 69 // The input data and partials records that the stack actor is copied from the input nodes and needs to be 70 // stored in the device tensor in the stack. 71 mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_stack_data_; 72 mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<OpPartialPtr>>> input_stack_partials_; 73 // When the input has side effects, some control arrows need to be pushed to the stack, which needs to be 74 // recorded according to the from aids, but if the input node is a call node, the input control arrows may 75 // come from different exit actors, so the relationship between from actor and index needs to be recorded 76 // during the schedule, and the number of control arrows is recorded according to the index at runtime. 77 mindspore::HashMap<AID, size_t> control_aid_to_indexs_; 78 mindspore::HashMap<int, mindspore::HashMap<size_t, size_t>> input_stack_controls_; 79 std::set<AID> stack_control_aids_; 80 // Input parameter num represents the number of actor's input come from funcgraph itself, these inputs will 81 // be ranked at the front of input. 82 size_t input_stack_data_num_{0}; 83 size_t input_stack_partials_num_{0}; 84 size_t input_stack_controls_num_{0}; 85 bool is_branch_id_enable_{true}; 86 }; 87 88 using StackActorPtr = std::shared_ptr<StackActor>; 89 } // namespace runtime 90 } // namespace mindspore 91 92 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_STACK_ACTOR_H_ 93