1 /* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef OHOS_DISTRIBUTED_DATA_FRAMEWORKS_COMMON_PRIORITY_QUEUE_H 17 #define OHOS_DISTRIBUTED_DATA_FRAMEWORKS_COMMON_PRIORITY_QUEUE_H 18 #include <map> 19 #include <memory> 20 #include <mutex> 21 #include <queue> 22 #include <set> 23 #include <shared_mutex> 24 namespace OHOS { 25 template<typename _Tsk, typename _Tme, typename _Tid> 26 class PriorityQueue { 27 public: 28 struct PQMatrix { 29 _Tsk task_; 30 _Tid id_; PQMatrixPQMatrix31 PQMatrix(_Tsk task, _Tid id) : task_(task), id_(id) {} 32 }; 33 using TskIndex = typename std::map<_Tme, PQMatrix>::iterator; 34 using TskUpdater = typename std::function<std::pair<bool, _Tme>(_Tsk &element)>; 35 36 PriorityQueue(const _Tsk &task, TskUpdater updater = nullptr) INVALID_TSK(std::move (task))37 : INVALID_TSK(std::move(task)), updater_(std::move(updater)) 38 { 39 if (!updater_) { 40 updater_ = [](_Tsk &) { return std::pair{false, _Tme()};}; 41 } 42 } Pop()43 _Tsk Pop() 44 { 45 std::unique_lock<decltype(pqMtx_)> lock(pqMtx_); 46 while (!tasks_.empty()) { 47 auto waitTme = tasks_.begin()->first; 48 if (waitTme > std::chrono::steady_clock::now()) { 49 popCv_.wait_until(lock, waitTme); 50 continue; 51 } 52 auto temp = tasks_.begin(); 53 auto id = temp->second.id_; 54 running_.emplace(id, temp->second); 55 auto res = std::move(temp->second.task_); 56 tasks_.erase(temp); 57 indexes_.erase(id); 58 return res; 59 } 60 return INVALID_TSK; 61 } 62 Push(_Tsk tsk,_Tid id,_Tme tme)63 bool Push(_Tsk tsk, _Tid id, _Tme tme) 64 { 65 std::unique_lock<std::mutex> lock(pqMtx_); 66 if (!tsk.Valid()) { 67 return false; 68 } 69 auto temp = tasks_.emplace(tme, PQMatrix(std::move(tsk), id)); 70 indexes_.emplace(id, temp); 71 popCv_.notify_all(); 72 return true; 73 } 74 Size()75 size_t Size() 76 { 77 std::lock_guard<std::mutex> lock(pqMtx_); 78 return tasks_.size(); 79 } 80 Find(_Tid id)81 _Tsk Find(_Tid id) 82 { 83 std::unique_lock<decltype(pqMtx_)> lock(pqMtx_); 84 if (indexes_.find(id) != indexes_.end()) { 85 return indexes_[id]->second.task_; 86 } 87 return INVALID_TSK; 88 } 89 Update(_Tid id,TskUpdater updater)90 bool Update(_Tid id, TskUpdater updater) 91 { 92 std::unique_lock<decltype(pqMtx_)> lock(pqMtx_); 93 auto index = indexes_.find(id); 94 if (index != indexes_.end()) { 95 auto [updated, time] = updater(index->second->second.task_); 96 if (!updated) { 97 return false; 98 } 99 auto matrix = std::move(index->second->second); 100 tasks_.erase(index->second); 101 index->second = tasks_.emplace(time, std::move(matrix)); 102 popCv_.notify_all(); 103 return true; 104 } 105 106 auto running = running_.find(id); 107 if (running != running_.end()) { 108 auto [updated, time] = updater((*running).second.task_); 109 return updated; 110 } 111 112 return false; 113 } 114 Remove(_Tid id,bool wait)115 bool Remove(_Tid id, bool wait) 116 { 117 std::unique_lock<decltype(pqMtx_)> lock(pqMtx_); 118 removeCv_.wait(lock, [this, id, wait] { 119 return !wait || running_.find(id) == running_.end(); 120 }); 121 auto index = indexes_.find(id); 122 if (index == indexes_.end()) { 123 return false; 124 } 125 tasks_.erase(index->second); 126 indexes_.erase(index); 127 popCv_.notify_all(); 128 return true; 129 } 130 Clean()131 void Clean() 132 { 133 std::unique_lock<decltype(pqMtx_)> lock(pqMtx_); 134 indexes_.clear(); 135 tasks_.clear(); 136 popCv_.notify_all(); 137 } 138 Finish(_Tid id)139 void Finish(_Tid id) 140 { 141 std::unique_lock<decltype(pqMtx_)> lock(pqMtx_); 142 auto it = running_.find(id); 143 if (it == running_.end()) { 144 return; 145 } 146 auto [repeat, time] = updater_(it->second.task_); 147 if (repeat) { 148 indexes_.emplace(id, tasks_.emplace(time, std::move(it->second))); 149 } 150 running_.erase(it); 151 removeCv_.notify_all(); 152 } 153 154 private: 155 const _Tsk INVALID_TSK; 156 std::mutex pqMtx_; 157 std::condition_variable popCv_; 158 std::condition_variable removeCv_; 159 std::multimap<_Tme, PQMatrix> tasks_; 160 std::map<_Tid, PQMatrix> running_; 161 std::map<_Tid, TskIndex> indexes_; 162 TskUpdater updater_; 163 }; 164 } // namespace OHOS 165 #endif //OHOS_DISTRIBUTED_DATA_FRAMEWORKS_COMMON_PRIORITY_QUEUE_H 166