• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Guava Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.common.collect;
17 
18 import static com.google.common.truth.Truth.assertWithMessage;
19 
20 import com.google.common.annotations.GwtIncompatible;
21 import com.google.errorprone.annotations.CanIgnoreReturnValue;
22 import java.util.Arrays;
23 import java.util.Collections;
24 import java.util.List;
25 import java.util.Map;
26 import java.util.Set;
27 import java.util.function.BiConsumer;
28 import java.util.function.IntToDoubleFunction;
29 import java.util.function.Supplier;
30 import junit.framework.TestCase;
31 import org.checkerframework.checker.nullness.qual.Nullable;
32 
33 /**
34  * Abstract superclass for tests that hash flooding a collection has controlled worst-case
35  * performance.
36  */
37 @GwtIncompatible
38 public abstract class AbstractHashFloodingTest<T> extends TestCase {
39   private final List<Construction<T>> constructions;
40   private final IntToDoubleFunction constructionAsymptotics;
41   private final List<QueryOp<T>> queries;
42 
AbstractHashFloodingTest( List<Construction<T>> constructions, IntToDoubleFunction constructionAsymptotics, List<QueryOp<T>> queries)43   AbstractHashFloodingTest(
44       List<Construction<T>> constructions,
45       IntToDoubleFunction constructionAsymptotics,
46       List<QueryOp<T>> queries) {
47     this.constructions = constructions;
48     this.constructionAsymptotics = constructionAsymptotics;
49     this.queries = queries;
50   }
51 
52   /**
53    * A Comparable wrapper around a String which executes callbacks on calls to hashCode, equals, and
54    * compareTo.
55    */
56   private static class CountsHashCodeAndEquals implements Comparable<CountsHashCodeAndEquals> {
57     private final String delegateString;
58     private final Runnable onHashCode;
59     private final Runnable onEquals;
60     private final Runnable onCompareTo;
61 
CountsHashCodeAndEquals( String delegateString, Runnable onHashCode, Runnable onEquals, Runnable onCompareTo)62     CountsHashCodeAndEquals(
63         String delegateString, Runnable onHashCode, Runnable onEquals, Runnable onCompareTo) {
64       this.delegateString = delegateString;
65       this.onHashCode = onHashCode;
66       this.onEquals = onEquals;
67       this.onCompareTo = onCompareTo;
68     }
69 
70     @Override
hashCode()71     public int hashCode() {
72       onHashCode.run();
73       return delegateString.hashCode();
74     }
75 
76     @Override
equals(@ullable Object other)77     public boolean equals(@Nullable Object other) {
78       onEquals.run();
79       return other instanceof CountsHashCodeAndEquals
80           && delegateString.equals(((CountsHashCodeAndEquals) other).delegateString);
81     }
82 
83     @Override
compareTo(CountsHashCodeAndEquals o)84     public int compareTo(CountsHashCodeAndEquals o) {
85       onCompareTo.run();
86       return delegateString.compareTo(o.delegateString);
87     }
88   }
89 
90   /** A holder of counters for calls to hashCode, equals, and compareTo. */
91   private static final class CallsCounter {
92     long hashCode;
93     long equals;
94     long compareTo;
95 
total()96     long total() {
97       return hashCode + equals + compareTo;
98     }
99 
zero()100     void zero() {
101       hashCode = 0;
102       equals = 0;
103       compareTo = 0;
104     }
105   }
106 
107   @FunctionalInterface
108   interface Construction<T> {
109     @CanIgnoreReturnValue
create(List<?> keys)110     abstract T create(List<?> keys);
111 
mapFromKeys( Supplier<Map<Object, Object>> mutableSupplier)112     static Construction<Map<Object, Object>> mapFromKeys(
113         Supplier<Map<Object, Object>> mutableSupplier) {
114       return keys -> {
115         Map<Object, Object> map = mutableSupplier.get();
116         for (Object key : keys) {
117           map.put(key, new Object());
118         }
119         return map;
120       };
121     }
122 
setFromElements(Supplier<Set<Object>> mutableSupplier)123     static Construction<Set<Object>> setFromElements(Supplier<Set<Object>> mutableSupplier) {
124       return elements -> {
125         Set<Object> set = mutableSupplier.get();
126         set.addAll(elements);
127         return set;
128       };
129     }
130   }
131 
132   abstract static class QueryOp<T> {
133     static <T> QueryOp<T> create(
134         String name, BiConsumer<T, Object> queryLambda, IntToDoubleFunction asymptotic) {
135       return new QueryOp<T>() {
136         @Override
137         void apply(T collection, Object query) {
138           queryLambda.accept(collection, query);
139         }
140 
141         @Override
142         double expectedAsymptotic(int n) {
143           return asymptotic.applyAsDouble(n);
144         }
145 
146         @Override
147         public String toString() {
148           return name;
149         }
150       };
151     }
152 
153     static final QueryOp<Map<Object, Object>> MAP_GET =
154         QueryOp.create("Map.get", Map::get, Math::log);
155 
156     @SuppressWarnings("ReturnValueIgnored")
157     static final QueryOp<Set<Object>> SET_CONTAINS =
158         QueryOp.create("Set.contains", Set::contains, Math::log);
159 
160     abstract void apply(T collection, Object query);
161 
162     abstract double expectedAsymptotic(int n);
163   }
164 
165   /**
166    * Returns a list of objects with the same hash code, of size 2^power, counting calls to equals,
167    * hashCode, and compareTo in counter.
168    */
169   static List<CountsHashCodeAndEquals> createAdversarialInput(int power, CallsCounter counter) {
170     String str1 = "Aa";
171     String str2 = "BB";
172     assertEquals(str1.hashCode(), str2.hashCode());
173     List<String> haveSameHashes2 = Arrays.asList(str1, str2);
174     List<CountsHashCodeAndEquals> result =
175         Lists.newArrayList(
176             Lists.transform(
177                 Lists.cartesianProduct(Collections.nCopies(power, haveSameHashes2)),
178                 strs ->
179                     new CountsHashCodeAndEquals(
180                         String.join("", strs),
181                         () -> counter.hashCode++,
182                         () -> counter.equals++,
183                         () -> counter.compareTo++)));
184     assertEquals(
185         result.get(0).delegateString.hashCode(),
186         result.get(result.size() - 1).delegateString.hashCode());
187     return result;
188   }
189 
190   public void testResistsHashFloodingInConstruction() {
191     CallsCounter smallCounter = new CallsCounter();
192     List<CountsHashCodeAndEquals> haveSameHashesSmall = createAdversarialInput(10, smallCounter);
193     int smallSize = haveSameHashesSmall.size();
194 
195     CallsCounter largeCounter = new CallsCounter();
196     List<CountsHashCodeAndEquals> haveSameHashesLarge = createAdversarialInput(15, largeCounter);
197     int largeSize = haveSameHashesLarge.size();
198 
199     for (Construction<T> pathway : constructions) {
200       smallCounter.zero();
201       pathway.create(haveSameHashesSmall);
202       long smallOps = smallCounter.total();
203 
204       largeCounter.zero();
205       pathway.create(haveSameHashesLarge);
206       long largeOps = largeCounter.total();
207 
208       double ratio = (double) largeOps / smallOps;
209       assertWithMessage(
210               "ratio of equals/hashCode/compareTo operations to build with %s entries versus %s"
211                   + " entries",
212               largeSize, smallSize)
213           .that(ratio)
214           .isAtMost(
215               2
216                   * constructionAsymptotics.applyAsDouble(largeSize)
217                   / constructionAsymptotics.applyAsDouble(smallSize));
218       // allow up to 2x wobble in the constant factors
219     }
220   }
221 
222   public void testResistsHashFloodingOnQuery() {
223     CallsCounter smallCounter = new CallsCounter();
224     List<CountsHashCodeAndEquals> haveSameHashesSmall = createAdversarialInput(10, smallCounter);
225     int smallSize = haveSameHashesSmall.size();
226 
227     CallsCounter largeCounter = new CallsCounter();
228     List<CountsHashCodeAndEquals> haveSameHashesLarge = createAdversarialInput(15, largeCounter);
229     int largeSize = haveSameHashesLarge.size();
230 
231     for (QueryOp<T> query : queries) {
232       for (Construction<T> pathway : constructions) {
233         long worstSmallOps = getWorstCaseOps(smallCounter, haveSameHashesSmall, query, pathway);
234         long worstLargeOps = getWorstCaseOps(largeCounter, haveSameHashesLarge, query, pathway);
235 
236         double ratio = (double) worstLargeOps / worstSmallOps;
237         assertWithMessage(
238                 "ratio of equals/hashCode/compareTo operations to query %s with %s entries versus"
239                     + " %s entries",
240                 query, largeSize, smallSize)
241             .that(ratio)
242             .isAtMost(
243                 2 * query.expectedAsymptotic(largeSize) / query.expectedAsymptotic(smallSize));
244         // allow up to 2x wobble in the constant factors
245       }
246     }
247   }
248 
249   private long getWorstCaseOps(
250       CallsCounter counter,
251       List<CountsHashCodeAndEquals> haveSameHashes,
252       QueryOp<T> query,
253       Construction<T> pathway) {
254     T collection = pathway.create(haveSameHashes);
255     long worstOps = 0;
256     for (Object o : haveSameHashes) {
257       counter.zero();
258       query.apply(collection, o);
259       worstOps = Math.max(worstOps, counter.total());
260     }
261     return worstOps;
262   }
263 }
264