• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // marl::DAG<> provides an ahead of time, declarative, directed acyclic
16 // task graph.
17 
18 #ifndef marl_dag_h
19 #define marl_dag_h
20 
21 #include "containers.h"
22 #include "export.h"
23 #include "memory.h"
24 #include "scheduler.h"
25 #include "waitgroup.h"
26 
27 namespace marl {
28 namespace detail {
29 using DAGCounter = std::atomic<uint32_t>;
30 template <typename T>
31 struct DAGRunContext {
32   T data;
33   Allocator::unique_ptr<DAGCounter> counters;
34 
35   template <typename F>
invokeDAGRunContext36   MARL_NO_EXPORT inline void invoke(F&& f) {
37     f(data);
38   }
39 };
40 template <>
41 struct DAGRunContext<void> {
42   Allocator::unique_ptr<DAGCounter> counters;
43 
44   template <typename F>
45   MARL_NO_EXPORT inline void invoke(F&& f) {
46     f();
47   }
48 };
49 template <typename T>
50 struct DAGWork {
51   using type = std::function<void(T)>;
52 };
53 template <>
54 struct DAGWork<void> {
55   using type = std::function<void()>;
56 };
57 }  // namespace detail
58 
59 ///////////////////////////////////////////////////////////////////////////////
60 // Forward declarations
61 ///////////////////////////////////////////////////////////////////////////////
62 template <typename T>
63 class DAG;
64 
65 template <typename T>
66 class DAGBuilder;
67 
68 template <typename T>
69 class DAGNodeBuilder;
70 
71 ///////////////////////////////////////////////////////////////////////////////
72 // DAGBase<T>
73 ///////////////////////////////////////////////////////////////////////////////
74 
75 // DAGBase is derived by DAG<T> and DAG<void>. It has no public API.
76 template <typename T>
77 class DAGBase {
78  protected:
79   friend DAGBuilder<T>;
80   friend DAGNodeBuilder<T>;
81 
82   using RunContext = detail::DAGRunContext<T>;
83   using Counter = detail::DAGCounter;
84   using NodeIndex = size_t;
85   using Work = typename detail::DAGWork<T>::type;
86   static const constexpr size_t NumReservedNodes = 32;
87   static const constexpr size_t NumReservedNumOuts = 4;
88   static const constexpr size_t InvalidCounterIndex = ~static_cast<size_t>(0);
89   static const constexpr NodeIndex RootIndex = 0;
90   static const constexpr NodeIndex InvalidNodeIndex =
91       ~static_cast<NodeIndex>(0);
92 
93   // DAG work node.
94   struct Node {
95     MARL_NO_EXPORT inline Node() = default;
96     MARL_NO_EXPORT inline Node(Work&& work);
97 
98     // The work to perform for this node in the graph.
99     Work work;
100 
101     // counterIndex if valid, is the index of the counter in the RunContext for
102     // this node. The counter is decremented for each completed dependency task
103     // (ins), and once it reaches 0, this node will be invoked.
104     size_t counterIndex = InvalidCounterIndex;
105 
106     // Indices for all downstream nodes.
107     containers::vector<NodeIndex, NumReservedNumOuts> outs;
108   };
109 
110   // initCounters() allocates and initializes the ctx->coutners from
111   // initialCounters.
112   MARL_NO_EXPORT inline void initCounters(RunContext* ctx,
113                                           Allocator* allocator);
114 
115   // notify() is called each time a dependency task (ins) has completed for the
116   // node with the given index.
117   // If all dependency tasks have completed (or this is the root node) then
118   // notify() returns true and the caller should then call invoke().
119   MARL_NO_EXPORT inline bool notify(RunContext*, NodeIndex);
120 
121   // invoke() calls the work function for the node with the given index, then
122   // calls notify() and possibly invoke() for all the dependee nodes.
123   MARL_NO_EXPORT inline void invoke(RunContext*, NodeIndex, WaitGroup*);
124 
125   // nodes is the full list of the nodes in the graph.
126   // nodes[0] is always the root node, which has no dependencies (ins).
127   containers::vector<Node, NumReservedNodes> nodes;
128 
129   // initialCounters is a list of initial counter values to be copied to
130   // RunContext::counters on DAG<>::run().
131   // initialCounters is indexed by Node::counterIndex, and only contains counts
132   // for nodes that have at least 2 dependencies (ins) - because of this the
133   // number of entries in initialCounters may be fewer than nodes.
134   containers::vector<uint32_t, NumReservedNodes> initialCounters;
135 };
136 
137 template <typename T>
138 DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {}
139 
140 template <typename T>
141 void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) {
142   auto numCounters = initialCounters.size();
143   ctx->counters = allocator->make_unique_n<Counter>(numCounters);
144   for (size_t i = 0; i < numCounters; i++) {
145     ctx->counters.get()[i] = {initialCounters[i]};
146   }
147 }
148 
149 template <typename T>
150 bool DAGBase<T>::notify(RunContext* ctx, NodeIndex nodeIdx) {
151   Node* node = &nodes[nodeIdx];
152 
153   // If we have multiple dependencies, decrement the counter and check whether
154   // we've reached 0.
155   if (node->counterIndex == InvalidCounterIndex) {
156     return true;
157   }
158   auto counters = ctx->counters.get();
159   auto counter = --counters[node->counterIndex];
160   return counter == 0;
161 }
162 
163 template <typename T>
164 void DAGBase<T>::invoke(RunContext* ctx, NodeIndex nodeIdx, WaitGroup* wg) {
165   Node* node = &nodes[nodeIdx];
166 
167   // Run this node's work.
168   if (node->work) {
169     ctx->invoke(node->work);
170   }
171 
172   // Then call notify() on all dependees (outs), and invoke() those that
173   // returned true.
174   // We buffer the node to invoke (toInvoke) so we can schedule() all but the
175   // last node to invoke(), and directly call the last invoke() on this thread.
176   // This is done to avoid the overheads of scheduling when a direct call would
177   // suffice.
178   NodeIndex toInvoke = InvalidNodeIndex;
179   for (NodeIndex idx : node->outs) {
180     if (notify(ctx, idx)) {
181       if (toInvoke != InvalidNodeIndex) {
182         wg->add(1);
183         // Schedule while promoting the WaitGroup capture from a pointer
184         // reference to a value. This ensures that the WaitGroup isn't dropped
185         // while in use.
186         schedule(
187             [=](WaitGroup wg) {
188               invoke(ctx, toInvoke, &wg);
189               wg.done();
190             },
191             *wg);
192       }
193       toInvoke = idx;
194     }
195   }
196   if (toInvoke != InvalidNodeIndex) {
197     invoke(ctx, toInvoke, wg);
198   }
199 }
200 
201 ///////////////////////////////////////////////////////////////////////////////
202 // DAGNodeBuilder<T>
203 ///////////////////////////////////////////////////////////////////////////////
204 
205 // DAGNodeBuilder is the builder interface for a DAG node.
206 template <typename T>
207 class DAGNodeBuilder {
208   using NodeIndex = typename DAGBase<T>::NodeIndex;
209 
210  public:
211   // then() builds and returns a new DAG node that will be invoked after this
212   // node has completed.
213   //
214   // F is a function that will be called when the new DAG node is invoked, with
215   // the signature:
216   //   void(T)   when T is not void
217   // or
218   //   void()    when T is void
219   template <typename F>
220   MARL_NO_EXPORT inline DAGNodeBuilder then(F&&);
221 
222  private:
223   friend DAGBuilder<T>;
224   MARL_NO_EXPORT inline DAGNodeBuilder(DAGBuilder<T>*, NodeIndex);
225   DAGBuilder<T>* builder;
226   NodeIndex index;
227 };
228 
229 template <typename T>
230 DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index)
231     : builder(builder), index(index) {}
232 
233 template <typename T>
234 template <typename F>
235 DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) {
236   auto node = builder->node(std::move(work));
237   builder->addDependency(*this, node);
238   return node;
239 }
240 
241 ///////////////////////////////////////////////////////////////////////////////
242 // DAGBuilder<T>
243 ///////////////////////////////////////////////////////////////////////////////
244 template <typename T>
245 class DAGBuilder {
246  public:
247   // DAGBuilder constructor
248   MARL_NO_EXPORT inline DAGBuilder(Allocator* allocator = Allocator::Default);
249 
250   // root() returns the root DAG node.
251   MARL_NO_EXPORT inline DAGNodeBuilder<T> root();
252 
253   // node() builds and returns a new DAG node with no initial dependencies.
254   // The returned node must be attached to the graph in order to invoke F or any
255   // of the dependees of this returned node.
256   //
257   // F is a function that will be called when the new DAG node is invoked, with
258   // the signature:
259   //   void(T)   when T is not void
260   // or
261   //   void()    when T is void
262   template <typename F>
263   MARL_NO_EXPORT inline DAGNodeBuilder<T> node(F&& work);
264 
265   // node() builds and returns a new DAG node that depends on all the tasks in
266   // after to be completed before invoking F.
267   //
268   // F is a function that will be called when the new DAG node is invoked, with
269   // the signature:
270   //   void(T)   when T is not void
271   // or
272   //   void()    when T is void
273   template <typename F>
274   MARL_NO_EXPORT inline DAGNodeBuilder<T> node(
275       F&& work,
276       std::initializer_list<DAGNodeBuilder<T>> after);
277 
278   // addDependency() adds parent as dependency on child. All dependencies of
279   // child must have completed before child is invoked.
280   MARL_NO_EXPORT inline void addDependency(DAGNodeBuilder<T> parent,
281                                            DAGNodeBuilder<T> child);
282 
283   // build() constructs and returns the DAG. No other methods of this class may
284   // be called after calling build().
285   MARL_NO_EXPORT inline Allocator::unique_ptr<DAG<T>> build();
286 
287  private:
288   static const constexpr size_t NumReservedNumIns = 4;
289   using Node = typename DAG<T>::Node;
290 
291   // The DAG being built.
292   Allocator::unique_ptr<DAG<T>> dag;
293 
294   // Number of dependencies (ins) for each node in dag->nodes.
295   containers::vector<uint32_t, NumReservedNumIns> numIns;
296 };
297 
298 template <typename T>
299 DAGBuilder<T>::DAGBuilder(Allocator* allocator /* = Allocator::Default */)
300     : dag(allocator->make_unique<DAG<T>>()), numIns(allocator) {
301   // Add root
302   dag->nodes.emplace_back(Node{});
303   numIns.emplace_back(0);
304 }
305 
306 template <typename T>
307 DAGNodeBuilder<T> DAGBuilder<T>::root() {
308   return DAGNodeBuilder<T>{this, DAGBase<T>::RootIndex};
309 }
310 
311 template <typename T>
312 template <typename F>
313 DAGNodeBuilder<T> DAGBuilder<T>::node(F&& work) {
314   return node(std::forward<F>(work), {});
315 }
316 
317 template <typename T>
318 template <typename F>
319 DAGNodeBuilder<T> DAGBuilder<T>::node(
320     F&& work,
321     std::initializer_list<DAGNodeBuilder<T>> after) {
322   MARL_ASSERT(numIns.size() == dag->nodes.size(),
323               "NodeBuilder vectors out of sync");
324   auto index = dag->nodes.size();
325   numIns.emplace_back(0);
326   dag->nodes.emplace_back(Node{std::move(work)});
327   auto node = DAGNodeBuilder<T>{this, index};
328   for (auto in : after) {
329     addDependency(in, node);
330   }
331   return node;
332 }
333 
334 template <typename T>
335 void DAGBuilder<T>::addDependency(DAGNodeBuilder<T> parent,
336                                   DAGNodeBuilder<T> child) {
337   numIns[child.index]++;
338   dag->nodes[parent.index].outs.push_back(child.index);
339 }
340 
341 template <typename T>
342 Allocator::unique_ptr<DAG<T>> DAGBuilder<T>::build() {
343   auto numNodes = dag->nodes.size();
344   MARL_ASSERT(numIns.size() == dag->nodes.size(),
345               "NodeBuilder vectors out of sync");
346   for (size_t i = 0; i < numNodes; i++) {
347     if (numIns[i] > 1) {
348       auto& node = dag->nodes[i];
349       node.counterIndex = dag->initialCounters.size();
350       dag->initialCounters.push_back(numIns[i]);
351     }
352   }
353   return std::move(dag);
354 }
355 
356 ///////////////////////////////////////////////////////////////////////////////
357 // DAG<T>
358 ///////////////////////////////////////////////////////////////////////////////
359 template <typename T = void>
360 class DAG : public DAGBase<T> {
361  public:
362   using Builder = DAGBuilder<T>;
363   using NodeBuilder = DAGNodeBuilder<T>;
364 
365   // run() invokes the function of each node in the graph of the DAG, passing
366   // data to each, starting with the root node. All dependencies need to have
367   // completed their function before dependees will be invoked.
368   MARL_NO_EXPORT inline void run(T& data,
369                                  Allocator* allocator = Allocator::Default);
370 };
371 
372 template <typename T>
373 void DAG<T>::run(T& arg, Allocator* allocator /* = Allocator::Default */) {
374   typename DAGBase<T>::RunContext ctx{arg};
375   this->initCounters(&ctx, allocator);
376   WaitGroup wg;
377   this->invoke(&ctx, this->RootIndex, &wg);
378   wg.wait();
379 }
380 
381 ///////////////////////////////////////////////////////////////////////////////
382 // DAG<void>
383 ///////////////////////////////////////////////////////////////////////////////
384 template <>
385 class DAG<void> : public DAGBase<void> {
386  public:
387   using Builder = DAGBuilder<void>;
388   using NodeBuilder = DAGNodeBuilder<void>;
389 
390   // run() invokes the function of each node in the graph of the DAG, starting
391   // with the root node. All dependencies need to have completed their function
392   // before dependees will be invoked.
393   MARL_NO_EXPORT inline void run(Allocator* allocator = Allocator::Default);
394 };
395 
396 void DAG<void>::run(Allocator* allocator /* = Allocator::Default */) {
397   typename DAGBase<void>::RunContext ctx{};
398   this->initCounters(&ctx, allocator);
399   WaitGroup wg;
400   this->invoke(&ctx, this->RootIndex, &wg);
401   wg.wait();
402 }
403 
404 }  // namespace marl
405 
406 #endif  // marl_dag_h
407