• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2014 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.common.collect;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
22 
23 import com.google.common.annotations.GwtCompatible;
24 import com.google.common.math.IntMath;
25 import java.math.RoundingMode;
26 import java.util.Arrays;
27 import java.util.Collections;
28 import java.util.Comparator;
29 import java.util.Iterator;
30 import java.util.List;
31 import java.util.stream.Stream;
32 import javax.annotation.CheckForNull;
33 import org.checkerframework.checker.nullness.qual.Nullable;
34 
35 /**
36  * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided
37  * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to
38  * create the {@code TopKSelector} instance.
39  *
40  * <p>If your input data is available as a {@link Stream}, prefer passing {@link
41  * Comparators#least(int)} to {@link Stream#collect(java.util.stream.Collector)}. If it is available
42  * as an {@link Iterable} or {@link Iterator}, prefer {@link Ordering#leastOf(Iterable, int)}.
43  *
44  * <p>This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)},
45  * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link
46  * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same
47  * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log
48  * k). In benchmarks, this implementation performs at least as well as either implementation, and
49  * degrades more gracefully for worst-case input.
50  *
51  * <p>The implementation does not necessarily use a <i>stable</i> sorting algorithm; when multiple
52  * equivalent elements are added to it, it is undefined which will come first in the output.
53  *
54  * @author Louis Wasserman
55  */
56 @GwtCompatible
57 @ElementTypesAreNonnullByDefault
58 final class TopKSelector<
59     T extends @Nullable Object> {
60 
61   /**
62    * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it,
63    * relative to the natural ordering of the elements, and returns them via {@link #topK} in
64    * ascending order.
65    *
66    * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2}
67    */
least(int k)68   public static <T extends Comparable<? super T>> TopKSelector<T> least(int k) {
69     return least(k, Ordering.natural());
70   }
71 
72   /**
73    * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it,
74    * relative to the specified comparator, and returns them via {@link #topK} in ascending order.
75    *
76    * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2}
77    */
least( int k, Comparator<? super T> comparator)78   public static <T extends @Nullable Object> TopKSelector<T> least(
79       int k, Comparator<? super T> comparator) {
80     return new TopKSelector<T>(comparator, k);
81   }
82 
83   /**
84    * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it,
85    * relative to the natural ordering of the elements, and returns them via {@link #topK} in
86    * descending order.
87    *
88    * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2}
89    */
greatest(int k)90   public static <T extends Comparable<? super T>> TopKSelector<T> greatest(int k) {
91     return greatest(k, Ordering.natural());
92   }
93 
94   /**
95    * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it,
96    * relative to the specified comparator, and returns them via {@link #topK} in descending order.
97    *
98    * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2}
99    */
greatest( int k, Comparator<? super T> comparator)100   public static <T extends @Nullable Object> TopKSelector<T> greatest(
101       int k, Comparator<? super T> comparator) {
102     return new TopKSelector<T>(Ordering.from(comparator).reverse(), k);
103   }
104 
105   private final int k;
106   private final Comparator<? super T> comparator;
107 
108   /*
109    * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates
110    * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the
111    * range [0, k) and ignore the remaining elements.
112    */
113   private final @Nullable T[] buffer;
114   private int bufferSize;
115 
116   /**
117    * The largest of the lowest k elements we've seen so far relative to this comparator. If
118    * bufferSize ≥ k, then we can ignore any elements greater than this value.
119    */
120   @CheckForNull private T threshold;
121 
TopKSelector(Comparator<? super T> comparator, int k)122   private TopKSelector(Comparator<? super T> comparator, int k) {
123     this.comparator = checkNotNull(comparator, "comparator");
124     this.k = k;
125     checkArgument(k >= 0, "k (%s) must be >= 0", k);
126     checkArgument(k <= Integer.MAX_VALUE / 2, "k (%s) must be <= Integer.MAX_VALUE / 2", k);
127     this.buffer = (T[]) new Object[IntMath.checkedMultiply(k, 2)];
128     this.bufferSize = 0;
129     this.threshold = null;
130   }
131 
132   /**
133    * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized
134    * O(1) time.
135    */
offer(@arametricNullness T elem)136   public void offer(@ParametricNullness T elem) {
137     if (k == 0) {
138       return;
139     } else if (bufferSize == 0) {
140       buffer[0] = elem;
141       threshold = elem;
142       bufferSize = 1;
143     } else if (bufferSize < k) {
144       buffer[bufferSize++] = elem;
145       // uncheckedCastNullableTToT is safe because bufferSize > 0.
146       if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) > 0) {
147         threshold = elem;
148       }
149       // uncheckedCastNullableTToT is safe because bufferSize > 0.
150     } else if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) < 0) {
151       // Otherwise, we can ignore elem; we've seen k better elements.
152       buffer[bufferSize++] = elem;
153       if (bufferSize == 2 * k) {
154         trim();
155       }
156     }
157   }
158 
159   /**
160    * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log
161    * k) worst case.
162    */
trim()163   private void trim() {
164     int left = 0;
165     int right = 2 * k - 1;
166 
167     int minThresholdPosition = 0;
168     // The leftmost position at which the greatest of the k lower elements
169     // -- the new value of threshold -- might be found.
170 
171     int iterations = 0;
172     int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3;
173     while (left < right) {
174       int pivotIndex = (left + right + 1) >>> 1;
175 
176       int pivotNewIndex = partition(left, right, pivotIndex);
177 
178       if (pivotNewIndex > k) {
179         right = pivotNewIndex - 1;
180       } else if (pivotNewIndex < k) {
181         left = Math.max(pivotNewIndex, left + 1);
182         minThresholdPosition = pivotNewIndex;
183       } else {
184         break;
185       }
186       iterations++;
187       if (iterations >= maxIterations) {
188         // We've already taken O(k log k), let's make sure we don't take longer than O(k log k).
189         Arrays.sort(buffer, left, right + 1, comparator);
190         break;
191       }
192     }
193     bufferSize = k;
194 
195     threshold = uncheckedCastNullableTToT(buffer[minThresholdPosition]);
196     for (int i = minThresholdPosition + 1; i < k; i++) {
197       if (comparator.compare(
198               uncheckedCastNullableTToT(buffer[i]), uncheckedCastNullableTToT(threshold))
199           > 0) {
200         threshold = buffer[i];
201       }
202     }
203   }
204 
205   /**
206    * Partitions the contents of buffer in the range [left, right] around the pivot element
207    * previously stored in buffer[pivotValue]. Returns the new index of the pivot element,
208    * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in
209    * (pivotNewIndex, right] is greater than pivotValue.
210    */
partition(int left, int right, int pivotIndex)211   private int partition(int left, int right, int pivotIndex) {
212     T pivotValue = uncheckedCastNullableTToT(buffer[pivotIndex]);
213     buffer[pivotIndex] = buffer[right];
214 
215     int pivotNewIndex = left;
216     for (int i = left; i < right; i++) {
217       if (comparator.compare(uncheckedCastNullableTToT(buffer[i]), pivotValue) < 0) {
218         swap(pivotNewIndex, i);
219         pivotNewIndex++;
220       }
221     }
222     buffer[right] = buffer[pivotNewIndex];
223     buffer[pivotNewIndex] = pivotValue;
224     return pivotNewIndex;
225   }
226 
swap(int i, int j)227   private void swap(int i, int j) {
228     T tmp = buffer[i];
229     buffer[i] = buffer[j];
230     buffer[j] = tmp;
231   }
232 
combine(TopKSelector<T> other)233   TopKSelector<T> combine(TopKSelector<T> other) {
234     for (int i = 0; i < other.bufferSize; i++) {
235       this.offer(uncheckedCastNullableTToT(other.buffer[i]));
236     }
237     return this;
238   }
239 
240   /**
241    * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This
242    * operation takes amortized linear time in the length of {@code elements}.
243    *
244    * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer
245    * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case.
246    */
offerAll(Iterable<? extends T> elements)247   public void offerAll(Iterable<? extends T> elements) {
248     offerAll(elements.iterator());
249   }
250 
251   /**
252    * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This
253    * operation takes amortized linear time in the length of {@code elements}. The iterator is
254    * consumed after this operation completes.
255    *
256    * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer
257    * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case.
258    */
offerAll(Iterator<? extends T> elements)259   public void offerAll(Iterator<? extends T> elements) {
260     while (elements.hasNext()) {
261       offer(elements.next());
262     }
263   }
264 
265   /**
266    * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if
267    * fewer than {@code k} have been offered, in the order specified by the factory used to create
268    * this {@code TopKSelector}.
269    *
270    * <p>The returned list is an unmodifiable copy and will not be affected by further changes to
271    * this {@code TopKSelector}. This method returns in O(k log k) time.
272    */
topK()273   public List<T> topK() {
274     Arrays.sort(buffer, 0, bufferSize, comparator);
275     if (bufferSize > k) {
276       Arrays.fill(buffer, k, buffer.length, null);
277       bufferSize = k;
278       threshold = buffer[k - 1];
279     }
280     // we have to support null elements, so no ImmutableList for us
281     return Collections.unmodifiableList(Arrays.asList(Arrays.copyOf(buffer, bufferSize)));
282   }
283 }
284