1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ 19 20 #include <functional> 21 #include <unordered_map> 22 #include <map> 23 #include <set> 24 #include <string> 25 #include <queue> 26 #include <vector> 27 #include <memory> 28 #include <unordered_set> 29 #include <utility> 30 #include "runtime/base.h" 31 #include "runtime/rt_model.h" 32 #include "runtime/stream.h" 33 #include "backend/session/kernel_graph.h" 34 #include "utils/contract.h" 35 36 namespace mindspore { 37 namespace device { 38 namespace ascend { 39 using std::map; 40 using std::queue; 41 using std::shared_ptr; 42 using std::unordered_map; 43 using std::unordered_set; 44 using std::vector; 45 using CNodeKey = void *; 46 const uint32_t kInvalidStreamId = UINT32_MAX; 47 const uint32_t kInvalidEventId = UINT32_MAX; 48 class AscendResourceMng { 49 public: GetInstance()50 static AscendResourceMng &GetInstance() { 51 static AscendResourceMng instance; 52 return instance; 53 } 54 ResetResource()55 void ResetResource() { 56 cur_stream_num_ = 0; 57 cur_event_num_ = 0; 58 } ApplyNewStream()59 uint32_t ApplyNewStream() { 60 if (!cur_stream_num_) { 61 uint32_t cur_stream_id = cur_stream_num_; 62 cur_stream_num_++; 63 return cur_stream_id; 64 } 65 uint32_t cur_stream_id = cur_stream_num_; 66 cur_stream_num_++; 67 return cur_stream_id; 68 } ApplyNewEvent()69 uint32_t ApplyNewEvent() { 70 if (!cur_event_num_) { 71 uint32_t cur_event_id = cur_event_num_; 72 cur_event_num_++; 73 return cur_event_id; 74 } 75 uint32_t cur_event_id = cur_event_num_; 76 cur_event_num_++; 77 return cur_event_id; 78 } 79 DeleteEvent()80 void DeleteEvent() { 81 if (!cur_event_num_) { 82 MS_LOG(WARNING) << "total event num is 0, no event to delete"; 83 } else { 84 --cur_event_num_; 85 } 86 } get_cur_stream_num()87 uint32_t get_cur_stream_num() { return cur_stream_num_; } GetCurAllocStreamId()88 uint32_t GetCurAllocStreamId() { 89 if (!cur_stream_num_) { 90 MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; 91 } 92 return cur_stream_num_ - 1; 93 } get_cur_event_num()94 uint32_t get_cur_event_num() { return cur_event_num_; } 95 96 private: 97 uint32_t cur_stream_num_{0}; 98 uint32_t cur_event_num_{0}; 99 }; 100 101 enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail }; 102 class AscendStreamAssign { 103 public: GetInstance()104 static AscendStreamAssign &GetInstance() { 105 static AscendStreamAssign instance; // Guaranteed to be destroyed. 106 return instance; 107 } 108 109 AscendStreamAssign(const AscendStreamAssign &) = delete; 110 AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; 111 112 void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr); 113 void GetHcomStreams(std::vector<uint32_t> *streams); 114 void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); get_stream_group()115 const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; } get_event_map()116 const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; } 117 118 private: 119 AscendStreamAssign() = default; 120 ~AscendStreamAssign() = default; 121 void Reset(); 122 CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); 123 CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); 124 void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr); 125 void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr); 126 void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr); 127 void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr); 128 void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); 129 void AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr); 130 uint32_t AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph); 131 void AssignIndependent(const NotNull<KernelGraphPtr> &graph_ptr); 132 uint32_t AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph); 133 void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr); 134 void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr); 135 void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr); 136 void InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr); 137 void InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr); 138 void InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr); 139 void ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr, const std::set<uint32_t> &hcom_streams); 140 void ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr, 141 const std::set<uint32_t> &independent_streams); 142 void ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr, 143 std::map<uint32_t, std::set<uint32_t>> other_graph); 144 bool CheckStreamSwitch(const CNodePtr &switch_ptr); 145 void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); 146 void InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); 147 void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr); 148 void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); 149 void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr); 150 void InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr); 151 void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); 152 void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, 153 const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index); 154 void InsertEventHcomDependHcomAtSameGroup(const NotNull<KernelGraphPtr> &graph_ptr, 155 std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item); 156 std::vector<std::pair<uint32_t, vector<size_t>>> GetStreamIDHcomMap(const std::vector<CNodePtr> &cnode_ptr_list, 157 const std::string &group, size_t graph_id); 158 159 void AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr); 160 vector<CNodePtr> GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr); 161 bool IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index, const CNodePtr &node_ptr, 162 size_t index); 163 164 void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr); 165 void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr); 166 void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr); 167 168 void CheckScenario(const NotNull<KernelGraphPtr> &graph_ptr, vector<CNodePtr> *last_grad_and_status); 169 CNodePtr GetCNodesNeededMoved(vector<CNodePtr> *moved_backward_cnodes, vector<CNodePtr> *moved_forward_cnodes, 170 const vector<CNodePtr> &last_grad_and_status, const NotNull<KernelGraphPtr> &graph_ptr); 171 CNodePtr GetTargetOutputNode(const vector<CNodePtr> &moved_backward_cnodes, const CNodePtr first_node, 172 const NotNull<KernelGraphPtr> &graph_ptr); 173 bool FinetuneSubgraphExecOrder(vector<CNodePtr> *cnodes); 174 void TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> &graph_ptr); 175 176 uint32_t GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr); 177 uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key); 178 uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr); 179 void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr); 180 bool IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node, const CNodePtr &cur_node, 181 bool exclude_hcom); 182 bool IsTaskSink(); 183 bool IsHcom(const CNodePtr &cur_cnode_ptr); 184 bool IsIndependentNode(const CNodePtr &node_ptr); 185 bool IsProcessedStream(uint32_t stream_id); 186 vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, 187 const CNodePtr &node, bool exclude_hcom); 188 void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); 189 void SetLoopSink(); 190 191 // function for memory reuse 192 void GetStreamRelations(); 193 void DFS(uint32_t start, std::vector<uint32_t> *group); 194 bool IsVecExist(const std::vector<uint32_t> &group); 195 void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr); 196 void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); 197 void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index); 198 StreamActiveKind GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index); 199 uint32_t GetStreamByActivedStream(uint32_t actived_stream_id); 200 void PrintStreamRelations(); 201 void PrintStreamGroups(); 202 void FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr); 203 bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const; 204 vector<CNodePtr> GetInputKernels(const CNodePtr &cnode); 205 206 bool independent_stream_activated_{false}; 207 bool hcom_stream_activated_{false}; 208 bool loop_sink_{false}; 209 210 // key:stream id, value:node number 211 std::map<uint32_t, uint32_t> common_stream_map_{}; 212 // key:stream id, value:node number 213 std::map<uint32_t, uint32_t> independent_stream_map_{}; 214 // key:stream id, value:task number 215 std::map<uint32_t, uint32_t> hcom_stream_map_{}; 216 217 std::set<uint32_t> processed_streams_{}; 218 std::vector<uint32_t> need_first_active_streams_{}; 219 std::set<CNodeKey> independent_targets_; 220 221 // key:group name, value:key1:graph id, value1:stream id 222 std::map<std::string, std::map<uint32_t, std::set<uint32_t>>> group_hcom_graph_map_; 223 // key:graph id, value:stream set 224 std::map<uint32_t, std::set<uint32_t>> independent_graph_map_; 225 226 // attr for memory copy reuse 227 std::map<uint32_t, std::vector<uint32_t>> stream_relations_{}; 228 std::vector<std::vector<uint32_t>> stream_groups_{}; 229 std::map<CNodePtr, CNodePtr> event_map_{}; 230 std::set<uint32_t> middle_active_streams_{}; 231 // new policy end 232 bool IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode); 233 vector<CNodePtr>::iterator FindGraphEnd(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end); 234 }; 235 } // namespace ascend 236 } // namespace device 237 } // namespace mindspore 238 239 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ 240