• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/data_queue/data_queue_mgr.h"
18 #include <algorithm>
19 #include <utility>
20 #include "utils/log_adapter.h"
21 #include "utils/ms_utils.h"
22 #include "utils/ms_context.h"
23 #include "pybind11/pybind11.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/backend/data_queue/blocking_queue.h"
26 #include "include/backend/kernel_info.h"
27 
28 namespace py = pybind11;
29 
30 namespace mindspore {
31 namespace device {
GetInstance()32 DataQueueMgr &DataQueueMgr::GetInstance() noexcept {
33   static DataQueueMgr instance;
34   return instance;
35 }
36 
RegisterDataQueueCreator(const std::string & device_name,DataQueueCreator && creator)37 void DataQueueMgr::RegisterDataQueueCreator(const std::string &device_name, DataQueueCreator &&creator) {
38   data_queue_creator_map_.emplace(device_name, std::forward<DataQueueCreator>(creator));
39 }
40 
Clear()41 void DataQueueMgr::Clear() { data_queue_creator_map_.clear(); }
42 
CreateDataQueue(const std::string & device_name,const std::string & channel_name,bool dynamic_shape,size_t capacity,const std::vector<size_t> & shape)43 std::shared_ptr<DataQueue> DataQueueMgr::CreateDataQueue(const std::string &device_name,
44                                                          const std::string &channel_name, bool dynamic_shape,
45                                                          size_t capacity, const std::vector<size_t> &shape) {
46   auto iter = data_queue_creator_map_.find(device_name);
47   if (iter == data_queue_creator_map_.end()) {
48     return nullptr;
49   }
50 
51   return iter->second(channel_name, dynamic_shape, capacity, shape);
52 }
53 
Manage(const std::string & channel_name,const std::shared_ptr<BlockingQueue> & queue)54 void DataQueueMgr::Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue) {
55   (void)name_queue_map_.insert(std::make_pair(channel_name, queue));
56 }
57 
Create(const std::string & channel_name,const std::vector<size_t> & shape,const size_t capacity)58 DataQueueStatus DataQueueMgr::Create(const std::string &channel_name, const std::vector<size_t> &shape,
59                                      const size_t capacity) {
60   MS_LOG(INFO) << "Data queue: " << channel_name << " created";
61   if (name_queue_map_.find(channel_name) != name_queue_map_.end()) {
62     return DataQueueStatus::QUEUE_EXIST;
63   }
64 
65   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
66   std::shared_ptr<DataQueue> data_queue = CreateDataQueue(
67     MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET), channel_name, false, capacity, shape);
68   if (data_queue != nullptr) {
69     std::shared_ptr<BlockingQueue> queue = std::make_shared<BlockingQueue>();
70     DataQueueStatus rt = queue->Create(data_queue);
71     if (rt != DataQueueStatus::SUCCESS) {
72       MS_LOG(ERROR) << "Queue: " << channel_name << "create failed: " << rt;
73       return rt;
74     }
75     (void)name_queue_map_.insert(std::make_pair(channel_name, queue));
76     init_ = true;
77     return DataQueueStatus::SUCCESS;
78   }
79   return DataQueueStatus::INTERNAL_ERROR;
80 }
81 
Open(const std::string & channel_name,const std::function<void (void *,int32_t)> func)82 DataQueueStatus DataQueueMgr::Open(const std::string &channel_name, const std::function<void(void *, int32_t)> func) {
83   MS_LOG(INFO) << "Data queue: " << channel_name << " opened by dataset";
84   if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
85     MS_LOG(ERROR) << "Data queue not exist " << channel_name;
86     return DataQueueStatus::QUEUE_NOT_EXIST;
87   }
88 
89   name_queue_map_[channel_name]->RegisterRelease(func);
90   open_by_dataset_++;
91   return DataQueueStatus::SUCCESS;
92 }
93 
CreateDynamicBufQueue(const std::string & channel_name,const size_t & capacity)94 DataQueueStatus DataQueueMgr::CreateDynamicBufQueue(const std::string &channel_name, const size_t &capacity) {
95   if (name_queue_map_.find(channel_name) != name_queue_map_.end()) {
96     return DataQueueStatus::QUEUE_EXIST;
97   }
98   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
99   std::string device_name = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
100   std::shared_ptr<DataQueue> device_queue = CreateDataQueue(device_name, channel_name, true, capacity);
101   if (device_queue != nullptr) {
102     std::shared_ptr<BlockingQueue> queue = std::make_shared<BlockingQueue>();
103     DataQueueStatus rt = queue->Create(device_queue);
104     if (rt != DataQueueStatus::SUCCESS) {
105       MS_LOG(ERROR) << "Queue: " << channel_name << "create failed: " << rt;
106       return rt;
107     }
108     (void)name_queue_map_.insert(std::make_pair(channel_name, queue));
109     init_ = true;
110     return DataQueueStatus::SUCCESS;
111   }
112 
113   MS_LOG(ERROR) << "Dynamic data queue only support Ascend/GPU target, bug got " << device_name;
114   return DataQueueStatus::INTERNAL_ERROR;
115 }
116 
Open(const std::string & channel_name) const117 DataQueueStatus DataQueueMgr::Open(const std::string &channel_name) const {
118   if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
119     MS_LOG(ERROR) << "Queue not exist " << channel_name;
120     return DataQueueStatus::QUEUE_NOT_EXIST;
121   }
122   return DataQueueStatus::SUCCESS;
123 }
124 
Push(const std::string & channel_name,const std::vector<DataQueueItem> & data,unsigned int timeout_in_sec)125 DataQueueStatus DataQueueMgr::Push(const std::string &channel_name, const std::vector<DataQueueItem> &data,
126                                    unsigned int timeout_in_sec) {
127   auto iter = name_queue_map_.find(channel_name);
128   if (iter == name_queue_map_.end()) {
129     MS_LOG(ERROR) << "Queue not exist " << channel_name;
130     return DataQueueStatus::QUEUE_NOT_EXIST;
131   }
132   return iter->second->Push(data, timeout_in_sec);
133 }
134 
Front(const std::string & channel_name,std::vector<DataQueueItem> * data)135 DataQueueStatus DataQueueMgr::Front(const std::string &channel_name, std::vector<DataQueueItem> *data) {
136   auto iter = name_queue_map_.find(channel_name);
137   if (iter == name_queue_map_.end()) {
138     MS_LOG(ERROR) << "Queue not exist " << channel_name;
139     return DataQueueStatus::QUEUE_NOT_EXIST;
140   }
141   return iter->second->Front(data);
142 }
FrontAsync(const std::string & channel_name,std::vector<DataQueueItem> * data)143 DataQueueStatus DataQueueMgr::FrontAsync(const std::string &channel_name, std::vector<DataQueueItem> *data) {
144   auto iter = name_queue_map_.find(channel_name);
145   if (iter == name_queue_map_.end()) {
146     MS_LOG(ERROR) << "Queue not exist " << channel_name;
147     return DataQueueStatus::QUEUE_NOT_EXIST;
148   }
149   return iter->second->FrontAsync(data);
150 }
151 
Pop(const std::string & channel_name)152 DataQueueStatus DataQueueMgr::Pop(const std::string &channel_name) {
153   auto iter = name_queue_map_.find(channel_name);
154   if (iter == name_queue_map_.end()) {
155     MS_LOG(ERROR) << "Queue not exist " << channel_name;
156     return DataQueueStatus::QUEUE_NOT_EXIST;
157   }
158 
159   return iter->second->Pop();
160 }
161 
Release()162 void DataQueueMgr::Release() { name_queue_map_.clear(); }
163 
Free(const std::string & channel_name)164 void DataQueueMgr::Free(const std::string &channel_name) {
165   auto iter = name_queue_map_.find(channel_name);
166   if (iter != name_queue_map_.end()) {
167     name_queue_map_.erase(iter);
168   }
169 }
170 
Clear(const std::string & channel_name)171 DataQueueStatus DataQueueMgr::Clear(const std::string &channel_name) {
172   auto iter = name_queue_map_.find(channel_name);
173   if (iter == name_queue_map_.end()) {
174     MS_LOG(ERROR) << "Queue not exist " << channel_name;
175     return DataQueueStatus::QUEUE_NOT_EXIST;
176   }
177 
178   return iter->second->Clear();
179 }
180 
Close(const std::string & channel_name) const181 void DataQueueMgr::Close(const std::string &channel_name) const noexcept {
182   MS_LOG(INFO) << "Close the queue: " << channel_name;
183   return;
184 }
185 
IsInit() const186 bool DataQueueMgr::IsInit() const { return init_; }
187 
IsClosed() const188 bool DataQueueMgr::IsClosed() const { return closed_; }
189 
IsCreated(const std::string & channel_name) const190 bool DataQueueMgr::IsCreated(const std::string &channel_name) const {
191   return name_queue_map_.find(channel_name) != name_queue_map_.end();
192 }
193 
CloseNotify()194 bool DataQueueMgr::CloseNotify() {
195   py::gil_scoped_release release;
196   bool result = true;
197   // lock scope
198   {
199     std::lock_guard<std::mutex> lk(close_mutex_);
200     // set closed_ to be true, all the dataset retry can be jumped out of the while
201     closed_ = true;
202   }
203 
204   // wati for the dataset threads' ack
205   for (int i = 0; i < open_by_dataset_; i++) {
206     if (sema.Wait() == false) {
207       MS_LOG(ERROR) << "time out of receiving signals";
208       result = false;
209     }
210     MS_LOG(DEBUG) << "receive one signal (" << (i + 1) << "/" << open_by_dataset_ << ")";
211   }
212   return result;
213 }
214 
CloseConfirm()215 void DataQueueMgr::CloseConfirm() { sema.Signal(); }
216 
Size(const std::string & channel_name)217 size_t DataQueueMgr::Size(const std::string &channel_name) {
218   if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
219     MS_LOG(ERROR) << "Queue not exist " << channel_name;
220     return 0;
221   }
222   return name_queue_map_.at(channel_name)->Size();
223 }
224 
Capacity(const std::string & channel_name)225 size_t DataQueueMgr::Capacity(const std::string &channel_name) {
226   if (name_queue_map_.find(channel_name) == name_queue_map_.end()) {
227     MS_LOG(ERROR) << "Queue not exist " << channel_name;
228     return 0;
229   }
230   return name_queue_map_.at(channel_name)->Capacity();
231 }
232 
GetDataQueue(const std::string & channel_name) const233 std::shared_ptr<BlockingQueue> DataQueueMgr::GetDataQueue(const std::string &channel_name) const {
234   auto iter = name_queue_map_.find(channel_name);
235   if (iter == name_queue_map_.end()) {
236     MS_LOG(ERROR) << "Queue not exist " << channel_name;
237     return nullptr;
238   }
239   MS_EXCEPTION_IF_NULL(iter->second);
240   return iter->second;
241 }
242 
SetThreadDevice(const std::string & channel_name) const243 DataQueueStatus DataQueueMgr::SetThreadDevice(const std::string &channel_name) const {
244   auto queue = GetDataQueue(channel_name);
245   if (queue == nullptr || queue->Queue() == nullptr) {
246     return DataQueueStatus::QUEUE_NOT_EXIST;
247   }
248   queue->Queue()->SetThreadDevice();
249   return DataQueueStatus::SUCCESS;
250 }
251 
252 #ifndef BUILD_LITE
UpdateGetNextWithDataQueueItems(const AnfNodePtr & data_kernel,const std::vector<device::DataQueueItem> & data)253 void UpdateGetNextWithDataQueueItems(const AnfNodePtr &data_kernel, const std::vector<device::DataQueueItem> &data) {
254   auto kernel_info = dynamic_cast<device::KernelInfo *>(data_kernel->kernel_info());
255   std::vector<std::shared_ptr<device::DeviceAddress>> device_tensors;
256   for (auto &device_tensor : kernel_info->output_address_list()) {
257     MS_EXCEPTION_IF_NULL(device_tensor);
258     device_tensors.push_back(device_tensor);
259   }
260   MS_EXCEPTION_IF_CHECK_FAIL(data.size() == device_tensors.size(),
261                              "The number of data tensor popped from dynamic queue is not correct");
262   std::vector<ShapeVector> shapes;
263   std::vector<TypeId> types;
264   std::vector<size_t> output_size_list;
265   for (size_t i = 0; i < data.size(); ++i) {
266     device_tensors[i]->SetSize(data[i].data_len);
267     device_tensors[i]->set_from_mem_pool(true);
268     output_size_list.push_back(data[i].data_len);
269     shapes.push_back(data[i].shapes);
270     types.push_back(common::AnfAlgo::GetOutputInferDataType(data_kernel, i));
271   }
272   auto kernel_mod = kernel_info->MutableKernelMod();
273   kernel_mod->SetOutputSizeList(output_size_list);
274   common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, data_kernel.get());
275 }
276 
UpdateGetNextWithDataQueueItems(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs,const std::vector<device::DataQueueItem> & data,std::vector<size_t> * output_size_list)277 void UpdateGetNextWithDataQueueItems(const std::vector<kernel::KernelTensor *> &inputs,
278                                      const std::vector<kernel::KernelTensor *> &outputs,
279                                      const std::vector<device::DataQueueItem> &data,
280                                      std::vector<size_t> *output_size_list) {
281   MS_EXCEPTION_IF_CHECK_FAIL(data.size() == outputs.size(),
282                              "The number of data tensor popped from dynamic queue is not correct");
283   output_size_list->clear();
284   for (size_t i = 0; i < data.size(); ++i) {
285     outputs[i]->set_size(data[i].data_len);
286     outputs[i]->SetShapeVector(data[i].shapes);
287     output_size_list->push_back(data[i].data_len);
288   }
289 }
290 
RetryPeakItemFromDataQueue(const AnfNodePtr & data_kernel,const std::shared_ptr<BlockingQueue> & data_queue,std::vector<device::DataQueueItem> * data)291 void RetryPeakItemFromDataQueue(const AnfNodePtr &data_kernel, const std::shared_ptr<BlockingQueue> &data_queue,
292                                 std::vector<device::DataQueueItem> *data) {
293   auto front_ret = DataQueueStatus::TIMEOUT;
294   auto ms_context = MsContext::GetInstance();
295   MS_EXCEPTION_IF_NULL(ms_context);
296   uint32_t op_timeout = ms_context->get_param<uint32_t>(MS_CTX_OP_TIMEOUT);
297   time_t start_time = time(nullptr);
298   while (front_ret == DataQueueStatus::TIMEOUT && ((time(nullptr) - start_time) < op_timeout || op_timeout == 0)) {
299     front_ret = data_queue->FrontAsync(data);
300   }
301   if (front_ret != DataQueueStatus::SUCCESS) {
302     if (front_ret == DataQueueStatus::TIMEOUT) {
303       MS_LOG(ERROR) << "Getnext gets peek data time out, that most likely caused by data processing being too slow";
304     }
305     MS_LOG(EXCEPTION) << "Getnext gets peek data from data queue failed: " << front_ret;
306   }
307 }
308 
UpdateGetNextNode(const AnfNodePtr & data_kernel)309 void UpdateGetNextNode(const AnfNodePtr &data_kernel) {
310   auto queue_name = common::AnfAlgo::GetNodeAttr<std::string>(data_kernel, "shared_name");
311   device::DataQueueMgr &buf_mgr = device::DataQueueMgr::GetInstance();
312   auto ret = buf_mgr.Open(queue_name);
313   MS_EXCEPTION_IF_CHECK_FAIL(ret == device::DataQueueStatus::SUCCESS, "Open dynamic data queue failed");
314   auto data_queue = buf_mgr.GetDataQueue(queue_name);
315   std::vector<device::DataQueueItem> data;
316   RetryPeakItemFromDataQueue(data_kernel, data_queue, &data);
317   UpdateGetNextWithDataQueueItems(data_kernel, data);
318 }
319 
UpdateGetNextNode(const PrimitivePtr & primitive,const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs,std::vector<size_t> * output_size_list)320 void UpdateGetNextNode(const PrimitivePtr &primitive, const std::vector<kernel::KernelTensor *> &inputs,
321                        const std::vector<kernel::KernelTensor *> &outputs, std::vector<size_t> *output_size_list) {
322   auto queue_name = GetValue<std::string>(primitive->GetAttr("shared_name"));
323   device::DataQueueMgr &buf_mgr = device::DataQueueMgr::GetInstance();
324   auto ret = buf_mgr.Open(queue_name);
325   MS_EXCEPTION_IF_CHECK_FAIL(ret == device::DataQueueStatus::SUCCESS, "Open dynamic data queue failed");
326   auto data_queue = buf_mgr.GetDataQueue(queue_name);
327   std::vector<device::DataQueueItem> data;
328   RetryPeakItemFromDataQueue(nullptr, data_queue, &data);
329   UpdateGetNextWithDataQueueItems(inputs, outputs, data, output_size_list);
330 }
331 
332 #endif
333 }  // namespace device
334 }  // namespace mindspore
335