1 // Copyright 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 //////////////////////////////////////////////////////////////////////////////// 16 17 package com.google.crypto.tink.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static java.nio.charset.StandardCharsets.UTF_8; 21 import static java.util.concurrent.TimeUnit.SECONDS; 22 23 import com.google.crypto.tink.InsecureSecretKeyAccess; 24 import com.google.crypto.tink.Key; 25 import com.google.crypto.tink.Parameters; 26 import com.google.crypto.tink.SecretKeyAccess; 27 import com.google.crypto.tink.util.Bytes; 28 import com.google.errorprone.annotations.Immutable; 29 import java.nio.ByteBuffer; 30 import java.security.GeneralSecurityException; 31 import java.util.ArrayList; 32 import java.util.List; 33 import java.util.concurrent.ExecutorService; 34 import java.util.concurrent.Executors; 35 import java.util.concurrent.Future; 36 import javax.annotation.Nullable; 37 import org.junit.Test; 38 import org.junit.runner.RunWith; 39 import org.junit.runners.JUnit4; 40 41 /** Thread safety tests for {@link MutableSerializationRegistry}. */ 42 @RunWith(JUnit4.class) 43 public final class MutableSerializationRegistryMultithreadTest { 44 private static final SecretKeyAccess ACCESS = InsecureSecretKeyAccess.get(); 45 46 private static final Bytes A_1 = Bytes.copyFrom("0".getBytes(UTF_8)); 47 private static final Bytes A_2 = Bytes.copyFrom("1".getBytes(UTF_8)); 48 private static final Bytes B_1 = Bytes.copyFrom("1".getBytes(UTF_8)); 49 private static final Bytes B_2 = Bytes.copyFrom("2".getBytes(UTF_8)); 50 51 @Immutable 52 private static final class TestParameters1 extends Parameters { 53 @Override hasIdRequirement()54 public boolean hasIdRequirement() { 55 return false; 56 } 57 } 58 59 @Immutable 60 private static final class TestParameters2 extends Parameters { 61 @Override hasIdRequirement()62 public boolean hasIdRequirement() { 63 return false; 64 } 65 } 66 67 @Immutable 68 private static final class TestKey1 extends Key { 69 @Override getParameters()70 public Parameters getParameters() { 71 throw new UnsupportedOperationException("Not needed in test"); 72 } 73 74 @Override 75 @Nullable getIdRequirementOrNull()76 public Integer getIdRequirementOrNull() { 77 throw new UnsupportedOperationException("Not needed in test"); 78 } 79 80 @Override equalsKey(Key other)81 public boolean equalsKey(Key other) { 82 throw new UnsupportedOperationException("Not needed in test"); 83 } 84 } 85 86 @Immutable 87 private static final class TestKey2 extends Key { 88 @Override getParameters()89 public Parameters getParameters() { 90 throw new UnsupportedOperationException("Not needed in test"); 91 } 92 93 @Override 94 @Nullable getIdRequirementOrNull()95 public Integer getIdRequirementOrNull() { 96 throw new UnsupportedOperationException("Not needed in test"); 97 } 98 99 @Override equalsKey(Key other)100 public boolean equalsKey(Key other) { 101 throw new UnsupportedOperationException("Not needed in test"); 102 } 103 } 104 105 @Immutable 106 private static final class TestSerializationA implements Serialization { TestSerializationA(Bytes objectIdentifier)107 public TestSerializationA(Bytes objectIdentifier) { 108 this.objectIdentifier = objectIdentifier; 109 } 110 111 private final Bytes objectIdentifier; 112 113 @Override getObjectIdentifier()114 public Bytes getObjectIdentifier() { 115 return objectIdentifier; 116 } 117 } 118 119 @Immutable 120 private static final class TestSerializationB implements Serialization { TestSerializationB(Bytes objectIdentifier)121 public TestSerializationB(Bytes objectIdentifier) { 122 this.objectIdentifier = objectIdentifier; 123 } 124 125 private final Bytes objectIdentifier; 126 127 @Override getObjectIdentifier()128 public Bytes getObjectIdentifier() { 129 return objectIdentifier; 130 } 131 } 132 serializeKey1ToA(TestKey1 key, @Nullable SecretKeyAccess access)133 private static TestSerializationA serializeKey1ToA(TestKey1 key, @Nullable SecretKeyAccess access) 134 throws GeneralSecurityException { 135 SecretKeyAccess.requireAccess(access); 136 return new TestSerializationA(A_1); 137 } 138 serializeKey2ToA(TestKey2 key, @Nullable SecretKeyAccess access)139 private static TestSerializationA serializeKey2ToA(TestKey2 key, @Nullable SecretKeyAccess access) 140 throws GeneralSecurityException { 141 SecretKeyAccess.requireAccess(access); 142 return new TestSerializationA(A_2); 143 } 144 serializeKey1ToB(TestKey1 key, @Nullable SecretKeyAccess access)145 private static TestSerializationB serializeKey1ToB(TestKey1 key, @Nullable SecretKeyAccess access) 146 throws GeneralSecurityException { 147 SecretKeyAccess.requireAccess(access); 148 return new TestSerializationB(B_1); 149 } 150 serializeKey2ToB(TestKey2 key, @Nullable SecretKeyAccess access)151 private static TestSerializationB serializeKey2ToB(TestKey2 key, @Nullable SecretKeyAccess access) 152 throws GeneralSecurityException { 153 SecretKeyAccess.requireAccess(access); 154 return new TestSerializationB(B_2); 155 } 156 parseAToKey1( TestSerializationA serialization, @Nullable SecretKeyAccess access)157 private static Key parseAToKey1( 158 TestSerializationA serialization, @Nullable SecretKeyAccess access) 159 throws GeneralSecurityException { 160 if (!A_1.equals(serialization.getObjectIdentifier())) { 161 throw new GeneralSecurityException("Wrong object identifier"); 162 } 163 SecretKeyAccess.requireAccess(access); 164 return new TestKey1(); 165 } 166 parseBToKey1( TestSerializationB serialization, @Nullable SecretKeyAccess access)167 private static Key parseBToKey1( 168 TestSerializationB serialization, @Nullable SecretKeyAccess access) 169 throws GeneralSecurityException { 170 if (!B_1.equals(serialization.getObjectIdentifier())) { 171 throw new GeneralSecurityException("Wrong object identifier"); 172 } 173 SecretKeyAccess.requireAccess(access); 174 return new TestKey1(); 175 } 176 serializeParameters1ToA(TestParameters1 parameters)177 private static TestSerializationA serializeParameters1ToA(TestParameters1 parameters) 178 throws GeneralSecurityException { 179 return new TestSerializationA(A_1); 180 } 181 serializeParameters2ToA(TestParameters2 parameters)182 private static TestSerializationA serializeParameters2ToA(TestParameters2 parameters) 183 throws GeneralSecurityException { 184 return new TestSerializationA(A_2); 185 } 186 serializeParameters1ToB(TestParameters1 parameters)187 private static TestSerializationB serializeParameters1ToB(TestParameters1 parameters) 188 throws GeneralSecurityException { 189 return new TestSerializationB(B_1); 190 } 191 serializeParameters2ToB(TestParameters2 parameters)192 private static TestSerializationB serializeParameters2ToB(TestParameters2 parameters) 193 throws GeneralSecurityException { 194 return new TestSerializationB(B_2); 195 } 196 parseAToParameters1(TestSerializationA serialization)197 private static Parameters parseAToParameters1(TestSerializationA serialization) 198 throws GeneralSecurityException { 199 if (!A_1.equals(serialization.getObjectIdentifier())) { 200 throw new GeneralSecurityException("Wrong object identifier"); 201 } 202 return new TestParameters1(); 203 } 204 parseBToParameters1(TestSerializationB serialization)205 private static Parameters parseBToParameters1(TestSerializationB serialization) 206 throws GeneralSecurityException { 207 if (!B_1.equals(serialization.getObjectIdentifier())) { 208 throw new GeneralSecurityException("Wrong object identifier"); 209 } 210 return new TestParameters1(); 211 } 212 213 private static final int REPETITIONS = 1000; 214 215 @Test registerAndParseAndSerializeInParallel_works()216 public void registerAndParseAndSerializeInParallel_works() throws Exception { 217 MutableSerializationRegistry registry = new MutableSerializationRegistry(); 218 ExecutorService threadPool = Executors.newFixedThreadPool(4); 219 List<Future<?>> futures = new ArrayList<>(); 220 registry.registerKeySerializer( 221 KeySerializer.create( 222 MutableSerializationRegistryMultithreadTest::serializeKey1ToA, 223 TestKey1.class, 224 TestSerializationA.class)); 225 registry.registerKeyParser( 226 KeyParser.create( 227 MutableSerializationRegistryMultithreadTest::parseAToKey1, 228 A_1, 229 TestSerializationA.class)); 230 registry.registerParametersSerializer( 231 ParametersSerializer.create( 232 MutableSerializationRegistryMultithreadTest::serializeParameters1ToA, 233 TestParameters1.class, 234 TestSerializationA.class)); 235 registry.registerParametersParser( 236 ParametersParser.create( 237 MutableSerializationRegistryMultithreadTest::parseAToParameters1, 238 A_1, 239 TestSerializationA.class)); 240 241 futures.add( 242 threadPool.submit( 243 () -> { 244 try { 245 for (int i = 0; i < REPETITIONS; ++i) { 246 registry.registerKeyParser( 247 KeyParser.create( 248 MutableSerializationRegistryMultithreadTest::parseAToKey1, 249 Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()), 250 TestSerializationA.class)); 251 } 252 } catch (GeneralSecurityException e) { 253 throw new RuntimeException(e); 254 } 255 })); 256 futures.add( 257 threadPool.submit( 258 () -> { 259 try { 260 // This thread mainly wants to do a key serializer registration, but we only have 261 // one of those, since each needs either a new serialization class, or a new key 262 // class. So first do a few parsing registrations to mix things up. 263 for (int i = 0; i < REPETITIONS / 2; ++i) { 264 registry.registerKeyParser( 265 KeyParser.create( 266 MutableSerializationRegistryMultithreadTest::parseBToKey1, 267 Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()), 268 TestSerializationB.class)); 269 } 270 registry.registerKeySerializer( 271 KeySerializer.create( 272 MutableSerializationRegistryMultithreadTest::serializeKey2ToA, 273 TestKey2.class, 274 TestSerializationA.class)); 275 registry.registerKeySerializer( 276 KeySerializer.create( 277 MutableSerializationRegistryMultithreadTest::serializeKey2ToB, 278 TestKey2.class, 279 TestSerializationB.class)); 280 registry.registerKeySerializer( 281 KeySerializer.create( 282 MutableSerializationRegistryMultithreadTest::serializeKey1ToB, 283 TestKey1.class, 284 TestSerializationB.class)); 285 } catch (GeneralSecurityException e) { 286 throw new RuntimeException(e); 287 } 288 })); 289 futures.add( 290 threadPool.submit( 291 () -> { 292 try { 293 for (int i = 0; i < REPETITIONS; ++i) { 294 Object unused = registry.parseKey(new TestSerializationA(A_1), ACCESS); 295 } 296 } catch (GeneralSecurityException e) { 297 throw new RuntimeException(e); 298 } 299 })); 300 futures.add( 301 threadPool.submit( 302 () -> { 303 try { 304 for (int i = 0; i < REPETITIONS; ++i) { 305 Object unused = 306 registry.serializeKey(new TestKey1(), TestSerializationA.class, ACCESS); 307 } 308 } catch (GeneralSecurityException e) { 309 throw new RuntimeException(e); 310 } 311 })); 312 // =============================== More threads doing the same thing, this time for parameters. 313 futures.add( 314 threadPool.submit( 315 () -> { 316 try { 317 for (int i = 0; i < REPETITIONS; ++i) { 318 registry.registerParametersParser( 319 ParametersParser.create( 320 MutableSerializationRegistryMultithreadTest::parseAToParameters1, 321 Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()), 322 TestSerializationA.class)); 323 } 324 } catch (GeneralSecurityException e) { 325 throw new RuntimeException(e); 326 } 327 })); 328 futures.add( 329 threadPool.submit( 330 () -> { 331 try { 332 // This thread mainly wants to do a key serializer registration, but we only have 333 // one of those, since each needs either a new serialization class, or a new key 334 // class. So first do a few parsing registrations to mix things up. 335 for (int i = 0; i < REPETITIONS / 2; ++i) { 336 registry.registerParametersParser( 337 ParametersParser.create( 338 MutableSerializationRegistryMultithreadTest::parseBToParameters1, 339 Bytes.copyFrom(ByteBuffer.allocate(4).putInt(i).array()), 340 TestSerializationB.class)); 341 } 342 registry.registerParametersSerializer( 343 ParametersSerializer.create( 344 MutableSerializationRegistryMultithreadTest::serializeParameters2ToA, 345 TestParameters2.class, 346 TestSerializationA.class)); 347 registry.registerParametersSerializer( 348 ParametersSerializer.create( 349 MutableSerializationRegistryMultithreadTest::serializeParameters2ToB, 350 TestParameters2.class, 351 TestSerializationB.class)); 352 registry.registerParametersSerializer( 353 ParametersSerializer.create( 354 MutableSerializationRegistryMultithreadTest::serializeParameters1ToB, 355 TestParameters1.class, 356 TestSerializationB.class)); 357 } catch (GeneralSecurityException e) { 358 throw new RuntimeException(e); 359 } 360 })); 361 362 futures.add( 363 threadPool.submit( 364 () -> { 365 try { 366 for (int i = 0; i < REPETITIONS; ++i) { 367 Object unused = registry.parseParameters(new TestSerializationA(A_1)); 368 } 369 } catch (GeneralSecurityException e) { 370 throw new RuntimeException(e); 371 } 372 })); 373 futures.add( 374 threadPool.submit( 375 () -> { 376 try { 377 for (int i = 0; i < REPETITIONS; ++i) { 378 Object unused = 379 registry.serializeParameters(new TestParameters1(), TestSerializationA.class); 380 } 381 } catch (GeneralSecurityException e) { 382 throw new RuntimeException(e); 383 } 384 })); 385 386 threadPool.shutdown(); 387 assertThat(threadPool.awaitTermination(300, SECONDS)).isTrue(); 388 for (int i = 0; i < futures.size(); ++i) { 389 futures.get(i).get(); // This will throw an exception if the thread threw an exception. 390 } 391 } 392 } 393