1 /* 2 * Copyright (C) 2014 The Guava Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.common.collect; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT; 22 23 import com.google.common.annotations.GwtCompatible; 24 import com.google.common.math.IntMath; 25 import java.math.RoundingMode; 26 import java.util.Arrays; 27 import java.util.Collections; 28 import java.util.Comparator; 29 import java.util.Iterator; 30 import java.util.List; 31 import java.util.stream.Stream; 32 import javax.annotation.CheckForNull; 33 import org.checkerframework.checker.nullness.qual.Nullable; 34 35 /** 36 * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided 37 * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to 38 * create the {@code TopKSelector} instance. 39 * 40 * <p>If your input data is available as a {@link Stream}, prefer passing {@link 41 * Comparators#least(int)} to {@link Stream#collect(java.util.stream.Collector)}. If it is available 42 * as an {@link Iterable} or {@link Iterator}, prefer {@link Ordering#leastOf(Iterable, int)}. 43 * 44 * <p>This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)}, 45 * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link 46 * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same 47 * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log 48 * k). In benchmarks, this implementation performs at least as well as either implementation, and 49 * degrades more gracefully for worst-case input. 50 * 51 * <p>The implementation does not necessarily use a <i>stable</i> sorting algorithm; when multiple 52 * equivalent elements are added to it, it is undefined which will come first in the output. 53 * 54 * @author Louis Wasserman 55 */ 56 @GwtCompatible 57 @ElementTypesAreNonnullByDefault 58 final class TopKSelector< 59 T extends @Nullable Object> { 60 61 /** 62 * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, 63 * relative to the natural ordering of the elements, and returns them via {@link #topK} in 64 * ascending order. 65 * 66 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 67 */ least(int k)68 public static <T extends Comparable<? super T>> TopKSelector<T> least(int k) { 69 return least(k, Ordering.natural()); 70 } 71 72 /** 73 * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, 74 * relative to the specified comparator, and returns them via {@link #topK} in ascending order. 75 * 76 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 77 */ least( int k, Comparator<? super T> comparator)78 public static <T extends @Nullable Object> TopKSelector<T> least( 79 int k, Comparator<? super T> comparator) { 80 return new TopKSelector<T>(comparator, k); 81 } 82 83 /** 84 * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, 85 * relative to the natural ordering of the elements, and returns them via {@link #topK} in 86 * descending order. 87 * 88 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 89 */ greatest(int k)90 public static <T extends Comparable<? super T>> TopKSelector<T> greatest(int k) { 91 return greatest(k, Ordering.natural()); 92 } 93 94 /** 95 * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, 96 * relative to the specified comparator, and returns them via {@link #topK} in descending order. 97 * 98 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 99 */ greatest( int k, Comparator<? super T> comparator)100 public static <T extends @Nullable Object> TopKSelector<T> greatest( 101 int k, Comparator<? super T> comparator) { 102 return new TopKSelector<T>(Ordering.from(comparator).reverse(), k); 103 } 104 105 private final int k; 106 private final Comparator<? super T> comparator; 107 108 /* 109 * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates 110 * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the 111 * range [0, k) and ignore the remaining elements. 112 */ 113 private final @Nullable T[] buffer; 114 private int bufferSize; 115 116 /** 117 * The largest of the lowest k elements we've seen so far relative to this comparator. If 118 * bufferSize ≥ k, then we can ignore any elements greater than this value. 119 */ 120 @CheckForNull private T threshold; 121 TopKSelector(Comparator<? super T> comparator, int k)122 private TopKSelector(Comparator<? super T> comparator, int k) { 123 this.comparator = checkNotNull(comparator, "comparator"); 124 this.k = k; 125 checkArgument(k >= 0, "k (%s) must be >= 0", k); 126 checkArgument(k <= Integer.MAX_VALUE / 2, "k (%s) must be <= Integer.MAX_VALUE / 2", k); 127 this.buffer = (T[]) new Object[IntMath.checkedMultiply(k, 2)]; 128 this.bufferSize = 0; 129 this.threshold = null; 130 } 131 132 /** 133 * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized 134 * O(1) time. 135 */ offer(@arametricNullness T elem)136 public void offer(@ParametricNullness T elem) { 137 if (k == 0) { 138 return; 139 } else if (bufferSize == 0) { 140 buffer[0] = elem; 141 threshold = elem; 142 bufferSize = 1; 143 } else if (bufferSize < k) { 144 buffer[bufferSize++] = elem; 145 // uncheckedCastNullableTToT is safe because bufferSize > 0. 146 if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) > 0) { 147 threshold = elem; 148 } 149 // uncheckedCastNullableTToT is safe because bufferSize > 0. 150 } else if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) < 0) { 151 // Otherwise, we can ignore elem; we've seen k better elements. 152 buffer[bufferSize++] = elem; 153 if (bufferSize == 2 * k) { 154 trim(); 155 } 156 } 157 } 158 159 /** 160 * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log 161 * k) worst case. 162 */ trim()163 private void trim() { 164 int left = 0; 165 int right = 2 * k - 1; 166 167 int minThresholdPosition = 0; 168 // The leftmost position at which the greatest of the k lower elements 169 // -- the new value of threshold -- might be found. 170 171 int iterations = 0; 172 int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3; 173 while (left < right) { 174 int pivotIndex = (left + right + 1) >>> 1; 175 176 int pivotNewIndex = partition(left, right, pivotIndex); 177 178 if (pivotNewIndex > k) { 179 right = pivotNewIndex - 1; 180 } else if (pivotNewIndex < k) { 181 left = Math.max(pivotNewIndex, left + 1); 182 minThresholdPosition = pivotNewIndex; 183 } else { 184 break; 185 } 186 iterations++; 187 if (iterations >= maxIterations) { 188 // We've already taken O(k log k), let's make sure we don't take longer than O(k log k). 189 Arrays.sort(buffer, left, right + 1, comparator); 190 break; 191 } 192 } 193 bufferSize = k; 194 195 threshold = uncheckedCastNullableTToT(buffer[minThresholdPosition]); 196 for (int i = minThresholdPosition + 1; i < k; i++) { 197 if (comparator.compare( 198 uncheckedCastNullableTToT(buffer[i]), uncheckedCastNullableTToT(threshold)) 199 > 0) { 200 threshold = buffer[i]; 201 } 202 } 203 } 204 205 /** 206 * Partitions the contents of buffer in the range [left, right] around the pivot element 207 * previously stored in buffer[pivotValue]. Returns the new index of the pivot element, 208 * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in 209 * (pivotNewIndex, right] is greater than pivotValue. 210 */ partition(int left, int right, int pivotIndex)211 private int partition(int left, int right, int pivotIndex) { 212 T pivotValue = uncheckedCastNullableTToT(buffer[pivotIndex]); 213 buffer[pivotIndex] = buffer[right]; 214 215 int pivotNewIndex = left; 216 for (int i = left; i < right; i++) { 217 if (comparator.compare(uncheckedCastNullableTToT(buffer[i]), pivotValue) < 0) { 218 swap(pivotNewIndex, i); 219 pivotNewIndex++; 220 } 221 } 222 buffer[right] = buffer[pivotNewIndex]; 223 buffer[pivotNewIndex] = pivotValue; 224 return pivotNewIndex; 225 } 226 swap(int i, int j)227 private void swap(int i, int j) { 228 T tmp = buffer[i]; 229 buffer[i] = buffer[j]; 230 buffer[j] = tmp; 231 } 232 combine(TopKSelector<T> other)233 TopKSelector<T> combine(TopKSelector<T> other) { 234 for (int i = 0; i < other.bufferSize; i++) { 235 this.offer(uncheckedCastNullableTToT(other.buffer[i])); 236 } 237 return this; 238 } 239 240 /** 241 * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This 242 * operation takes amortized linear time in the length of {@code elements}. 243 * 244 * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer 245 * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case. 246 */ offerAll(Iterable<? extends T> elements)247 public void offerAll(Iterable<? extends T> elements) { 248 offerAll(elements.iterator()); 249 } 250 251 /** 252 * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This 253 * operation takes amortized linear time in the length of {@code elements}. The iterator is 254 * consumed after this operation completes. 255 * 256 * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer 257 * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case. 258 */ offerAll(Iterator<? extends T> elements)259 public void offerAll(Iterator<? extends T> elements) { 260 while (elements.hasNext()) { 261 offer(elements.next()); 262 } 263 } 264 265 /** 266 * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if 267 * fewer than {@code k} have been offered, in the order specified by the factory used to create 268 * this {@code TopKSelector}. 269 * 270 * <p>The returned list is an unmodifiable copy and will not be affected by further changes to 271 * this {@code TopKSelector}. This method returns in O(k log k) time. 272 */ topK()273 public List<T> topK() { 274 Arrays.sort(buffer, 0, bufferSize, comparator); 275 if (bufferSize > k) { 276 Arrays.fill(buffer, k, buffer.length, null); 277 bufferSize = k; 278 threshold = buffer[k - 1]; 279 } 280 // we have to support null elements, so no ImmutableList for us 281 return Collections.unmodifiableList(Arrays.asList(Arrays.copyOf(buffer, bufferSize))); 282 } 283 } 284