1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_ticket_h
16 #define marl_ticket_h
17
18 #include "conditionvariable.h"
19 #include "pool.h"
20 #include "scheduler.h"
21
22 namespace marl {
23
24 // Ticket is a synchronization primitive used to serially order execution.
25 //
26 // Tickets exist in 3 mutually exclusive states: Waiting, Called and Finished.
27 //
28 // Tickets are obtained from a Ticket::Queue, using the Ticket::Queue::take()
29 // methods. The order in which tickets are taken from the queue dictates the
30 // order in which they are called.
31 //
32 // The first ticket to be taken from a queue will be in the 'called' state,
33 // subsequent tickets will be in the 'waiting' state.
34 //
35 // Ticket::wait() will block until the ticket is called.
36 //
37 // Ticket::done() moves the ticket into the 'finished' state. If all preceeding
38 // tickets are finished, done() will call the next unfinished ticket.
39 //
40 // If the last remaining reference to an unfinished ticket is dropped then
41 // done() will be automatically called on that ticket.
42 //
43 // Example:
44 //
45 // void runTasksConcurrentThenSerially(int numConcurrentTasks)
46 // {
47 // marl::Ticket::Queue queue;
48 // for (int i = 0; i < numConcurrentTasks; i++)
49 // {
50 // auto ticket = queue.take();
51 // marl::schedule([=] {
52 // doConcurrentWork(); // <- function may be called concurrently
53 // ticket.wait(); // <- serialize tasks
54 // doSerialWork(); // <- function will not be called concurrently
55 // ticket.done(); // <- optional, as done() is called implicitly on
56 // // dropping of last reference
57 // });
58 // }
59 // }
60 class Ticket {
61 struct Shared;
62 struct Record;
63
64 public:
65 using OnCall = std::function<void()>;
66
67 // Queue hands out Tickets.
68 class Queue {
69 public:
70 // take() returns a single ticket from the queue.
71 MARL_NO_EXPORT inline Ticket take();
72
73 // take() retrieves count tickets from the queue, calling f() with each
74 // retrieved ticket.
75 // F must be a function of the signature: void(Ticket&&)
76 template <typename F>
77 MARL_NO_EXPORT inline void take(size_t count, const F& f);
78
79 private:
80 std::shared_ptr<Shared> shared = std::make_shared<Shared>();
81 UnboundedPool<Record> pool;
82 };
83
84 MARL_NO_EXPORT inline Ticket() = default;
85 MARL_NO_EXPORT inline Ticket(const Ticket& other) = default;
86 MARL_NO_EXPORT inline Ticket(Ticket&& other) = default;
87 MARL_NO_EXPORT inline Ticket& operator=(const Ticket& other) = default;
88
89 // wait() blocks until the ticket is called.
90 MARL_NO_EXPORT inline void wait() const;
91
92 // done() marks the ticket as finished and calls the next ticket.
93 MARL_NO_EXPORT inline void done() const;
94
95 // onCall() registers the function f to be invoked when this ticket is
96 // called. If the ticket is already called prior to calling onCall(), then
97 // f() will be executed immediately.
98 // F must be a function of the OnCall signature.
99 template <typename F>
100 MARL_NO_EXPORT inline void onCall(F&& f) const;
101
102 private:
103 // Internal doubly-linked-list data structure. One per ticket instance.
104 struct Record {
105 MARL_NO_EXPORT inline ~Record();
106
107 MARL_NO_EXPORT inline void done();
108 MARL_NO_EXPORT inline void callAndUnlock(marl::lock& lock);
109 MARL_NO_EXPORT inline void unlink(); // guarded by shared->mutex
110
111 ConditionVariable isCalledCondVar;
112
113 std::shared_ptr<Shared> shared;
114 Record* next = nullptr; // guarded by shared->mutex
115 Record* prev = nullptr; // guarded by shared->mutex
116 OnCall onCall; // guarded by shared->mutex
117 bool isCalled = false; // guarded by shared->mutex
118 std::atomic<bool> isDone = {false};
119 };
120
121 // Data shared between all tickets and the queue.
122 struct Shared {
123 marl::mutex mutex;
124 Record tail;
125 };
126
127 MARL_NO_EXPORT inline Ticket(Loan<Record>&& record);
128
129 Loan<Record> record;
130 };
131
132 ////////////////////////////////////////////////////////////////////////////////
133 // Ticket
134 ////////////////////////////////////////////////////////////////////////////////
135
Ticket(Loan<Record> && record)136 Ticket::Ticket(Loan<Record>&& record) : record(std::move(record)) {}
137
wait()138 void Ticket::wait() const {
139 marl::lock lock(record->shared->mutex);
140 record->isCalledCondVar.wait(lock, [this] { return record->isCalled; });
141 }
142
done()143 void Ticket::done() const {
144 record->done();
145 }
146
147 template <typename Function>
onCall(Function && f)148 void Ticket::onCall(Function&& f) const {
149 marl::lock lock(record->shared->mutex);
150 if (record->isCalled) {
151 marl::schedule(std::forward<Function>(f));
152 return;
153 }
154 if (record->onCall) {
155 struct Joined {
156 void operator()() const {
157 a();
158 b();
159 }
160 OnCall a, b;
161 };
162 record->onCall =
163 std::move(Joined{std::move(record->onCall), std::forward<Function>(f)});
164 } else {
165 record->onCall = std::forward<Function>(f);
166 }
167 }
168
169 ////////////////////////////////////////////////////////////////////////////////
170 // Ticket::Queue
171 ////////////////////////////////////////////////////////////////////////////////
172
take()173 Ticket Ticket::Queue::take() {
174 Ticket out;
175 take(1, [&](Ticket&& ticket) { out = std::move(ticket); });
176 return out;
177 }
178
179 template <typename F>
take(size_t n,const F & f)180 void Ticket::Queue::take(size_t n, const F& f) {
181 Loan<Record> first, last;
182 pool.borrow(n, [&](Loan<Record>&& record) {
183 Loan<Record> rec = std::move(record);
184 rec->shared = shared;
185 if (first.get() == nullptr) {
186 first = rec;
187 }
188 if (last.get() != nullptr) {
189 last->next = rec.get();
190 rec->prev = last.get();
191 }
192 last = rec;
193 f(std::move(Ticket(std::move(rec))));
194 });
195 last->next = &shared->tail;
196 marl::lock lock(shared->mutex);
197 first->prev = shared->tail.prev;
198 shared->tail.prev = last.get();
199 if (first->prev == nullptr) {
200 first->callAndUnlock(lock);
201 } else {
202 first->prev->next = first.get();
203 }
204 }
205
206 ////////////////////////////////////////////////////////////////////////////////
207 // Ticket::Record
208 ////////////////////////////////////////////////////////////////////////////////
209
~Record()210 Ticket::Record::~Record() {
211 if (shared != nullptr) {
212 done();
213 }
214 }
215
done()216 void Ticket::Record::done() {
217 if (isDone.exchange(true)) {
218 return;
219 }
220 marl::lock lock(shared->mutex);
221 auto callNext = (prev == nullptr && next != nullptr) ? next : nullptr;
222 unlink();
223 if (callNext != nullptr) {
224 // lock needs to be held otherwise callNext might be destructed.
225 callNext->callAndUnlock(lock);
226 }
227 }
228
callAndUnlock(marl::lock & lock)229 void Ticket::Record::callAndUnlock(marl::lock& lock) {
230 if (isCalled) {
231 return;
232 }
233 isCalled = true;
234 OnCall callback;
235 std::swap(callback, onCall);
236 isCalledCondVar.notify_all();
237 lock.unlock_no_tsa();
238
239 if (callback) {
240 marl::schedule(std::move(callback));
241 }
242 }
243
unlink()244 void Ticket::Record::unlink() {
245 if (prev != nullptr) {
246 prev->next = next;
247 }
248 if (next != nullptr) {
249 next->prev = prev;
250 }
251 prev = nullptr;
252 next = nullptr;
253 }
254
255 } // namespace marl
256
257 #endif // marl_ticket_h
258