• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2007 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.common.collect;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkState;
21 import static com.google.common.collect.CollectPreconditions.checkNonnegative;
22 import static com.google.common.collect.CollectPreconditions.checkRemove;
23 
24 import com.google.common.annotations.GwtCompatible;
25 import com.google.common.annotations.GwtIncompatible;
26 import com.google.common.base.MoreObjects;
27 import com.google.common.primitives.Ints;
28 import com.google.errorprone.annotations.CanIgnoreReturnValue;
29 import java.io.IOException;
30 import java.io.ObjectInputStream;
31 import java.io.ObjectOutputStream;
32 import java.io.Serializable;
33 import java.util.Comparator;
34 import java.util.ConcurrentModificationException;
35 import java.util.Iterator;
36 import java.util.NoSuchElementException;
37 import org.checkerframework.checker.nullness.compatqual.NullableDecl;
38 
39 /**
40  * A multiset which maintains the ordering of its elements, according to either their natural order
41  * or an explicit {@link Comparator}. In all cases, this implementation uses {@link
42  * Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to determine
43  * equivalence of instances.
44  *
45  * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
46  * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the {@link
47  * java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
48  *
49  * <p>See the Guava User Guide article on <a href=
50  * "https://github.com/google/guava/wiki/NewCollectionTypesExplained#multiset"> {@code
51  * Multiset}</a>.
52  *
53  * @author Louis Wasserman
54  * @author Jared Levy
55  * @since 2.0
56  */
57 @GwtCompatible(emulated = true)
58 public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
59 
60   /**
61    * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
62    * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
63    * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
64    * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
65    * user attempts to add an element to the multiset that violates this constraint (for example, the
66    * user attempts to add a string element to a set whose elements are integers), the {@code
67    * add(Object)} call will throw a {@code ClassCastException}.
68    *
69    * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
70    * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
71    */
create()72   public static <E extends Comparable> TreeMultiset<E> create() {
73     return new TreeMultiset<E>(Ordering.natural());
74   }
75 
76   /**
77    * Creates a new, empty multiset, sorted according to the specified comparator. All elements
78    * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
79    * {@code comparator.compare(e1, e2)} must not throw a {@code ClassCastException} for any elements
80    * {@code e1} and {@code e2} in the multiset. If the user attempts to add an element to the
81    * multiset that violates this constraint, the {@code add(Object)} call will throw a {@code
82    * ClassCastException}.
83    *
84    * @param comparator the comparator that will be used to sort this multiset. A null value
85    *     indicates that the elements' <i>natural ordering</i> should be used.
86    */
87   @SuppressWarnings("unchecked")
create(@ullableDecl Comparator<? super E> comparator)88   public static <E> TreeMultiset<E> create(@NullableDecl Comparator<? super E> comparator) {
89     return (comparator == null)
90         ? new TreeMultiset<E>((Comparator) Ordering.natural())
91         : new TreeMultiset<E>(comparator);
92   }
93 
94   /**
95    * Creates an empty multiset containing the given initial elements, sorted according to the
96    * elements' natural order.
97    *
98    * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
99    *
100    * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
101    * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
102    */
create(Iterable<? extends E> elements)103   public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
104     TreeMultiset<E> multiset = create();
105     Iterables.addAll(multiset, elements);
106     return multiset;
107   }
108 
109   private final transient Reference<AvlNode<E>> rootReference;
110   private final transient GeneralRange<E> range;
111   private final transient AvlNode<E> header;
112 
TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink)113   TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
114     super(range.comparator());
115     this.rootReference = rootReference;
116     this.range = range;
117     this.header = endLink;
118   }
119 
TreeMultiset(Comparator<? super E> comparator)120   TreeMultiset(Comparator<? super E> comparator) {
121     super(comparator);
122     this.range = GeneralRange.all(comparator);
123     this.header = new AvlNode<E>(null, 1);
124     successor(header, header);
125     this.rootReference = new Reference<>();
126   }
127 
128   /** A function which can be summed across a subtree. */
129   private enum Aggregate {
130     SIZE {
131       @Override
nodeAggregate(AvlNode<?> node)132       int nodeAggregate(AvlNode<?> node) {
133         return node.elemCount;
134       }
135 
136       @Override
treeAggregate(@ullableDecl AvlNode<?> root)137       long treeAggregate(@NullableDecl AvlNode<?> root) {
138         return (root == null) ? 0 : root.totalCount;
139       }
140     },
141     DISTINCT {
142       @Override
nodeAggregate(AvlNode<?> node)143       int nodeAggregate(AvlNode<?> node) {
144         return 1;
145       }
146 
147       @Override
treeAggregate(@ullableDecl AvlNode<?> root)148       long treeAggregate(@NullableDecl AvlNode<?> root) {
149         return (root == null) ? 0 : root.distinctElements;
150       }
151     };
152 
nodeAggregate(AvlNode<?> node)153     abstract int nodeAggregate(AvlNode<?> node);
154 
treeAggregate(@ullableDecl AvlNode<?> root)155     abstract long treeAggregate(@NullableDecl AvlNode<?> root);
156   }
157 
aggregateForEntries(Aggregate aggr)158   private long aggregateForEntries(Aggregate aggr) {
159     AvlNode<E> root = rootReference.get();
160     long total = aggr.treeAggregate(root);
161     if (range.hasLowerBound()) {
162       total -= aggregateBelowRange(aggr, root);
163     }
164     if (range.hasUpperBound()) {
165       total -= aggregateAboveRange(aggr, root);
166     }
167     return total;
168   }
169 
aggregateBelowRange(Aggregate aggr, @NullableDecl AvlNode<E> node)170   private long aggregateBelowRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
171     if (node == null) {
172       return 0;
173     }
174     int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
175     if (cmp < 0) {
176       return aggregateBelowRange(aggr, node.left);
177     } else if (cmp == 0) {
178       switch (range.getLowerBoundType()) {
179         case OPEN:
180           return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
181         case CLOSED:
182           return aggr.treeAggregate(node.left);
183         default:
184           throw new AssertionError();
185       }
186     } else {
187       return aggr.treeAggregate(node.left)
188           + aggr.nodeAggregate(node)
189           + aggregateBelowRange(aggr, node.right);
190     }
191   }
192 
aggregateAboveRange(Aggregate aggr, @NullableDecl AvlNode<E> node)193   private long aggregateAboveRange(Aggregate aggr, @NullableDecl AvlNode<E> node) {
194     if (node == null) {
195       return 0;
196     }
197     int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
198     if (cmp > 0) {
199       return aggregateAboveRange(aggr, node.right);
200     } else if (cmp == 0) {
201       switch (range.getUpperBoundType()) {
202         case OPEN:
203           return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
204         case CLOSED:
205           return aggr.treeAggregate(node.right);
206         default:
207           throw new AssertionError();
208       }
209     } else {
210       return aggr.treeAggregate(node.right)
211           + aggr.nodeAggregate(node)
212           + aggregateAboveRange(aggr, node.left);
213     }
214   }
215 
216   @Override
size()217   public int size() {
218     return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
219   }
220 
221   @Override
distinctElements()222   int distinctElements() {
223     return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
224   }
225 
distinctElements(@ullableDecl AvlNode<?> node)226   static int distinctElements(@NullableDecl AvlNode<?> node) {
227     return (node == null) ? 0 : node.distinctElements;
228   }
229 
230   @Override
count(@ullableDecl Object element)231   public int count(@NullableDecl Object element) {
232     try {
233       @SuppressWarnings("unchecked")
234       E e = (E) element;
235       AvlNode<E> root = rootReference.get();
236       if (!range.contains(e) || root == null) {
237         return 0;
238       }
239       return root.count(comparator(), e);
240     } catch (ClassCastException | NullPointerException e) {
241       return 0;
242     }
243   }
244 
245   @CanIgnoreReturnValue
246   @Override
add(@ullableDecl E element, int occurrences)247   public int add(@NullableDecl E element, int occurrences) {
248     checkNonnegative(occurrences, "occurrences");
249     if (occurrences == 0) {
250       return count(element);
251     }
252     checkArgument(range.contains(element));
253     AvlNode<E> root = rootReference.get();
254     if (root == null) {
255       comparator().compare(element, element);
256       AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
257       successor(header, newRoot, header);
258       rootReference.checkAndSet(root, newRoot);
259       return 0;
260     }
261     int[] result = new int[1]; // used as a mutable int reference to hold result
262     AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
263     rootReference.checkAndSet(root, newRoot);
264     return result[0];
265   }
266 
267   @CanIgnoreReturnValue
268   @Override
remove(@ullableDecl Object element, int occurrences)269   public int remove(@NullableDecl Object element, int occurrences) {
270     checkNonnegative(occurrences, "occurrences");
271     if (occurrences == 0) {
272       return count(element);
273     }
274     AvlNode<E> root = rootReference.get();
275     int[] result = new int[1]; // used as a mutable int reference to hold result
276     AvlNode<E> newRoot;
277     try {
278       @SuppressWarnings("unchecked")
279       E e = (E) element;
280       if (!range.contains(e) || root == null) {
281         return 0;
282       }
283       newRoot = root.remove(comparator(), e, occurrences, result);
284     } catch (ClassCastException | NullPointerException e) {
285       return 0;
286     }
287     rootReference.checkAndSet(root, newRoot);
288     return result[0];
289   }
290 
291   @CanIgnoreReturnValue
292   @Override
setCount(@ullableDecl E element, int count)293   public int setCount(@NullableDecl E element, int count) {
294     checkNonnegative(count, "count");
295     if (!range.contains(element)) {
296       checkArgument(count == 0);
297       return 0;
298     }
299 
300     AvlNode<E> root = rootReference.get();
301     if (root == null) {
302       if (count > 0) {
303         add(element, count);
304       }
305       return 0;
306     }
307     int[] result = new int[1]; // used as a mutable int reference to hold result
308     AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
309     rootReference.checkAndSet(root, newRoot);
310     return result[0];
311   }
312 
313   @CanIgnoreReturnValue
314   @Override
setCount(@ullableDecl E element, int oldCount, int newCount)315   public boolean setCount(@NullableDecl E element, int oldCount, int newCount) {
316     checkNonnegative(newCount, "newCount");
317     checkNonnegative(oldCount, "oldCount");
318     checkArgument(range.contains(element));
319 
320     AvlNode<E> root = rootReference.get();
321     if (root == null) {
322       if (oldCount == 0) {
323         if (newCount > 0) {
324           add(element, newCount);
325         }
326         return true;
327       } else {
328         return false;
329       }
330     }
331     int[] result = new int[1]; // used as a mutable int reference to hold result
332     AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
333     rootReference.checkAndSet(root, newRoot);
334     return result[0] == oldCount;
335   }
336 
337   @Override
clear()338   public void clear() {
339     if (!range.hasLowerBound() && !range.hasUpperBound()) {
340       // We can do this in O(n) rather than removing one by one, which could force rebalancing.
341       for (AvlNode<E> current = header.succ; current != header; ) {
342         AvlNode<E> next = current.succ;
343 
344         current.elemCount = 0;
345         // Also clear these fields so that one deleted Entry doesn't retain all elements.
346         current.left = null;
347         current.right = null;
348         current.pred = null;
349         current.succ = null;
350 
351         current = next;
352       }
353       successor(header, header);
354       rootReference.clear();
355     } else {
356       // TODO(cpovirk): Perhaps we can optimize in this case, too?
357       Iterators.clear(entryIterator());
358     }
359   }
360 
wrapEntry(final AvlNode<E> baseEntry)361   private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
362     return new Multisets.AbstractEntry<E>() {
363       @Override
364       public E getElement() {
365         return baseEntry.getElement();
366       }
367 
368       @Override
369       public int getCount() {
370         int result = baseEntry.getCount();
371         if (result == 0) {
372           return count(getElement());
373         } else {
374           return result;
375         }
376       }
377     };
378   }
379 
380   /** Returns the first node in the tree that is in range. */
381   @NullableDecl
382   private AvlNode<E> firstNode() {
383     AvlNode<E> root = rootReference.get();
384     if (root == null) {
385       return null;
386     }
387     AvlNode<E> node;
388     if (range.hasLowerBound()) {
389       E endpoint = range.getLowerEndpoint();
390       node = rootReference.get().ceiling(comparator(), endpoint);
391       if (node == null) {
392         return null;
393       }
394       if (range.getLowerBoundType() == BoundType.OPEN
395           && comparator().compare(endpoint, node.getElement()) == 0) {
396         node = node.succ;
397       }
398     } else {
399       node = header.succ;
400     }
401     return (node == header || !range.contains(node.getElement())) ? null : node;
402   }
403 
404   @NullableDecl
405   private AvlNode<E> lastNode() {
406     AvlNode<E> root = rootReference.get();
407     if (root == null) {
408       return null;
409     }
410     AvlNode<E> node;
411     if (range.hasUpperBound()) {
412       E endpoint = range.getUpperEndpoint();
413       node = rootReference.get().floor(comparator(), endpoint);
414       if (node == null) {
415         return null;
416       }
417       if (range.getUpperBoundType() == BoundType.OPEN
418           && comparator().compare(endpoint, node.getElement()) == 0) {
419         node = node.pred;
420       }
421     } else {
422       node = header.pred;
423     }
424     return (node == header || !range.contains(node.getElement())) ? null : node;
425   }
426 
427   @Override
428   Iterator<E> elementIterator() {
429     return Multisets.elementIterator(entryIterator());
430   }
431 
432   @Override
433   Iterator<Entry<E>> entryIterator() {
434     return new Iterator<Entry<E>>() {
435       AvlNode<E> current = firstNode();
436       @NullableDecl Entry<E> prevEntry;
437 
438       @Override
439       public boolean hasNext() {
440         if (current == null) {
441           return false;
442         } else if (range.tooHigh(current.getElement())) {
443           current = null;
444           return false;
445         } else {
446           return true;
447         }
448       }
449 
450       @Override
451       public Entry<E> next() {
452         if (!hasNext()) {
453           throw new NoSuchElementException();
454         }
455         Entry<E> result = wrapEntry(current);
456         prevEntry = result;
457         if (current.succ == header) {
458           current = null;
459         } else {
460           current = current.succ;
461         }
462         return result;
463       }
464 
465       @Override
466       public void remove() {
467         checkRemove(prevEntry != null);
468         setCount(prevEntry.getElement(), 0);
469         prevEntry = null;
470       }
471     };
472   }
473 
474   @Override
475   Iterator<Entry<E>> descendingEntryIterator() {
476     return new Iterator<Entry<E>>() {
477       AvlNode<E> current = lastNode();
478       Entry<E> prevEntry = null;
479 
480       @Override
481       public boolean hasNext() {
482         if (current == null) {
483           return false;
484         } else if (range.tooLow(current.getElement())) {
485           current = null;
486           return false;
487         } else {
488           return true;
489         }
490       }
491 
492       @Override
493       public Entry<E> next() {
494         if (!hasNext()) {
495           throw new NoSuchElementException();
496         }
497         Entry<E> result = wrapEntry(current);
498         prevEntry = result;
499         if (current.pred == header) {
500           current = null;
501         } else {
502           current = current.pred;
503         }
504         return result;
505       }
506 
507       @Override
508       public void remove() {
509         checkRemove(prevEntry != null);
510         setCount(prevEntry.getElement(), 0);
511         prevEntry = null;
512       }
513     };
514   }
515 
516   @Override
517   public Iterator<E> iterator() {
518     return Multisets.iteratorImpl(this);
519   }
520 
521   @Override
522   public SortedMultiset<E> headMultiset(@NullableDecl E upperBound, BoundType boundType) {
523     return new TreeMultiset<E>(
524         rootReference,
525         range.intersect(GeneralRange.upTo(comparator(), upperBound, boundType)),
526         header);
527   }
528 
529   @Override
530   public SortedMultiset<E> tailMultiset(@NullableDecl E lowerBound, BoundType boundType) {
531     return new TreeMultiset<E>(
532         rootReference,
533         range.intersect(GeneralRange.downTo(comparator(), lowerBound, boundType)),
534         header);
535   }
536 
537   private static final class Reference<T> {
538     @NullableDecl private T value;
539 
540     @NullableDecl
541     public T get() {
542       return value;
543     }
544 
545     public void checkAndSet(@NullableDecl T expected, T newValue) {
546       if (value != expected) {
547         throw new ConcurrentModificationException();
548       }
549       value = newValue;
550     }
551 
552     void clear() {
553       value = null;
554     }
555   }
556 
557   private static final class AvlNode<E> {
558     @NullableDecl private final E elem;
559 
560     // elemCount is 0 iff this node has been deleted.
561     private int elemCount;
562 
563     private int distinctElements;
564     private long totalCount;
565     private int height;
566     @NullableDecl private AvlNode<E> left;
567     @NullableDecl private AvlNode<E> right;
568     @NullableDecl private AvlNode<E> pred;
569     @NullableDecl private AvlNode<E> succ;
570 
571     AvlNode(@NullableDecl E elem, int elemCount) {
572       checkArgument(elemCount > 0);
573       this.elem = elem;
574       this.elemCount = elemCount;
575       this.totalCount = elemCount;
576       this.distinctElements = 1;
577       this.height = 1;
578       this.left = null;
579       this.right = null;
580     }
581 
582     public int count(Comparator<? super E> comparator, E e) {
583       int cmp = comparator.compare(e, elem);
584       if (cmp < 0) {
585         return (left == null) ? 0 : left.count(comparator, e);
586       } else if (cmp > 0) {
587         return (right == null) ? 0 : right.count(comparator, e);
588       } else {
589         return elemCount;
590       }
591     }
592 
593     private AvlNode<E> addRightChild(E e, int count) {
594       right = new AvlNode<E>(e, count);
595       successor(this, right, succ);
596       height = Math.max(2, height);
597       distinctElements++;
598       totalCount += count;
599       return this;
600     }
601 
602     private AvlNode<E> addLeftChild(E e, int count) {
603       left = new AvlNode<E>(e, count);
604       successor(pred, left, this);
605       height = Math.max(2, height);
606       distinctElements++;
607       totalCount += count;
608       return this;
609     }
610 
611     AvlNode<E> add(Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
612       /*
613        * It speeds things up considerably to unconditionally add count to totalCount here,
614        * but that destroys failure atomicity in the case of count overflow. =(
615        */
616       int cmp = comparator.compare(e, elem);
617       if (cmp < 0) {
618         AvlNode<E> initLeft = left;
619         if (initLeft == null) {
620           result[0] = 0;
621           return addLeftChild(e, count);
622         }
623         int initHeight = initLeft.height;
624 
625         left = initLeft.add(comparator, e, count, result);
626         if (result[0] == 0) {
627           distinctElements++;
628         }
629         this.totalCount += count;
630         return (left.height == initHeight) ? this : rebalance();
631       } else if (cmp > 0) {
632         AvlNode<E> initRight = right;
633         if (initRight == null) {
634           result[0] = 0;
635           return addRightChild(e, count);
636         }
637         int initHeight = initRight.height;
638 
639         right = initRight.add(comparator, e, count, result);
640         if (result[0] == 0) {
641           distinctElements++;
642         }
643         this.totalCount += count;
644         return (right.height == initHeight) ? this : rebalance();
645       }
646 
647       // adding count to me!  No rebalance possible.
648       result[0] = elemCount;
649       long resultCount = (long) elemCount + count;
650       checkArgument(resultCount <= Integer.MAX_VALUE);
651       this.elemCount += count;
652       this.totalCount += count;
653       return this;
654     }
655 
656     AvlNode<E> remove(
657         Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
658       int cmp = comparator.compare(e, elem);
659       if (cmp < 0) {
660         AvlNode<E> initLeft = left;
661         if (initLeft == null) {
662           result[0] = 0;
663           return this;
664         }
665 
666         left = initLeft.remove(comparator, e, count, result);
667 
668         if (result[0] > 0) {
669           if (count >= result[0]) {
670             this.distinctElements--;
671             this.totalCount -= result[0];
672           } else {
673             this.totalCount -= count;
674           }
675         }
676         return (result[0] == 0) ? this : rebalance();
677       } else if (cmp > 0) {
678         AvlNode<E> initRight = right;
679         if (initRight == null) {
680           result[0] = 0;
681           return this;
682         }
683 
684         right = initRight.remove(comparator, e, count, result);
685 
686         if (result[0] > 0) {
687           if (count >= result[0]) {
688             this.distinctElements--;
689             this.totalCount -= result[0];
690           } else {
691             this.totalCount -= count;
692           }
693         }
694         return rebalance();
695       }
696 
697       // removing count from me!
698       result[0] = elemCount;
699       if (count >= elemCount) {
700         return deleteMe();
701       } else {
702         this.elemCount -= count;
703         this.totalCount -= count;
704         return this;
705       }
706     }
707 
708     AvlNode<E> setCount(
709         Comparator<? super E> comparator, @NullableDecl E e, int count, int[] result) {
710       int cmp = comparator.compare(e, elem);
711       if (cmp < 0) {
712         AvlNode<E> initLeft = left;
713         if (initLeft == null) {
714           result[0] = 0;
715           return (count > 0) ? addLeftChild(e, count) : this;
716         }
717 
718         left = initLeft.setCount(comparator, e, count, result);
719 
720         if (count == 0 && result[0] != 0) {
721           this.distinctElements--;
722         } else if (count > 0 && result[0] == 0) {
723           this.distinctElements++;
724         }
725 
726         this.totalCount += count - result[0];
727         return rebalance();
728       } else if (cmp > 0) {
729         AvlNode<E> initRight = right;
730         if (initRight == null) {
731           result[0] = 0;
732           return (count > 0) ? addRightChild(e, count) : this;
733         }
734 
735         right = initRight.setCount(comparator, e, count, result);
736 
737         if (count == 0 && result[0] != 0) {
738           this.distinctElements--;
739         } else if (count > 0 && result[0] == 0) {
740           this.distinctElements++;
741         }
742 
743         this.totalCount += count - result[0];
744         return rebalance();
745       }
746 
747       // setting my count
748       result[0] = elemCount;
749       if (count == 0) {
750         return deleteMe();
751       }
752       this.totalCount += count - elemCount;
753       this.elemCount = count;
754       return this;
755     }
756 
757     AvlNode<E> setCount(
758         Comparator<? super E> comparator,
759         @NullableDecl E e,
760         int expectedCount,
761         int newCount,
762         int[] result) {
763       int cmp = comparator.compare(e, elem);
764       if (cmp < 0) {
765         AvlNode<E> initLeft = left;
766         if (initLeft == null) {
767           result[0] = 0;
768           if (expectedCount == 0 && newCount > 0) {
769             return addLeftChild(e, newCount);
770           }
771           return this;
772         }
773 
774         left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
775 
776         if (result[0] == expectedCount) {
777           if (newCount == 0 && result[0] != 0) {
778             this.distinctElements--;
779           } else if (newCount > 0 && result[0] == 0) {
780             this.distinctElements++;
781           }
782           this.totalCount += newCount - result[0];
783         }
784         return rebalance();
785       } else if (cmp > 0) {
786         AvlNode<E> initRight = right;
787         if (initRight == null) {
788           result[0] = 0;
789           if (expectedCount == 0 && newCount > 0) {
790             return addRightChild(e, newCount);
791           }
792           return this;
793         }
794 
795         right = initRight.setCount(comparator, e, expectedCount, newCount, result);
796 
797         if (result[0] == expectedCount) {
798           if (newCount == 0 && result[0] != 0) {
799             this.distinctElements--;
800           } else if (newCount > 0 && result[0] == 0) {
801             this.distinctElements++;
802           }
803           this.totalCount += newCount - result[0];
804         }
805         return rebalance();
806       }
807 
808       // setting my count
809       result[0] = elemCount;
810       if (expectedCount == elemCount) {
811         if (newCount == 0) {
812           return deleteMe();
813         }
814         this.totalCount += newCount - elemCount;
815         this.elemCount = newCount;
816       }
817       return this;
818     }
819 
820     private AvlNode<E> deleteMe() {
821       int oldElemCount = this.elemCount;
822       this.elemCount = 0;
823       successor(pred, succ);
824       if (left == null) {
825         return right;
826       } else if (right == null) {
827         return left;
828       } else if (left.height >= right.height) {
829         AvlNode<E> newTop = pred;
830         // newTop is the maximum node in my left subtree
831         newTop.left = left.removeMax(newTop);
832         newTop.right = right;
833         newTop.distinctElements = distinctElements - 1;
834         newTop.totalCount = totalCount - oldElemCount;
835         return newTop.rebalance();
836       } else {
837         AvlNode<E> newTop = succ;
838         newTop.right = right.removeMin(newTop);
839         newTop.left = left;
840         newTop.distinctElements = distinctElements - 1;
841         newTop.totalCount = totalCount - oldElemCount;
842         return newTop.rebalance();
843       }
844     }
845 
846     // Removes the minimum node from this subtree to be reused elsewhere
847     private AvlNode<E> removeMin(AvlNode<E> node) {
848       if (left == null) {
849         return right;
850       } else {
851         left = left.removeMin(node);
852         distinctElements--;
853         totalCount -= node.elemCount;
854         return rebalance();
855       }
856     }
857 
858     // Removes the maximum node from this subtree to be reused elsewhere
859     private AvlNode<E> removeMax(AvlNode<E> node) {
860       if (right == null) {
861         return left;
862       } else {
863         right = right.removeMax(node);
864         distinctElements--;
865         totalCount -= node.elemCount;
866         return rebalance();
867       }
868     }
869 
870     private void recomputeMultiset() {
871       this.distinctElements =
872           1 + TreeMultiset.distinctElements(left) + TreeMultiset.distinctElements(right);
873       this.totalCount = elemCount + totalCount(left) + totalCount(right);
874     }
875 
876     private void recomputeHeight() {
877       this.height = 1 + Math.max(height(left), height(right));
878     }
879 
880     private void recompute() {
881       recomputeMultiset();
882       recomputeHeight();
883     }
884 
885     private AvlNode<E> rebalance() {
886       switch (balanceFactor()) {
887         case -2:
888           if (right.balanceFactor() > 0) {
889             right = right.rotateRight();
890           }
891           return rotateLeft();
892         case 2:
893           if (left.balanceFactor() < 0) {
894             left = left.rotateLeft();
895           }
896           return rotateRight();
897         default:
898           recomputeHeight();
899           return this;
900       }
901     }
902 
903     private int balanceFactor() {
904       return height(left) - height(right);
905     }
906 
907     private AvlNode<E> rotateLeft() {
908       checkState(right != null);
909       AvlNode<E> newTop = right;
910       this.right = newTop.left;
911       newTop.left = this;
912       newTop.totalCount = this.totalCount;
913       newTop.distinctElements = this.distinctElements;
914       this.recompute();
915       newTop.recomputeHeight();
916       return newTop;
917     }
918 
919     private AvlNode<E> rotateRight() {
920       checkState(left != null);
921       AvlNode<E> newTop = left;
922       this.left = newTop.right;
923       newTop.right = this;
924       newTop.totalCount = this.totalCount;
925       newTop.distinctElements = this.distinctElements;
926       this.recompute();
927       newTop.recomputeHeight();
928       return newTop;
929     }
930 
931     private static long totalCount(@NullableDecl AvlNode<?> node) {
932       return (node == null) ? 0 : node.totalCount;
933     }
934 
935     private static int height(@NullableDecl AvlNode<?> node) {
936       return (node == null) ? 0 : node.height;
937     }
938 
939     @NullableDecl
940     private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
941       int cmp = comparator.compare(e, elem);
942       if (cmp < 0) {
943         return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
944       } else if (cmp == 0) {
945         return this;
946       } else {
947         return (right == null) ? null : right.ceiling(comparator, e);
948       }
949     }
950 
951     @NullableDecl
952     private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
953       int cmp = comparator.compare(e, elem);
954       if (cmp > 0) {
955         return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
956       } else if (cmp == 0) {
957         return this;
958       } else {
959         return (left == null) ? null : left.floor(comparator, e);
960       }
961     }
962 
963     E getElement() {
964       return elem;
965     }
966 
967     int getCount() {
968       return elemCount;
969     }
970 
971     @Override
972     public String toString() {
973       return Multisets.immutableEntry(getElement(), getCount()).toString();
974     }
975   }
976 
977   private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
978     a.succ = b;
979     b.pred = a;
980   }
981 
982   private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
983     successor(a, b);
984     successor(b, c);
985   }
986 
987   /*
988    * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
989    * calls the comparator to compare the two keys. If that change is made,
990    * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
991    */
992 
993   /**
994    * @serialData the comparator, the number of distinct elements, the first element, its count, the
995    *     second element, its count, and so on
996    */
997   @GwtIncompatible // java.io.ObjectOutputStream
998   private void writeObject(ObjectOutputStream stream) throws IOException {
999     stream.defaultWriteObject();
1000     stream.writeObject(elementSet().comparator());
1001     Serialization.writeMultiset(this, stream);
1002   }
1003 
1004   @GwtIncompatible // java.io.ObjectInputStream
1005   private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
1006     stream.defaultReadObject();
1007     @SuppressWarnings("unchecked")
1008     // reading data stored by writeObject
1009     Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
1010     Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
1011     Serialization.getFieldSetter(TreeMultiset.class, "range")
1012         .set(this, GeneralRange.all(comparator));
1013     Serialization.getFieldSetter(TreeMultiset.class, "rootReference")
1014         .set(this, new Reference<AvlNode<E>>());
1015     AvlNode<E> header = new AvlNode<E>(null, 1);
1016     Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
1017     successor(header, header);
1018     Serialization.populateMultiset(this, stream);
1019   }
1020 
1021   @GwtIncompatible // not needed in emulated source
1022   private static final long serialVersionUID = 1;
1023 }
1024