1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/data_queue/data_queue_mgr.h"
18 #include <algorithm>
19 #include <utility>
20 #include "utils/log_adapter.h"
21 #include "utils/ms_utils.h"
22 #include "utils/ms_context.h"
23 #include "pybind11/pybind11.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/backend/data_queue/blocking_queue.h"
26 #include "include/backend/kernel_info.h"
27
28 namespace py = pybind11;
29
30 namespace mindspore {
31 namespace device {
GetInstance()32 DataQueueMgr &DataQueueMgr::GetInstance() noexcept {
33 static DataQueueMgr instance;
34 return instance;
35 }
36
RegisterDataQueueCreator(const std::string & device_name,DataQueueCreator && creator)37 void DataQueueMgr::RegisterDataQueueCreator(const std::string &device_name, DataQueueCreator &&creator) {
38 data_queue_creator_map_.emplace(device_name, std::forward<DataQueueCreator>(creator));
39 }
40
Clear()41 void DataQueueMgr::Clear() { data_queue_creator_map_.clear(); }
42
CreateDataQueue(const std::string & device_name,const std::string & channel_name,bool dynamic_shape,size_t capacity,const std::vector<size_t> & shape)43 std::shared_ptr<DataQueue> DataQueueMgr::CreateDataQueue(const std::string &device_name,
44 const std::string &channel_name, bool dynamic_shape,
45 size_t capacity, const std::vector<size_t> &shape) {
46 auto iter = data_queue_creator_map_.find(device_name);
47 if (iter == data_queue_creator_map_.end()) {
48 return nullptr;
49 }
50
51 return iter->second(channel_name, dynamic_shape, capacity, shape);
52 }
53
Manage(const std::string & channel_name,const std::shared_ptr<BlockingQueue> & queue)54 void DataQueueMgr::Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue) {
55 (void)name_queue_map_.insert(std::make_pair(channel_name, queue));
56 }
57
Create(const std::string & channel_name,const std::vector<size_t> & shape,const size_t capacity)58 DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector<size_t> &shape,
59 const size_t capacity) {
60 MS_LOG(INFO) << "Data queue: " << channel_name << " created";
61 if (name_queue_map_.find(channel_name) != name_queue_map_.end()) {
62 return DataQueueStatus::QUEUE_EXIST;
63 }
64
65 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
66 std::shared_ptr<DataQueue> data_queue = CreateDataQueue(
67 MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET), channel_name, false, capacity, shape);
68 if (data_queue != nullptr) {
69 std::shared_ptr<BlockingQueue> queue = std::make_shared<BlockingQueue>();
70 DataQueueStatus rt = queue->Create(data_queue);
71 if (rt != DataQueueStatus::SUCCESS) {
72 MS_LOG(ERROR) << "Queue: " << channel_name << "create failed: " << rt;
73 return rt;
74 }
75 (void)name_queue_map_.insert(std::make_pair(channel_name, queue));
76 init_ = true;
77 return DataQueueStatus::SUCCESS;
78 }
79 return DataQueueStatus::INTERNAL_ERROR;
80 }
81
Open(const std::string & channel_name,const std::function<void (void *,int32_t)> func)82 DataQueueStatus DataQueueMgr::Open(const std::string &channel_name, const std::function<void(void *, int32_t)> func) {
83 MS_LOG(INFO) << "Data queue: " << channel_name << " opened by dataset";
84 if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
85 MS_LOG(ERROR) << "Data queue not exist " << channel_name;
86 return DataQueueStatus::QUEUE_NOT_EXIST;
87 }
88
89 name_queue_map_[channel_name]->RegisterRelease(func);
90 open_by_dataset_++;
91 return DataQueueStatus::SUCCESS;
92 }
93
CreateDynamicBufQueue(const std::string & channel_name,const size_t & capacity)94 DataQueueStatus DataQueueMgr::CreateDynamicBufQueue(const std::string &channel_name, const size_t &capacity) {
95 if (name_queue_map_.find(channel_name) != name_queue_map_.end()) {
96 return DataQueueStatus::QUEUE_EXIST;
97 }
98 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
99 std::string device_name = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
100 std::shared_ptr<DataQueue> device_queue = CreateDataQueue(device_name, channel_name, true, capacity);
101 if (device_queue != nullptr) {
102 std::shared_ptr<BlockingQueue> queue = std::make_shared<BlockingQueue>();
103 DataQueueStatus rt = queue->Create(device_queue);
104 if (rt != DataQueueStatus::SUCCESS) {
105 MS_LOG(ERROR) << "Queue: " << channel_name << "create failed: " << rt;
106 return rt;
107 }
108 (void)name_queue_map_.insert(std::make_pair(channel_name, queue));
109 init_ = true;
110 return DataQueueStatus::SUCCESS;
111 }
112
113 MS_LOG(ERROR) << "Dynamic data queue only support Ascend/GPU target, bug got " << device_name;
114 return DataQueueStatus::INTERNAL_ERROR;
115 }
116
Open(const std::string & channel_name) const117 DataQueueStatus DataQueueMgr::Open(const std::string &channel_name) const {
118 if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
119 MS_LOG(ERROR) << "Queue not exist " << channel_name;
120 return DataQueueStatus::QUEUE_NOT_EXIST;
121 }
122 return DataQueueStatus::SUCCESS;
123 }
124
Push(const std::string & channel_name,const std::vector<DataQueueItem> & data,unsigned int timeout_in_sec)125 DataQueueStatus DataQueueMgr::Push(const std::string &channel_name, const std::vector<DataQueueItem> &data,
126 unsigned int timeout_in_sec) {
127 auto iter = name_queue_map_.find(channel_name);
128 if (iter == name_queue_map_.end()) {
129 MS_LOG(ERROR) << "Queue not exist " << channel_name;
130 return DataQueueStatus::QUEUE_NOT_EXIST;
131 }
132 return iter->second->Push(data, timeout_in_sec);
133 }
134
Front(const std::string & channel_name,std::vector<DataQueueItem> * data)135 DataQueueStatus DataQueueMgr::Front(const std::string &channel_name, std::vector<DataQueueItem> *data) {
136 auto iter = name_queue_map_.find(channel_name);
137 if (iter == name_queue_map_.end()) {
138 MS_LOG(ERROR) << "Queue not exist " << channel_name;
139 return DataQueueStatus::QUEUE_NOT_EXIST;
140 }
141 return iter->second->Front(data);
142 }
FrontAsync(const std::string & channel_name,std::vector<DataQueueItem> * data)143 DataQueueStatus DataQueueMgr::FrontAsync(const std::string &channel_name, std::vector<DataQueueItem> *data) {
144 auto iter = name_queue_map_.find(channel_name);
145 if (iter == name_queue_map_.end()) {
146 MS_LOG(ERROR) << "Queue not exist " << channel_name;
147 return DataQueueStatus::QUEUE_NOT_EXIST;
148 }
149 return iter->second->FrontAsync(data);
150 }
151
Pop(const std::string & channel_name)152 DataQueueStatus DataQueueMgr::Pop(const std::string &channel_name) {
153 auto iter = name_queue_map_.find(channel_name);
154 if (iter == name_queue_map_.end()) {
155 MS_LOG(ERROR) << "Queue not exist " << channel_name;
156 return DataQueueStatus::QUEUE_NOT_EXIST;
157 }
158
159 return iter->second->Pop();
160 }
161
Release()162 void DataQueueMgr::Release() { name_queue_map_.clear(); }
163
Free(const std::string & channel_name)164 void DataQueueMgr::Free(const std::string &channel_name) {
165 auto iter = name_queue_map_.find(channel_name);
166 if (iter != name_queue_map_.end()) {
167 name_queue_map_.erase(iter);
168 }
169 }
170
Clear(const std::string & channel_name)171 DataQueueStatus DataQueueMgr::Clear(const std::string &channel_name) {
172 auto iter = name_queue_map_.find(channel_name);
173 if (iter == name_queue_map_.end()) {
174 MS_LOG(ERROR) << "Queue not exist " << channel_name;
175 return DataQueueStatus::QUEUE_NOT_EXIST;
176 }
177
178 return iter->second->Clear();
179 }
180
Close(const std::string & channel_name) const181 void DataQueueMgr::Close(const std::string &channel_name) const noexcept {
182 MS_LOG(INFO) << "Close the queue: " << channel_name;
183 return;
184 }
185
IsInit() const186 bool DataQueueMgr::IsInit() const { return init_; }
187
IsClosed() const188 bool DataQueueMgr::IsClosed() const { return closed_; }
189
IsCreated(const std::string & channel_name) const190 bool DataQueueMgr::IsCreated(const std::string &channel_name) const {
191 return name_queue_map_.find(channel_name) != name_queue_map_.end();
192 }
193
CloseNotify()194 bool DataQueueMgr::CloseNotify() {
195 py::gil_scoped_release release;
196 bool result = true;
197 // lock scope
198 {
199 std::lock_guard<std::mutex> lk(close_mutex_);
200 // set closed_ to be true, all the dataset retry can be jumped out of the while
201 closed_ = true;
202 }
203
204 // wati for the dataset threads' ack
205 for (int i = 0; i < open_by_dataset_; i++) {
206 if (sema.Wait() == false) {
207 MS_LOG(ERROR) << "time out of receiving signals";
208 result = false;
209 }
210 MS_LOG(DEBUG) << "receive one signal (" << (i + 1) << "/" << open_by_dataset_ << ")";
211 }
212 return result;
213 }
214
CloseConfirm()215 void DataQueueMgr::CloseConfirm() { sema.Signal(); }
216
Size(const std::string & channel_name)217 size_t DataQueueMgr::Size(const std::string &channel_name) {
218 if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
219 MS_LOG(ERROR) << "Queue not exist " << channel_name;
220 return 0;
221 }
222 return name_queue_map_.at(channel_name)->Size();
223 }
224
Capacity(const std::string & channel_name)225 size_t DataQueueMgr::Capacity(const std::string &channel_name) {
226 if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
227 MS_LOG(ERROR) << "Queue not exist " << channel_name;
228 return 0;
229 }
230 return name_queue_map_.at(channel_name)->Capacity();
231 }
232
GetDataQueue(const std::string & channel_name) const233 std::shared_ptr<BlockingQueue> DataQueueMgr::GetDataQueue(const std::string &channel_name) const {
234 auto iter = name_queue_map_.find(channel_name);
235 if (iter == name_queue_map_.end()) {
236 MS_LOG(ERROR) << "Queue not exist " << channel_name;
237 return nullptr;
238 }
239 MS_EXCEPTION_IF_NULL(iter->second);
240 return iter->second;
241 }
242
SetThreadDevice(const std::string & channel_name) const243 DataQueueStatus DataQueueMgr::SetThreadDevice(const std::string &channel_name) const {
244 auto queue = GetDataQueue(channel_name);
245 if (queue == nullptr || queue->Queue() == nullptr) {
246 return DataQueueStatus::QUEUE_NOT_EXIST;
247 }
248 queue->Queue()->SetThreadDevice();
249 return DataQueueStatus::SUCCESS;
250 }
251
252 #ifndef BUILD_LITE
UpdateGetNextWithDataQueueItems(const AnfNodePtr & data_kernel,const std::vector<device::DataQueueItem> & data)253 void UpdateGetNextWithDataQueueItems(const AnfNodePtr &data_kernel, const std::vector<device::DataQueueItem> &data) {
254 auto kernel_info = dynamic_cast<device::KernelInfo *>(data_kernel->kernel_info());
255 std::vector<std::shared_ptr<device::DeviceAddress>> device_tensors;
256 for (auto &device_tensor : kernel_info->output_address_list()) {
257 MS_EXCEPTION_IF_NULL(device_tensor);
258 device_tensors.push_back(device_tensor);
259 }
260 MS_EXCEPTION_IF_CHECK_FAIL(data.size() == device_tensors.size(),
261 "The number of data tensor popped from dynamic queue is not correct");
262 std::vector<ShapeVector> shapes;
263 std::vector<TypeId> types;
264 std::vector<size_t> output_size_list;
265 for (size_t i = 0; i < data.size(); ++i) {
266 device_tensors[i]->SetSize(data[i].data_len);
267 device_tensors[i]->set_from_mem_pool(true);
268 output_size_list.push_back(data[i].data_len);
269 shapes.push_back(data[i].shapes);
270 types.push_back(common::AnfAlgo::GetOutputInferDataType(data_kernel, i));
271 }
272 auto kernel_mod = kernel_info->MutableKernelMod();
273 kernel_mod->SetOutputSizeList(output_size_list);
274 common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, data_kernel.get());
275 }
276
UpdateGetNextWithDataQueueItems(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs,const std::vector<device::DataQueueItem> & data,std::vector<size_t> * output_size_list)277 void UpdateGetNextWithDataQueueItems(const std::vector<kernel::KernelTensor *> &inputs,
278 const std::vector<kernel::KernelTensor *> &outputs,
279 const std::vector<device::DataQueueItem> &data,
280 std::vector<size_t> *output_size_list) {
281 MS_EXCEPTION_IF_CHECK_FAIL(data.size() == outputs.size(),
282 "The number of data tensor popped from dynamic queue is not correct");
283 output_size_list->clear();
284 for (size_t i = 0; i < data.size(); ++i) {
285 outputs[i]->set_size(data[i].data_len);
286 outputs[i]->SetShapeVector(data[i].shapes);
287 output_size_list->push_back(data[i].data_len);
288 }
289 }
290
RetryPeakItemFromDataQueue(const AnfNodePtr & data_kernel,const std::shared_ptr<BlockingQueue> & data_queue,std::vector<device::DataQueueItem> * data)291 void RetryPeakItemFromDataQueue(const AnfNodePtr &data_kernel, const std::shared_ptr<BlockingQueue> &data_queue,
292 std::vector<device::DataQueueItem> *data) {
293 auto front_ret = DataQueueStatus::TIMEOUT;
294 auto ms_context = MsContext::GetInstance();
295 MS_EXCEPTION_IF_NULL(ms_context);
296 uint32_t op_timeout = ms_context->get_param<uint32_t>(MS_CTX_OP_TIMEOUT);
297 time_t start_time = time(nullptr);
298 while (front_ret == DataQueueStatus::TIMEOUT && ((time(nullptr) - start_time) < op_timeout || op_timeout == 0)) {
299 front_ret = data_queue->FrontAsync(data);
300 }
301 if (front_ret != DataQueueStatus::SUCCESS) {
302 if (front_ret == DataQueueStatus::TIMEOUT) {
303 MS_LOG(ERROR) << "Getnext gets peek data time out, that most likely caused by data processing being too slow";
304 }
305 MS_LOG(EXCEPTION) << "Getnext gets peek data from data queue failed: " << front_ret;
306 }
307 }
308
UpdateGetNextNode(const AnfNodePtr & data_kernel)309 void UpdateGetNextNode(const AnfNodePtr &data_kernel) {
310 auto queue_name = common::AnfAlgo::GetNodeAttr<std::string>(data_kernel, "shared_name");
311 device::DataQueueMgr &buf_mgr = device::DataQueueMgr::GetInstance();
312 auto ret = buf_mgr.Open(queue_name);
313 MS_EXCEPTION_IF_CHECK_FAIL(ret == device::DataQueueStatus::SUCCESS, "Open dynamic data queue failed");
314 auto data_queue = buf_mgr.GetDataQueue(queue_name);
315 std::vector<device::DataQueueItem> data;
316 RetryPeakItemFromDataQueue(data_kernel, data_queue, &data);
317 UpdateGetNextWithDataQueueItems(data_kernel, data);
318 }
319
UpdateGetNextNode(const PrimitivePtr & primitive,const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs,std::vector<size_t> * output_size_list)320 void UpdateGetNextNode(const PrimitivePtr &primitive, const std::vector<kernel::KernelTensor *> &inputs,
321 const std::vector<kernel::KernelTensor *> &outputs, std::vector<size_t> *output_size_list) {
322 auto queue_name = GetValue<std::string>(primitive->GetAttr("shared_name"));
323 device::DataQueueMgr &buf_mgr = device::DataQueueMgr::GetInstance();
324 auto ret = buf_mgr.Open(queue_name);
325 MS_EXCEPTION_IF_CHECK_FAIL(ret == device::DataQueueStatus::SUCCESS, "Open dynamic data queue failed");
326 auto data_queue = buf_mgr.GetDataQueue(queue_name);
327 std::vector<device::DataQueueItem> data;
328 RetryPeakItemFromDataQueue(nullptr, data_queue, &data);
329 UpdateGetNextWithDataQueueItems(inputs, outputs, data, output_size_list);
330 }
331
332 #endif
333 } // namespace device
334 } // namespace mindspore
335