• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_
18 
19 #include <atomic>
20 #include <memory>
21 #include <mutex>
22 #include <string>
23 #include <type_traits>
24 #include <utility>
25 #include <vector>
26 
27 #include "./securec.h"
28 #include "utils/ms_utils.h"
29 #include "minddata/dataset/util/allocator.h"
30 #include "minddata/dataset/util/log_adapter.h"
31 #include "minddata/dataset/util/services.h"
32 #include "minddata/dataset/util/cond_var.h"
33 #include "minddata/dataset/util/task_manager.h"
34 
35 namespace mindspore {
36 namespace dataset {
37 // A simple thread safe queue using a fixed size array
38 template <typename T>
39 class Queue {
40  public:
41   using value_type = T;
42   using pointer = T *;
43   using const_pointer = const T *;
44   using reference = T &;
45   using const_reference = const T &;
46 
Queue(int sz)47   explicit Queue(int sz)
48       : sz_(sz), arr_(Services::GetAllocator<T>()), head_(0), tail_(0), my_name_(Services::GetUniqueID()) {
49     Status rc = arr_.allocate(sz);
50     if (rc.IsError()) {
51       MS_LOG(ERROR) << "Fail to create a queue.";
52       std::terminate();
53     } else {
54       MS_LOG(DEBUG) << "Create Q with uuid " << my_name_ << " of size " << sz_ << ".";
55     }
56   }
57 
~Queue()58   virtual ~Queue() { ResetQue(); }
59 
size()60   size_t size() const {
61     size_t v = tail_ - head_;
62     return (v >= 0) ? v : 0;
63   }
64 
capacity()65   size_t capacity() const { return sz_; }
66 
empty()67   bool empty() const { return head_ == tail_; }
68 
Reset()69   void Reset() { ResetQue(); }
70 
71   // Producer
Add(const_reference ele)72   Status Add(const_reference ele) noexcept {
73     std::unique_lock<std::mutex> _lock(mux_);
74     // Block when full
75     Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); });
76     if (rc.IsOk()) {
77       auto k = tail_++ % sz_;
78       *(arr_[k]) = ele;
79       empty_cv_.NotifyAll();
80       _lock.unlock();
81     } else {
82       empty_cv_.Interrupt();
83     }
84     return rc;
85   }
86 
Add(T && ele)87   Status Add(T &&ele) noexcept {
88     std::unique_lock<std::mutex> _lock(mux_);
89     // Block when full
90     Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); });
91     if (rc.IsOk()) {
92       auto k = tail_++ % sz_;
93       *(arr_[k]) = std::forward<T>(ele);
94       empty_cv_.NotifyAll();
95       _lock.unlock();
96     } else {
97       empty_cv_.Interrupt();
98     }
99     return rc;
100   }
101 
102   template <typename... Ts>
EmplaceBack(Ts &&...args)103   Status EmplaceBack(Ts &&... args) noexcept {
104     std::unique_lock<std::mutex> _lock(mux_);
105     // Block when full
106     Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); });
107     if (rc.IsOk()) {
108       auto k = tail_++ % sz_;
109       new (arr_[k]) T(std::forward<Ts>(args)...);
110       empty_cv_.NotifyAll();
111       _lock.unlock();
112     } else {
113       empty_cv_.Interrupt();
114     }
115     return rc;
116   }
117 
118   // Consumer
PopFront(pointer p)119   Status PopFront(pointer p) {
120     std::unique_lock<std::mutex> _lock(mux_);
121     // Block when empty
122     Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !empty(); });
123     if (rc.IsOk()) {
124       auto k = head_++ % sz_;
125       *p = std::move(*(arr_[k]));
126       full_cv_.NotifyAll();
127       _lock.unlock();
128     } else {
129       full_cv_.Interrupt();
130     }
131     return rc;
132   }
133 
ResetQue()134   void ResetQue() noexcept {
135     std::unique_lock<std::mutex> _lock(mux_);
136     // If there are elements in the queue, drain them. We won't call PopFront directly
137     // because we have got the lock already. We will deadlock if we call PopFront
138     for (auto i = head_; i < tail_; ++i) {
139       auto k = i % sz_;
140       auto val = std::move(*(arr_[k]));
141       // Let val go out of scope and its destructor will be invoked automatically.
142       // But our compiler may complain val is not in use. So let's do some useless
143       // stuff.
144       MS_LOG(DEBUG) << "Address of val: " << &val;
145     }
146     empty_cv_.ResetIntrpState();
147     full_cv_.ResetIntrpState();
148     head_ = 0;
149     tail_ = 0;
150   }
151 
Register(TaskGroup * vg)152   Status Register(TaskGroup *vg) {
153     Status rc1 = empty_cv_.Register(vg->GetIntrpService());
154     Status rc2 = full_cv_.Register(vg->GetIntrpService());
155     if (rc1.IsOk()) {
156       return rc2;
157     } else {
158       return rc1;
159     }
160   }
161 
162  private:
163   size_t sz_;
164   MemGuard<T, Allocator<T>> arr_;
165   size_t head_;
166   size_t tail_;
167   std::string my_name_;
168   std::mutex mux_;
169   CondVar empty_cv_;
170   CondVar full_cv_;
171 };
172 
173 // A container of queues with [] operator accessors.  Basically this is a wrapper over of a vector of queues
174 // to help abstract/simplify code that is maintaining multiple queues.
175 template <typename T>
176 class QueueList {
177  public:
QueueList()178   QueueList() {}
179 
Init(int num_queues,int capacity)180   void Init(int num_queues, int capacity) {
181     queue_list_.reserve(num_queues);
182     for (int i = 0; i < num_queues; i++) {
183       queue_list_.emplace_back(std::make_unique<Queue<T>>(capacity));
184     }
185   }
186 
Register(TaskGroup * vg)187   Status Register(TaskGroup *vg) {
188     if (vg == nullptr) {
189       return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
190                     "Null task group during QueueList registration.");
191     }
192     for (int i = 0; i < queue_list_.size(); ++i) {
193       RETURN_IF_NOT_OK(queue_list_[i]->Register(vg));
194     }
195     return Status::OK();
196   }
197 
size()198   auto size() const { return queue_list_.size(); }
199 
200   std::unique_ptr<Queue<T>> &operator[](const int index) { return queue_list_[index]; }
201 
202   const std::unique_ptr<Queue<T>> &operator[](const int index) const { return queue_list_[index]; }
203 
204   ~QueueList() = default;
205 
206  private:
207   // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector
208   // requirement that objects must have copy semantics.  To resolve this, we use a vector of unique
209   // pointers.  This allows us to provide dynamic creation of queues in a container.
210   std::vector<std::unique_ptr<Queue<T>>> queue_list_;
211 };
212 }  // namespace dataset
213 }  // namespace mindspore
214 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_H_
215