• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_ticket_h
16 #define marl_ticket_h
17 
18 #include "conditionvariable.h"
19 #include "pool.h"
20 #include "scheduler.h"
21 
22 namespace marl {
23 
24 // Ticket is a synchronization primitive used to serially order execution.
25 //
26 // Tickets exist in 3 mutually exclusive states: Waiting, Called and Finished.
27 //
28 // Tickets are obtained from a Ticket::Queue, using the Ticket::Queue::take()
29 // methods. The order in which tickets are taken from the queue dictates the
30 // order in which they are called.
31 //
32 // The first ticket to be taken from a queue will be in the 'called' state,
33 // subsequent tickets will be in the 'waiting' state.
34 //
35 // Ticket::wait() will block until the ticket is called.
36 //
37 // Ticket::done() moves the ticket into the 'finished' state. If all preceeding
38 // tickets are finished, done() will call the next unfinished ticket.
39 //
40 // If the last remaining reference to an unfinished ticket is dropped then
41 // done() will be automatically called on that ticket.
42 //
43 // Example:
44 //
45 //  void runTasksConcurrentThenSerially(int numConcurrentTasks)
46 //  {
47 //      marl::Ticket::Queue queue;
48 //      for (int i = 0; i < numConcurrentTasks; i++)
49 //      {
50 //          auto ticket = queue.take();
51 //          marl::schedule([=] {
52 //              doConcurrentWork(); // <- function may be called concurrently
53 //              ticket.wait(); // <- serialize tasks
54 //              doSerialWork(); // <- function will not be called concurrently
55 //              ticket.done(); // <- optional, as done() is called implicitly on
56 //                             // dropping of last reference
57 //          });
58 //      }
59 //  }
60 class Ticket {
61   struct Shared;
62   struct Record;
63 
64  public:
65   using OnCall = std::function<void()>;
66 
67   // Queue hands out Tickets.
68   class Queue {
69    public:
70     // take() returns a single ticket from the queue.
71     MARL_NO_EXPORT inline Ticket take();
72 
73     // take() retrieves count tickets from the queue, calling f() with each
74     // retrieved ticket.
75     // F must be a function of the signature: void(Ticket&&)
76     template <typename F>
77     MARL_NO_EXPORT inline void take(size_t count, const F& f);
78 
79    private:
80     std::shared_ptr<Shared> shared = std::make_shared<Shared>();
81     UnboundedPool<Record> pool;
82   };
83 
84   MARL_NO_EXPORT inline Ticket() = default;
85   MARL_NO_EXPORT inline Ticket(const Ticket& other) = default;
86   MARL_NO_EXPORT inline Ticket(Ticket&& other) = default;
87   MARL_NO_EXPORT inline Ticket& operator=(const Ticket& other) = default;
88 
89   // wait() blocks until the ticket is called.
90   MARL_NO_EXPORT inline void wait() const;
91 
92   // done() marks the ticket as finished and calls the next ticket.
93   MARL_NO_EXPORT inline void done() const;
94 
95   // onCall() registers the function f to be invoked when this ticket is
96   // called. If the ticket is already called prior to calling onCall(), then
97   // f() will be executed immediately.
98   // F must be a function of the OnCall signature.
99   template <typename F>
100   MARL_NO_EXPORT inline void onCall(F&& f) const;
101 
102  private:
103   // Internal doubly-linked-list data structure. One per ticket instance.
104   struct Record {
105     MARL_NO_EXPORT inline ~Record();
106 
107     MARL_NO_EXPORT inline void done();
108     MARL_NO_EXPORT inline void callAndUnlock(marl::lock& lock);
109     MARL_NO_EXPORT inline void unlink();  // guarded by shared->mutex
110 
111     ConditionVariable isCalledCondVar;
112 
113     std::shared_ptr<Shared> shared;
114     Record* next = nullptr;  // guarded by shared->mutex
115     Record* prev = nullptr;  // guarded by shared->mutex
116     OnCall onCall;           // guarded by shared->mutex
117     bool isCalled = false;   // guarded by shared->mutex
118     std::atomic<bool> isDone = {false};
119   };
120 
121   // Data shared between all tickets and the queue.
122   struct Shared {
123     marl::mutex mutex;
124     Record tail;
125   };
126 
127   MARL_NO_EXPORT inline Ticket(Loan<Record>&& record);
128 
129   Loan<Record> record;
130 };
131 
132 ////////////////////////////////////////////////////////////////////////////////
133 // Ticket
134 ////////////////////////////////////////////////////////////////////////////////
135 
Ticket(Loan<Record> && record)136 Ticket::Ticket(Loan<Record>&& record) : record(std::move(record)) {}
137 
wait()138 void Ticket::wait() const {
139   marl::lock lock(record->shared->mutex);
140   record->isCalledCondVar.wait(lock, [this] { return record->isCalled; });
141 }
142 
done()143 void Ticket::done() const {
144   record->done();
145 }
146 
147 template <typename Function>
onCall(Function && f)148 void Ticket::onCall(Function&& f) const {
149   marl::lock lock(record->shared->mutex);
150   if (record->isCalled) {
151     marl::schedule(std::forward<Function>(f));
152     return;
153   }
154   if (record->onCall) {
155     struct Joined {
156       void operator()() const {
157         a();
158         b();
159       }
160       OnCall a, b;
161     };
162     record->onCall =
163         std::move(Joined{std::move(record->onCall), std::forward<Function>(f)});
164   } else {
165     record->onCall = std::forward<Function>(f);
166   }
167 }
168 
169 ////////////////////////////////////////////////////////////////////////////////
170 // Ticket::Queue
171 ////////////////////////////////////////////////////////////////////////////////
172 
take()173 Ticket Ticket::Queue::take() {
174   Ticket out;
175   take(1, [&](Ticket&& ticket) { out = std::move(ticket); });
176   return out;
177 }
178 
179 template <typename F>
take(size_t n,const F & f)180 void Ticket::Queue::take(size_t n, const F& f) {
181   Loan<Record> first, last;
182   pool.borrow(n, [&](Loan<Record>&& record) {
183     Loan<Record> rec = std::move(record);
184     rec->shared = shared;
185     if (first.get() == nullptr) {
186       first = rec;
187     }
188     if (last.get() != nullptr) {
189       last->next = rec.get();
190       rec->prev = last.get();
191     }
192     last = rec;
193     f(std::move(Ticket(std::move(rec))));
194   });
195   last->next = &shared->tail;
196   marl::lock lock(shared->mutex);
197   first->prev = shared->tail.prev;
198   shared->tail.prev = last.get();
199   if (first->prev == nullptr) {
200     first->callAndUnlock(lock);
201   } else {
202     first->prev->next = first.get();
203   }
204 }
205 
206 ////////////////////////////////////////////////////////////////////////////////
207 // Ticket::Record
208 ////////////////////////////////////////////////////////////////////////////////
209 
~Record()210 Ticket::Record::~Record() {
211   if (shared != nullptr) {
212     done();
213   }
214 }
215 
done()216 void Ticket::Record::done() {
217   if (isDone.exchange(true)) {
218     return;
219   }
220   marl::lock lock(shared->mutex);
221   auto callNext = (prev == nullptr && next != nullptr) ? next : nullptr;
222   unlink();
223   if (callNext != nullptr) {
224     // lock needs to be held otherwise callNext might be destructed.
225     callNext->callAndUnlock(lock);
226   }
227 }
228 
callAndUnlock(marl::lock & lock)229 void Ticket::Record::callAndUnlock(marl::lock& lock) {
230   if (isCalled) {
231     return;
232   }
233   isCalled = true;
234   OnCall callback;
235   std::swap(callback, onCall);
236   isCalledCondVar.notify_all();
237   lock.unlock_no_tsa();
238 
239   if (callback) {
240     marl::schedule(std::move(callback));
241   }
242 }
243 
unlink()244 void Ticket::Record::unlink() {
245   if (prev != nullptr) {
246     prev->next = next;
247   }
248   if (next != nullptr) {
249     next->prev = prev;
250   }
251   prev = nullptr;
252   next = nullptr;
253 }
254 
255 }  // namespace marl
256 
257 #endif  // marl_ticket_h
258