• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_
19 
20 #include <functional>
21 #include <unordered_map>
22 #include <map>
23 #include <set>
24 #include <string>
25 #include <queue>
26 #include <vector>
27 #include <memory>
28 #include <unordered_set>
29 #include <utility>
30 #include "runtime/base.h"
31 #include "runtime/rt_model.h"
32 #include "runtime/stream.h"
33 #include "backend/session/kernel_graph.h"
34 #include "utils/contract.h"
35 
36 namespace mindspore {
37 namespace device {
38 namespace ascend {
39 using std::map;
40 using std::queue;
41 using std::shared_ptr;
42 using std::unordered_map;
43 using std::unordered_set;
44 using std::vector;
45 using CNodeKey = void *;
46 const uint32_t kInvalidStreamId = UINT32_MAX;
47 const uint32_t kInvalidEventId = UINT32_MAX;
48 class AscendResourceMng {
49  public:
GetInstance()50   static AscendResourceMng &GetInstance() {
51     static AscendResourceMng instance;
52     return instance;
53   }
54 
ResetResource()55   void ResetResource() {
56     cur_stream_num_ = 0;
57     cur_event_num_ = 0;
58   }
ApplyNewStream()59   uint32_t ApplyNewStream() {
60     if (!cur_stream_num_) {
61       uint32_t cur_stream_id = cur_stream_num_;
62       cur_stream_num_++;
63       return cur_stream_id;
64     }
65     uint32_t cur_stream_id = cur_stream_num_;
66     cur_stream_num_++;
67     return cur_stream_id;
68   }
ApplyNewEvent()69   uint32_t ApplyNewEvent() {
70     if (!cur_event_num_) {
71       uint32_t cur_event_id = cur_event_num_;
72       cur_event_num_++;
73       return cur_event_id;
74     }
75     uint32_t cur_event_id = cur_event_num_;
76     cur_event_num_++;
77     return cur_event_id;
78   }
79 
DeleteEvent()80   void DeleteEvent() {
81     if (!cur_event_num_) {
82       MS_LOG(WARNING) << "total event num is 0, no event to delete";
83     } else {
84       --cur_event_num_;
85     }
86   }
get_cur_stream_num()87   uint32_t get_cur_stream_num() { return cur_stream_num_; }
GetCurAllocStreamId()88   uint32_t GetCurAllocStreamId() {
89     if (!cur_stream_num_) {
90       MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get";
91     }
92     return cur_stream_num_ - 1;
93   }
get_cur_event_num()94   uint32_t get_cur_event_num() { return cur_event_num_; }
95 
96  private:
97   uint32_t cur_stream_num_{0};
98   uint32_t cur_event_num_{0};
99 };
100 
101 enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
102 class AscendStreamAssign {
103  public:
GetInstance()104   static AscendStreamAssign &GetInstance() {
105     static AscendStreamAssign instance;  // Guaranteed to be destroyed.
106     return instance;
107   }
108 
109   AscendStreamAssign(const AscendStreamAssign &) = delete;
110   AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
111 
112   void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr);
113   void GetHcomStreams(std::vector<uint32_t> *streams);
114   void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
get_stream_group()115   const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; }
get_event_map()116   const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; }
117 
118  private:
119   AscendStreamAssign() = default;
120   ~AscendStreamAssign() = default;
121   void Reset();
122   CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
123   CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
124   void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr);
125   void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr);
126   void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
127   void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
128   void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr);
129   void AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr);
130   uint32_t AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
131   void AssignIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
132   uint32_t AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
133   void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
134   void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr);
135   void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr);
136   void InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr);
137   void InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
138   void InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr);
139   void ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr, const std::set<uint32_t> &hcom_streams);
140   void ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr,
141                                   const std::set<uint32_t> &independent_streams);
142   void ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
143                                 std::map<uint32_t, std::set<uint32_t>> other_graph);
144   bool CheckStreamSwitch(const CNodePtr &switch_ptr);
145   void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
146   void InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
147   void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
148   void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
149   void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
150   void InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr);
151   void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
152   void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
153                               const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index);
154   void InsertEventHcomDependHcomAtSameGroup(const NotNull<KernelGraphPtr> &graph_ptr,
155                                             std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item);
156   std::vector<std::pair<uint32_t, vector<size_t>>> GetStreamIDHcomMap(const std::vector<CNodePtr> &cnode_ptr_list,
157                                                                       const std::string &group, size_t graph_id);
158 
159   void AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr);
160   vector<CNodePtr> GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
161   bool IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index, const CNodePtr &node_ptr,
162                        size_t index);
163 
164   void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
165   void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
166   void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr);
167 
168   void CheckScenario(const NotNull<KernelGraphPtr> &graph_ptr, vector<CNodePtr> *last_grad_and_status);
169   CNodePtr GetCNodesNeededMoved(vector<CNodePtr> *moved_backward_cnodes, vector<CNodePtr> *moved_forward_cnodes,
170                                 const vector<CNodePtr> &last_grad_and_status, const NotNull<KernelGraphPtr> &graph_ptr);
171   CNodePtr GetTargetOutputNode(const vector<CNodePtr> &moved_backward_cnodes, const CNodePtr first_node,
172                                const NotNull<KernelGraphPtr> &graph_ptr);
173   bool FinetuneSubgraphExecOrder(vector<CNodePtr> *cnodes);
174   void TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> &graph_ptr);
175 
176   uint32_t GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr);
177   uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key);
178   uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
179   void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr);
180   bool IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node, const CNodePtr &cur_node,
181                        bool exclude_hcom);
182   bool IsTaskSink();
183   bool IsHcom(const CNodePtr &cur_cnode_ptr);
184   bool IsIndependentNode(const CNodePtr &node_ptr);
185   bool IsProcessedStream(uint32_t stream_id);
186   vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
187                                           const CNodePtr &node, bool exclude_hcom);
188   void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
189   void SetLoopSink();
190 
191   // function for memory reuse
192   void GetStreamRelations();
193   void DFS(uint32_t start, std::vector<uint32_t> *group);
194   bool IsVecExist(const std::vector<uint32_t> &group);
195   void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr);
196   void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr);
197   void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);
198   StreamActiveKind GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);
199   uint32_t GetStreamByActivedStream(uint32_t actived_stream_id);
200   void PrintStreamRelations();
201   void PrintStreamGroups();
202   void FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr);
203   bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const;
204   vector<CNodePtr> GetInputKernels(const CNodePtr &cnode);
205 
206   bool independent_stream_activated_{false};
207   bool hcom_stream_activated_{false};
208   bool loop_sink_{false};
209 
210   // key:stream id, value:node number
211   std::map<uint32_t, uint32_t> common_stream_map_{};
212   // key:stream id, value:node number
213   std::map<uint32_t, uint32_t> independent_stream_map_{};
214   // key:stream id, value:task number
215   std::map<uint32_t, uint32_t> hcom_stream_map_{};
216 
217   std::set<uint32_t> processed_streams_{};
218   std::vector<uint32_t> need_first_active_streams_{};
219   std::set<CNodeKey> independent_targets_;
220 
221   // key:group name, value:key1:graph id, value1:stream id
222   std::map<std::string, std::map<uint32_t, std::set<uint32_t>>> group_hcom_graph_map_;
223   // key:graph id, value:stream set
224   std::map<uint32_t, std::set<uint32_t>> independent_graph_map_;
225 
226   // attr for memory copy reuse
227   std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
228   std::vector<std::vector<uint32_t>> stream_groups_{};
229   std::map<CNodePtr, CNodePtr> event_map_{};
230   std::set<uint32_t> middle_active_streams_{};
231   // new policy end
232   bool IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode);
233   vector<CNodePtr>::iterator FindGraphEnd(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end);
234 };
235 }  // namespace ascend
236 }  // namespace device
237 }  // namespace mindspore
238 
239 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_
240