• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
19 
20 #include <vector>
21 #include <string>
22 #include <map>
23 #include <memory>
24 #include <utility>
25 #include "utils/hash_map.h"
26 #include "runtime/graph_scheduler/actor/actor_common.h"
27 #include "runtime/graph_scheduler/actor/control_flow/control_actor.h"
28 
29 namespace mindspore {
30 namespace runtime {
31 // The exit actor is used to receive a set of data arrow and a branch id in the control flow, and then send the
32 // device tensors in the data to the corresponding actor. It is the exit of the end of kernel graph execution.
33 class ExitActor : public ControlActor {
34  public:
ExitActor(const std::string & name,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters,const AnfNodePtr & node)35   ExitActor(const std::string &name, const AID &memory_manager_aid, const std::vector<KernelWithIndex> &parameters,
36             const AnfNodePtr &node)
37       : ControlActor(name, KernelTransformType::kExitActor, memory_manager_aid, parameters, node) {
38     device_contexts_.resize(parameters.size());
39     input_device_tensors_.resize(parameters.size());
40   }
41   ~ExitActor() override = default;
42 
output_branch_control_arrows()43   const mindspore::HashMap<int, std::vector<AID>> &output_branch_control_arrows() const {
44     return output_branch_control_arrows_;
45   }
output_branch_data_arrows()46   const mindspore::HashMap<int, std::vector<DataArrowPtr>> &output_branch_data_arrows() const {
47     return output_branch_data_arrows_;
48   }
output_branch_partial_arrows()49   const mindspore::HashMap<int, std::vector<DataArrowPtr>> &output_branch_partial_arrows() const {
50     return output_branch_partial_arrows_;
51   }
is_need_copy_device_tensors()52   const std::vector<bool> &is_need_copy_device_tensors() const { return is_need_copy_device_tensors_; }
output_branch_dynamic_len_index()53   const mindspore::HashMap<int, std::vector<std::pair<std::vector<size_t>, bool>>> &output_branch_dynamic_len_index()
54     const {
55     return output_branch_dynamic_len_index_;
56   }
57   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
58 
59  protected:
60   void Init() override;
61   void FetchInput(OpContext<DeviceTensor> *const context) override;
62   void SendOutput(OpContext<DeviceTensor> *const context) override;
63   void IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context) override;
64 
65  private:
66   friend class ControlNodeScheduler;
67   friend class SchedulerHelper;
68 
69   void CopyDeviceAddress(OpContext<DeviceTensor> *const context);
70   void UpdateDeviceOutputData();
71   void MergeDynamiclenDeviceAddress(OpContext<DeviceTensor> *const context);
72   bool IsNeedCopyDeviceAddress(DeviceTensor *const input_device_tensor, size_t index);
73 
74   // Exit actor will send to different actors according to different callers, so the output data, control,
75   // and partial arrows will have branch.
76   mindspore::HashMap<int, std::vector<DataArrowPtr>> output_branch_data_arrows_;
77   mindspore::HashMap<int, std::vector<AID>> output_branch_control_arrows_;
78   mindspore::HashMap<int, std::vector<DataArrowPtr>> output_branch_partial_arrows_;
79   // The real index of actor output, the first int means the output branch id and the bool value means if the
80   // output is a dynamic len.
81   // eg. argument: (A, (B1, B2), C)  parameter: (a, b, c)
82   //     the vector would be {<{0}, false>, <{1, 2}, true>,<{3},false>}
83   mindspore::HashMap<int, std::vector<std::pair<std::vector<size_t>, bool>>> output_branch_dynamic_len_index_;
84 
85   // In exit actor, we need to copy a new device tensor for the output of the kernel actor, but parameter is not
86   // needed. This mark is used to record whether it need to be copied.
87   std::vector<bool> is_need_copy_device_tensors_;
88   std::vector<bool> is_need_dynamic_checks_;
89   std::map<KernelWithIndex, KernelWithIndex> ref_out_in_map_;
90   // Cache the dynamic shape flag to optimize the running performance.
91   std::vector<bool> is_dynamic_shapes_;
92   // Output data.
93   //  The output branch data corresponds to the output_data_arrows_ one by one.
94   mindspore::HashMap<int, std::vector<std::pair<size_t, OpDataUniquePtr<DeviceTensor>>>> output_branch_data_;
95   // The value of haspmap indicates the output data flag. See constant prefixed with kOutputDataFalg for details.
96   mindspore::HashMap<int, std::vector<size_t>> output_branch_data_flag_;
97 };
98 
99 using ExitActorPtr = std::shared_ptr<ExitActor>;
100 }  // namespace runtime
101 }  // namespace mindspore
102 
103 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_EXIT_ACTOR_H_
104