• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
18 #define MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
19 #include <atomic>
20 #include <vector>
21 
22 namespace mindspore {
23 // implement a lock-free queue
24 // refer to https://www.cs.rochester.edu/u/scott/papers/1996_PODC_queues.pdf
25 template <typename T>
26 class HQueue;
27 struct Pointer {
28   int32_t index = -1;
29   uint32_t version = 0;
30   bool operator==(const Pointer &that) const { return (index == that.index && version == that.version); }
31   bool operator!=(const Pointer &that) const { return !(*this == that); }
32 };
33 
34 template <typename T>
35 struct HQNode {
36   std::atomic<Pointer> next;
37   T *value = nullptr;
38   std::atomic_bool free = {true};
39 };
40 
41 template <typename T>
42 class HQueue {
43  public:
44   HQueue(const HQueue &) = delete;
45   HQueue &operator=(const HQueue &) = delete;
HQueue()46   HQueue() {}
~HQueue()47   virtual ~HQueue() {}
48 
IsInit()49   bool IsInit() const { return nodes.size() != 0; }
50 
Init(int32_t sz)51   bool Init(int32_t sz) {
52     if (IsInit() || sz <= 0) {
53       return false;
54     }
55     for (int32_t i = 0; i < sz; i++) {
56       auto node = new HQNode<T>();
57       if (node == nullptr) {
58         Clean();
59         return false;
60       }
61       node->value = nullptr;
62       node->free = true;
63       node->next = {-1, 0};
64       nodes.emplace_back(node);
65     }
66 
67     // init first node as dummy head
68     qhead = {0, 0};
69     qtail = {0, 0};
70     nodes[0]->free = false;
71     queue_size = sz;
72     free_index = 1;
73     return true;
74   }
75 
Clean()76   void Clean() {
77     for (auto node : nodes) {
78       delete node;
79     }
80     nodes.clear();
81   }
82 
Enqueue(T * t)83   bool Enqueue(T *t) {
84     HQNode<T> *node = nullptr;
85     int32_t nodeIdx = free_index;
86     for (; nodeIdx < queue_size; ++nodeIdx) {
87       bool expected = true;
88       if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) {
89         node = nodes[nodeIdx];
90         free_index = nodeIdx + 1;
91         break;
92       }
93     }
94     if (node == nullptr) {
95       free_index = 1;
96       for (nodeIdx = 1; nodeIdx < queue_size; ++nodeIdx) {
97         bool expected = true;
98         if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) {
99           node = nodes[nodeIdx];
100           free_index = nodeIdx + 1;
101           break;
102         }
103       }
104       if (node == nullptr) {
105         return false;
106       }
107     }
108 
109     node->value = t;
110     node->next = {-1, 0};
111 
112     while (true) {
113       Pointer tail = qtail;
114       if (tail.index == -1) {
115         continue;
116       }
117       Pointer next = nodes[tail.index]->next;
118 
119       if (tail != this->qtail) {
120         continue;
121       }
122 
123       if (next.index != -1) {
124         this->qtail.compare_exchange_strong(tail, {next.index, tail.version + 1});
125         continue;
126       }
127 
128       if (nodes[tail.index]->next.compare_exchange_strong(next, {nodeIdx, next.version + 1})) {
129         this->qtail.compare_exchange_strong(tail, {nodeIdx, tail.version + 1});
130         break;
131       }
132     }
133 
134     return true;
135   }
136 
Dequeue()137   T *Dequeue() {
138     while (true) {
139       T *ret = nullptr;
140       Pointer head = qhead;
141       Pointer tail = qtail;
142       if (head.index == -1) {
143         continue;
144       }
145       Pointer next = nodes[head.index]->next;
146 
147       if (head != this->qhead) {
148         continue;
149       }
150 
151       if (head.index == tail.index) {
152         if (next.index == -1) {
153           return nullptr;
154         }
155         this->qtail.compare_exchange_strong(tail, {next.index, tail.version + 1});
156       } else {
157         if (next.index == -1) {
158           continue;
159         }
160         ret = nodes[next.index]->value;
161         if (this->qhead.compare_exchange_strong(head, {next.index, head.version + 1})) {
162           // free head
163           nodes[head.index]->free = true;
164           return ret;
165         }
166       }
167     }
168   }
169 
Empty()170   bool Empty() {
171     Pointer head = qhead;
172     Pointer tail = qtail;
173     if (head.index < 0) {
174       return false;
175     }
176     Pointer next = nodes[head.index]->next;
177 
178     if (head == this->qhead && head.index == tail.index && next.index == -1) {
179       return true;
180     }
181 
182     return false;
183   }
184 
185  private:
186   std::atomic<Pointer> qhead;
187   std::atomic<Pointer> qtail;
188   std::vector<HQNode<T> *> nodes;
189   int32_t queue_size{};
190   std::atomic<int32_t> free_index;
191 };
192 }  // namespace mindspore
193 
194 #endif  // MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
195