• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.common.collect;
17 
18 import static com.google.common.truth.Truth.assertWithMessage;
19 
20 import com.google.common.annotations.GwtCompatible;
21 import com.google.common.annotations.GwtIncompatible;
22 import com.google.errorprone.annotations.CanIgnoreReturnValue;
23 import java.util.Arrays;
24 import java.util.Collections;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Set;
28 import java.util.function.BiConsumer;
29 import java.util.function.IntToDoubleFunction;
30 import java.util.function.Supplier;
31 import junit.framework.TestCase;
32 import org.checkerframework.checker.nullness.qual.Nullable;
33 
34 /**
35  * Abstract superclass for tests that hash flooding a collection has controlled worst-case
36  * performance.
37  */
38 @GwtCompatible
39 public abstract class AbstractHashFloodingTest<T> extends TestCase {
40   private final List<Construction<T>> constructions;
41   private final IntToDoubleFunction constructionAsymptotics;
42   private final List<QueryOp<T>> queries;
43 
AbstractHashFloodingTest( List<Construction<T>> constructions, IntToDoubleFunction constructionAsymptotics, List<QueryOp<T>> queries)44   AbstractHashFloodingTest(
45       List<Construction<T>> constructions,
46       IntToDoubleFunction constructionAsymptotics,
47       List<QueryOp<T>> queries) {
48     this.constructions = constructions;
49     this.constructionAsymptotics = constructionAsymptotics;
50     this.queries = queries;
51   }
52 
53   /**
54    * A Comparable wrapper around a String which executes callbacks on calls to hashCode, equals, and
55    * compareTo.
56    */
57   private static class CountsHashCodeAndEquals implements Comparable<CountsHashCodeAndEquals> {
58     private final String delegateString;
59     private final Runnable onHashCode;
60     private final Runnable onEquals;
61     private final Runnable onCompareTo;
62 
CountsHashCodeAndEquals( String delegateString, Runnable onHashCode, Runnable onEquals, Runnable onCompareTo)63     CountsHashCodeAndEquals(
64         String delegateString, Runnable onHashCode, Runnable onEquals, Runnable onCompareTo) {
65       this.delegateString = delegateString;
66       this.onHashCode = onHashCode;
67       this.onEquals = onEquals;
68       this.onCompareTo = onCompareTo;
69     }
70 
71     @Override
hashCode()72     public int hashCode() {
73       onHashCode.run();
74       return delegateString.hashCode();
75     }
76 
77     @Override
equals(@ullable Object other)78     public boolean equals(@Nullable Object other) {
79       onEquals.run();
80       return other instanceof CountsHashCodeAndEquals
81           && delegateString.equals(((CountsHashCodeAndEquals) other).delegateString);
82     }
83 
84     @Override
compareTo(CountsHashCodeAndEquals o)85     public int compareTo(CountsHashCodeAndEquals o) {
86       onCompareTo.run();
87       return delegateString.compareTo(o.delegateString);
88     }
89   }
90 
91   /** A holder of counters for calls to hashCode, equals, and compareTo. */
92   private static final class CallsCounter {
93     long hashCode;
94     long equals;
95     long compareTo;
96 
total()97     long total() {
98       return hashCode + equals + compareTo;
99     }
100 
zero()101     void zero() {
102       hashCode = 0;
103       equals = 0;
104       compareTo = 0;
105     }
106   }
107 
108   @FunctionalInterface
109   interface Construction<T> {
110     @CanIgnoreReturnValue
create(List<?> keys)111     abstract T create(List<?> keys);
112 
mapFromKeys( Supplier<Map<Object, Object>> mutableSupplier)113     static Construction<Map<Object, Object>> mapFromKeys(
114         Supplier<Map<Object, Object>> mutableSupplier) {
115       return keys -> {
116         Map<Object, Object> map = mutableSupplier.get();
117         for (Object key : keys) {
118           map.put(key, new Object());
119         }
120         return map;
121       };
122     }
123 
setFromElements(Supplier<Set<Object>> mutableSupplier)124     static Construction<Set<Object>> setFromElements(Supplier<Set<Object>> mutableSupplier) {
125       return elements -> {
126         Set<Object> set = mutableSupplier.get();
127         set.addAll(elements);
128         return set;
129       };
130     }
131   }
132 
133   abstract static class QueryOp<T> {
134     static <T> QueryOp<T> create(
135         String name, BiConsumer<T, Object> queryLambda, IntToDoubleFunction asymptotic) {
136       return new QueryOp<T>() {
137         @Override
138         void apply(T collection, Object query) {
139           queryLambda.accept(collection, query);
140         }
141 
142         @Override
143         double expectedAsymptotic(int n) {
144           return asymptotic.applyAsDouble(n);
145         }
146 
147         @Override
148         public String toString() {
149           return name;
150         }
151       };
152     }
153 
154     static final QueryOp<Map<Object, Object>> MAP_GET =
155         QueryOp.create("Map.get", Map::get, Math::log);
156 
157     @SuppressWarnings("ReturnValueIgnored")
158     static final QueryOp<Set<Object>> SET_CONTAINS =
159         QueryOp.create("Set.contains", Set::contains, Math::log);
160 
161     abstract void apply(T collection, Object query);
162 
163     abstract double expectedAsymptotic(int n);
164   }
165 
166   /**
167    * Returns a list of objects with the same hash code, of size 2^power, counting calls to equals,
168    * hashCode, and compareTo in counter.
169    */
170   static List<CountsHashCodeAndEquals> createAdversarialInput(int power, CallsCounter counter) {
171     String str1 = "Aa";
172     String str2 = "BB";
173     assertEquals(str1.hashCode(), str2.hashCode());
174     List<String> haveSameHashes2 = Arrays.asList(str1, str2);
175     List<CountsHashCodeAndEquals> result =
176         Lists.newArrayList(
177             Lists.transform(
178                 Lists.cartesianProduct(Collections.nCopies(power, haveSameHashes2)),
179                 strs ->
180                     new CountsHashCodeAndEquals(
181                         String.join("", strs),
182                         () -> counter.hashCode++,
183                         () -> counter.equals++,
184                         () -> counter.compareTo++)));
185     assertEquals(
186         result.get(0).delegateString.hashCode(),
187         result.get(result.size() - 1).delegateString.hashCode());
188     return result;
189   }
190 
191   @GwtIncompatible
192   public void testResistsHashFloodingInConstruction() {
193     CallsCounter smallCounter = new CallsCounter();
194     List<CountsHashCodeAndEquals> haveSameHashesSmall = createAdversarialInput(10, smallCounter);
195     int smallSize = haveSameHashesSmall.size();
196 
197     CallsCounter largeCounter = new CallsCounter();
198     List<CountsHashCodeAndEquals> haveSameHashesLarge = createAdversarialInput(15, largeCounter);
199     int largeSize = haveSameHashesLarge.size();
200 
201     for (Construction<T> pathway : constructions) {
202       smallCounter.zero();
203       pathway.create(haveSameHashesSmall);
204       long smallOps = smallCounter.total();
205 
206       largeCounter.zero();
207       pathway.create(haveSameHashesLarge);
208       long largeOps = largeCounter.total();
209 
210       double ratio = (double) largeOps / smallOps;
211       assertWithMessage(
212               "ratio of equals/hashCode/compareTo operations to build with %s entries versus %s"
213                   + " entries",
214               largeSize, smallSize)
215           .that(ratio)
216           .isAtMost(
217               2
218                   * constructionAsymptotics.applyAsDouble(largeSize)
219                   / constructionAsymptotics.applyAsDouble(smallSize));
220       // allow up to 2x wobble in the constant factors
221     }
222   }
223 
224   @GwtIncompatible
225   public void testResistsHashFloodingOnQuery() {
226     CallsCounter smallCounter = new CallsCounter();
227     List<CountsHashCodeAndEquals> haveSameHashesSmall = createAdversarialInput(10, smallCounter);
228     int smallSize = haveSameHashesSmall.size();
229 
230     CallsCounter largeCounter = new CallsCounter();
231     List<CountsHashCodeAndEquals> haveSameHashesLarge = createAdversarialInput(15, largeCounter);
232     int largeSize = haveSameHashesLarge.size();
233 
234     for (QueryOp<T> query : queries) {
235       for (Construction<T> pathway : constructions) {
236         long worstSmallOps = getWorstCaseOps(smallCounter, haveSameHashesSmall, query, pathway);
237         long worstLargeOps = getWorstCaseOps(largeCounter, haveSameHashesLarge, query, pathway);
238 
239         double ratio = (double) worstLargeOps / worstSmallOps;
240         assertWithMessage(
241                 "ratio of equals/hashCode/compareTo operations to query %s with %s entries versus"
242                     + " %s entries",
243                 query, largeSize, smallSize)
244             .that(ratio)
245             .isAtMost(
246                 2 * query.expectedAsymptotic(largeSize) / query.expectedAsymptotic(smallSize));
247         // allow up to 2x wobble in the constant factors
248       }
249     }
250   }
251 
252   private long getWorstCaseOps(
253       CallsCounter counter,
254       List<CountsHashCodeAndEquals> haveSameHashes,
255       QueryOp<T> query,
256       Construction<T> pathway) {
257     T collection = pathway.create(haveSameHashes);
258     long worstOps = 0;
259     for (Object o : haveSameHashes) {
260       counter.zero();
261       query.apply(collection, o);
262       worstOps = Math.max(worstOps, counter.total());
263     }
264     return worstOps;
265   }
266 }
267