1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "runtime/device/multi_stream_controller.h"
17
18 #include <algorithm>
19
20 namespace mindspore {
21 namespace device {
22 constexpr size_t kDefaultStreamRefreshSize = 2;
23
GetInstance()24 MultiStreamControllerPtr &MultiStreamController::GetInstance() {
25 static std::once_flag init_flag = {};
26 static MultiStreamControllerPtr multi_stream_controller = nullptr;
27 std::call_once(init_flag, [&]() {
28 if (multi_stream_controller == nullptr) {
29 MS_LOG(INFO) << "Create MultiStreamController.";
30 multi_stream_controller = std::make_shared<MultiStreamController>();
31 }
32 });
33
34 return multi_stream_controller;
35 }
36
Refresh(const DeviceContext * device_context)37 void MultiStreamController::Refresh(const DeviceContext *device_context) {
38 auto stream_size = device_context->device_res_manager_->QueryStreamSize();
39 MS_LOG(INFO) << "Stream manager initialize, device_context : " << device_context << ", stream_size : " << stream_size
40 << ".";
41 if (stream_size == 0) {
42 // CPU has no concept of stream, stream size must be zero.
43 MS_LOG(INFO) << "Stream size is 0, will initialize with 2 streams.";
44 stream_size = kDefaultStreamRefreshSize;
45 }
46 task_id_on_stream_manager_[device_context].Resize(stream_size);
47 if (event_pools_.count(device_context) == 0) {
48 (void)event_pools_.emplace(device_context, std::make_shared<EventPool>([device_context]() {
49 // Event in pool need to do synchronization between streams, need to enable blocking.
50 return device_context->device_res_manager_->CreateRuntimeEvent(true, false);
51 }));
52 }
53 }
54
UpdateTaskIdOnStream(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)55 bool MultiStreamController::UpdateTaskIdOnStream(const DeviceContext *device_context, int64_t task_id_on_stream,
56 uint32_t user_stream_id, uint32_t memory_stream_id) {
57 return task_id_on_stream_manager_[device_context].Update(task_id_on_stream, user_stream_id, memory_stream_id);
58 }
59
QueryTaskIdOnStream(const DeviceContext * device_context,uint32_t user_stream_id,uint32_t memory_stream_id)60 int64_t MultiStreamController::QueryTaskIdOnStream(const DeviceContext *device_context, uint32_t user_stream_id,
61 uint32_t memory_stream_id) {
62 return task_id_on_stream_manager_[device_context].Query(user_stream_id, memory_stream_id);
63 }
64
LaunchTaskIdOnStream(const DeviceContext * device_context,uint32_t stream_id)65 int64_t MultiStreamController::LaunchTaskIdOnStream(const DeviceContext *device_context, uint32_t stream_id) {
66 auto iter = task_id_on_stream_manager_.find(device_context);
67 if (iter == task_id_on_stream_manager_.end()) {
68 if (device_context->GetDeviceType() == DeviceType::kCPU) {
69 return INT64_MAX;
70 }
71
72 MS_LOG(WARNING) << "LaunchTaskIdOnStream device context is not found, device_context name : "
73 << device_context->device_context_key().device_name_ << ", stream id : " << stream_id
74 << ", refresh context.";
75 Refresh(device_context);
76 return task_id_on_stream_manager_[device_context].Launch(stream_id);
77 }
78 return iter->second.Launch(stream_id);
79 }
80
GetTaskIdOnStream(const DeviceContext * device_context,uint32_t stream_id)81 int64_t MultiStreamController::GetTaskIdOnStream(const DeviceContext *device_context, uint32_t stream_id) {
82 return task_id_on_stream_manager_[device_context].Get(stream_id);
83 }
84
GetStreamMutex(const DeviceContext * device_context,size_t stream_id)85 std::mutex &MultiStreamController::GetStreamMutex(const DeviceContext *device_context, size_t stream_id) {
86 return stream_mutexes_[device_context][stream_id];
87 }
88
RecordEvent(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id,const std::vector<std::pair<uint32_t,DeviceMemPtr>> & memory_stream_addresses)89 bool MultiStreamController::RecordEvent(const DeviceContext *device_context, int64_t task_id_on_stream,
90 uint32_t user_stream_id,
91 const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses) {
92 auto mem_manager = device_context->device_res_manager_->mem_manager();
93 if (mem_manager == nullptr) {
94 MS_LOG(WARNING) << "mem_manager_ is nullptr.";
95 return false;
96 }
97
98 auto event = device_context->device_res_manager_->CreateRuntimeEvent(false, true);
99 if (event == nullptr) {
100 return true;
101 }
102 event->RecordEvent(user_stream_id);
103 // Record event on mem buf.
104 return mem_manager->RecordEvent(task_id_on_stream, user_stream_id, memory_stream_addresses, event);
105 }
106
WaitEvent(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)107 bool MultiStreamController::WaitEvent(const DeviceContext *device_context, int64_t task_id_on_stream,
108 uint32_t user_stream_id, uint32_t memory_stream_id) {
109 auto mem_manager = device_context->device_res_manager_->mem_manager();
110 if (mem_manager == nullptr) {
111 MS_LOG(WARNING) << "mem_manager_ is nullptr.";
112 return false;
113 }
114 // If update task id on stream failed, means task id on stream is elder one, no need to wait event on mem manager.
115 if (!UpdateTaskIdOnStream(device_context, task_id_on_stream, user_stream_id, memory_stream_id)) {
116 MS_LOG(DEBUG) << "Skip Wait Event.";
117 return false;
118 }
119 return mem_manager->WaitEvent(task_id_on_stream, user_stream_id, memory_stream_id);
120 }
121
WaitEvent(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id)122 bool MultiStreamController::WaitEvent(const DeviceContext *device_context, int64_t task_id_on_stream,
123 uint32_t user_stream_id) {
124 auto mem_manager = device_context->device_res_manager_->mem_manager();
125 if (mem_manager == nullptr) {
126 MS_LOG(WARNING) << "mem_manager_ is nullptr.";
127 return false;
128 }
129
130 return mem_manager->WaitEvent(task_id_on_stream, user_stream_id);
131 }
132
DispatchRecordWaitEvent(const DeviceContext * device_context,uint32_t user_stream_id,uint32_t memory_stream_id)133 bool MultiStreamController::DispatchRecordWaitEvent(const DeviceContext *device_context, uint32_t user_stream_id,
134 uint32_t memory_stream_id) {
135 if (event_pools_.count(device_context) == 0) {
136 MS_LOG(INTERNAL_EXCEPTION) << "device context has not initialized.";
137 }
138 auto &event_pool = event_pools_[device_context];
139 auto event = event_pool->Get();
140 // Note : record event on memory stream id and wait event on user stream id to make sure memory is safe.
141 event->RecordEvent(memory_stream_id);
142 event->WaitEvent(user_stream_id);
143 return true;
144 }
145
SyncStream(const DeviceContext * device_context,size_t stream_id)146 bool MultiStreamController::SyncStream(const DeviceContext *device_context, size_t stream_id) {
147 auto &device_res_manager = device_context->device_res_manager_;
148 bool ret = device_res_manager->SyncStream(stream_id);
149 auto mem_manager = device_res_manager->mem_manager();
150 if (mem_manager != nullptr) {
151 auto task_id_on_stream = GetTaskIdOnStream(device_context, stream_id);
152 mem_manager->WaitEvent(task_id_on_stream, stream_id);
153 }
154 return ret;
155 }
156
SyncAllStreams(const DeviceContext * device_context)157 bool MultiStreamController::SyncAllStreams(const DeviceContext *device_context) {
158 auto &device_res_manager = device_context->device_res_manager_;
159 bool ret = device_res_manager->SyncAllStreams();
160 auto mem_manager = device_res_manager->mem_manager();
161 if (mem_manager != nullptr) {
162 mem_manager->SyncAllEvents();
163 }
164 return ret;
165 }
166
SyncNotDefaultStreams(const DeviceContext * device_context)167 bool MultiStreamController::SyncNotDefaultStreams(const DeviceContext *device_context) {
168 auto &device_res_manager = device_context->device_res_manager_;
169 bool ret = device_res_manager->SyncNotDefaultStreams();
170 auto mem_manager = device_res_manager->mem_manager();
171 if (mem_manager != nullptr) {
172 auto stream_ids = device_res_manager->GetStreamIds();
173 for (auto stream_id : stream_ids) {
174 auto task_id_on_stream = GetTaskIdOnStream(device_context, stream_id);
175 mem_manager->WaitEvent(task_id_on_stream, stream_id);
176 }
177 }
178 return ret;
179 }
180
Resize(uint32_t stream_size)181 void TaskIdOnStreamManager::Resize(uint32_t stream_size) {
182 std::lock_guard<std::mutex> lock(mutex_);
183 if (initialized_ && stream_size <= initialize_size_) {
184 MS_LOG(INFO) << "Task id on stream manager has already initialized, current size : " << initialize_size_ << ".";
185 return;
186 }
187 MS_LOG(INFO) << "Task id on stream manager initialize : " << initialized_ << ", stream_size : " << stream_size << ".";
188 uint32_t min_stream_size = 2;
189 initialize_size_ = std::max(stream_size, min_stream_size);
190 generator_.resize(initialize_size_);
191 status_.resize(initialize_size_);
192 for (auto &vec : status_) {
193 vec.resize(initialize_size_);
194 }
195 initialized_ = true;
196 }
197
Query(uint32_t user_stream_id,uint32_t memory_stream_id)198 int64_t TaskIdOnStreamManager::Query(uint32_t user_stream_id, uint32_t memory_stream_id) {
199 std::lock_guard<std::mutex> lock(mutex_);
200 return status_[user_stream_id][memory_stream_id];
201 }
202
Update(int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)203 bool TaskIdOnStreamManager::Update(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id) {
204 std::lock_guard<std::mutex> lock(mutex_);
205 if (status_[user_stream_id][memory_stream_id] >= task_id_on_stream) {
206 return false;
207 }
208 status_[user_stream_id][memory_stream_id] = task_id_on_stream;
209 return true;
210 }
211
Launch(uint32_t stream_id)212 int64_t TaskIdOnStreamManager::Launch(uint32_t stream_id) {
213 if (stream_id >= generator_.size()) {
214 MS_LOG(WARNING) << "Launch stream id : " << stream_id << " failed, generator_ size : " << generator_.size();
215 generator_.resize(stream_id + 1);
216 status_.resize(stream_id + 1);
217 }
218 return ++generator_[stream_id].value_;
219 }
220
Get(uint32_t stream_id)221 int64_t TaskIdOnStreamManager::Get(uint32_t stream_id) { return generator_[stream_id].value_; }
222
EventPool(std::function<DeviceEventPtr (void)> event_creator)223 EventPool::EventPool(std::function<DeviceEventPtr(void)> event_creator) : event_creator_(std::move(event_creator)) {}
224
~EventPool()225 EventPool::~EventPool() {
226 std::lock_guard<std::mutex> lock(mutex_);
227 expired_ = true;
228 events_.clear();
229 cached_events_.clear();
230 }
231
Get()232 DeviceEventPtr EventPool::Get() {
233 MS_LOG(DEBUG) << "Event pool get start.";
234 std::lock_guard<std::mutex> lock(mutex_);
235 DeviceEvent *event = nullptr;
236 // Try to create event firstly before reached core size.
237 if (size_ < core_size_) {
238 auto created_event = event_creator_();
239 if (created_event != nullptr && created_event->IsReady()) {
240 cached_events_.push_back(created_event);
241 size_++;
242 event = created_event.get();
243 }
244 }
245 // Try to reuse event.
246 if (event == nullptr) {
247 auto iter = events_.begin();
248 while (iter != events_.end()) {
249 auto event_in_list = *iter;
250 if (event_in_list == nullptr) {
251 MS_LOG(INTERNAL_EXCEPTION) << "exception : event in list is nullptr, events_ size : " << events_.size() << ".";
252 }
253 if (event_in_list->QueryEvent()) {
254 event = event_in_list;
255 events_.erase(iter);
256 break;
257 }
258 iter++;
259 }
260 }
261 // Reuse failed, try to create more event.
262 if (event == nullptr) {
263 auto created_event = event_creator_();
264 if (created_event != nullptr && created_event->IsReady()) {
265 cached_events_.push_back(created_event);
266 event = created_event.get();
267 size_++;
268 } else {
269 MS_LOG(INTERNAL_EXCEPTION) << "Get event failed.";
270 }
271 }
272 MS_LOG(DEBUG) << "Get event, events_ size : " << events_.size() << ", event : " << event << ".";
273
274 auto event_ptr = std::shared_ptr<DeviceEvent>(event, [&](DeviceEvent *e) {
275 std::lock_guard<std::mutex> lock(mutex_);
276 if (!expired_) {
277 MS_LOG(DEBUG) << "Return event : " << e << ".";
278 events_.push_back(e);
279 } else {
280 MS_LOG(DEBUG) << "Return event : " << e << "failed.";
281 }
282 });
283 return event_ptr;
284 }
285 } // namespace device
286 } // namespace mindspore
287