• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
18 
19 #include "utils/convert_utils_base.h"
20 #include "utils/log_adapter.h"
21 #ifndef ENABLE_SECURITY
22 #include "include/backend/debug/data_dump/dump_json_parser.h"
23 #endif
24 #include "acl/error_codes/rt_error_codes.h"
25 #include "plugin/device/ascend/hal/device/ascend_gmem_adapter.h"
26 #include "transform/symbol/acl_rt_symbol.h"
27 #include "transform/symbol/symbol_utils.h"
28 
29 namespace mindspore {
30 namespace device {
31 namespace ascend {
GetInstance()32 AscendStreamMng &AscendStreamMng::GetInstance() {
33   static AscendStreamMng instance{};
34   return instance;
35 }
36 
DestroyAllRtEvents()37 void AscendStreamMng::DestroyAllRtEvents() {
38   for (size_t i = 0; i < events_.size(); ++i) {
39     if (events_[i] != nullptr) {
40       auto rt_ret = CALL_ASCEND_API(aclrtDestroyEvent, events_[i]);
41       if (rt_ret != ACL_ERROR_NONE) {
42         MS_LOG(ERROR) << "Call aclrtDestroyEvent failed, ret:" << rt_ret;
43       }
44     }
45   }
46   events_.clear();
47 }
48 
DeleteEvent()49 void AscendStreamMng::DeleteEvent() {
50   if (cur_event_num_ == 0) {
51     MS_LOG(WARNING) << "total event num is 0, no event to delete";
52   } else {
53     --cur_event_num_;
54   }
55 }
56 
DeleteStream()57 void AscendStreamMng::DeleteStream() {
58   if (cur_stream_num_ == 0) {
59     MS_LOG(WARNING) << " total stream num is 0, no stream to delete";
60   } else {
61     --cur_stream_num_;
62   }
63 }
64 
GetCurAllocStreamId() const65 uint32_t AscendStreamMng::GetCurAllocStreamId() const {
66   if (cur_stream_num_ == 0) {
67     MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get";
68   }
69   return cur_stream_num_ - 1;
70 }
71 
CreateStream(aclrtStream * stream,int32_t priority)72 void AscendStreamMng::CreateStream(aclrtStream *stream, int32_t priority) {
73   std::lock_guard<std::mutex> lock_streams(stream_mutex_);
74   auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, stream, IntToUint(priority),
75                              (ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC));
76   if (ret != ACL_ERROR_NONE) {
77     MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
78   }
79   ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, *stream, ACL_STOP_ON_FAILURE);
80   if (ret != ACL_ERROR_NONE) {
81     MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
82   }
83   (void)streams_.emplace_back(*stream);
84   // If this is the first stream ever created, set it as default stream.
85   if (streams_.size() == 1) {
86     default_stream_ = *stream;
87     default_stream_id_ = kIndex0;
88   }
89   RegCallback(*stream);
90 }
91 
CreateStream(size_t * stream_id,int32_t priority)92 void AscendStreamMng::CreateStream(size_t *stream_id, int32_t priority) {
93   std::lock_guard<std::mutex> lock_streams(stream_mutex_);
94   aclrtStream stream;
95   auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, &stream, IntToUint(priority),
96                              (ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC));
97   if (ret != ACL_ERROR_NONE) {
98     MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
99   }
100   ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, stream, ACL_STOP_ON_FAILURE);
101   if (ret != ACL_ERROR_NONE) {
102     MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
103   }
104   *stream_id = streams_.size();
105   (void)streams_.emplace_back(stream);
106   RegCallback(stream);
107 }
108 
RegCallback(aclrtStream stream)109 void AscendStreamMng::RegCallback(aclrtStream stream) {
110   MS_LOG(INFO) << "Register callback thread, stream : " << stream << ".";
111   (void)callback_cached_streams_.emplace_back(stream);
112   if (callback_cached_streams_.size() > 1 && !is_enable_callback_) {
113     is_enable_callback_ = true;
114   }
115   if (!is_enable_callback_) {
116     return;
117   }
118 #ifdef WITH_BACKEND
119   for (const auto &callback_cached_stream : callback_cached_streams_) {
120     if (stream_call_backs_.count(callback_cached_stream) > 0) {
121       MS_LOG(WARNING) << "Register callback thread failed, stream : " << callback_cached_stream
122                       << " is already registered.";
123       continue;
124     }
125 
126     auto callback_thread = std::make_shared<CallbackThread>();
127     callback_thread->create();
128     auto ret = CALL_ASCEND_API(aclrtSubscribeReport, callback_thread->thread_, (aclrtStream)callback_cached_stream);
129     if (!ret) {
130       MS_LOG(INFO) << "Register callback thread success, stream : " << callback_cached_stream << ".";
131       (void)stream_call_backs_.emplace(callback_cached_stream, callback_thread);
132     } else {
133       MS_LOG(INTERNAL_EXCEPTION) << "Register callback thread failed, stream : " << callback_cached_stream
134                                  << ", ret : " << ret;
135     }
136   }
137 #endif
138   callback_cached_streams_.clear();
139 }
140 
UnRegCallback(aclrtStream stream)141 void AscendStreamMng::UnRegCallback(aclrtStream stream) {
142   MS_LOG(INFO) << "Unregister callback thread, stream : " << stream << ".";
143   if (!is_enable_callback_) {
144     return;
145   }
146 #ifdef WITH_BACKEND
147   if (stream_call_backs_.count(stream) == 0) {
148     MS_LOG(WARNING) << "Unregister callback thread failed, stream : " << stream << " is not exist.";
149     return;
150   }
151   auto callback_thread = stream_call_backs_.at(stream);
152   // Cannot call aclrtUnSubscribeReport.
153   callback_thread->cancel();
154   stream_call_backs_.erase(stream);
155 #endif
156 }
157 
CreateStreamWithFlags(aclrtStream * stream,uint32_t flags,int32_t priority)158 void AscendStreamMng::CreateStreamWithFlags(aclrtStream *stream, uint32_t flags, int32_t priority) {
159   std::lock_guard<std::mutex> lock_streams(stream_mutex_);
160   auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, stream, IntToUint(priority), flags);
161   if (ret != ACL_ERROR_NONE) {
162     MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
163   }
164   ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, *stream, ACL_STOP_ON_FAILURE);
165   if (ret != ACL_ERROR_NONE) {
166     MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
167   }
168   (void)streams_.emplace_back(*stream);
169   RegCallback(*stream);
170 }
171 
CreateStreamWithFlags(size_t * stream_id,uint32_t flags,int32_t priority)172 void AscendStreamMng::CreateStreamWithFlags(size_t *stream_id, uint32_t flags, int32_t priority) {
173   std::lock_guard<std::mutex> lock_streams(stream_mutex_);
174   aclrtStream stream;
175   auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, &stream, IntToUint(priority), flags);
176   if (ret != ACL_ERROR_NONE) {
177     MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
178   }
179   ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, stream, ACL_STOP_ON_FAILURE);
180   if (ret != ACL_ERROR_NONE) {
181     MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
182   }
183   *stream_id = streams_.size();
184   (void)streams_.emplace_back(stream);
185   RegCallback(stream);
186 }
187 
ApplyRtEvent()188 aclrtEvent AscendStreamMng::ApplyRtEvent() {
189   aclrtEvent rt_event = nullptr;
190   auto ret = CALL_ASCEND_API(aclrtCreateEvent, &rt_event);
191   if (ret != ACL_ERROR_NONE) {
192     MS_LOG(EXCEPTION) << "aclrtCreateEvent failed, ret:" << ret;
193   }
194   (void)events_.emplace_back(rt_event);
195   return rt_event;
196 }
197 
DestroyStream(size_t stream_id)198 bool AscendStreamMng::DestroyStream(size_t stream_id) {
199   std::lock_guard<std::mutex> lock_streams(stream_mutex_);
200   if (stream_id >= streams_.size()) {
201     MS_LOG(ERROR) << "Ascend stream not found for stream id " << stream_id;
202     return false;
203   }
204   if (streams_.at(stream_id) == nullptr) {
205     MS_LOG(WARNING) << "Ascend stream hsa been destroyed for stream id " << stream_id;
206     return true;
207   }
208   const auto ret = CALL_ASCEND_API(aclrtDestroyStream, streams_.at(stream_id));
209   if (ret != ACL_ERROR_NONE) {
210     MS_LOG(EXCEPTION) << "Call aclrtDestroyStream, ret[" << ret << "]";
211   }
212   UnRegCallback(streams_.at(stream_id));
213   streams_[stream_id] = nullptr;
214   return true;
215 }
216 
DestroyAllStreams()217 bool AscendStreamMng::DestroyAllStreams() {
218   std::lock_guard<std::mutex> lock_streams(stream_mutex_);
219   for (const auto &stream : streams_) {
220     if (stream == nullptr) {
221       continue;
222     }
223     const auto ret = CALL_ASCEND_API(aclrtDestroyStream, stream);
224     if (ret != ACL_ERROR_NONE) {
225       MS_LOG(EXCEPTION) << "Call aclrtDestroyStream, ret[" << ret << "]";
226     }
227     UnRegCallback(stream);
228   }
229   streams_.clear();
230   return true;
231 }
232 
GetStream(size_t stream_id) const233 aclrtStream AscendStreamMng::GetStream(size_t stream_id) const {
234   if (stream_id >= streams_.size()) {
235     MS_LOG(DEBUG) << "Stream for stream id[" << stream_id << "] not found, return nullptr.";
236     return nullptr;
237   }
238   return streams_[stream_id];
239 }
240 
SyncStream(size_t stream_id) const241 bool AscendStreamMng::SyncStream(size_t stream_id) const {
242   if (stream_id >= streams_.size()) {
243     MS_LOG(EXCEPTION) << "Stream for stream id[" << stream_id << "] has not been created.";
244   }
245   const auto stream = streams_[stream_id];
246   if (stream == nullptr) {
247     MS_LOG(WARNING) << "Stream for stream id[" << stream_id << "] has been destroyed.";
248     return false;
249   }
250   return SyncStream(stream);
251 }
252 
SyncStream(aclrtStream stream) const253 bool AscendStreamMng::SyncStream(aclrtStream stream) const {
254   MS_EXCEPTION_IF_NULL(stream);
255   auto RET = CALL_ASCEND_API(aclrtSynchronizeStreamWithTimeout, stream, -1);
256   if (RET != ACL_ERROR_NONE && RET != ACL_ERROR_RT_AICORE_OVER_FLOW) {  // o for switch stream
257     MS_LOG(ERROR) << "Call runtime aclrtSynchronizeStreamWithTimeout error.";
258     return false;
259   }
260   if (RET == ACL_ERROR_RT_AICORE_OVER_FLOW) {
261     MS_LOG(WARNING) << "Call runtime aclrtSynchronizeStreamWithTimeout, the stream get overflow.";
262   }
263   return true;
264 }
265 
SyncAllStreams() const266 bool AscendStreamMng::SyncAllStreams() const {
267   for (size_t i = 0; i < streams_.size(); ++i) {
268     const auto stream = streams_[i];
269     if (stream != nullptr && !SyncStream(stream)) {
270       MS_LOG(ERROR) << "SyncStream for stream id " << i << " failed.";
271       return false;
272     }
273   }
274   return true;
275 }
276 
SyncNotDefaultStreams() const277 bool AscendStreamMng::SyncNotDefaultStreams() const {
278   bool res = true;
279   for (size_t i = 0; i < streams_.size(); i++) {
280     if (i != default_stream_id_ && !SyncStream(i)) {
281       MS_LOG(ERROR) << "Failed to sync for ascend stream id: " << i;
282       res = false;
283     }
284   }
285   return res;
286 }
287 
SyncExceptStreamsInList(const std::set<aclrtStream> & except_streams) const288 bool AscendStreamMng::SyncExceptStreamsInList(const std::set<aclrtStream> &except_streams) const {
289   bool res = true;
290   for (size_t i = 0; i < streams_.size(); i++) {
291     if (except_streams.count(streams_[i]) > 0) {
292       MS_LOG(DEBUG) << "Stream id:" << i << " is been synchronized.";
293       continue;
294     }
295     if (!SyncStream(i)) {
296       MS_LOG(ERROR) << "Failed to sync for ascend stream id: " << i;
297       res = false;
298     }
299   }
300   return res;
301 }
302 
QueryStreamSize() const303 size_t AscendStreamMng::QueryStreamSize() const { return streams_.size(); }
304 
QueryStream(size_t stream_id)305 bool AscendStreamMng::QueryStream(size_t stream_id) {
306   if (stream_id >= streams_.size()) {
307     MS_LOG(EXCEPTION) << "Stream for stream id[" << stream_id << "] has not been created.";
308   }
309   const auto stream = streams_[stream_id];
310   if (stream == nullptr) {
311     MS_LOG(WARNING) << "Stream for stream id[" << stream_id << "] has been destroyed.";
312     return false;
313   }
314 
315   aclrtStreamStatus status;
316   auto ret = CALL_ASCEND_API(aclrtStreamQuery, stream, &status);
317   if (ret != ACL_SUCCESS) {
318     MS_LOG(EXCEPTION) << "Failed to query completion status for stream id: " << stream_id;
319   }
320   return status == ACL_STREAM_STATUS_COMPLETE;
321 }
322 
GetStreamId(void * stream_ptr)323 size_t AscendStreamMng::GetStreamId(void *stream_ptr) {
324   auto iter = std::find(streams_.begin(), streams_.end(), stream_ptr);
325   if (iter == streams_.end()) {
326     MS_LOG(EXCEPTION) << "Failed to find stream_ptr in streams_, stream_ptr:" << stream_ptr;
327   }
328 
329   return LongToSize(std::distance(streams_.begin(), iter));
330 }
331 
GetStreamIds() const332 std::vector<uint32_t> AscendStreamMng::GetStreamIds() const {
333   std::vector<uint32_t> stream_ids;
334   for (size_t i = 0; i < streams_.size(); i++) {
335     if (streams_[i] != nullptr) {
336       (void)stream_ids.emplace_back(static_cast<uint32_t>(i));
337     }
338   }
339   return stream_ids;
340 }
341 }  // namespace ascend
342 }  // namespace device
343 }  // namespace mindspore
344