1 /* 2 * Copyright (C) 2014 The Guava Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.common.collect; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT; 22 23 import com.google.common.annotations.GwtCompatible; 24 import com.google.common.math.IntMath; 25 import java.math.RoundingMode; 26 import java.util.Arrays; 27 import java.util.Collections; 28 import java.util.Comparator; 29 import java.util.Iterator; 30 import java.util.List; 31 import java.util.stream.Stream; 32 import javax.annotation.CheckForNull; 33 import org.checkerframework.checker.nullness.qual.Nullable; 34 35 /** 36 * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided 37 * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to 38 * create the {@code TopKSelector} instance. 39 * 40 * <p>If your input data is available as a {@link Stream}, prefer passing {@link 41 * Comparators#least(int)} to {@link Stream#collect(java.util.stream.Collector)}. If it is available 42 * as an {@link Iterable} or {@link Iterator}, prefer {@link Ordering#leastOf(Iterable, int)}. 43 * 44 * <p>This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)}, 45 * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link 46 * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same 47 * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log 48 * k). In benchmarks, this implementation performs at least as well as either implementation, and 49 * degrades more gracefully for worst-case input. 50 * 51 * <p>The implementation does not necessarily use a <i>stable</i> sorting algorithm; when multiple 52 * equivalent elements are added to it, it is undefined which will come first in the output. 53 * 54 * @author Louis Wasserman 55 */ 56 @GwtCompatible 57 @ElementTypesAreNonnullByDefault 58 final class TopKSelector< 59 T extends @Nullable Object> { 60 61 /** 62 * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, 63 * relative to the natural ordering of the elements, and returns them via {@link #topK} in 64 * ascending order. 65 * 66 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 67 */ least(int k)68 public static <T extends Comparable<? super T>> TopKSelector<T> least(int k) { 69 return least(k, Ordering.natural()); 70 } 71 72 /** 73 * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, 74 * relative to the specified comparator, and returns them via {@link #topK} in ascending order. 75 * 76 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 77 */ least( int k, Comparator<? super T> comparator)78 public static <T extends @Nullable Object> TopKSelector<T> least( 79 int k, Comparator<? super T> comparator) { 80 return new TopKSelector<T>(comparator, k); 81 } 82 83 /** 84 * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, 85 * relative to the natural ordering of the elements, and returns them via {@link #topK} in 86 * descending order. 87 * 88 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 89 */ greatest(int k)90 public static <T extends Comparable<? super T>> TopKSelector<T> greatest(int k) { 91 return greatest(k, Ordering.natural()); 92 } 93 94 /** 95 * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, 96 * relative to the specified comparator, and returns them via {@link #topK} in descending order. 97 * 98 * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} 99 */ greatest( int k, Comparator<? super T> comparator)100 public static <T extends @Nullable Object> TopKSelector<T> greatest( 101 int k, Comparator<? super T> comparator) { 102 return new TopKSelector<T>(Ordering.from(comparator).reverse(), k); 103 } 104 105 private final int k; 106 private final Comparator<? super T> comparator; 107 108 /* 109 * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates 110 * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the 111 * range [0, k) and ignore the remaining elements. 112 */ 113 private final @Nullable T[] buffer; 114 private int bufferSize; 115 116 /** 117 * The largest of the lowest k elements we've seen so far relative to this comparator. If 118 * bufferSize ≥ k, then we can ignore any elements greater than this value. 119 */ 120 @CheckForNull private T threshold; 121 TopKSelector(Comparator<? super T> comparator, int k)122 private TopKSelector(Comparator<? super T> comparator, int k) { 123 this.comparator = checkNotNull(comparator, "comparator"); 124 this.k = k; 125 checkArgument(k >= 0, "k (%s) must be >= 0", k); 126 checkArgument(k <= Integer.MAX_VALUE / 2, "k (%s) must be <= Integer.MAX_VALUE / 2", k); 127 this.buffer = (T[]) new Object[IntMath.checkedMultiply(k, 2)]; 128 this.bufferSize = 0; 129 this.threshold = null; 130 } 131 132 /** 133 * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized 134 * O(1) time. 135 */ offer(@arametricNullness T elem)136 public void offer(@ParametricNullness T elem) { 137 if (k == 0) { 138 return; 139 } else if (bufferSize == 0) { 140 buffer[0] = elem; 141 threshold = elem; 142 bufferSize = 1; 143 } else if (bufferSize < k) { 144 buffer[bufferSize++] = elem; 145 // uncheckedCastNullableTToT is safe because bufferSize > 0. 146 if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) > 0) { 147 threshold = elem; 148 } 149 // uncheckedCastNullableTToT is safe because bufferSize > 0. 150 } else if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) < 0) { 151 // Otherwise, we can ignore elem; we've seen k better elements. 152 buffer[bufferSize++] = elem; 153 if (bufferSize == 2 * k) { 154 trim(); 155 } 156 } 157 } 158 159 /** 160 * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log 161 * k) worst case. 162 */ trim()163 private void trim() { 164 int left = 0; 165 int right = 2 * k - 1; 166 167 int minThresholdPosition = 0; 168 // The leftmost position at which the greatest of the k lower elements 169 // -- the new value of threshold -- might be found. 170 171 int iterations = 0; 172 int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3; 173 while (left < right) { 174 int pivotIndex = (left + right + 1) >>> 1; 175 176 int pivotNewIndex = partition(left, right, pivotIndex); 177 178 if (pivotNewIndex > k) { 179 right = pivotNewIndex - 1; 180 } else if (pivotNewIndex < k) { 181 left = Math.max(pivotNewIndex, left + 1); 182 minThresholdPosition = pivotNewIndex; 183 } else { 184 break; 185 } 186 iterations++; 187 if (iterations >= maxIterations) { 188 @SuppressWarnings("nullness") // safe because we pass sort() a range that contains real Ts 189 T[] castBuffer = (T[]) buffer; 190 // We've already taken O(k log k), let's make sure we don't take longer than O(k log k). 191 Arrays.sort(castBuffer, left, right + 1, comparator); 192 break; 193 } 194 } 195 bufferSize = k; 196 197 threshold = uncheckedCastNullableTToT(buffer[minThresholdPosition]); 198 for (int i = minThresholdPosition + 1; i < k; i++) { 199 if (comparator.compare( 200 uncheckedCastNullableTToT(buffer[i]), uncheckedCastNullableTToT(threshold)) 201 > 0) { 202 threshold = buffer[i]; 203 } 204 } 205 } 206 207 /** 208 * Partitions the contents of buffer in the range [left, right] around the pivot element 209 * previously stored in buffer[pivotValue]. Returns the new index of the pivot element, 210 * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in 211 * (pivotNewIndex, right] is greater than pivotValue. 212 */ partition(int left, int right, int pivotIndex)213 private int partition(int left, int right, int pivotIndex) { 214 T pivotValue = uncheckedCastNullableTToT(buffer[pivotIndex]); 215 buffer[pivotIndex] = buffer[right]; 216 217 int pivotNewIndex = left; 218 for (int i = left; i < right; i++) { 219 if (comparator.compare(uncheckedCastNullableTToT(buffer[i]), pivotValue) < 0) { 220 swap(pivotNewIndex, i); 221 pivotNewIndex++; 222 } 223 } 224 buffer[right] = buffer[pivotNewIndex]; 225 buffer[pivotNewIndex] = pivotValue; 226 return pivotNewIndex; 227 } 228 swap(int i, int j)229 private void swap(int i, int j) { 230 T tmp = buffer[i]; 231 buffer[i] = buffer[j]; 232 buffer[j] = tmp; 233 } 234 combine(TopKSelector<T> other)235 TopKSelector<T> combine(TopKSelector<T> other) { 236 for (int i = 0; i < other.bufferSize; i++) { 237 this.offer(uncheckedCastNullableTToT(other.buffer[i])); 238 } 239 return this; 240 } 241 242 /** 243 * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This 244 * operation takes amortized linear time in the length of {@code elements}. 245 * 246 * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer 247 * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case. 248 */ offerAll(Iterable<? extends T> elements)249 public void offerAll(Iterable<? extends T> elements) { 250 offerAll(elements.iterator()); 251 } 252 253 /** 254 * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This 255 * operation takes amortized linear time in the length of {@code elements}. The iterator is 256 * consumed after this operation completes. 257 * 258 * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer 259 * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case. 260 */ offerAll(Iterator<? extends T> elements)261 public void offerAll(Iterator<? extends T> elements) { 262 while (elements.hasNext()) { 263 offer(elements.next()); 264 } 265 } 266 267 /** 268 * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if 269 * fewer than {@code k} have been offered, in the order specified by the factory used to create 270 * this {@code TopKSelector}. 271 * 272 * <p>The returned list is an unmodifiable copy and will not be affected by further changes to 273 * this {@code TopKSelector}. This method returns in O(k log k) time. 274 */ topK()275 public List<T> topK() { 276 @SuppressWarnings("nullness") // safe because we pass sort() a range that contains real Ts 277 T[] castBuffer = (T[]) buffer; 278 Arrays.sort(castBuffer, 0, bufferSize, comparator); 279 if (bufferSize > k) { 280 Arrays.fill(buffer, k, buffer.length, null); 281 bufferSize = k; 282 threshold = buffer[k - 1]; 283 } 284 // Up to bufferSize, all elements of buffer are real Ts (not null unless T includes null) 285 T[] topK = Arrays.copyOf(castBuffer, bufferSize); 286 // we have to support null elements, so no ImmutableList for us 287 return Collections.unmodifiableList(Arrays.asList(topK)); 288 } 289 } 290