• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_
18 
19 #include <memory>
20 #include <mutex>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "securec.h"
26 #include "minddata/dataset/util/allocator.h"
27 #include "minddata/dataset/util/log_adapter.h"
28 #include "minddata/dataset/util/services.h"
29 #include "minddata/dataset/util/cond_var.h"
30 #include "minddata/dataset/util/task_manager.h"
31 
32 namespace mindspore {
33 namespace dataset {
34 // A simple thread safe queue using a fixed size array
35 template <typename T>
36 class Queue {
37  public:
38   using value_type = T;
39   using pointer = T *;
40   using const_pointer = const T *;
41   using reference = T &;
42   using const_reference = const T &;
43 
Queue(int sz)44   explicit Queue(int sz)
45       : sz_(sz), arr_(Services::GetAllocator<T>()), head_(0), tail_(0), my_name_(Services::GetUniqueID()) {
46     Status rc = arr_.allocate(sz);
47     if (rc.IsError()) {
48       MS_LOG(ERROR) << "Fail to create a queue.";
49       std::terminate();
50     } else {
51       MS_LOG(DEBUG) << "Create Q with uuid " << my_name_ << " of size " << sz_ << ".";
52     }
53   }
54 
~Queue()55   virtual ~Queue() { ResetQue(); }
56 
size()57   size_t size() const {
58     std::unique_lock<std::mutex> _lock(mux_);
59     size_t v = 0;
60     if (tail_ >= head_) {
61       v = tail_ - head_;
62     }
63     return v;
64   }
65 
capacity()66   size_t capacity() const {
67     std::unique_lock<std::mutex> _lock(mux_);
68     return sz_;
69   }
70 
empty()71   bool empty() const {
72     std::unique_lock<std::mutex> _lock(mux_);
73     return head_ == tail_;
74   }
75 
Reset()76   void Reset() {
77     std::unique_lock<std::mutex> _lock(mux_);
78     ResetQue();
79     extra_arr_.clear();
80   }
81 
82   // Producer
Add(const_reference ele)83   Status Add(const_reference ele) noexcept {
84     std::unique_lock<std::mutex> _lock(mux_);
85     // Block when full
86     Status rc =
87       full_cv_.Wait(&_lock, [this]() -> bool { return (SizeWhileHoldingLock() != CapacityWhileHoldingLock()); });
88     if (rc.IsOk()) {
89       this->AddWhileHoldingLock(ele);
90       empty_cv_.NotifyAll();
91       _lock.unlock();
92     } else {
93       empty_cv_.Interrupt();
94     }
95     return rc;
96   }
97 
Add(T && ele)98   Status Add(T &&ele) noexcept {
99     std::unique_lock<std::mutex> _lock(mux_);
100     // Block when full
101     Status rc =
102       full_cv_.Wait(&_lock, [this]() -> bool { return (SizeWhileHoldingLock() != CapacityWhileHoldingLock()); });
103     if (rc.IsOk()) {
104       this->AddWhileHoldingLock(std::forward<T>(ele));
105       empty_cv_.NotifyAll();
106       _lock.unlock();
107     } else {
108       empty_cv_.Interrupt();
109     }
110     return rc;
111   }
112 
113   template <typename... Ts>
EmplaceBack(Ts &&...args)114   Status EmplaceBack(Ts &&... args) noexcept {
115     std::unique_lock<std::mutex> _lock(mux_);
116     // Block when full
117     Status rc =
118       full_cv_.Wait(&_lock, [this]() -> bool { return (SizeWhileHoldingLock() != CapacityWhileHoldingLock()); });
119     if (rc.IsOk()) {
120       auto k = tail_++ % sz_;
121       new (arr_[k]) T(std::forward<Ts>(args)...);
122       empty_cv_.NotifyAll();
123       _lock.unlock();
124     } else {
125       empty_cv_.Interrupt();
126     }
127     return rc;
128   }
129 
130   // Consumer
PopFront(pointer p)131   virtual Status PopFront(pointer p) {
132     std::unique_lock<std::mutex> _lock(mux_);
133     // Block when empty
134     Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !EmptyWhileHoldingLock(); });
135     if (rc.IsOk()) {
136       this->PopFrontWhileHoldingLock(p, true);
137       full_cv_.NotifyAll();
138       _lock.unlock();
139     } else {
140       full_cv_.Interrupt();
141     }
142     return rc;
143   }
144 
Register(TaskGroup * vg)145   Status Register(TaskGroup *vg) {
146     Status rc1 = empty_cv_.Register(vg->GetIntrpService());
147     Status rc2 = full_cv_.Register(vg->GetIntrpService());
148     if (rc1.IsOk()) {
149       return rc2;
150     } else {
151       return rc1;
152     }
153   }
154 
Resize(int32_t new_capacity)155   Status Resize(int32_t new_capacity) {
156     std::unique_lock<std::mutex> _lock(mux_);
157     CHECK_FAIL_RETURN_UNEXPECTED(new_capacity > 0,
158                                  "New capacity: " + std::to_string(new_capacity) + ", should be larger than 0");
159     RETURN_OK_IF_TRUE(new_capacity == static_cast<int32_t>(CapacityWhileHoldingLock()));
160     std::vector<T> queue;
161     // pop from the original queue until the new_capacity is full
162     for (int32_t i = 0; i < new_capacity; ++i) {
163       if (head_ < tail_) {
164         // if there are elements left in queue, pop out
165         T temp;
166         this->PopFrontWhileHoldingLock(&temp, true);
167         queue.push_back(temp);
168       } else {
169         // if there is nothing left in queue, check extra_arr_
170         if (!extra_arr_.empty()) {
171           // if extra_arr_ is not empty, push to fill the new_capacity
172           queue.push_back(extra_arr_[0]);
173           extra_arr_.erase(extra_arr_.begin());
174         } else {
175           // if everything in the queue and extra_arr_ is popped out, break the loop
176           break;
177         }
178       }
179     }
180     // if there are extra elements in queue, put them to extra_arr_
181     while (head_ < tail_) {
182       T temp;
183       this->PopFrontWhileHoldingLock(&temp, false);
184       extra_arr_.push_back(temp);
185     }
186     this->ResetQue();
187     RETURN_IF_NOT_OK(arr_.allocate(new_capacity));
188     sz_ = new_capacity;
189     for (int32_t i = 0; i < static_cast<int32_t>(queue.size()); ++i) {
190       this->AddWhileHoldingLock(queue[i]);
191     }
192     queue.clear();
193     _lock.unlock();
194     return Status::OK();
195   }
196 
197  private:
198   size_t sz_;
199   MemGuard<T, Allocator<T>> arr_;
200   std::vector<T> extra_arr_;  // used to store extra elements after reducing capacity, will not be changed by Add,
201                               // will pop when there is a space in queue (by PopFront or Resize)
202   size_t head_;
203   size_t tail_;
204   std::string my_name_;
205   mutable std::mutex mux_;
206   CondVar empty_cv_;
207   CondVar full_cv_;
208 
209   // Helper function for Add, must be called when holding a lock
AddWhileHoldingLock(const_reference ele)210   void AddWhileHoldingLock(const_reference ele) {
211     auto k = tail_++ % sz_;
212     *(arr_[k]) = ele;
213   }
214 
215   // Helper function for Add, must be called when holding a lock
AddWhileHoldingLock(T && ele)216   void AddWhileHoldingLock(T &&ele) {
217     auto k = tail_++ % sz_;
218     *(arr_[k]) = std::forward<T>(ele);
219   }
220 
221   // Helper function for PopFront, must be called when holding a lock
PopFrontWhileHoldingLock(pointer p,bool clean_extra)222   void PopFrontWhileHoldingLock(pointer p, bool clean_extra) {
223     auto k = head_++ % sz_;
224     *p = std::move(*(arr_[k]));
225     if (!extra_arr_.empty() && clean_extra) {
226       this->AddWhileHoldingLock(std::forward<T>(extra_arr_[0]));
227       extra_arr_.erase(extra_arr_.begin());
228     }
229   }
230 
ResetQue()231   void ResetQue() noexcept {
232     while (head_ < tail_) {
233       T val;
234       this->PopFrontWhileHoldingLock(&val, false);
235       MS_LOG(DEBUG) << "Address of val: " << &val;
236     }
237     empty_cv_.ResetIntrpState();
238     full_cv_.ResetIntrpState();
239     head_ = 0;
240     tail_ = 0;
241   }
242 
SizeWhileHoldingLock()243   size_t SizeWhileHoldingLock() const {
244     size_t v = 0;
245     if (tail_ >= head_) {
246       v = tail_ - head_;
247     }
248     return v;
249   }
250 
CapacityWhileHoldingLock()251   size_t CapacityWhileHoldingLock() const { return sz_; }
252 
EmptyWhileHoldingLock()253   bool EmptyWhileHoldingLock() const { return head_ == tail_; }
254 };
255 
256 // A container of queues with [] operator accessors.  Basically this is a wrapper over of a vector of queues
257 // to help abstract/simplify code that is maintaining multiple queues.
258 template <typename T>
259 class QueueList {
260  public:
QueueList()261   QueueList() {}
262 
Init(int num_queues,int capacity)263   void Init(int num_queues, int capacity) {
264     (void)queue_list_.reserve(num_queues);
265     for (int i = 0; i < num_queues; i++) {
266       (void)queue_list_.emplace_back(std::make_unique<Queue<T>>(capacity));
267     }
268   }
269 
Register(TaskGroup * vg)270   Status Register(TaskGroup *vg) {
271     if (vg == nullptr) {
272       RETURN_STATUS_UNEXPECTED("Null task group during QueueList registration.");
273     }
274     for (int i = 0; i < queue_list_.size(); ++i) {
275       RETURN_IF_NOT_OK(queue_list_[i]->Register(vg));
276     }
277     return Status::OK();
278   }
279 
size()280   auto size() const {
281     std::unique_lock<std::mutex> _lock(mux_);
282     return queue_list_.size();
283   }
284 
285   std::unique_ptr<Queue<T>> &operator[](const int index) {
286     std::unique_lock<std::mutex> _lock(mux_);
287     return queue_list_[index];
288   }
289 
290   const std::unique_ptr<Queue<T>> &operator[](const int index) const {
291     std::unique_lock<std::mutex> _lock(mux_);
292     return queue_list_[index];
293   }
294 
295   ~QueueList() = default;
296 
AddQueue(TaskGroup * vg)297   Status AddQueue(TaskGroup *vg) {
298     std::unique_lock<std::mutex> _lock(mux_);
299     (void)queue_list_.emplace_back(std::make_unique<Queue<T>>(queue_list_[0]->capacity()));
300     return queue_list_[queue_list_.size() - 1]->Register(vg);
301   }
RemoveLastQueue()302   Status RemoveLastQueue() {
303     std::unique_lock<std::mutex> _lock(mux_);
304     CHECK_FAIL_RETURN_UNEXPECTED(queue_list_.size() > 1, "Cannot remove more than the current queues.");
305     (void)queue_list_.pop_back();
306     return Status::OK();
307   }
308 
309  private:
310   // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector
311   // requirement that objects must have copy semantics.  To resolve this, we use a vector of unique
312   // pointers.  This allows us to provide dynamic creation of queues in a container.
313   std::vector<std::unique_ptr<Queue<T>>> queue_list_;
314 
315   mutable std::mutex mux_;
316 };
317 }  // namespace dataset
318 }  // namespace mindspore
319 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_
320