• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
18 #define MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
19 #include <atomic>
20 #include <vector>
21 
22 namespace mindspore {
23 // implement a lock-free queue
24 // refer to https://www.cs.rochester.edu/u/scott/papers/1996_PODC_queues.pdf
25 
26 template <typename T>
27 class HQueue;
28 
29 struct Pointer {
30   int32_t index = -1;
31   uint32_t version = 0;
32   bool operator==(const Pointer &that) { return (index == that.index && version == that.version); }
33   bool operator!=(const Pointer &that) { return !(*this == that); }
34 };
35 
36 template <typename T>
37 struct HQNode {
38   std::atomic<Pointer> next;
39   T *value = nullptr;
40   std::atomic_bool free = {true};
41 };
42 
43 template <typename T>
44 class HQueue {
45  public:
46   HQueue(const HQueue &) = delete;
47   HQueue &operator=(const HQueue &) = delete;
HQueue()48   HQueue() {}
~HQueue()49   virtual ~HQueue() {}
50 
Init(int32_t sz)51   bool Init(int32_t sz) {
52     for (int32_t i = 0; i < sz; i++) {
53       auto node = new HQNode<T>();
54       if (node == nullptr) {
55         Clean();
56         return false;
57       }
58       node->value = nullptr;
59       node->free = true;
60       node->next = {-1, 0};
61       nodes.emplace_back(node);
62     }
63 
64     // init first node as dummy head
65     qhead = {0, 0};
66     qtail = {0, 0};
67     nodes[0]->free = false;
68     return true;
69   }
70 
Clean()71   void Clean() {
72     for (auto node : nodes) {
73       delete node;
74     }
75     nodes.clear();
76   }
77 
Enqueue(T * t)78   bool Enqueue(T *t) {
79     HQNode<T> *node = nullptr;
80     int32_t nodeIdx;
81     for (nodeIdx = 0; nodeIdx < static_cast<int32_t>(nodes.size()); nodeIdx++) {
82       bool expected = true;
83       if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) {
84         node = nodes[nodeIdx];
85         break;
86       }
87     }
88     if (node == nullptr) {
89       return false;
90     }
91     node->value = t;
92     node->next = {-1, 0};
93 
94     while (true) {
95       Pointer tail = qtail;
96       if (tail.index == -1) {
97         continue;
98       }
99       Pointer next = nodes[tail.index]->next;
100 
101       if (tail != this->qtail) {
102         continue;
103       }
104 
105       if (next.index != -1) {
106         this->qtail.compare_exchange_strong(tail, {next.index, tail.version + 1});
107         continue;
108       }
109 
110       if (nodes[tail.index]->next.compare_exchange_strong(next, {nodeIdx, next.version + 1})) {
111         this->qtail.compare_exchange_strong(tail, {nodeIdx, tail.version + 1});
112         break;
113       }
114     }
115 
116     return true;
117   }
118 
Dequeue()119   T *Dequeue() {
120     while (true) {
121       T *ret = nullptr;
122       Pointer head = qhead;
123       Pointer tail = qtail;
124       if (head.index == -1) {
125         continue;
126       }
127       Pointer next = nodes[head.index]->next;
128 
129       if (head != this->qhead) {
130         continue;
131       }
132 
133       if (head.index == tail.index) {
134         if (next.index == -1) {
135           return nullptr;
136         }
137         this->qtail.compare_exchange_strong(tail, {next.index, tail.version + 1});
138       } else {
139         if (next.index == -1) {
140           continue;
141         }
142         ret = nodes[next.index]->value;
143         if (this->qhead.compare_exchange_strong(head, {next.index, head.version + 1})) {
144           // free head
145           nodes[head.index]->free = true;
146           return ret;
147         }
148       }
149     }
150   }
151 
Empty()152   bool Empty() {
153     Pointer head = qhead;
154     Pointer tail = qtail;
155     if (head.index < 0) {
156       return false;
157     }
158     Pointer next = nodes[head.index]->next;
159 
160     if (head == this->qhead && head.index == tail.index && next.index == -1) {
161       return true;
162     }
163 
164     return false;
165   }
166 
167  private:
168   std::atomic<Pointer> qhead;
169   std::atomic<Pointer> qtail;
170   std::vector<HQNode<T> *> nodes;
171 };
172 
173 }  // namespace mindspore
174 
175 #endif  // MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
176