• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
18 
19 #include <atomic>
20 #include <deque>
21 #include <iostream>
22 #include <map>
23 #include <memory>
24 #include <mutex>
25 #include "minddata/dataset/util/allocator.h"
26 #include "minddata/dataset/util/system_pool.h"
27 #include "minddata/dataset/util/semaphore.h"
28 #include "minddata/dataset/util/services.h"
29 namespace mindspore {
30 namespace dataset {
31 template <typename K, typename T>
32 /// \brief QueueMap is like a Queue but instead of there is a map of deque<T>.
33 /// Consumer will block if the corresponding deque is empty.
34 /// Producer can add an element of type T with key of type K to the map and
35 /// wake up any waiting consumer.
36 /// \tparam K key type
37 /// \tparam T payload of the map
38 class QueueMap {
39  public:
40   using key_type = K;
41   using value_type = T;
42 
QueueMap()43   QueueMap() : num_rows_(0) {}
44   virtual ~QueueMap() = default;
45 
46   /// Add an element <key, T> to the map and wake up any consumer that is waiting
47   /// \param key
48   /// \param payload
49   /// \return Status object
Add(key_type key,T && payload)50   virtual Status Add(key_type key, T &&payload) {
51     RequestQueue *rq = nullptr;
52     RETURN_IF_NOT_OK(GetRq(key, &rq));
53     RETURN_IF_NOT_OK(rq->WakeUpAny(std::move(payload)));
54     ++num_rows_;
55     return Status::OK();
56   }
57 
58   /// Pop the front of the deque with key. Block if the deque is empty.
PopFront(key_type key,T * out)59   virtual Status PopFront(key_type key, T *out) {
60     RequestQueue *rq = nullptr;
61     RETURN_IF_NOT_OK(GetRq(key, &rq));
62     RETURN_IF_NOT_OK(rq->Wait(out));
63     --num_rows_;
64     return Status::OK();
65   }
66 
67   /// Get the number of elements in the container
68   /// \return The number of elements in the container
size()69   int64_t size() const { return num_rows_; }
70 
71   /// \return if the container is empty
empty()72   bool empty() const { return num_rows_ == 0; }
73 
74   /// Print out some useful information about the container
75   friend std::ostream &operator<<(std::ostream &out, const QueueMap &qm) {
76     std::unique_lock<std::mutex> lck(qm.mux_);
77     out << "Number of elements: " << qm.num_rows_ << "\n";
78     out << "Dumping internal info:\n";
79     int64_t k = 0;
80     for (auto &it : qm.all_) {
81       auto key = it.first;
82       const RequestQueue *rq = it.second.GetPointer();
83       out << "(k:" << key << "," << *rq << ") ";
84       ++k;
85       if (k % 6 == 0) {
86         out << "\n";
87       }
88     }
89     return out;
90   }
91 
92  protected:
93   /// This is a handshake structure between producer and consumer
94   class RequestQueue {
95    public:
RequestQueue()96     RequestQueue() : use_count_(0) {}
97     ~RequestQueue() = default;
98 
Wait(T * out)99     Status Wait(T *out) {
100       RETURN_UNEXPECTED_IF_NULL(out);
101       // Block until the missing row is in the pool.
102       RETURN_IF_NOT_OK(use_count_.P());
103       std::unique_lock<std::mutex> lck(dq_mux_);
104       CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
105       *out = std::move(row_.front());
106       row_.pop_front();
107       return Status::OK();
108     }
109 
WakeUpAny(T && row)110     Status WakeUpAny(T &&row) {
111       std::unique_lock<std::mutex> lck(dq_mux_);
112       row_.push_back(std::move(row));
113       // Bump up the use count by 1. This wake up any parallel worker which is waiting
114       // for this row.
115       use_count_.V();
116       return Status::OK();
117     }
118 
119     friend std::ostream &operator<<(std::ostream &out, const RequestQueue &rq) {
120       out << "sz:" << rq.row_.size() << ",uc:" << rq.use_count_.Peek();
121       return out;
122     }
123 
124    private:
125     mutable std::mutex dq_mux_;
126     Semaphore use_count_;
127     std::deque<T> row_;
128   };
129 
130   /// Create or locate an element with matching key
131   /// \param key
132   /// \param out
133   /// \return Status object
GetRq(key_type key,RequestQueue ** out)134   Status GetRq(key_type key, RequestQueue **out) {
135     RETURN_UNEXPECTED_IF_NULL(out);
136     std::unique_lock<std::mutex> lck(mux_);
137     auto it = all_.find(key);
138     if (it != all_.end()) {
139       *out = it->second.GetMutablePointer();
140     } else {
141       // We will create a new one.
142       auto alloc = SystemPool::GetAllocator<RequestQueue>();
143       auto r = all_.emplace(key, MemGuard<RequestQueue, Allocator<RequestQueue>>(alloc));
144       if (r.second) {
145         auto &mem = r.first->second;
146         RETURN_IF_NOT_OK(mem.allocate(1));
147         *out = mem.GetMutablePointer();
148       } else {
149         RETURN_STATUS_UNEXPECTED("Map insert fail.");
150       }
151     }
152     return Status::OK();
153   }
154 
155  private:
156   mutable std::mutex mux_;
157   std::map<K, MemGuard<RequestQueue, Allocator<RequestQueue>>> all_;
158   std::atomic<int64_t> num_rows_;
159 };
160 }  // namespace dataset
161 }  // namespace mindspore
162 
163 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_QUEUE_MAP_H_
164