1 // Copyright 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 //////////////////////////////////////////////////////////////////////////////// 16 17 package com.google.crypto.tink.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static java.util.concurrent.TimeUnit.SECONDS; 21 22 import com.google.crypto.tink.Key; 23 import com.google.crypto.tink.Parameters; 24 import com.google.errorprone.annotations.Immutable; 25 import java.security.GeneralSecurityException; 26 import java.util.ArrayList; 27 import java.util.List; 28 import java.util.concurrent.ExecutorService; 29 import java.util.concurrent.Executors; 30 import java.util.concurrent.Future; 31 import javax.annotation.Nullable; 32 import org.junit.Test; 33 import org.junit.runner.RunWith; 34 import org.junit.runners.JUnit4; 35 36 @RunWith(JUnit4.class) 37 public class MutablePrimitiveRegistryMultithreadTest { 38 @Immutable 39 private static final class TestKey1 extends Key { 40 @Override getParameters()41 public Parameters getParameters() { 42 throw new UnsupportedOperationException("Not needed in test"); 43 } 44 45 @Override 46 @Nullable getIdRequirementOrNull()47 public Integer getIdRequirementOrNull() { 48 throw new UnsupportedOperationException("Not needed in test"); 49 } 50 51 @Override equalsKey(Key other)52 public boolean equalsKey(Key other) { 53 throw new UnsupportedOperationException("Not needed in test"); 54 } 55 } 56 57 @Immutable 58 private static final class TestKey2 extends Key { 59 @Override getParameters()60 public Parameters getParameters() { 61 throw new UnsupportedOperationException("Not needed in test"); 62 } 63 64 @Override 65 @Nullable getIdRequirementOrNull()66 public Integer getIdRequirementOrNull() { 67 throw new UnsupportedOperationException("Not needed in test"); 68 } 69 70 @Override equalsKey(Key other)71 public boolean equalsKey(Key other) { 72 throw new UnsupportedOperationException("Not needed in test"); 73 } 74 } 75 76 @Immutable 77 private static final class TestPrimitiveA { TestPrimitiveA()78 public TestPrimitiveA() {} 79 } 80 81 @Immutable 82 private static final class TestPrimitiveB { TestPrimitiveB()83 public TestPrimitiveB() {} 84 } 85 getPrimitiveAKey1( MutablePrimitiveRegistryMultithreadTest.TestKey1 key)86 private static MutablePrimitiveRegistryMultithreadTest.TestPrimitiveA getPrimitiveAKey1( 87 MutablePrimitiveRegistryMultithreadTest.TestKey1 key) { 88 return new MutablePrimitiveRegistryMultithreadTest.TestPrimitiveA(); 89 } 90 getPrimitiveAKey2( MutablePrimitiveRegistryMultithreadTest.TestKey2 key)91 private static MutablePrimitiveRegistryMultithreadTest.TestPrimitiveA getPrimitiveAKey2( 92 MutablePrimitiveRegistryMultithreadTest.TestKey2 key) { 93 return new MutablePrimitiveRegistryMultithreadTest.TestPrimitiveA(); 94 } 95 getPrimitiveBKey1( MutablePrimitiveRegistryMultithreadTest.TestKey1 key)96 private static MutablePrimitiveRegistryMultithreadTest.TestPrimitiveB getPrimitiveBKey1( 97 MutablePrimitiveRegistryMultithreadTest.TestKey1 key) { 98 return new MutablePrimitiveRegistryMultithreadTest.TestPrimitiveB(); 99 } 100 getPrimitiveBKey2( MutablePrimitiveRegistryMultithreadTest.TestKey2 key)101 private static MutablePrimitiveRegistryMultithreadTest.TestPrimitiveB getPrimitiveBKey2( 102 MutablePrimitiveRegistryMultithreadTest.TestKey2 key) { 103 return new MutablePrimitiveRegistryMultithreadTest.TestPrimitiveB(); 104 } 105 106 private static final int REPETITIONS = 10000; 107 private static final int THREAD_NUMBER = 12; 108 109 @Test registerAndGetPrimitivesInParallel_works()110 public void registerAndGetPrimitivesInParallel_works() throws Exception { 111 MutablePrimitiveRegistry registry = new MutablePrimitiveRegistry(); 112 ExecutorService threadPool = Executors.newFixedThreadPool(THREAD_NUMBER); 113 List<Future<?>> futures = new ArrayList<>(); 114 registry.registerPrimitiveConstructor( 115 PrimitiveConstructor.create( 116 MutablePrimitiveRegistryMultithreadTest::getPrimitiveAKey1, 117 TestKey1.class, 118 TestPrimitiveA.class)); 119 120 // It's questionable how mixed up things are gonna be with so few registrations but 121 // registering many constructors would require around square as many of both key and 122 // primitive test classes created, and that's gonna be a serious code bloat. 123 futures.add( 124 threadPool.submit( 125 () -> { 126 try { 127 registry.registerPrimitiveConstructor( 128 PrimitiveConstructor.create( 129 MutablePrimitiveRegistryMultithreadTest::getPrimitiveAKey2, 130 TestKey2.class, 131 TestPrimitiveA.class)); 132 } catch (GeneralSecurityException e) { 133 throw new RuntimeException(e); 134 } 135 })); 136 futures.add( 137 threadPool.submit( 138 () -> { 139 try { 140 registry.registerPrimitiveConstructor( 141 PrimitiveConstructor.create( 142 MutablePrimitiveRegistryMultithreadTest::getPrimitiveBKey1, 143 TestKey1.class, 144 TestPrimitiveB.class)); 145 registry.registerPrimitiveConstructor( 146 PrimitiveConstructor.create( 147 MutablePrimitiveRegistryMultithreadTest::getPrimitiveBKey2, 148 TestKey2.class, 149 TestPrimitiveB.class)); 150 } catch (GeneralSecurityException e) { 151 throw new RuntimeException(e); 152 } 153 })); 154 // Thread pool size - the number of registration threads. 155 for (int i = 0; i < THREAD_NUMBER - 2; ++i) { 156 futures.add( 157 threadPool.submit( 158 () -> { 159 try { 160 for (int j = 0; j < REPETITIONS; ++j) { 161 TestPrimitiveA unused = 162 registry.getPrimitive(new TestKey1(), TestPrimitiveA.class); 163 } 164 } catch (GeneralSecurityException e) { 165 throw new RuntimeException(e); 166 } 167 })); 168 } 169 170 threadPool.shutdown(); 171 assertThat(threadPool.awaitTermination(300, SECONDS)).isTrue(); 172 for (Future<?> future : futures) { 173 future.get(); // This will throw an exception if the thread threw an exception. 174 } 175 } 176 } 177