1 /* 2 * Written by Doug Lea with assistance from members of JCP JSR-166 3 * Expert Group and released to the public domain, as explained at 4 * http://creativecommons.org/publicdomain/zero/1.0/ 5 */ 6 7 package java.util; 8 9 import java.util.concurrent.CountedCompleter; 10 import java.util.concurrent.ForkJoinPool; 11 import java.util.function.BinaryOperator; 12 import java.util.function.DoubleBinaryOperator; 13 import java.util.function.IntBinaryOperator; 14 import java.util.function.LongBinaryOperator; 15 16 /** 17 * ForkJoin tasks to perform Arrays.parallelPrefix operations. 18 * 19 * @author Doug Lea 20 * @since 1.8 21 */ 22 class ArrayPrefixHelpers { ArrayPrefixHelpers()23 private ArrayPrefixHelpers() {} // non-instantiable 24 25 /* 26 * Parallel prefix (aka cumulate, scan) task classes 27 * are based loosely on Guy Blelloch's original 28 * algorithm (http://www.cs.cmu.edu/~scandal/alg/scan.html): 29 * Keep dividing by two to threshold segment size, and then: 30 * Pass 1: Create tree of partial sums for each segment 31 * Pass 2: For each segment, cumulate with offset of left sibling 32 * 33 * This version improves performance within FJ framework mainly by 34 * allowing the second pass of ready left-hand sides to proceed 35 * even if some right-hand side first passes are still executing. 36 * It also combines first and second pass for leftmost segment, 37 * and skips the first pass for rightmost segment (whose result is 38 * not needed for second pass). It similarly manages to avoid 39 * requiring that users supply an identity basis for accumulations 40 * by tracking those segments/subtasks for which the first 41 * existing element is used as base. 42 * 43 * Managing this relies on ORing some bits in the pendingCount for 44 * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the 45 * main phase bit. When false, segments compute only their sum. 46 * When true, they cumulate array elements. CUMULATE is set at 47 * root at beginning of second pass and then propagated down. But 48 * it may also be set earlier for subtrees with lo==0 (the left 49 * spine of tree). SUMMED is a one bit join count. For leafs, it 50 * is set when summed. For internal nodes, it becomes true when 51 * one child is summed. When the second child finishes summing, 52 * we then moves up tree to trigger the cumulate phase. FINISHED 53 * is also a one bit join count. For leafs, it is set when 54 * cumulated. For internal nodes, it becomes true when one child 55 * is cumulated. When the second child finishes cumulating, it 56 * then moves up tree, completing at the root. 57 * 58 * To better exploit locality and reduce overhead, the compute 59 * method loops starting with the current task, moving if possible 60 * to one of its subtasks rather than forking. 61 * 62 * As usual for this sort of utility, there are 4 versions, that 63 * are simple copy/paste/adapt variants of each other. (The 64 * double and int versions differ from long version solely by 65 * replacing "long" (with case-matching)). 66 */ 67 68 // see above 69 static final int CUMULATE = 1; 70 static final int SUMMED = 2; 71 static final int FINISHED = 4; 72 73 /** The smallest subtask array partition size to use as threshold */ 74 static final int MIN_PARTITION = 16; 75 76 static final class CumulateTask<T> extends CountedCompleter<Void> { 77 final T[] array; 78 final BinaryOperator<T> function; 79 CumulateTask<T> left, right; 80 T in, out; 81 final int lo, hi, origin, fence, threshold; 82 83 /** Root task constructor */ CumulateTask(CumulateTask<T> parent, BinaryOperator<T> function, T[] array, int lo, int hi)84 public CumulateTask(CumulateTask<T> parent, 85 BinaryOperator<T> function, 86 T[] array, int lo, int hi) { 87 super(parent); 88 this.function = function; this.array = array; 89 this.lo = this.origin = lo; this.hi = this.fence = hi; 90 int p; 91 this.threshold = 92 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 93 <= MIN_PARTITION ? MIN_PARTITION : p; 94 } 95 96 /** Subtask constructor */ CumulateTask(CumulateTask<T> parent, BinaryOperator<T> function, T[] array, int origin, int fence, int threshold, int lo, int hi)97 CumulateTask(CumulateTask<T> parent, BinaryOperator<T> function, 98 T[] array, int origin, int fence, int threshold, 99 int lo, int hi) { 100 super(parent); 101 this.function = function; this.array = array; 102 this.origin = origin; this.fence = fence; 103 this.threshold = threshold; 104 this.lo = lo; this.hi = hi; 105 } 106 compute()107 public final void compute() { 108 final BinaryOperator<T> fn; 109 final T[] a; 110 if ((fn = this.function) == null || (a = this.array) == null) 111 throw new NullPointerException(); // hoist checks 112 int th = threshold, org = origin, fnc = fence, l, h; 113 CumulateTask<T> t = this; 114 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 115 if (h - l > th) { 116 CumulateTask<T> lt = t.left, rt = t.right, f; 117 if (lt == null) { // first pass 118 int mid = (l + h) >>> 1; 119 f = rt = t.right = 120 new CumulateTask<T>(t, fn, a, org, fnc, th, mid, h); 121 t = lt = t.left = 122 new CumulateTask<T>(t, fn, a, org, fnc, th, l, mid); 123 } 124 else { // possibly refork 125 T pin = t.in; 126 lt.in = pin; 127 f = t = null; 128 if (rt != null) { 129 T lout = lt.out; 130 rt.in = (l == org ? lout : 131 fn.apply(pin, lout)); 132 for (int c;;) { 133 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 134 break; 135 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 136 t = rt; 137 break; 138 } 139 } 140 } 141 for (int c;;) { 142 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 143 break; 144 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 145 if (t != null) 146 f = t; 147 t = lt; 148 break; 149 } 150 } 151 if (t == null) 152 break; 153 } 154 if (f != null) 155 f.fork(); 156 } 157 else { 158 int state; // Transition to sum, cumulate, or both 159 for (int b;;) { 160 if (((b = t.getPendingCount()) & FINISHED) != 0) 161 break outer; // already done 162 state = ((b & CUMULATE) != 0 ? FINISHED : 163 (l > org) ? SUMMED : (SUMMED|FINISHED)); 164 if (t.compareAndSetPendingCount(b, b|state)) 165 break; 166 } 167 168 T sum; 169 if (state != SUMMED) { 170 int first; 171 if (l == org) { // leftmost; no in 172 sum = a[org]; 173 first = org + 1; 174 } 175 else { 176 sum = t.in; 177 first = l; 178 } 179 for (int i = first; i < h; ++i) // cumulate 180 a[i] = sum = fn.apply(sum, a[i]); 181 } 182 else if (h < fnc) { // skip rightmost 183 sum = a[l]; 184 for (int i = l + 1; i < h; ++i) // sum only 185 sum = fn.apply(sum, a[i]); 186 } 187 else 188 sum = t.in; 189 t.out = sum; 190 for (CumulateTask<T> par;;) { // propagate 191 @SuppressWarnings("unchecked") CumulateTask<T> partmp 192 = (CumulateTask<T>)t.getCompleter(); 193 if ((par = partmp) == null) { 194 if ((state & FINISHED) != 0) // enable join 195 t.quietlyComplete(); 196 break outer; 197 } 198 int b = par.getPendingCount(); 199 if ((b & state & FINISHED) != 0) 200 t = par; // both done 201 else if ((b & state & SUMMED) != 0) { // both summed 202 int nextState; CumulateTask<T> lt, rt; 203 if ((lt = par.left) != null && 204 (rt = par.right) != null) { 205 T lout = lt.out; 206 par.out = (rt.hi == fnc ? lout : 207 fn.apply(lout, rt.out)); 208 } 209 int refork = (((b & CUMULATE) == 0 && 210 par.lo == org) ? CUMULATE : 0); 211 if ((nextState = b|state|refork) == b || 212 par.compareAndSetPendingCount(b, nextState)) { 213 state = SUMMED; // drop finished 214 t = par; 215 if (refork != 0) 216 par.fork(); 217 } 218 } 219 else if (par.compareAndSetPendingCount(b, b|state)) 220 break outer; // sib not ready 221 } 222 } 223 } 224 } 225 private static final long serialVersionUID = 5293554502939613543L; 226 } 227 228 static final class LongCumulateTask extends CountedCompleter<Void> { 229 final long[] array; 230 final LongBinaryOperator function; 231 LongCumulateTask left, right; 232 long in, out; 233 final int lo, hi, origin, fence, threshold; 234 235 /** Root task constructor */ LongCumulateTask(LongCumulateTask parent, LongBinaryOperator function, long[] array, int lo, int hi)236 public LongCumulateTask(LongCumulateTask parent, 237 LongBinaryOperator function, 238 long[] array, int lo, int hi) { 239 super(parent); 240 this.function = function; this.array = array; 241 this.lo = this.origin = lo; this.hi = this.fence = hi; 242 int p; 243 this.threshold = 244 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 245 <= MIN_PARTITION ? MIN_PARTITION : p; 246 } 247 248 /** Subtask constructor */ LongCumulateTask(LongCumulateTask parent, LongBinaryOperator function, long[] array, int origin, int fence, int threshold, int lo, int hi)249 LongCumulateTask(LongCumulateTask parent, LongBinaryOperator function, 250 long[] array, int origin, int fence, int threshold, 251 int lo, int hi) { 252 super(parent); 253 this.function = function; this.array = array; 254 this.origin = origin; this.fence = fence; 255 this.threshold = threshold; 256 this.lo = lo; this.hi = hi; 257 } 258 compute()259 public final void compute() { 260 final LongBinaryOperator fn; 261 final long[] a; 262 if ((fn = this.function) == null || (a = this.array) == null) 263 throw new NullPointerException(); // hoist checks 264 int th = threshold, org = origin, fnc = fence, l, h; 265 LongCumulateTask t = this; 266 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 267 if (h - l > th) { 268 LongCumulateTask lt = t.left, rt = t.right, f; 269 if (lt == null) { // first pass 270 int mid = (l + h) >>> 1; 271 f = rt = t.right = 272 new LongCumulateTask(t, fn, a, org, fnc, th, mid, h); 273 t = lt = t.left = 274 new LongCumulateTask(t, fn, a, org, fnc, th, l, mid); 275 } 276 else { // possibly refork 277 long pin = t.in; 278 lt.in = pin; 279 f = t = null; 280 if (rt != null) { 281 long lout = lt.out; 282 rt.in = (l == org ? lout : 283 fn.applyAsLong(pin, lout)); 284 for (int c;;) { 285 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 286 break; 287 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 288 t = rt; 289 break; 290 } 291 } 292 } 293 for (int c;;) { 294 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 295 break; 296 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 297 if (t != null) 298 f = t; 299 t = lt; 300 break; 301 } 302 } 303 if (t == null) 304 break; 305 } 306 if (f != null) 307 f.fork(); 308 } 309 else { 310 int state; // Transition to sum, cumulate, or both 311 for (int b;;) { 312 if (((b = t.getPendingCount()) & FINISHED) != 0) 313 break outer; // already done 314 state = ((b & CUMULATE) != 0 ? FINISHED : 315 (l > org) ? SUMMED : (SUMMED|FINISHED)); 316 if (t.compareAndSetPendingCount(b, b|state)) 317 break; 318 } 319 320 long sum; 321 if (state != SUMMED) { 322 int first; 323 if (l == org) { // leftmost; no in 324 sum = a[org]; 325 first = org + 1; 326 } 327 else { 328 sum = t.in; 329 first = l; 330 } 331 for (int i = first; i < h; ++i) // cumulate 332 a[i] = sum = fn.applyAsLong(sum, a[i]); 333 } 334 else if (h < fnc) { // skip rightmost 335 sum = a[l]; 336 for (int i = l + 1; i < h; ++i) // sum only 337 sum = fn.applyAsLong(sum, a[i]); 338 } 339 else 340 sum = t.in; 341 t.out = sum; 342 for (LongCumulateTask par;;) { // propagate 343 if ((par = (LongCumulateTask)t.getCompleter()) == null) { 344 if ((state & FINISHED) != 0) // enable join 345 t.quietlyComplete(); 346 break outer; 347 } 348 int b = par.getPendingCount(); 349 if ((b & state & FINISHED) != 0) 350 t = par; // both done 351 else if ((b & state & SUMMED) != 0) { // both summed 352 int nextState; LongCumulateTask lt, rt; 353 if ((lt = par.left) != null && 354 (rt = par.right) != null) { 355 long lout = lt.out; 356 par.out = (rt.hi == fnc ? lout : 357 fn.applyAsLong(lout, rt.out)); 358 } 359 int refork = (((b & CUMULATE) == 0 && 360 par.lo == org) ? CUMULATE : 0); 361 if ((nextState = b|state|refork) == b || 362 par.compareAndSetPendingCount(b, nextState)) { 363 state = SUMMED; // drop finished 364 t = par; 365 if (refork != 0) 366 par.fork(); 367 } 368 } 369 else if (par.compareAndSetPendingCount(b, b|state)) 370 break outer; // sib not ready 371 } 372 } 373 } 374 } 375 private static final long serialVersionUID = -5074099945909284273L; 376 } 377 378 static final class DoubleCumulateTask extends CountedCompleter<Void> { 379 final double[] array; 380 final DoubleBinaryOperator function; 381 DoubleCumulateTask left, right; 382 double in, out; 383 final int lo, hi, origin, fence, threshold; 384 385 /** Root task constructor */ DoubleCumulateTask(DoubleCumulateTask parent, DoubleBinaryOperator function, double[] array, int lo, int hi)386 public DoubleCumulateTask(DoubleCumulateTask parent, 387 DoubleBinaryOperator function, 388 double[] array, int lo, int hi) { 389 super(parent); 390 this.function = function; this.array = array; 391 this.lo = this.origin = lo; this.hi = this.fence = hi; 392 int p; 393 this.threshold = 394 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 395 <= MIN_PARTITION ? MIN_PARTITION : p; 396 } 397 398 /** Subtask constructor */ DoubleCumulateTask(DoubleCumulateTask parent, DoubleBinaryOperator function, double[] array, int origin, int fence, int threshold, int lo, int hi)399 DoubleCumulateTask(DoubleCumulateTask parent, DoubleBinaryOperator function, 400 double[] array, int origin, int fence, int threshold, 401 int lo, int hi) { 402 super(parent); 403 this.function = function; this.array = array; 404 this.origin = origin; this.fence = fence; 405 this.threshold = threshold; 406 this.lo = lo; this.hi = hi; 407 } 408 compute()409 public final void compute() { 410 final DoubleBinaryOperator fn; 411 final double[] a; 412 if ((fn = this.function) == null || (a = this.array) == null) 413 throw new NullPointerException(); // hoist checks 414 int th = threshold, org = origin, fnc = fence, l, h; 415 DoubleCumulateTask t = this; 416 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 417 if (h - l > th) { 418 DoubleCumulateTask lt = t.left, rt = t.right, f; 419 if (lt == null) { // first pass 420 int mid = (l + h) >>> 1; 421 f = rt = t.right = 422 new DoubleCumulateTask(t, fn, a, org, fnc, th, mid, h); 423 t = lt = t.left = 424 new DoubleCumulateTask(t, fn, a, org, fnc, th, l, mid); 425 } 426 else { // possibly refork 427 double pin = t.in; 428 lt.in = pin; 429 f = t = null; 430 if (rt != null) { 431 double lout = lt.out; 432 rt.in = (l == org ? lout : 433 fn.applyAsDouble(pin, lout)); 434 for (int c;;) { 435 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 436 break; 437 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 438 t = rt; 439 break; 440 } 441 } 442 } 443 for (int c;;) { 444 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 445 break; 446 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 447 if (t != null) 448 f = t; 449 t = lt; 450 break; 451 } 452 } 453 if (t == null) 454 break; 455 } 456 if (f != null) 457 f.fork(); 458 } 459 else { 460 int state; // Transition to sum, cumulate, or both 461 for (int b;;) { 462 if (((b = t.getPendingCount()) & FINISHED) != 0) 463 break outer; // already done 464 state = ((b & CUMULATE) != 0 ? FINISHED : 465 (l > org) ? SUMMED : (SUMMED|FINISHED)); 466 if (t.compareAndSetPendingCount(b, b|state)) 467 break; 468 } 469 470 double sum; 471 if (state != SUMMED) { 472 int first; 473 if (l == org) { // leftmost; no in 474 sum = a[org]; 475 first = org + 1; 476 } 477 else { 478 sum = t.in; 479 first = l; 480 } 481 for (int i = first; i < h; ++i) // cumulate 482 a[i] = sum = fn.applyAsDouble(sum, a[i]); 483 } 484 else if (h < fnc) { // skip rightmost 485 sum = a[l]; 486 for (int i = l + 1; i < h; ++i) // sum only 487 sum = fn.applyAsDouble(sum, a[i]); 488 } 489 else 490 sum = t.in; 491 t.out = sum; 492 for (DoubleCumulateTask par;;) { // propagate 493 if ((par = (DoubleCumulateTask)t.getCompleter()) == null) { 494 if ((state & FINISHED) != 0) // enable join 495 t.quietlyComplete(); 496 break outer; 497 } 498 int b = par.getPendingCount(); 499 if ((b & state & FINISHED) != 0) 500 t = par; // both done 501 else if ((b & state & SUMMED) != 0) { // both summed 502 int nextState; DoubleCumulateTask lt, rt; 503 if ((lt = par.left) != null && 504 (rt = par.right) != null) { 505 double lout = lt.out; 506 par.out = (rt.hi == fnc ? lout : 507 fn.applyAsDouble(lout, rt.out)); 508 } 509 int refork = (((b & CUMULATE) == 0 && 510 par.lo == org) ? CUMULATE : 0); 511 if ((nextState = b|state|refork) == b || 512 par.compareAndSetPendingCount(b, nextState)) { 513 state = SUMMED; // drop finished 514 t = par; 515 if (refork != 0) 516 par.fork(); 517 } 518 } 519 else if (par.compareAndSetPendingCount(b, b|state)) 520 break outer; // sib not ready 521 } 522 } 523 } 524 } 525 private static final long serialVersionUID = -586947823794232033L; 526 } 527 528 static final class IntCumulateTask extends CountedCompleter<Void> { 529 final int[] array; 530 final IntBinaryOperator function; 531 IntCumulateTask left, right; 532 int in, out; 533 final int lo, hi, origin, fence, threshold; 534 535 /** Root task constructor */ IntCumulateTask(IntCumulateTask parent, IntBinaryOperator function, int[] array, int lo, int hi)536 public IntCumulateTask(IntCumulateTask parent, 537 IntBinaryOperator function, 538 int[] array, int lo, int hi) { 539 super(parent); 540 this.function = function; this.array = array; 541 this.lo = this.origin = lo; this.hi = this.fence = hi; 542 int p; 543 this.threshold = 544 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 545 <= MIN_PARTITION ? MIN_PARTITION : p; 546 } 547 548 /** Subtask constructor */ IntCumulateTask(IntCumulateTask parent, IntBinaryOperator function, int[] array, int origin, int fence, int threshold, int lo, int hi)549 IntCumulateTask(IntCumulateTask parent, IntBinaryOperator function, 550 int[] array, int origin, int fence, int threshold, 551 int lo, int hi) { 552 super(parent); 553 this.function = function; this.array = array; 554 this.origin = origin; this.fence = fence; 555 this.threshold = threshold; 556 this.lo = lo; this.hi = hi; 557 } 558 compute()559 public final void compute() { 560 final IntBinaryOperator fn; 561 final int[] a; 562 if ((fn = this.function) == null || (a = this.array) == null) 563 throw new NullPointerException(); // hoist checks 564 int th = threshold, org = origin, fnc = fence, l, h; 565 IntCumulateTask t = this; 566 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 567 if (h - l > th) { 568 IntCumulateTask lt = t.left, rt = t.right, f; 569 if (lt == null) { // first pass 570 int mid = (l + h) >>> 1; 571 f = rt = t.right = 572 new IntCumulateTask(t, fn, a, org, fnc, th, mid, h); 573 t = lt = t.left = 574 new IntCumulateTask(t, fn, a, org, fnc, th, l, mid); 575 } 576 else { // possibly refork 577 int pin = t.in; 578 lt.in = pin; 579 f = t = null; 580 if (rt != null) { 581 int lout = lt.out; 582 rt.in = (l == org ? lout : 583 fn.applyAsInt(pin, lout)); 584 for (int c;;) { 585 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 586 break; 587 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 588 t = rt; 589 break; 590 } 591 } 592 } 593 for (int c;;) { 594 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 595 break; 596 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 597 if (t != null) 598 f = t; 599 t = lt; 600 break; 601 } 602 } 603 if (t == null) 604 break; 605 } 606 if (f != null) 607 f.fork(); 608 } 609 else { 610 int state; // Transition to sum, cumulate, or both 611 for (int b;;) { 612 if (((b = t.getPendingCount()) & FINISHED) != 0) 613 break outer; // already done 614 state = ((b & CUMULATE) != 0 ? FINISHED : 615 (l > org) ? SUMMED : (SUMMED|FINISHED)); 616 if (t.compareAndSetPendingCount(b, b|state)) 617 break; 618 } 619 620 int sum; 621 if (state != SUMMED) { 622 int first; 623 if (l == org) { // leftmost; no in 624 sum = a[org]; 625 first = org + 1; 626 } 627 else { 628 sum = t.in; 629 first = l; 630 } 631 for (int i = first; i < h; ++i) // cumulate 632 a[i] = sum = fn.applyAsInt(sum, a[i]); 633 } 634 else if (h < fnc) { // skip rightmost 635 sum = a[l]; 636 for (int i = l + 1; i < h; ++i) // sum only 637 sum = fn.applyAsInt(sum, a[i]); 638 } 639 else 640 sum = t.in; 641 t.out = sum; 642 for (IntCumulateTask par;;) { // propagate 643 if ((par = (IntCumulateTask)t.getCompleter()) == null) { 644 if ((state & FINISHED) != 0) // enable join 645 t.quietlyComplete(); 646 break outer; 647 } 648 int b = par.getPendingCount(); 649 if ((b & state & FINISHED) != 0) 650 t = par; // both done 651 else if ((b & state & SUMMED) != 0) { // both summed 652 int nextState; IntCumulateTask lt, rt; 653 if ((lt = par.left) != null && 654 (rt = par.right) != null) { 655 int lout = lt.out; 656 par.out = (rt.hi == fnc ? lout : 657 fn.applyAsInt(lout, rt.out)); 658 } 659 int refork = (((b & CUMULATE) == 0 && 660 par.lo == org) ? CUMULATE : 0); 661 if ((nextState = b|state|refork) == b || 662 par.compareAndSetPendingCount(b, nextState)) { 663 state = SUMMED; // drop finished 664 t = par; 665 if (refork != 0) 666 par.fork(); 667 } 668 } 669 else if (par.compareAndSetPendingCount(b, b|state)) 670 break outer; // sib not ready 671 } 672 } 673 } 674 } 675 private static final long serialVersionUID = 3731755594596840961L; 676 } 677 } 678