• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
18 
19 #include <atomic>
20 #include <deque>
21 #include <iostream>
22 #include <map>
23 #include <memory>
24 #include <mutex>
25 #include "minddata/dataset/util/allocator.h"
26 #include "minddata/dataset/util/system_pool.h"
27 #include "minddata/dataset/util/semaphore.h"
28 #include "minddata/dataset/util/services.h"
29 namespace mindspore {
30 namespace dataset {
31 template <typename K, typename T>
32 /// \brief QueueMap is like a Queue but instead of there is a map of deque<T>.
33 /// Consumer will block if the corresponding deque is empty.
34 /// Producer can add an element of type T with key of type K to the map and
35 /// wake up any waiting consumer.
36 /// \tparam K key type
37 /// \tparam T payload of the map
38 class QueueMap {
39  public:
40   using key_type = K;
41   using value_type = T;
42 
QueueMap()43   QueueMap() : num_rows_(0) {}
44   virtual ~QueueMap() = default;
45 
46   /// Add an element <key, T> to the map and wake up any consumer that is waiting
47   /// \param key
48   /// \param payload
49   /// \return Status object
Add(key_type key,T && payload)50   virtual Status Add(key_type key, T &&payload) {
51     RequestQueue *rq = nullptr;
52     RETURN_IF_NOT_OK(GetRq(key, &rq));
53     RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload)));
54     ++num_rows_;
55     return Status::OK();
56   }
57 
58   /// Pop the front of the deque with key. Block if the deque is empty.
PopFront(key_type key,T * out)59   virtual Status PopFront(key_type key, T *out) {
60     RequestQueue *rq = nullptr;
61     RETURN_IF_NOT_OK(GetRq(key, &rq));
62     RETURN_IF_NOT_OK(rq->Wait(out));
63     --num_rows_;
64     return Status::OK();
65   }
66 
67   /// Get the number of elements in the container
68   /// \return The number of elements in the container
size()69   int64_t size() const { return num_rows_; }
70 
71   /// \return if the container is empty
empty()72   bool empty() const { return num_rows_ == 0; }
73 
74   /// Print out some useful information about the container
75   friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) {
76     std::unique_lock<std::mutex> lck(qm.mux_);
77     out << "Number of elements: " << qm.num_rows_ << "\n";
78     out << "Dumping internal info:\n";
79     int64_t k = 0;
80     constexpr int64_t line_breaks_number = 6;
81     for (auto &it : qm.all_) {
82       auto key = it.first;
83       const RequestQueue *rq = it.second.GetPointer();
84       out << "(k:" << key << "," << *rq << ") ";
85       ++k;
86       if (k % line_breaks_number == 0) {
87         out << "\n";
88       }
89     }
90     return out;
91   }
92 
93  protected:
94   /// This is a handshake structure between producer and consumer
95   class RequestQueue {
96    public:
RequestQueue()97     RequestQueue() : use_count_(0) {}
98     ~RequestQueue() = default;
99 
Wait(T * out)100     Status Wait(T *out) {
101       RETURN_UNEXPECTED_IF_NULL(out);
102       // Block until the missing row is in the pool.
103       RETURN_IF_NOT_OK(use_count_.P());
104       std::unique_lock<std::mutex> lck(dq_mux_);
105       CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
106       *out = std::move(row_.front());
107       row_.pop_front();
108       return Status::OK();
109     }
110 
WakeUpAny(T && row)111     Status WakeUpAny(T &&row) {
112       std::unique_lock<std::mutex> lck(dq_mux_);
113       row_.push_back(std::move(row));
114       // Bump up the use count by 1. This wake up any parallel worker which is waiting
115       // for this row.
116       use_count_.V();
117       return Status::OK();
118     }
119 
120     friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) {
121       out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek();
122       return out;
123     }
124 
125    private:
126     mutable std::mutex dq_mux_;
127     Semaphore use_count_;
128     std::deque<T> row_;
129   };
130 
131   /// Create or locate an element with matching key
132   /// \param key
133   /// \param out
134   /// \return Status object
GetRq(key_type key,RequestQueue ** out)135   Status GetRq(key_type key, RequestQueue **out) {
136     RETURN_UNEXPECTED_IF_NULL(out);
137     std::unique_lock<std::mutex> lck(mux_);
138     auto it = all_.find(key);
139     if (it != all_.end()) {
140       *out = it->second.GetMutablePointer();
141     } else {
142       // We will create a new one.
143       auto alloc = SystemPool::GetAllocator<RequestQueue>();
144       auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc));
145       if (r.second) {
146         auto &mem = r.first->second;
147         RETURN_IF_NOT_OK(mem.allocate(1));
148         *out = mem.GetMutablePointer();
149       } else {
150         RETURN_STATUS_UNEXPECTED("Map insert fail.");
151       }
152     }
153     return Status::OK();
154   }
155 
156  private:
157   mutable std::mutex mux_;
158   std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_;
159   std::atomic<int64_t> num_rows_;
160 };
161 }  // namespace dataset
162 }  // namespace mindspore
163 
164 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
165