1 // Copyright 2020 The Marl Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // marl::DAG<> provides an ahead of time, declarative, directed acyclic 16 // task graph. 17 18 #ifndef marl_dag_h 19 #define marl_dag_h 20 21 #include "containers.h" 22 #include "export.h" 23 #include "memory.h" 24 #include "scheduler.h" 25 #include "waitgroup.h" 26 27 namespace marl { 28 namespace detail { 29 using DAGCounter = std::atomic<uint32_t>; 30 template <typename T> 31 struct DAGRunContext { 32 T data; 33 Allocator::unique_ptr<DAGCounter> counters; 34 35 template <typename F> invokeDAGRunContext36 MARL_NO_EXPORT inline void invoke(F&& f) { 37 f(data); 38 } 39 }; 40 template <> 41 struct DAGRunContext<void> { 42 Allocator::unique_ptr<DAGCounter> counters; 43 44 template <typename F> 45 MARL_NO_EXPORT inline void invoke(F&& f) { 46 f(); 47 } 48 }; 49 template <typename T> 50 struct DAGWork { 51 using type = std::function<void(T)>; 52 }; 53 template <> 54 struct DAGWork<void> { 55 using type = std::function<void()>; 56 }; 57 } // namespace detail 58 59 /////////////////////////////////////////////////////////////////////////////// 60 // Forward declarations 61 /////////////////////////////////////////////////////////////////////////////// 62 template <typename T> 63 class DAG; 64 65 template <typename T> 66 class DAGBuilder; 67 68 template <typename T> 69 class DAGNodeBuilder; 70 71 /////////////////////////////////////////////////////////////////////////////// 72 // DAGBase<T> 73 /////////////////////////////////////////////////////////////////////////////// 74 75 // DAGBase is derived by DAG<T> and DAG<void>. It has no public API. 76 template <typename T> 77 class DAGBase { 78 protected: 79 friend DAGBuilder<T>; 80 friend DAGNodeBuilder<T>; 81 82 using RunContext = detail::DAGRunContext<T>; 83 using Counter = detail::DAGCounter; 84 using NodeIndex = size_t; 85 using Work = typename detail::DAGWork<T>::type; 86 static const constexpr size_t NumReservedNodes = 32; 87 static const constexpr size_t NumReservedNumOuts = 4; 88 static const constexpr size_t InvalidCounterIndex = ~static_cast<size_t>(0); 89 static const constexpr NodeIndex RootIndex = 0; 90 static const constexpr NodeIndex InvalidNodeIndex = 91 ~static_cast<NodeIndex>(0); 92 93 // DAG work node. 94 struct Node { 95 MARL_NO_EXPORT inline Node() = default; 96 MARL_NO_EXPORT inline Node(Work&& work); 97 98 // The work to perform for this node in the graph. 99 Work work; 100 101 // counterIndex if valid, is the index of the counter in the RunContext for 102 // this node. The counter is decremented for each completed dependency task 103 // (ins), and once it reaches 0, this node will be invoked. 104 size_t counterIndex = InvalidCounterIndex; 105 106 // Indices for all downstream nodes. 107 containers::vector<NodeIndex, NumReservedNumOuts> outs; 108 }; 109 110 // initCounters() allocates and initializes the ctx->coutners from 111 // initialCounters. 112 MARL_NO_EXPORT inline void initCounters(RunContext* ctx, 113 Allocator* allocator); 114 115 // notify() is called each time a dependency task (ins) has completed for the 116 // node with the given index. 117 // If all dependency tasks have completed (or this is the root node) then 118 // notify() returns true and the caller should then call invoke(). 119 MARL_NO_EXPORT inline bool notify(RunContext*, NodeIndex); 120 121 // invoke() calls the work function for the node with the given index, then 122 // calls notify() and possibly invoke() for all the dependee nodes. 123 MARL_NO_EXPORT inline void invoke(RunContext*, NodeIndex, WaitGroup*); 124 125 // nodes is the full list of the nodes in the graph. 126 // nodes[0] is always the root node, which has no dependencies (ins). 127 containers::vector<Node, NumReservedNodes> nodes; 128 129 // initialCounters is a list of initial counter values to be copied to 130 // RunContext::counters on DAG<>::run(). 131 // initialCounters is indexed by Node::counterIndex, and only contains counts 132 // for nodes that have at least 2 dependencies (ins) - because of this the 133 // number of entries in initialCounters may be fewer than nodes. 134 containers::vector<uint32_t, NumReservedNodes> initialCounters; 135 }; 136 137 template <typename T> 138 DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {} 139 140 template <typename T> 141 void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) { 142 auto numCounters = initialCounters.size(); 143 ctx->counters = allocator->make_unique_n<Counter>(numCounters); 144 for (size_t i = 0; i < numCounters; i++) { 145 ctx->counters.get()[i] = {initialCounters[i]}; 146 } 147 } 148 149 template <typename T> 150 bool DAGBase<T>::notify(RunContext* ctx, NodeIndex nodeIdx) { 151 Node* node = &nodes[nodeIdx]; 152 153 // If we have multiple dependencies, decrement the counter and check whether 154 // we've reached 0. 155 if (node->counterIndex == InvalidCounterIndex) { 156 return true; 157 } 158 auto counters = ctx->counters.get(); 159 auto counter = --counters[node->counterIndex]; 160 return counter == 0; 161 } 162 163 template <typename T> 164 void DAGBase<T>::invoke(RunContext* ctx, NodeIndex nodeIdx, WaitGroup* wg) { 165 Node* node = &nodes[nodeIdx]; 166 167 // Run this node's work. 168 if (node->work) { 169 ctx->invoke(node->work); 170 } 171 172 // Then call notify() on all dependees (outs), and invoke() those that 173 // returned true. 174 // We buffer the node to invoke (toInvoke) so we can schedule() all but the 175 // last node to invoke(), and directly call the last invoke() on this thread. 176 // This is done to avoid the overheads of scheduling when a direct call would 177 // suffice. 178 NodeIndex toInvoke = InvalidNodeIndex; 179 for (NodeIndex idx : node->outs) { 180 if (notify(ctx, idx)) { 181 if (toInvoke != InvalidNodeIndex) { 182 wg->add(1); 183 // Schedule while promoting the WaitGroup capture from a pointer 184 // reference to a value. This ensures that the WaitGroup isn't dropped 185 // while in use. 186 schedule( 187 [=](WaitGroup wg) { 188 invoke(ctx, toInvoke, &wg); 189 wg.done(); 190 }, 191 *wg); 192 } 193 toInvoke = idx; 194 } 195 } 196 if (toInvoke != InvalidNodeIndex) { 197 invoke(ctx, toInvoke, wg); 198 } 199 } 200 201 /////////////////////////////////////////////////////////////////////////////// 202 // DAGNodeBuilder<T> 203 /////////////////////////////////////////////////////////////////////////////// 204 205 // DAGNodeBuilder is the builder interface for a DAG node. 206 template <typename T> 207 class DAGNodeBuilder { 208 using NodeIndex = typename DAGBase<T>::NodeIndex; 209 210 public: 211 // then() builds and returns a new DAG node that will be invoked after this 212 // node has completed. 213 // 214 // F is a function that will be called when the new DAG node is invoked, with 215 // the signature: 216 // void(T) when T is not void 217 // or 218 // void() when T is void 219 template <typename F> 220 MARL_NO_EXPORT inline DAGNodeBuilder then(F&&); 221 222 private: 223 friend DAGBuilder<T>; 224 MARL_NO_EXPORT inline DAGNodeBuilder(DAGBuilder<T>*, NodeIndex); 225 DAGBuilder<T>* builder; 226 NodeIndex index; 227 }; 228 229 template <typename T> 230 DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index) 231 : builder(builder), index(index) {} 232 233 template <typename T> 234 template <typename F> 235 DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) { 236 auto node = builder->node(std::move(work)); 237 builder->addDependency(*this, node); 238 return node; 239 } 240 241 /////////////////////////////////////////////////////////////////////////////// 242 // DAGBuilder<T> 243 /////////////////////////////////////////////////////////////////////////////// 244 template <typename T> 245 class DAGBuilder { 246 public: 247 // DAGBuilder constructor 248 MARL_NO_EXPORT inline DAGBuilder(Allocator* allocator = Allocator::Default); 249 250 // root() returns the root DAG node. 251 MARL_NO_EXPORT inline DAGNodeBuilder<T> root(); 252 253 // node() builds and returns a new DAG node with no initial dependencies. 254 // The returned node must be attached to the graph in order to invoke F or any 255 // of the dependees of this returned node. 256 // 257 // F is a function that will be called when the new DAG node is invoked, with 258 // the signature: 259 // void(T) when T is not void 260 // or 261 // void() when T is void 262 template <typename F> 263 MARL_NO_EXPORT inline DAGNodeBuilder<T> node(F&& work); 264 265 // node() builds and returns a new DAG node that depends on all the tasks in 266 // after to be completed before invoking F. 267 // 268 // F is a function that will be called when the new DAG node is invoked, with 269 // the signature: 270 // void(T) when T is not void 271 // or 272 // void() when T is void 273 template <typename F> 274 MARL_NO_EXPORT inline DAGNodeBuilder<T> node( 275 F&& work, 276 std::initializer_list<DAGNodeBuilder<T>> after); 277 278 // addDependency() adds parent as dependency on child. All dependencies of 279 // child must have completed before child is invoked. 280 MARL_NO_EXPORT inline void addDependency(DAGNodeBuilder<T> parent, 281 DAGNodeBuilder<T> child); 282 283 // build() constructs and returns the DAG. No other methods of this class may 284 // be called after calling build(). 285 MARL_NO_EXPORT inline Allocator::unique_ptr<DAG<T>> build(); 286 287 private: 288 static const constexpr size_t NumReservedNumIns = 4; 289 using Node = typename DAG<T>::Node; 290 291 // The DAG being built. 292 Allocator::unique_ptr<DAG<T>> dag; 293 294 // Number of dependencies (ins) for each node in dag->nodes. 295 containers::vector<uint32_t, NumReservedNumIns> numIns; 296 }; 297 298 template <typename T> 299 DAGBuilder<T>::DAGBuilder(Allocator* allocator /* = Allocator::Default */) 300 : dag(allocator->make_unique<DAG<T>>()), numIns(allocator) { 301 // Add root 302 dag->nodes.emplace_back(Node{}); 303 numIns.emplace_back(0); 304 } 305 306 template <typename T> 307 DAGNodeBuilder<T> DAGBuilder<T>::root() { 308 return DAGNodeBuilder<T>{this, DAGBase<T>::RootIndex}; 309 } 310 311 template <typename T> 312 template <typename F> 313 DAGNodeBuilder<T> DAGBuilder<T>::node(F&& work) { 314 return node(std::forward<F>(work), {}); 315 } 316 317 template <typename T> 318 template <typename F> 319 DAGNodeBuilder<T> DAGBuilder<T>::node( 320 F&& work, 321 std::initializer_list<DAGNodeBuilder<T>> after) { 322 MARL_ASSERT(numIns.size() == dag->nodes.size(), 323 "NodeBuilder vectors out of sync"); 324 auto index = dag->nodes.size(); 325 numIns.emplace_back(0); 326 dag->nodes.emplace_back(Node{std::move(work)}); 327 auto node = DAGNodeBuilder<T>{this, index}; 328 for (auto in : after) { 329 addDependency(in, node); 330 } 331 return node; 332 } 333 334 template <typename T> 335 void DAGBuilder<T>::addDependency(DAGNodeBuilder<T> parent, 336 DAGNodeBuilder<T> child) { 337 numIns[child.index]++; 338 dag->nodes[parent.index].outs.push_back(child.index); 339 } 340 341 template <typename T> 342 Allocator::unique_ptr<DAG<T>> DAGBuilder<T>::build() { 343 auto numNodes = dag->nodes.size(); 344 MARL_ASSERT(numIns.size() == dag->nodes.size(), 345 "NodeBuilder vectors out of sync"); 346 for (size_t i = 0; i < numNodes; i++) { 347 if (numIns[i] > 1) { 348 auto& node = dag->nodes[i]; 349 node.counterIndex = dag->initialCounters.size(); 350 dag->initialCounters.push_back(numIns[i]); 351 } 352 } 353 return std::move(dag); 354 } 355 356 /////////////////////////////////////////////////////////////////////////////// 357 // DAG<T> 358 /////////////////////////////////////////////////////////////////////////////// 359 template <typename T = void> 360 class DAG : public DAGBase<T> { 361 public: 362 using Builder = DAGBuilder<T>; 363 using NodeBuilder = DAGNodeBuilder<T>; 364 365 // run() invokes the function of each node in the graph of the DAG, passing 366 // data to each, starting with the root node. All dependencies need to have 367 // completed their function before dependees will be invoked. 368 MARL_NO_EXPORT inline void run(T& data, 369 Allocator* allocator = Allocator::Default); 370 }; 371 372 template <typename T> 373 void DAG<T>::run(T& arg, Allocator* allocator /* = Allocator::Default */) { 374 typename DAGBase<T>::RunContext ctx{arg}; 375 this->initCounters(&ctx, allocator); 376 WaitGroup wg; 377 this->invoke(&ctx, this->RootIndex, &wg); 378 wg.wait(); 379 } 380 381 /////////////////////////////////////////////////////////////////////////////// 382 // DAG<void> 383 /////////////////////////////////////////////////////////////////////////////// 384 template <> 385 class DAG<void> : public DAGBase<void> { 386 public: 387 using Builder = DAGBuilder<void>; 388 using NodeBuilder = DAGNodeBuilder<void>; 389 390 // run() invokes the function of each node in the graph of the DAG, starting 391 // with the root node. All dependencies need to have completed their function 392 // before dependees will be invoked. 393 MARL_NO_EXPORT inline void run(Allocator* allocator = Allocator::Default); 394 }; 395 396 void DAG<void>::run(Allocator* allocator /* = Allocator::Default */) { 397 typename DAGBase<void>::RunContext ctx{}; 398 this->initCounters(&ctx, allocator); 399 WaitGroup wg; 400 this->invoke(&ctx, this->RootIndex, &wg); 401 wg.wait(); 402 } 403 404 } // namespace marl 405 406 #endif // marl_dag_h 407