• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "runtime/device/multi_stream_controller.h"
17 
18 #include <algorithm>
19 
20 namespace mindspore {
21 namespace device {
22 constexpr size_t kDefaultStreamRefreshSize = 2;
23 
GetInstance()24 MultiStreamControllerPtr &MultiStreamController::GetInstance() {
25   static std::once_flag init_flag = {};
26   static MultiStreamControllerPtr multi_stream_controller = nullptr;
27   std::call_once(init_flag, [&]() {
28     if (multi_stream_controller == nullptr) {
29       MS_LOG(INFO) << "Create MultiStreamController.";
30       multi_stream_controller = std::make_shared<MultiStreamController>();
31     }
32   });
33 
34   return multi_stream_controller;
35 }
36 
Refresh(const DeviceContext * device_context)37 void MultiStreamController::Refresh(const DeviceContext *device_context) {
38   auto stream_size = device_context->device_res_manager_->QueryStreamSize();
39   MS_LOG(INFO) << "Stream manager initialize, device_context : " << device_context << ", stream_size : " << stream_size
40                << ".";
41   if (stream_size == 0) {
42     // CPU has no concept of stream, stream size must be zero.
43     MS_LOG(INFO) << "Stream size is 0, will initialize with 2 streams.";
44     stream_size = kDefaultStreamRefreshSize;
45   }
46   task_id_on_stream_manager_[device_context].Resize(stream_size);
47   if (event_pools_.count(device_context) == 0) {
48     (void)event_pools_.emplace(device_context, std::make_shared<EventPool>([device_context]() {
49                                  // Event in pool need to do synchronization between streams, need to enable blocking.
50                                  return device_context->device_res_manager_->CreateRuntimeEvent(true, false);
51                                }));
52   }
53 }
54 
UpdateTaskIdOnStream(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)55 bool MultiStreamController::UpdateTaskIdOnStream(const DeviceContext *device_context, int64_t task_id_on_stream,
56                                                  uint32_t user_stream_id, uint32_t memory_stream_id) {
57   return task_id_on_stream_manager_[device_context].Update(task_id_on_stream, user_stream_id, memory_stream_id);
58 }
59 
QueryTaskIdOnStream(const DeviceContext * device_context,uint32_t user_stream_id,uint32_t memory_stream_id)60 int64_t MultiStreamController::QueryTaskIdOnStream(const DeviceContext *device_context, uint32_t user_stream_id,
61                                                    uint32_t memory_stream_id) {
62   return task_id_on_stream_manager_[device_context].Query(user_stream_id, memory_stream_id);
63 }
64 
LaunchTaskIdOnStream(const DeviceContext * device_context,uint32_t stream_id)65 int64_t MultiStreamController::LaunchTaskIdOnStream(const DeviceContext *device_context, uint32_t stream_id) {
66   auto iter = task_id_on_stream_manager_.find(device_context);
67   if (iter == task_id_on_stream_manager_.end()) {
68     if (device_context->GetDeviceType() == DeviceType::kCPU) {
69       return INT64_MAX;
70     }
71 
72     MS_LOG(WARNING) << "LaunchTaskIdOnStream device context is not found, device_context name : "
73                     << device_context->device_context_key().device_name_ << ", stream id : " << stream_id
74                     << ", refresh context.";
75     Refresh(device_context);
76     return task_id_on_stream_manager_[device_context].Launch(stream_id);
77   }
78   return iter->second.Launch(stream_id);
79 }
80 
GetTaskIdOnStream(const DeviceContext * device_context,uint32_t stream_id)81 int64_t MultiStreamController::GetTaskIdOnStream(const DeviceContext *device_context, uint32_t stream_id) {
82   return task_id_on_stream_manager_[device_context].Get(stream_id);
83 }
84 
GetStreamMutex(const DeviceContext * device_context,size_t stream_id)85 std::mutex &MultiStreamController::GetStreamMutex(const DeviceContext *device_context, size_t stream_id) {
86   return stream_mutexes_[device_context][stream_id];
87 }
88 
RecordEvent(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id,const std::vector<std::pair<uint32_t,DeviceMemPtr>> & memory_stream_addresses)89 bool MultiStreamController::RecordEvent(const DeviceContext *device_context, int64_t task_id_on_stream,
90                                         uint32_t user_stream_id,
91                                         const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses) {
92   auto mem_manager = device_context->device_res_manager_->mem_manager();
93   if (mem_manager == nullptr) {
94     MS_LOG(WARNING) << "mem_manager_ is nullptr.";
95     return false;
96   }
97 
98   auto event = device_context->device_res_manager_->CreateRuntimeEvent(false, true);
99   if (event == nullptr) {
100     return true;
101   }
102   event->RecordEvent(user_stream_id);
103   // Record event on mem buf.
104   return mem_manager->RecordEvent(task_id_on_stream, user_stream_id, memory_stream_addresses, event);
105 }
106 
WaitEvent(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)107 bool MultiStreamController::WaitEvent(const DeviceContext *device_context, int64_t task_id_on_stream,
108                                       uint32_t user_stream_id, uint32_t memory_stream_id) {
109   auto mem_manager = device_context->device_res_manager_->mem_manager();
110   if (mem_manager == nullptr) {
111     MS_LOG(WARNING) << "mem_manager_ is nullptr.";
112     return false;
113   }
114   // If update task id on stream failed, means task id on stream is elder one, no need to wait event on mem manager.
115   if (!UpdateTaskIdOnStream(device_context, task_id_on_stream, user_stream_id, memory_stream_id)) {
116     MS_LOG(DEBUG) << "Skip Wait Event.";
117     return false;
118   }
119   return mem_manager->WaitEvent(task_id_on_stream, user_stream_id, memory_stream_id);
120 }
121 
WaitEvent(const DeviceContext * device_context,int64_t task_id_on_stream,uint32_t user_stream_id)122 bool MultiStreamController::WaitEvent(const DeviceContext *device_context, int64_t task_id_on_stream,
123                                       uint32_t user_stream_id) {
124   auto mem_manager = device_context->device_res_manager_->mem_manager();
125   if (mem_manager == nullptr) {
126     MS_LOG(WARNING) << "mem_manager_ is nullptr.";
127     return false;
128   }
129 
130   return mem_manager->WaitEvent(task_id_on_stream, user_stream_id);
131 }
132 
DispatchRecordWaitEvent(const DeviceContext * device_context,uint32_t user_stream_id,uint32_t memory_stream_id)133 bool MultiStreamController::DispatchRecordWaitEvent(const DeviceContext *device_context, uint32_t user_stream_id,
134                                                     uint32_t memory_stream_id) {
135   if (event_pools_.count(device_context) == 0) {
136     MS_LOG(INTERNAL_EXCEPTION) << "device context has not initialized.";
137   }
138   auto &event_pool = event_pools_[device_context];
139   auto event = event_pool->Get();
140   // Note : record event on memory stream id and wait event on user stream id to make sure memory is safe.
141   event->RecordEvent(memory_stream_id);
142   event->WaitEvent(user_stream_id);
143   return true;
144 }
145 
SyncStream(const DeviceContext * device_context,size_t stream_id)146 bool MultiStreamController::SyncStream(const DeviceContext *device_context, size_t stream_id) {
147   auto &device_res_manager = device_context->device_res_manager_;
148   bool ret = device_res_manager->SyncStream(stream_id);
149   auto mem_manager = device_res_manager->mem_manager();
150   if (mem_manager != nullptr) {
151     auto task_id_on_stream = GetTaskIdOnStream(device_context, stream_id);
152     mem_manager->WaitEvent(task_id_on_stream, stream_id);
153   }
154   return ret;
155 }
156 
SyncAllStreams(const DeviceContext * device_context)157 bool MultiStreamController::SyncAllStreams(const DeviceContext *device_context) {
158   auto &device_res_manager = device_context->device_res_manager_;
159   bool ret = device_res_manager->SyncAllStreams();
160   auto mem_manager = device_res_manager->mem_manager();
161   if (mem_manager != nullptr) {
162     mem_manager->SyncAllEvents();
163   }
164   return ret;
165 }
166 
SyncNotDefaultStreams(const DeviceContext * device_context)167 bool MultiStreamController::SyncNotDefaultStreams(const DeviceContext *device_context) {
168   auto &device_res_manager = device_context->device_res_manager_;
169   bool ret = device_res_manager->SyncNotDefaultStreams();
170   auto mem_manager = device_res_manager->mem_manager();
171   if (mem_manager != nullptr) {
172     auto stream_ids = device_res_manager->GetStreamIds();
173     for (auto stream_id : stream_ids) {
174       auto task_id_on_stream = GetTaskIdOnStream(device_context, stream_id);
175       mem_manager->WaitEvent(task_id_on_stream, stream_id);
176     }
177   }
178   return ret;
179 }
180 
Resize(uint32_t stream_size)181 void TaskIdOnStreamManager::Resize(uint32_t stream_size) {
182   std::lock_guard<std::mutex> lock(mutex_);
183   if (initialized_ && stream_size <= initialize_size_) {
184     MS_LOG(INFO) << "Task id on stream manager has already initialized, current size : " << initialize_size_ << ".";
185     return;
186   }
187   MS_LOG(INFO) << "Task id on stream manager initialize : " << initialized_ << ", stream_size : " << stream_size << ".";
188   uint32_t min_stream_size = 2;
189   initialize_size_ = std::max(stream_size, min_stream_size);
190   generator_.resize(initialize_size_);
191   status_.resize(initialize_size_);
192   for (auto &vec : status_) {
193     vec.resize(initialize_size_);
194   }
195   initialized_ = true;
196 }
197 
Query(uint32_t user_stream_id,uint32_t memory_stream_id)198 int64_t TaskIdOnStreamManager::Query(uint32_t user_stream_id, uint32_t memory_stream_id) {
199   std::lock_guard<std::mutex> lock(mutex_);
200   return status_[user_stream_id][memory_stream_id];
201 }
202 
Update(int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)203 bool TaskIdOnStreamManager::Update(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id) {
204   std::lock_guard<std::mutex> lock(mutex_);
205   if (status_[user_stream_id][memory_stream_id] >= task_id_on_stream) {
206     return false;
207   }
208   status_[user_stream_id][memory_stream_id] = task_id_on_stream;
209   return true;
210 }
211 
Launch(uint32_t stream_id)212 int64_t TaskIdOnStreamManager::Launch(uint32_t stream_id) {
213   if (stream_id >= generator_.size()) {
214     MS_LOG(WARNING) << "Launch stream id : " << stream_id << " failed, generator_ size : " << generator_.size();
215     generator_.resize(stream_id + 1);
216     status_.resize(stream_id + 1);
217   }
218   return ++generator_[stream_id].value_;
219 }
220 
Get(uint32_t stream_id)221 int64_t TaskIdOnStreamManager::Get(uint32_t stream_id) { return generator_[stream_id].value_; }
222 
EventPool(std::function<DeviceEventPtr (void)> event_creator)223 EventPool::EventPool(std::function<DeviceEventPtr(void)> event_creator) : event_creator_(std::move(event_creator)) {}
224 
~EventPool()225 EventPool::~EventPool() {
226   std::lock_guard<std::mutex> lock(mutex_);
227   expired_ = true;
228   events_.clear();
229   cached_events_.clear();
230 }
231 
Get()232 DeviceEventPtr EventPool::Get() {
233   MS_LOG(DEBUG) << "Event pool get start.";
234   std::lock_guard<std::mutex> lock(mutex_);
235   DeviceEvent *event = nullptr;
236   // Try to create event firstly before reached core size.
237   if (size_ < core_size_) {
238     auto created_event = event_creator_();
239     if (created_event != nullptr && created_event->IsReady()) {
240       cached_events_.push_back(created_event);
241       size_++;
242       event = created_event.get();
243     }
244   }
245   // Try to reuse event.
246   if (event == nullptr) {
247     auto iter = events_.begin();
248     while (iter != events_.end()) {
249       auto event_in_list = *iter;
250       if (event_in_list == nullptr) {
251         MS_LOG(INTERNAL_EXCEPTION) << "exception : event in list is nullptr, events_ size : " << events_.size() << ".";
252       }
253       if (event_in_list->QueryEvent()) {
254         event = event_in_list;
255         events_.erase(iter);
256         break;
257       }
258       iter++;
259     }
260   }
261   // Reuse failed, try to create more event.
262   if (event == nullptr) {
263     auto created_event = event_creator_();
264     if (created_event != nullptr && created_event->IsReady()) {
265       cached_events_.push_back(created_event);
266       event = created_event.get();
267       size_++;
268     } else {
269       MS_LOG(INTERNAL_EXCEPTION) << "Get event failed.";
270     }
271   }
272   MS_LOG(DEBUG) << "Get event, events_ size : " << events_.size() << ", event : " << event << ".";
273 
274   auto event_ptr = std::shared_ptr<DeviceEvent>(event, [&](DeviceEvent *e) {
275     std::lock_guard<std::mutex> lock(mutex_);
276     if (!expired_) {
277       MS_LOG(DEBUG) << "Return event : " << e << ".";
278       events_.push_back(e);
279     } else {
280       MS_LOG(DEBUG) << "Return event : " << e << "failed.";
281     }
282   });
283   return event_ptr;
284 }
285 }  // namespace device
286 }  // namespace mindspore
287