• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ////////////////////////////////////////////////////////////////////////////////
16 
17 package com.google.crypto.tink.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static java.nio.charset.StandardCharsets.UTF_8;
21 import static java.util.concurrent.TimeUnit.SECONDS;
22 
23 import com.google.crypto.tink.InsecureSecretKeyAccess;
24 import com.google.crypto.tink.Key;
25 import com.google.crypto.tink.Parameters;
26 import com.google.crypto.tink.SecretKeyAccess;
27 import com.google.crypto.tink.util.Bytes;
28 import com.google.errorprone.annotations.Immutable;
29 import java.nio.ByteBuffer;
30 import java.security.GeneralSecurityException;
31 import java.util.ArrayList;
32 import java.util.List;
33 import java.util.concurrent.ExecutorService;
34 import java.util.concurrent.Executors;
35 import java.util.concurrent.Future;
36 import javax.annotation.Nullable;
37 import org.junit.Test;
38 import org.junit.runner.RunWith;
39 import org.junit.runners.JUnit4;
40 
41 /** Thread safety tests for {@link MutableSerializationRegistry}. */
42 @RunWith(JUnit4.class)
43 public final class MutableSerializationRegistryMultithreadTest {
44   private static final SecretKeyAccess ACCESS = InsecureSecretKeyAccess.get();
45 
46   private static final Bytes A_1 = Bytes.copyFrom("0".getBytes(UTF_8));
47   private static final Bytes A_2 = Bytes.copyFrom("1".getBytes(UTF_8));
48   private static final Bytes B_1 = Bytes.copyFrom("1".getBytes(UTF_8));
49   private static final Bytes B_2 = Bytes.copyFrom("2".getBytes(UTF_8));
50 
51   @Immutable
52   private static final class TestParameters1 extends Parameters {
53     @Override
hasIdRequirement()54     public boolean hasIdRequirement() {
55       return false;
56     }
57   }
58 
59   @Immutable
60   private static final class TestParameters2 extends Parameters {
61     @Override
hasIdRequirement()62     public boolean hasIdRequirement() {
63       return false;
64     }
65   }
66 
67   @Immutable
68   private static final class TestKey1 extends Key {
69     @Override
getParameters()70     public Parameters getParameters() {
71       throw new UnsupportedOperationException("Not needed in test");
72     }
73 
74     @Override
75     @Nullable
getIdRequirementOrNull()76     public Integer getIdRequirementOrNull() {
77       throw new UnsupportedOperationException("Not needed in test");
78     }
79 
80     @Override
equalsKey(Key other)81     public boolean equalsKey(Key other) {
82       throw new UnsupportedOperationException("Not needed in test");
83     }
84   }
85 
86   @Immutable
87   private static final class TestKey2 extends Key {
88     @Override
getParameters()89     public Parameters getParameters() {
90       throw new UnsupportedOperationException("Not needed in test");
91     }
92 
93     @Override
94     @Nullable
getIdRequirementOrNull()95     public Integer getIdRequirementOrNull() {
96       throw new UnsupportedOperationException("Not needed in test");
97     }
98 
99     @Override
equalsKey(Key other)100     public boolean equalsKey(Key other) {
101       throw new UnsupportedOperationException("Not needed in test");
102     }
103   }
104 
105   @Immutable
106   private static final class TestSerializationA implements Serialization {
TestSerializationA(Bytes objectIdentifier)107     public TestSerializationA(Bytes objectIdentifier) {
108       this.objectIdentifier = objectIdentifier;
109     }
110 
111     private final Bytes objectIdentifier;
112 
113     @Override
getObjectIdentifier()114     public Bytes getObjectIdentifier() {
115       return objectIdentifier;
116     }
117   }
118 
119   @Immutable
120   private static final class TestSerializationB implements Serialization {
TestSerializationB(Bytes objectIdentifier)121     public TestSerializationB(Bytes objectIdentifier) {
122       this.objectIdentifier = objectIdentifier;
123     }
124 
125     private final Bytes objectIdentifier;
126 
127     @Override
getObjectIdentifier()128     public Bytes getObjectIdentifier() {
129       return objectIdentifier;
130     }
131   }
132 
serializeKey1ToA(TestKey1 key, @Nullable SecretKeyAccess access)133   private static TestSerializationA serializeKey1ToA(TestKey1 key, @Nullable SecretKeyAccess access)
134       throws GeneralSecurityException {
135     SecretKeyAccess.requireAccess(access);
136     return new TestSerializationA(A_1);
137   }
138 
serializeKey2ToA(TestKey2 key, @Nullable SecretKeyAccess access)139   private static TestSerializationA serializeKey2ToA(TestKey2 key, @Nullable SecretKeyAccess access)
140       throws GeneralSecurityException {
141     SecretKeyAccess.requireAccess(access);
142     return new TestSerializationA(A_2);
143   }
144 
serializeKey1ToB(TestKey1 key, @Nullable SecretKeyAccess access)145   private static TestSerializationB serializeKey1ToB(TestKey1 key, @Nullable SecretKeyAccess access)
146       throws GeneralSecurityException {
147     SecretKeyAccess.requireAccess(access);
148     return new TestSerializationB(B_1);
149   }
150 
serializeKey2ToB(TestKey2 key, @Nullable SecretKeyAccess access)151   private static TestSerializationB serializeKey2ToB(TestKey2 key, @Nullable SecretKeyAccess access)
152       throws GeneralSecurityException {
153     SecretKeyAccess.requireAccess(access);
154     return new TestSerializationB(B_2);
155   }
156 
parseAToKey1( TestSerializationA serialization, @Nullable SecretKeyAccess access)157   private static Key parseAToKey1(
158       TestSerializationA serialization, @Nullable SecretKeyAccess access)
159       throws GeneralSecurityException {
160     if (!A_1.equals(serialization.getObjectIdentifier())) {
161       throw new GeneralSecurityException("Wrong object identifier");
162     }
163     SecretKeyAccess.requireAccess(access);
164     return new TestKey1();
165   }
166 
parseBToKey1( TestSerializationB serialization, @Nullable SecretKeyAccess access)167   private static Key parseBToKey1(
168       TestSerializationB serialization, @Nullable SecretKeyAccess access)
169       throws GeneralSecurityException {
170     if (!B_1.equals(serialization.getObjectIdentifier())) {
171       throw new GeneralSecurityException("Wrong object identifier");
172     }
173     SecretKeyAccess.requireAccess(access);
174     return new TestKey1();
175   }
176 
serializeParameters1ToA(TestParameters1 parameters)177   private static TestSerializationA serializeParameters1ToA(TestParameters1 parameters)
178       throws GeneralSecurityException {
179     return new TestSerializationA(A_1);
180   }
181 
serializeParameters2ToA(TestParameters2 parameters)182   private static TestSerializationA serializeParameters2ToA(TestParameters2 parameters)
183       throws GeneralSecurityException {
184     return new TestSerializationA(A_2);
185   }
186 
serializeParameters1ToB(TestParameters1 parameters)187   private static TestSerializationB serializeParameters1ToB(TestParameters1 parameters)
188       throws GeneralSecurityException {
189     return new TestSerializationB(B_1);
190   }
191 
serializeParameters2ToB(TestParameters2 parameters)192   private static TestSerializationB serializeParameters2ToB(TestParameters2 parameters)
193       throws GeneralSecurityException {
194     return new TestSerializationB(B_2);
195   }
196 
parseAToParameters1(TestSerializationA serialization)197   private static Parameters parseAToParameters1(TestSerializationA serialization)
198       throws GeneralSecurityException {
199     if (!A_1.equals(serialization.getObjectIdentifier())) {
200       throw new GeneralSecurityException("Wrong object identifier");
201     }
202     return new TestParameters1();
203   }
204 
parseBToParameters1(TestSerializationB serialization)205   private static Parameters parseBToParameters1(TestSerializationB serialization)
206       throws GeneralSecurityException {
207     if (!B_1.equals(serialization.getObjectIdentifier())) {
208       throw new GeneralSecurityException("Wrong object identifier");
209     }
210     return new TestParameters1();
211   }
212 
213   private static final int REPETITIONS = 1000;
214 
215   @Test
registerAndParseAndSerializeInParallel_works()216   public void registerAndParseAndSerializeInParallel_works() throws Exception {
217     MutableSerializationRegistry registry = new MutableSerializationRegistry();
218     ExecutorService threadPool = Executors.newFixedThreadPool(4);
219     List<Future<?>> futures = new ArrayList<>();
220     registry.registerKeySerializer(
221         KeySerializer.create(
222             MutableSerializationRegistryMultithreadTest::serializeKey1ToA,
223             TestKey1.class,
224             TestSerializationA.class));
225     registry.registerKeyParser(
226         KeyParser.create(
227             MutableSerializationRegistryMultithreadTest::parseAToKey1,
228             A_1,
229             TestSerializationA.class));
230     registry.registerParametersSerializer(
231         ParametersSerializer.create(
232             MutableSerializationRegistryMultithreadTest::serializeParameters1ToA,
233             TestParameters1.class,
234             TestSerializationA.class));
235     registry.registerParametersParser(
236         ParametersParser.create(
237             MutableSerializationRegistryMultithreadTest::parseAToParameters1,
238             A_1,
239             TestSerializationA.class));
240 
241     futures.add(
242         threadPool.submit(
243             () -> {
244               try {
245                 for (int i = 0; i < REPETITIONS; ++i) {
246                   registry.registerKeyParser(
247                       KeyParser.create(
248                           MutableSerializationRegistryMultithreadTest::parseAToKey1,
249                           Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()),
250                           TestSerializationA.class));
251                 }
252               } catch (GeneralSecurityException e) {
253                 throw new RuntimeException(e);
254               }
255             }));
256     futures.add(
257         threadPool.submit(
258             () -> {
259               try {
260                 // This thread mainly wants to do a key serializer registration, but we only have
261                 // one of those, since each needs either a new serialization class, or a new key
262                 // class. So first do a few parsing registrations to mix things up.
263                 for (int i = 0; i < REPETITIONS / 2; ++i) {
264                   registry.registerKeyParser(
265                       KeyParser.create(
266                           MutableSerializationRegistryMultithreadTest::parseBToKey1,
267                           Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()),
268                           TestSerializationB.class));
269                 }
270                 registry.registerKeySerializer(
271                     KeySerializer.create(
272                         MutableSerializationRegistryMultithreadTest::serializeKey2ToA,
273                         TestKey2.class,
274                         TestSerializationA.class));
275                 registry.registerKeySerializer(
276                     KeySerializer.create(
277                         MutableSerializationRegistryMultithreadTest::serializeKey2ToB,
278                         TestKey2.class,
279                         TestSerializationB.class));
280                 registry.registerKeySerializer(
281                     KeySerializer.create(
282                         MutableSerializationRegistryMultithreadTest::serializeKey1ToB,
283                         TestKey1.class,
284                         TestSerializationB.class));
285               } catch (GeneralSecurityException e) {
286                 throw new RuntimeException(e);
287               }
288             }));
289     futures.add(
290         threadPool.submit(
291             () -> {
292               try {
293                 for (int i = 0; i < REPETITIONS; ++i) {
294                   Object unused = registry.parseKey(new TestSerializationA(A_1), ACCESS);
295                 }
296               } catch (GeneralSecurityException e) {
297                 throw new RuntimeException(e);
298               }
299             }));
300     futures.add(
301         threadPool.submit(
302             () -> {
303               try {
304                 for (int i = 0; i < REPETITIONS; ++i) {
305                   Object unused =
306                       registry.serializeKey(new TestKey1(), TestSerializationA.class, ACCESS);
307                 }
308               } catch (GeneralSecurityException e) {
309                 throw new RuntimeException(e);
310               }
311             }));
312     // =============================== More threads doing the same thing, this time for parameters.
313     futures.add(
314         threadPool.submit(
315             () -> {
316               try {
317                 for (int i = 0; i < REPETITIONS; ++i) {
318                   registry.registerParametersParser(
319                       ParametersParser.create(
320                           MutableSerializationRegistryMultithreadTest::parseAToParameters1,
321                           Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()),
322                           TestSerializationA.class));
323                 }
324               } catch (GeneralSecurityException e) {
325                 throw new RuntimeException(e);
326               }
327             }));
328     futures.add(
329         threadPool.submit(
330             () -> {
331               try {
332                 // This thread mainly wants to do a key serializer registration, but we only have
333                 // one of those, since each needs either a new serialization class, or a new key
334                 // class. So first do a few parsing registrations to mix things up.
335                 for (int i = 0; i < REPETITIONS / 2; ++i) {
336                   registry.registerParametersParser(
337                       ParametersParser.create(
338                           MutableSerializationRegistryMultithreadTest::parseBToParameters1,
339                           Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()),
340                           TestSerializationB.class));
341                 }
342                 registry.registerParametersSerializer(
343                     ParametersSerializer.create(
344                         MutableSerializationRegistryMultithreadTest::serializeParameters2ToA,
345                         TestParameters2.class,
346                         TestSerializationA.class));
347                 registry.registerParametersSerializer(
348                     ParametersSerializer.create(
349                         MutableSerializationRegistryMultithreadTest::serializeParameters2ToB,
350                         TestParameters2.class,
351                         TestSerializationB.class));
352                 registry.registerParametersSerializer(
353                     ParametersSerializer.create(
354                         MutableSerializationRegistryMultithreadTest::serializeParameters1ToB,
355                         TestParameters1.class,
356                         TestSerializationB.class));
357               } catch (GeneralSecurityException e) {
358                 throw new RuntimeException(e);
359               }
360             }));
361 
362     futures.add(
363         threadPool.submit(
364             () -> {
365               try {
366                 for (int i = 0; i < REPETITIONS; ++i) {
367                   Object unused = registry.parseParameters(new TestSerializationA(A_1));
368                 }
369               } catch (GeneralSecurityException e) {
370                 throw new RuntimeException(e);
371               }
372             }));
373     futures.add(
374         threadPool.submit(
375             () -> {
376               try {
377                 for (int i = 0; i < REPETITIONS; ++i) {
378                   Object unused =
379                       registry.serializeParameters(new TestParameters1(), TestSerializationA.class);
380                 }
381               } catch (GeneralSecurityException e) {
382                 throw new RuntimeException(e);
383               }
384             }));
385 
386     threadPool.shutdown();
387     assertThat(threadPool.awaitTermination(300, SECONDS)).isTrue();
388     for (int i = 0; i < futures.size(); ++i) {
389       futures.get(i).get(); // This will throw an exception if the thread threw an exception.
390     }
391   }
392 }
393