• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2014 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.common.collect;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 
22 import com.google.common.annotations.GwtCompatible;
23 import com.google.common.math.IntMath;
24 import java.math.RoundingMode;
25 import java.util.Arrays;
26 import java.util.Collections;
27 import java.util.Comparator;
28 import java.util.Iterator;
29 import java.util.List;
30 import org.checkerframework.checker.nullness.compatqual.NullableDecl;
31 
32 /**
33  * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided
34  * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to
35  * create the {@code TopKSelector} instance.
36  *
37  * <p>If your input data is available as an {@link Iterable} or {@link Iterator}, prefer {@link
38  * Ordering#leastOf(Iterable, int)}, which provides the same implementation with an interface
39  * tailored to that use case.
40  *
41  * <p>This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)},
42  * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link
43  * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same
44  * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log
45  * k). In benchmarks, this implementation performs at least as well as either implementation, and
46  * degrades more gracefully for worst-case input.
47  *
48  * <p>The implementation does not necessarily use a <i>stable</i> sorting algorithm; when multiple
49  * equivalent elements are added to it, it is undefined which will come first in the output.
50  *
51  * @author Louis Wasserman
52  */
53 @GwtCompatible final class TopKSelector<T> {
54 
55   /**
56    * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it,
57    * relative to the natural ordering of the elements, and returns them via {@link #topK} in
58    * ascending order.
59    *
60    * @throws IllegalArgumentException if {@code k < 0}
61    */
least(int k)62   public static <T extends Comparable<? super T>> TopKSelector<T> least(int k) {
63     return least(k, Ordering.natural());
64   }
65 
66   /**
67    * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it,
68    * relative to the specified comparator, and returns them via {@link #topK} in ascending order.
69    *
70    * @throws IllegalArgumentException if {@code k < 0}
71    */
least(int k, Comparator<? super T> comparator)72   public static <T> TopKSelector<T> least(int k, Comparator<? super T> comparator) {
73     return new TopKSelector<T>(comparator, k);
74   }
75 
76   /**
77    * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it,
78    * relative to the natural ordering of the elements, and returns them via {@link #topK} in
79    * descending order.
80    *
81    * @throws IllegalArgumentException if {@code k < 0}
82    */
greatest(int k)83   public static <T extends Comparable<? super T>> TopKSelector<T> greatest(int k) {
84     return greatest(k, Ordering.natural());
85   }
86 
87   /**
88    * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it,
89    * relative to the specified comparator, and returns them via {@link #topK} in descending order.
90    *
91    * @throws IllegalArgumentException if {@code k < 0}
92    */
greatest(int k, Comparator<? super T> comparator)93   public static <T> TopKSelector<T> greatest(int k, Comparator<? super T> comparator) {
94     return new TopKSelector<T>(Ordering.from(comparator).reverse(), k);
95   }
96 
97   private final int k;
98   private final Comparator<? super T> comparator;
99 
100   /*
101    * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates
102    * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the
103    * range [0, k) and ignore the remaining elements.
104    */
105   private final T[] buffer;
106   private int bufferSize;
107 
108   /**
109    * The largest of the lowest k elements we've seen so far relative to this comparator. If
110    * bufferSize ≥ k, then we can ignore any elements greater than this value.
111    */
112   @NullableDecl private T threshold;
113 
TopKSelector(Comparator<? super T> comparator, int k)114   private TopKSelector(Comparator<? super T> comparator, int k) {
115     this.comparator = checkNotNull(comparator, "comparator");
116     this.k = k;
117     checkArgument(k >= 0, "k must be nonnegative, was %s", k);
118     this.buffer = (T[]) new Object[k * 2];
119     this.bufferSize = 0;
120     this.threshold = null;
121   }
122 
123   /**
124    * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized
125    * O(1) time.
126    */
offer(@ullableDecl T elem)127   public void offer(@NullableDecl T elem) {
128     if (k == 0) {
129       return;
130     } else if (bufferSize == 0) {
131       buffer[0] = elem;
132       threshold = elem;
133       bufferSize = 1;
134     } else if (bufferSize < k) {
135       buffer[bufferSize++] = elem;
136       if (comparator.compare(elem, threshold) > 0) {
137         threshold = elem;
138       }
139     } else if (comparator.compare(elem, threshold) < 0) {
140       // Otherwise, we can ignore elem; we've seen k better elements.
141       buffer[bufferSize++] = elem;
142       if (bufferSize == 2 * k) {
143         trim();
144       }
145     }
146   }
147 
148   /**
149    * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log
150    * k) worst case.
151    */
trim()152   private void trim() {
153     int left = 0;
154     int right = 2 * k - 1;
155 
156     int minThresholdPosition = 0;
157     // The leftmost position at which the greatest of the k lower elements
158     // -- the new value of threshold -- might be found.
159 
160     int iterations = 0;
161     int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3;
162     while (left < right) {
163       int pivotIndex = (left + right + 1) >>> 1;
164 
165       int pivotNewIndex = partition(left, right, pivotIndex);
166 
167       if (pivotNewIndex > k) {
168         right = pivotNewIndex - 1;
169       } else if (pivotNewIndex < k) {
170         left = Math.max(pivotNewIndex, left + 1);
171         minThresholdPosition = pivotNewIndex;
172       } else {
173         break;
174       }
175       iterations++;
176       if (iterations >= maxIterations) {
177         // We've already taken O(k log k), let's make sure we don't take longer than O(k log k).
178         Arrays.sort(buffer, left, right, comparator);
179         break;
180       }
181     }
182     bufferSize = k;
183 
184     threshold = buffer[minThresholdPosition];
185     for (int i = minThresholdPosition + 1; i < k; i++) {
186       if (comparator.compare(buffer[i], threshold) > 0) {
187         threshold = buffer[i];
188       }
189     }
190   }
191 
192   /**
193    * Partitions the contents of buffer in the range [left, right] around the pivot element
194    * previously stored in buffer[pivotValue]. Returns the new index of the pivot element,
195    * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in
196    * (pivotNewIndex, right] is greater than pivotValue.
197    */
partition(int left, int right, int pivotIndex)198   private int partition(int left, int right, int pivotIndex) {
199     T pivotValue = buffer[pivotIndex];
200     buffer[pivotIndex] = buffer[right];
201 
202     int pivotNewIndex = left;
203     for (int i = left; i < right; i++) {
204       if (comparator.compare(buffer[i], pivotValue) < 0) {
205         swap(pivotNewIndex, i);
206         pivotNewIndex++;
207       }
208     }
209     buffer[right] = buffer[pivotNewIndex];
210     buffer[pivotNewIndex] = pivotValue;
211     return pivotNewIndex;
212   }
213 
swap(int i, int j)214   private void swap(int i, int j) {
215     T tmp = buffer[i];
216     buffer[i] = buffer[j];
217     buffer[j] = tmp;
218   }
219 
220   /**
221    * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This
222    * operation takes amortized linear time in the length of {@code elements}.
223    *
224    * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer
225    * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case.
226    */
offerAll(Iterable<? extends T> elements)227   public void offerAll(Iterable<? extends T> elements) {
228     offerAll(elements.iterator());
229   }
230 
231   /**
232    * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This
233    * operation takes amortized linear time in the length of {@code elements}. The iterator is
234    * consumed after this operation completes.
235    *
236    * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer
237    * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case.
238    */
offerAll(Iterator<? extends T> elements)239   public void offerAll(Iterator<? extends T> elements) {
240     while (elements.hasNext()) {
241       offer(elements.next());
242     }
243   }
244 
245   /**
246    * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if
247    * fewer than {@code k} have been offered, in the order specified by the factory used to create
248    * this {@code TopKSelector}.
249    *
250    * <p>The returned list is an unmodifiable copy and will not be affected by further changes to
251    * this {@code TopKSelector}. This method returns in O(k log k) time.
252    */
topK()253   public List<T> topK() {
254     Arrays.sort(buffer, 0, bufferSize, comparator);
255     if (bufferSize > k) {
256       Arrays.fill(buffer, k, buffer.length, null);
257       bufferSize = k;
258       threshold = buffer[k - 1];
259     }
260     // we have to support null elements, so no ImmutableList for us
261     return Collections.unmodifiableList(Arrays.asList(Arrays.copyOf(buffer, bufferSize)));
262   }
263 }
264