1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
18
19 #include "utils/convert_utils_base.h"
20 #include "utils/log_adapter.h"
21 #ifndef ENABLE_SECURITY
22 #include "include/backend/debug/data_dump/dump_json_parser.h"
23 #endif
24 #include "acl/error_codes/rt_error_codes.h"
25 #include "plugin/device/ascend/hal/device/ascend_gmem_adapter.h"
26 #include "transform/symbol/acl_rt_symbol.h"
27 #include "transform/symbol/symbol_utils.h"
28
29 namespace mindspore {
30 namespace device {
31 namespace ascend {
GetInstance()32 AscendStreamMng &AscendStreamMng::GetInstance() {
33 static AscendStreamMng instance{};
34 return instance;
35 }
36
DestroyAllRtEvents()37 void AscendStreamMng::DestroyAllRtEvents() {
38 for (size_t i = 0; i < events_.size(); ++i) {
39 if (events_[i] != nullptr) {
40 auto rt_ret = CALL_ASCEND_API(aclrtDestroyEvent, events_[i]);
41 if (rt_ret != ACL_ERROR_NONE) {
42 MS_LOG(ERROR) << "Call aclrtDestroyEvent failed, ret:" << rt_ret;
43 }
44 }
45 }
46 events_.clear();
47 }
48
DeleteEvent()49 void AscendStreamMng::DeleteEvent() {
50 if (cur_event_num_ == 0) {
51 MS_LOG(WARNING) << "total event num is 0, no event to delete";
52 } else {
53 --cur_event_num_;
54 }
55 }
56
DeleteStream()57 void AscendStreamMng::DeleteStream() {
58 if (cur_stream_num_ == 0) {
59 MS_LOG(WARNING) << " total stream num is 0, no stream to delete";
60 } else {
61 --cur_stream_num_;
62 }
63 }
64
GetCurAllocStreamId() const65 uint32_t AscendStreamMng::GetCurAllocStreamId() const {
66 if (cur_stream_num_ == 0) {
67 MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get";
68 }
69 return cur_stream_num_ - 1;
70 }
71
CreateStream(aclrtStream * stream,int32_t priority)72 void AscendStreamMng::CreateStream(aclrtStream *stream, int32_t priority) {
73 std::lock_guard<std::mutex> lock_streams(stream_mutex_);
74 auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, stream, IntToUint(priority),
75 (ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC));
76 if (ret != ACL_ERROR_NONE) {
77 MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
78 }
79 ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, *stream, ACL_STOP_ON_FAILURE);
80 if (ret != ACL_ERROR_NONE) {
81 MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
82 }
83 (void)streams_.emplace_back(*stream);
84 // If this is the first stream ever created, set it as default stream.
85 if (streams_.size() == 1) {
86 default_stream_ = *stream;
87 default_stream_id_ = kIndex0;
88 }
89 RegCallback(*stream);
90 }
91
CreateStream(size_t * stream_id,int32_t priority)92 void AscendStreamMng::CreateStream(size_t *stream_id, int32_t priority) {
93 std::lock_guard<std::mutex> lock_streams(stream_mutex_);
94 aclrtStream stream;
95 auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, &stream, IntToUint(priority),
96 (ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC));
97 if (ret != ACL_ERROR_NONE) {
98 MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
99 }
100 ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, stream, ACL_STOP_ON_FAILURE);
101 if (ret != ACL_ERROR_NONE) {
102 MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
103 }
104 *stream_id = streams_.size();
105 (void)streams_.emplace_back(stream);
106 RegCallback(stream);
107 }
108
RegCallback(aclrtStream stream)109 void AscendStreamMng::RegCallback(aclrtStream stream) {
110 MS_LOG(INFO) << "Register callback thread, stream : " << stream << ".";
111 (void)callback_cached_streams_.emplace_back(stream);
112 if (callback_cached_streams_.size() > 1 && !is_enable_callback_) {
113 is_enable_callback_ = true;
114 }
115 if (!is_enable_callback_) {
116 return;
117 }
118 #ifdef WITH_BACKEND
119 for (const auto &callback_cached_stream : callback_cached_streams_) {
120 if (stream_call_backs_.count(callback_cached_stream) > 0) {
121 MS_LOG(WARNING) << "Register callback thread failed, stream : " << callback_cached_stream
122 << " is already registered.";
123 continue;
124 }
125
126 auto callback_thread = std::make_shared<CallbackThread>();
127 callback_thread->create();
128 auto ret = CALL_ASCEND_API(aclrtSubscribeReport, callback_thread->thread_, (aclrtStream)callback_cached_stream);
129 if (!ret) {
130 MS_LOG(INFO) << "Register callback thread success, stream : " << callback_cached_stream << ".";
131 (void)stream_call_backs_.emplace(callback_cached_stream, callback_thread);
132 } else {
133 MS_LOG(INTERNAL_EXCEPTION) << "Register callback thread failed, stream : " << callback_cached_stream
134 << ", ret : " << ret;
135 }
136 }
137 #endif
138 callback_cached_streams_.clear();
139 }
140
UnRegCallback(aclrtStream stream)141 void AscendStreamMng::UnRegCallback(aclrtStream stream) {
142 MS_LOG(INFO) << "Unregister callback thread, stream : " << stream << ".";
143 if (!is_enable_callback_) {
144 return;
145 }
146 #ifdef WITH_BACKEND
147 if (stream_call_backs_.count(stream) == 0) {
148 MS_LOG(WARNING) << "Unregister callback thread failed, stream : " << stream << " is not exist.";
149 return;
150 }
151 auto callback_thread = stream_call_backs_.at(stream);
152 // Cannot call aclrtUnSubscribeReport.
153 callback_thread->cancel();
154 stream_call_backs_.erase(stream);
155 #endif
156 }
157
CreateStreamWithFlags(aclrtStream * stream,uint32_t flags,int32_t priority)158 void AscendStreamMng::CreateStreamWithFlags(aclrtStream *stream, uint32_t flags, int32_t priority) {
159 std::lock_guard<std::mutex> lock_streams(stream_mutex_);
160 auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, stream, IntToUint(priority), flags);
161 if (ret != ACL_ERROR_NONE) {
162 MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
163 }
164 ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, *stream, ACL_STOP_ON_FAILURE);
165 if (ret != ACL_ERROR_NONE) {
166 MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
167 }
168 (void)streams_.emplace_back(*stream);
169 RegCallback(*stream);
170 }
171
CreateStreamWithFlags(size_t * stream_id,uint32_t flags,int32_t priority)172 void AscendStreamMng::CreateStreamWithFlags(size_t *stream_id, uint32_t flags, int32_t priority) {
173 std::lock_guard<std::mutex> lock_streams(stream_mutex_);
174 aclrtStream stream;
175 auto ret = CALL_ASCEND_API(aclrtCreateStreamWithConfig, &stream, IntToUint(priority), flags);
176 if (ret != ACL_ERROR_NONE) {
177 MS_LOG(EXCEPTION) << "Create stream failed, ret:" << ret;
178 }
179 ret = CALL_ASCEND_API(aclrtSetStreamFailureMode, stream, ACL_STOP_ON_FAILURE);
180 if (ret != ACL_ERROR_NONE) {
181 MS_LOG(EXCEPTION) << "aclrtSetStreamFailureMode failed, ret:" << ret;
182 }
183 *stream_id = streams_.size();
184 (void)streams_.emplace_back(stream);
185 RegCallback(stream);
186 }
187
ApplyRtEvent()188 aclrtEvent AscendStreamMng::ApplyRtEvent() {
189 aclrtEvent rt_event = nullptr;
190 auto ret = CALL_ASCEND_API(aclrtCreateEvent, &rt_event);
191 if (ret != ACL_ERROR_NONE) {
192 MS_LOG(EXCEPTION) << "aclrtCreateEvent failed, ret:" << ret;
193 }
194 (void)events_.emplace_back(rt_event);
195 return rt_event;
196 }
197
DestroyStream(size_t stream_id)198 bool AscendStreamMng::DestroyStream(size_t stream_id) {
199 std::lock_guard<std::mutex> lock_streams(stream_mutex_);
200 if (stream_id >= streams_.size()) {
201 MS_LOG(ERROR) << "Ascend stream not found for stream id " << stream_id;
202 return false;
203 }
204 if (streams_.at(stream_id) == nullptr) {
205 MS_LOG(WARNING) << "Ascend stream hsa been destroyed for stream id " << stream_id;
206 return true;
207 }
208 const auto ret = CALL_ASCEND_API(aclrtDestroyStream, streams_.at(stream_id));
209 if (ret != ACL_ERROR_NONE) {
210 MS_LOG(EXCEPTION) << "Call aclrtDestroyStream, ret[" << ret << "]";
211 }
212 UnRegCallback(streams_.at(stream_id));
213 streams_[stream_id] = nullptr;
214 return true;
215 }
216
DestroyAllStreams()217 bool AscendStreamMng::DestroyAllStreams() {
218 std::lock_guard<std::mutex> lock_streams(stream_mutex_);
219 for (const auto &stream : streams_) {
220 if (stream == nullptr) {
221 continue;
222 }
223 const auto ret = CALL_ASCEND_API(aclrtDestroyStream, stream);
224 if (ret != ACL_ERROR_NONE) {
225 MS_LOG(EXCEPTION) << "Call aclrtDestroyStream, ret[" << ret << "]";
226 }
227 UnRegCallback(stream);
228 }
229 streams_.clear();
230 return true;
231 }
232
GetStream(size_t stream_id) const233 aclrtStream AscendStreamMng::GetStream(size_t stream_id) const {
234 if (stream_id >= streams_.size()) {
235 MS_LOG(DEBUG) << "Stream for stream id[" << stream_id << "] not found, return nullptr.";
236 return nullptr;
237 }
238 return streams_[stream_id];
239 }
240
SyncStream(size_t stream_id) const241 bool AscendStreamMng::SyncStream(size_t stream_id) const {
242 if (stream_id >= streams_.size()) {
243 MS_LOG(EXCEPTION) << "Stream for stream id[" << stream_id << "] has not been created.";
244 }
245 const auto stream = streams_[stream_id];
246 if (stream == nullptr) {
247 MS_LOG(WARNING) << "Stream for stream id[" << stream_id << "] has been destroyed.";
248 return false;
249 }
250 return SyncStream(stream);
251 }
252
SyncStream(aclrtStream stream) const253 bool AscendStreamMng::SyncStream(aclrtStream stream) const {
254 MS_EXCEPTION_IF_NULL(stream);
255 auto RET = CALL_ASCEND_API(aclrtSynchronizeStreamWithTimeout, stream, -1);
256 if (RET != ACL_ERROR_NONE && RET != ACL_ERROR_RT_AICORE_OVER_FLOW) { // o for switch stream
257 MS_LOG(ERROR) << "Call runtime aclrtSynchronizeStreamWithTimeout error.";
258 return false;
259 }
260 if (RET == ACL_ERROR_RT_AICORE_OVER_FLOW) {
261 MS_LOG(WARNING) << "Call runtime aclrtSynchronizeStreamWithTimeout, the stream get overflow.";
262 }
263 return true;
264 }
265
SyncAllStreams() const266 bool AscendStreamMng::SyncAllStreams() const {
267 for (size_t i = 0; i < streams_.size(); ++i) {
268 const auto stream = streams_[i];
269 if (stream != nullptr && !SyncStream(stream)) {
270 MS_LOG(ERROR) << "SyncStream for stream id " << i << " failed.";
271 return false;
272 }
273 }
274 return true;
275 }
276
SyncNotDefaultStreams() const277 bool AscendStreamMng::SyncNotDefaultStreams() const {
278 bool res = true;
279 for (size_t i = 0; i < streams_.size(); i++) {
280 if (i != default_stream_id_ && !SyncStream(i)) {
281 MS_LOG(ERROR) << "Failed to sync for ascend stream id: " << i;
282 res = false;
283 }
284 }
285 return res;
286 }
287
SyncExceptStreamsInList(const std::set<aclrtStream> & except_streams) const288 bool AscendStreamMng::SyncExceptStreamsInList(const std::set<aclrtStream> &except_streams) const {
289 bool res = true;
290 for (size_t i = 0; i < streams_.size(); i++) {
291 if (except_streams.count(streams_[i]) > 0) {
292 MS_LOG(DEBUG) << "Stream id:" << i << " is been synchronized.";
293 continue;
294 }
295 if (!SyncStream(i)) {
296 MS_LOG(ERROR) << "Failed to sync for ascend stream id: " << i;
297 res = false;
298 }
299 }
300 return res;
301 }
302
QueryStreamSize() const303 size_t AscendStreamMng::QueryStreamSize() const { return streams_.size(); }
304
QueryStream(size_t stream_id)305 bool AscendStreamMng::QueryStream(size_t stream_id) {
306 if (stream_id >= streams_.size()) {
307 MS_LOG(EXCEPTION) << "Stream for stream id[" << stream_id << "] has not been created.";
308 }
309 const auto stream = streams_[stream_id];
310 if (stream == nullptr) {
311 MS_LOG(WARNING) << "Stream for stream id[" << stream_id << "] has been destroyed.";
312 return false;
313 }
314
315 aclrtStreamStatus status;
316 auto ret = CALL_ASCEND_API(aclrtStreamQuery, stream, &status);
317 if (ret != ACL_SUCCESS) {
318 MS_LOG(EXCEPTION) << "Failed to query completion status for stream id: " << stream_id;
319 }
320 return status == ACL_STREAM_STATUS_COMPLETE;
321 }
322
GetStreamId(void * stream_ptr)323 size_t AscendStreamMng::GetStreamId(void *stream_ptr) {
324 auto iter = std::find(streams_.begin(), streams_.end(), stream_ptr);
325 if (iter == streams_.end()) {
326 MS_LOG(EXCEPTION) << "Failed to find stream_ptr in streams_, stream_ptr:" << stream_ptr;
327 }
328
329 return LongToSize(std::distance(streams_.begin(), iter));
330 }
331
GetStreamIds() const332 std::vector<uint32_t> AscendStreamMng::GetStreamIds() const {
333 std::vector<uint32_t> stream_ids;
334 for (size_t i = 0; i < streams_.size(); i++) {
335 if (streams_[i] != nullptr) {
336 (void)stream_ids.emplace_back(static_cast<uint32_t>(i));
337 }
338 }
339 return stream_ids;
340 }
341 } // namespace ascend
342 } // namespace device
343 } // namespace mindspore
344