1 /* 2 * Copyright (C) 2014 The Guava Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.common.collect; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 22 import com.google.common.annotations.GwtCompatible; 23 import com.google.common.math.IntMath; 24 import java.math.RoundingMode; 25 import java.util.Arrays; 26 import java.util.Collections; 27 import java.util.Comparator; 28 import java.util.Iterator; 29 import java.util.List; 30 import org.checkerframework.checker.nullness.compatqual.NullableDecl; 31 32 /** 33 * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided 34 * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to 35 * create the {@code TopKSelector} instance. 36 * 37 * <p>If your input data is available as an {@link Iterable} or {@link Iterator}, prefer {@link 38 * Ordering#leastOf(Iterable, int)}, which provides the same implementation with an interface 39 * tailored to that use case. 40 * 41 * <p>This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)}, 42 * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link 43 * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same 44 * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log 45 * k). In benchmarks, this implementation performs at least as well as either implementation, and 46 * degrades more gracefully for worst-case input. 47 * 48 * <p>The implementation does not necessarily use a <i>stable</i> sorting algorithm; when multiple 49 * equivalent elements are added to it, it is undefined which will come first in the output. 50 * 51 * @author Louis Wasserman 52 */ 53 @GwtCompatible final class TopKSelector<T> { 54 55 /** 56 * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, 57 * relative to the natural ordering of the elements, and returns them via {@link #topK} in 58 * ascending order. 59 * 60 * @throws IllegalArgumentException if {@code k < 0} 61 */ least(int k)62 public static <T extends Comparable<? super T>> TopKSelector<T> least(int k) { 63 return least(k, Ordering.natural()); 64 } 65 66 /** 67 * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, 68 * relative to the specified comparator, and returns them via {@link #topK} in ascending order. 69 * 70 * @throws IllegalArgumentException if {@code k < 0} 71 */ least(int k, Comparator<? super T> comparator)72 public static <T> TopKSelector<T> least(int k, Comparator<? super T> comparator) { 73 return new TopKSelector<T>(comparator, k); 74 } 75 76 /** 77 * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, 78 * relative to the natural ordering of the elements, and returns them via {@link #topK} in 79 * descending order. 80 * 81 * @throws IllegalArgumentException if {@code k < 0} 82 */ greatest(int k)83 public static <T extends Comparable<? super T>> TopKSelector<T> greatest(int k) { 84 return greatest(k, Ordering.natural()); 85 } 86 87 /** 88 * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, 89 * relative to the specified comparator, and returns them via {@link #topK} in descending order. 90 * 91 * @throws IllegalArgumentException if {@code k < 0} 92 */ greatest(int k, Comparator<? super T> comparator)93 public static <T> TopKSelector<T> greatest(int k, Comparator<? super T> comparator) { 94 return new TopKSelector<T>(Ordering.from(comparator).reverse(), k); 95 } 96 97 private final int k; 98 private final Comparator<? super T> comparator; 99 100 /* 101 * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates 102 * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the 103 * range [0, k) and ignore the remaining elements. 104 */ 105 private final T[] buffer; 106 private int bufferSize; 107 108 /** 109 * The largest of the lowest k elements we've seen so far relative to this comparator. If 110 * bufferSize ≥ k, then we can ignore any elements greater than this value. 111 */ 112 @NullableDecl private T threshold; 113 TopKSelector(Comparator<? super T> comparator, int k)114 private TopKSelector(Comparator<? super T> comparator, int k) { 115 this.comparator = checkNotNull(comparator, "comparator"); 116 this.k = k; 117 checkArgument(k >= 0, "k must be nonnegative, was %s", k); 118 this.buffer = (T[]) new Object[k * 2]; 119 this.bufferSize = 0; 120 this.threshold = null; 121 } 122 123 /** 124 * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized 125 * O(1) time. 126 */ offer(@ullableDecl T elem)127 public void offer(@NullableDecl T elem) { 128 if (k == 0) { 129 return; 130 } else if (bufferSize == 0) { 131 buffer[0] = elem; 132 threshold = elem; 133 bufferSize = 1; 134 } else if (bufferSize < k) { 135 buffer[bufferSize++] = elem; 136 if (comparator.compare(elem, threshold) > 0) { 137 threshold = elem; 138 } 139 } else if (comparator.compare(elem, threshold) < 0) { 140 // Otherwise, we can ignore elem; we've seen k better elements. 141 buffer[bufferSize++] = elem; 142 if (bufferSize == 2 * k) { 143 trim(); 144 } 145 } 146 } 147 148 /** 149 * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log 150 * k) worst case. 151 */ trim()152 private void trim() { 153 int left = 0; 154 int right = 2 * k - 1; 155 156 int minThresholdPosition = 0; 157 // The leftmost position at which the greatest of the k lower elements 158 // -- the new value of threshold -- might be found. 159 160 int iterations = 0; 161 int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3; 162 while (left < right) { 163 int pivotIndex = (left + right + 1) >>> 1; 164 165 int pivotNewIndex = partition(left, right, pivotIndex); 166 167 if (pivotNewIndex > k) { 168 right = pivotNewIndex - 1; 169 } else if (pivotNewIndex < k) { 170 left = Math.max(pivotNewIndex, left + 1); 171 minThresholdPosition = pivotNewIndex; 172 } else { 173 break; 174 } 175 iterations++; 176 if (iterations >= maxIterations) { 177 // We've already taken O(k log k), let's make sure we don't take longer than O(k log k). 178 Arrays.sort(buffer, left, right, comparator); 179 break; 180 } 181 } 182 bufferSize = k; 183 184 threshold = buffer[minThresholdPosition]; 185 for (int i = minThresholdPosition + 1; i < k; i++) { 186 if (comparator.compare(buffer[i], threshold) > 0) { 187 threshold = buffer[i]; 188 } 189 } 190 } 191 192 /** 193 * Partitions the contents of buffer in the range [left, right] around the pivot element 194 * previously stored in buffer[pivotValue]. Returns the new index of the pivot element, 195 * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in 196 * (pivotNewIndex, right] is greater than pivotValue. 197 */ partition(int left, int right, int pivotIndex)198 private int partition(int left, int right, int pivotIndex) { 199 T pivotValue = buffer[pivotIndex]; 200 buffer[pivotIndex] = buffer[right]; 201 202 int pivotNewIndex = left; 203 for (int i = left; i < right; i++) { 204 if (comparator.compare(buffer[i], pivotValue) < 0) { 205 swap(pivotNewIndex, i); 206 pivotNewIndex++; 207 } 208 } 209 buffer[right] = buffer[pivotNewIndex]; 210 buffer[pivotNewIndex] = pivotValue; 211 return pivotNewIndex; 212 } 213 swap(int i, int j)214 private void swap(int i, int j) { 215 T tmp = buffer[i]; 216 buffer[i] = buffer[j]; 217 buffer[j] = tmp; 218 } 219 220 /** 221 * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This 222 * operation takes amortized linear time in the length of {@code elements}. 223 * 224 * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer 225 * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case. 226 */ offerAll(Iterable<? extends T> elements)227 public void offerAll(Iterable<? extends T> elements) { 228 offerAll(elements.iterator()); 229 } 230 231 /** 232 * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This 233 * operation takes amortized linear time in the length of {@code elements}. The iterator is 234 * consumed after this operation completes. 235 * 236 * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer 237 * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case. 238 */ offerAll(Iterator<? extends T> elements)239 public void offerAll(Iterator<? extends T> elements) { 240 while (elements.hasNext()) { 241 offer(elements.next()); 242 } 243 } 244 245 /** 246 * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if 247 * fewer than {@code k} have been offered, in the order specified by the factory used to create 248 * this {@code TopKSelector}. 249 * 250 * <p>The returned list is an unmodifiable copy and will not be affected by further changes to 251 * this {@code TopKSelector}. This method returns in O(k log k) time. 252 */ topK()253 public List<T> topK() { 254 Arrays.sort(buffer, 0, bufferSize, comparator); 255 if (bufferSize > k) { 256 Arrays.fill(buffer, k, buffer.length, null); 257 bufferSize = k; 258 threshold = buffer[k - 1]; 259 } 260 // we have to support null elements, so no ImmutableList for us 261 return Collections.unmodifiableList(Arrays.asList(Arrays.copyOf(buffer, bufferSize))); 262 } 263 } 264