1 /* 2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"). 5 * You may not use this file except in compliance with the License. 6 * A copy of the License is located at 7 * 8 * http://aws.amazon.com/apache2.0 9 * 10 * or in the "license" file accompanying this file. This file is distributed 11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 * express or implied. See the License for the specific language governing 13 * permissions and limitations under the License. 14 */ 15 16 package software.amazon.awssdk.services; 17 18 import static java.util.Collections.emptyList; 19 import static java.util.Collections.emptyMap; 20 import static java.util.Collections.singletonList; 21 import static java.util.Collections.singletonMap; 22 import static org.assertj.core.api.Assertions.assertThat; 23 import static org.mockito.ArgumentMatchers.eq; 24 import static org.mockito.Mockito.mock; 25 import static org.mockito.Mockito.times; 26 import static software.amazon.awssdk.profiles.ProfileFile.Type.CONFIGURATION; 27 28 import java.lang.reflect.Field; 29 import java.net.URI; 30 import java.time.Duration; 31 import java.util.HashMap; 32 import java.util.List; 33 import java.util.Map; 34 import java.util.Optional; 35 import java.util.concurrent.CompletableFuture; 36 import java.util.concurrent.ScheduledExecutorService; 37 import java.util.concurrent.atomic.AtomicInteger; 38 import java.util.function.BiConsumer; 39 import java.util.function.Supplier; 40 import java.util.stream.Stream; 41 import org.junit.jupiter.api.Disabled; 42 import org.junit.jupiter.params.ParameterizedTest; 43 import org.junit.jupiter.params.provider.MethodSource; 44 import org.mockito.Mockito; 45 import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; 46 import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; 47 import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; 48 import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; 49 import software.amazon.awssdk.awscore.client.config.AwsClientOption; 50 import software.amazon.awssdk.core.CompressionConfiguration; 51 import software.amazon.awssdk.core.RequestOverrideConfiguration; 52 import software.amazon.awssdk.core.SdkPlugin; 53 import software.amazon.awssdk.core.client.builder.SdkClientBuilder; 54 import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; 55 import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption; 56 import software.amazon.awssdk.core.client.config.SdkClientConfiguration; 57 import software.amazon.awssdk.core.client.config.SdkClientOption; 58 import software.amazon.awssdk.core.interceptor.Context; 59 import software.amazon.awssdk.core.interceptor.ExecutionAttribute; 60 import software.amazon.awssdk.core.interceptor.ExecutionAttributes; 61 import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; 62 import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; 63 import software.amazon.awssdk.core.retry.RetryMode; 64 import software.amazon.awssdk.core.retry.RetryPolicy; 65 import software.amazon.awssdk.endpoints.Endpoint; 66 import software.amazon.awssdk.http.SdkHttpClient; 67 import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; 68 import software.amazon.awssdk.http.auth.scheme.NoAuthAuthScheme; 69 import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; 70 import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; 71 import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; 72 import software.amazon.awssdk.identity.spi.IdentityProvider; 73 import software.amazon.awssdk.identity.spi.IdentityProviders; 74 import software.amazon.awssdk.metrics.MetricPublisher; 75 import software.amazon.awssdk.profiles.Profile; 76 import software.amazon.awssdk.profiles.ProfileFile; 77 import software.amazon.awssdk.profiles.ProfileProperty; 78 import software.amazon.awssdk.regions.Region; 79 import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient; 80 import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClientBuilder; 81 import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonServiceClientConfiguration; 82 import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeProvider; 83 import software.amazon.awssdk.services.protocolrestjson.endpoints.ProtocolRestJsonEndpointProvider; 84 import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient; 85 import software.amazon.awssdk.utils.ImmutableMap; 86 import software.amazon.awssdk.utils.Lazy; 87 import software.amazon.awssdk.utils.StringInputStream; 88 89 /** 90 * Verify that configuration changes made by plugins are reflected in the SDK client configuration used by the request, and 91 * that the plugin can see all SDK configuration options. 92 */ 93 public class SdkPluginTest { 94 private static final AwsCredentialsProvider DEFAULT_CREDENTIALS = () -> AwsBasicCredentials.create("akid", "skid"); 95 testCases()96 public static Stream<TestCase<?>> testCases() { 97 Map<String, AuthScheme<?>> defaultAuthSchemes = 98 ImmutableMap.of(AwsV4AuthScheme.SCHEME_ID, AwsV4AuthScheme.create(), 99 NoAuthAuthScheme.SCHEME_ID, NoAuthAuthScheme.create()); 100 Map<String, AuthScheme<?>> nonDefaultAuthSchemes = new HashMap<>(defaultAuthSchemes); 101 nonDefaultAuthSchemes.put(CustomAuthScheme.SCHEME_ID, new CustomAuthScheme()); 102 103 ScheduledExecutorService mockScheduledExecutor = mock(ScheduledExecutorService.class); 104 MetricPublisher mockMetricPublisher = mock(MetricPublisher.class); 105 106 String profileFileContent = 107 "[default]\n" 108 + ProfileProperty.USE_FIPS_ENDPOINT + " = true" 109 + "[profile some-profile]\n" 110 + ProfileProperty.USE_FIPS_ENDPOINT + " = false"; 111 112 ProfileFile nonDefaultProfileFile = 113 ProfileFile.builder() 114 .type(CONFIGURATION) 115 .content(new StringInputStream(profileFileContent)) 116 .build(); 117 118 return Stream.of( 119 new TestCase<URI>("endpointOverride") 120 .nonDefaultValue(URI.create("https://example.aws")) 121 .clientSetter(SdkClientBuilder::endpointOverride) 122 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::endpointOverride) 123 .pluginValidator((c, v) -> assertThat(v).isEqualTo(c.endpointOverride())) 124 .beforeTransmissionValidator((r, a, v) -> { 125 assertThat(v).isEqualTo(removePathAndQueryString(r.httpRequest().getUri())); 126 }), 127 new TestCase<ProtocolRestJsonEndpointProvider>("endpointProvider") 128 .defaultValue(ProtocolRestJsonEndpointProvider.defaultProvider()) 129 .nonDefaultValue(a -> CompletableFuture.completedFuture(Endpoint.builder() 130 .url(URI.create("https://example.aws")) 131 .build())) 132 .clientSetter(ProtocolRestJsonClientBuilder::endpointProvider) 133 .requestSetter(RequestOverrideConfiguration.Builder::endpointProvider) 134 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::endpointProvider) 135 .pluginValidator((c, v) -> assertThat(c.endpointProvider()).isEqualTo(v)) 136 .beforeTransmissionValidator((r, a, v) -> { 137 assertThat(removePathAndQueryString(r.httpRequest().getUri())) 138 .isEqualTo(v.resolveEndpoint(x -> {}).join().url()); 139 }), 140 new TestCase<Map<String, AuthScheme<?>>>("authSchemes") 141 .defaultValue(defaultAuthSchemes) 142 .nonDefaultValue(nonDefaultAuthSchemes) 143 .clientSetter((b, v) -> v.forEach((x, scheme) -> b.putAuthScheme(scheme))) 144 .pluginSetter((b, v) -> v.forEach((x, scheme) -> b.putAuthScheme(scheme))) 145 .pluginValidator((c, v) -> v.forEach((id, s) -> assertThat(c.authSchemes()).containsEntry(id, s))) 146 .beforeTransmissionValidator((r, a, v) -> v.forEach((id, s) -> { 147 assertThat(a.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES)).containsEntry(id, s); 148 })), 149 new TestCase<Region>("region") 150 .defaultValue(Region.US_WEST_2) 151 .nonDefaultValue(Region.US_EAST_1) 152 .clientSetter(AwsClientBuilder::region) 153 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::region) 154 .pluginValidator((c, v) -> assertThat(c.region()).isEqualTo(v)) 155 .beforeTransmissionValidator((r, a, v) -> { 156 assertThat(r.httpRequest() 157 .firstMatchingHeader("Authorization")).get() 158 .asString() 159 .contains(v.id()); 160 assertThat(r.httpRequest().getUri().getHost()).contains(v.id()); 161 }), 162 new TestCase<AwsCredentialsProvider>("credentialsProvider") 163 .defaultValue(DEFAULT_CREDENTIALS) 164 .nonDefaultValue(DEFAULT_CREDENTIALS::resolveCredentials) 165 .clientSetter(AwsClientBuilder::credentialsProvider) 166 .requestSetter(AwsRequestOverrideConfiguration.Builder::credentialsProvider) 167 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::credentialsProvider) 168 .pluginValidator((c, v) -> assertThat(c.credentialsProvider()).isEqualTo(v)) 169 .beforeTransmissionValidator((r, a, v) -> { 170 assertThat(r.httpRequest() 171 .firstMatchingHeader("Authorization")).get() 172 .asString() 173 .contains(v.resolveCredentials().accessKeyId()); 174 }), 175 new TestCase<ProtocolRestJsonAuthSchemeProvider>("authSchemeProvider") 176 .defaultValue(ProtocolRestJsonAuthSchemeProvider.defaultProvider()) 177 .nonDefaultValue(p -> singletonList(AuthSchemeOption.builder().schemeId(NoAuthAuthScheme.SCHEME_ID).build())) 178 .clientSetter(ProtocolRestJsonClientBuilder::authSchemeProvider) 179 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::authSchemeProvider) 180 .pluginValidator((c, v) -> assertThat(c.authSchemeProvider()).isEqualTo(v)) 181 .beforeTransmissionValidator((r, a, v) -> { 182 assertThat(a.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME) 183 .authSchemeOption() 184 .schemeId()).isEqualTo(NoAuthAuthScheme.SCHEME_ID); 185 assertThat(r.httpRequest().firstMatchingHeader("Authorization")).isNotPresent(); 186 }), 187 new TestCase<Map<String, List<String>>>("override.headers") 188 .defaultValue(emptyMap()) 189 .nonDefaultValue(singletonMap("foo", singletonList("bar"))) 190 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.headers(v))) 191 .requestSetter(AwsRequestOverrideConfiguration.Builder::headers) 192 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.headers(v)))) 193 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().headers()).isEqualTo(v)) 194 .beforeTransmissionValidator((r, a, v) -> { 195 v.forEach((key, value) -> assertThat(r.httpRequest().headers().get(key)).isEqualTo(value)); 196 }), 197 new TestCase<RetryPolicy>("override.retryPolicy") 198 .defaultValue(RetryPolicy.defaultRetryPolicy()) 199 .nonDefaultValue(RetryPolicy.builder(RetryMode.STANDARD).numRetries(1).build()) 200 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.retryPolicy(v))) 201 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.retryPolicy(v)))) 202 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().retryPolicy().get().numRetries()) 203 .isEqualTo(v.numRetries())) 204 .beforeTransmissionValidator((r, a, v) -> { 205 assertThat(r.httpRequest().firstMatchingHeader("amz-sdk-request")) 206 .hasValue("attempt=1; max=" + (v.numRetries() + 1)); 207 }), 208 new TestCase<List<ExecutionInterceptor>>("override.executionInterceptors") 209 .defaultValue(emptyList()) 210 .nonDefaultValue(singletonList(new FlagSettingInterceptor())) 211 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.executionInterceptors(v))) 212 .pluginSetter((b, v) -> { 213 b.overrideConfiguration(b.overrideConfiguration().copy(c -> v.forEach(c::addExecutionInterceptor))); 214 }) 215 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().executionInterceptors()).containsAll(v)) 216 .beforeTransmissionValidator((r, a, v) -> { 217 if (v.stream().anyMatch(i -> i instanceof FlagSettingInterceptor)) { 218 assertThat(a.getAttribute(FlagSettingInterceptor.FLAG)).isEqualTo(true); 219 } else { 220 assertThat(a.getAttribute(FlagSettingInterceptor.FLAG)).isNull(); 221 } 222 }), 223 new TestCase<ScheduledExecutorService>("override.scheduledExecutorService") 224 .defaultValue(null) 225 .nonDefaultValue(mockScheduledExecutor) 226 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.scheduledExecutorService(v))) 227 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.scheduledExecutorService(v)))) 228 .pluginValidator((c, v) -> { 229 Optional<ScheduledExecutorService> executor = c.overrideConfiguration().scheduledExecutorService(); 230 if (v != null) { 231 // The SDK should decorate the non-default-value. Ensure that's what happened. 232 Runnable runnable = () -> {}; 233 v.submit(runnable); 234 assertThat(v).isEqualTo(mockScheduledExecutor); 235 Mockito.verify(v, times(1)).submit(eq(runnable)); 236 } else { 237 // Null means we're using the default, and the default should be specified by the runtime. 238 assertThat(executor).isPresent(); 239 } 240 }) 241 .clientConfigurationValidator((c, v) -> { 242 ScheduledExecutorService configuredService = c.option(SdkClientOption.SCHEDULED_EXECUTOR_SERVICE); 243 if (mockScheduledExecutor.equals(v)) { 244 // The SDK should decorate the non-default-value. Ensure that's what happened. 245 Runnable runnable = () -> {}; 246 configuredService.submit(runnable); 247 assertThat(v).isEqualTo(mockScheduledExecutor); 248 Mockito.verify(v, times(1)).submit(eq(runnable)); 249 } else { 250 assertThat(configuredService).isNotNull(); 251 } 252 }), 253 new TestCase<Map<SdkAdvancedClientOption<?>, ?>>("override.advancedOptions") 254 .defaultValue(emptyMap()) 255 .nonDefaultValue(singletonMap(SdkAdvancedClientOption.USER_AGENT_PREFIX, "foo")) 256 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.advancedOptions(v))) 257 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> { 258 v.forEach((option, value) -> unsafePutOption(c, option, value)); 259 }))) 260 .pluginValidator((c, v) -> { 261 v.forEach((o, ov) -> assertThat(c.overrideConfiguration().advancedOption(o).orElse(null)).isEqualTo(ov)); 262 }) 263 .clientConfigurationValidator((c, v) -> v.forEach((o, ov) -> assertThat(c.option(o)).isEqualTo(ov))), 264 new TestCase<Duration>("override.apiCallTimeout") 265 .defaultValue(null) 266 .nonDefaultValue(Duration.ofSeconds(5)) 267 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.apiCallTimeout(v))) 268 .requestSetter(AwsRequestOverrideConfiguration.Builder::apiCallTimeout) 269 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.apiCallTimeout(v)))) 270 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().apiCallTimeout().orElse(null)).isEqualTo(v)) 271 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.API_CALL_TIMEOUT)).isEqualTo(v)), 272 new TestCase<Duration>("override.apiCallAttemptTimeout") 273 .defaultValue(null) 274 .nonDefaultValue(Duration.ofSeconds(3)) 275 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.apiCallAttemptTimeout(v))) 276 .requestSetter(AwsRequestOverrideConfiguration.Builder::apiCallAttemptTimeout) 277 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.apiCallAttemptTimeout(v)))) 278 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().apiCallAttemptTimeout().orElse(null)).isEqualTo(v)) 279 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.API_CALL_ATTEMPT_TIMEOUT)).isEqualTo(v)), 280 new TestCase<Supplier<ProfileFile>>("override.defaultProfileFileSupplier") 281 .defaultValue(new Lazy<>(ProfileFile::defaultProfileFile)::getValue) 282 .nonDefaultValue(() -> nonDefaultProfileFile) 283 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.defaultProfileFileSupplier(v))) 284 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.defaultProfileFileSupplier(v)))) 285 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().defaultProfileFileSupplier().get().get()).isEqualTo(v.get())) 286 .clientConfigurationValidator((c, v) -> { 287 Supplier<ProfileFile> supplier = c.option(SdkClientOption.PROFILE_FILE_SUPPLIER); 288 assertThat(supplier.get()).isEqualTo(v.get()); 289 290 Optional<Profile> defaultProfile = v.get().profile("default"); 291 defaultProfile.ifPresent(profile -> { 292 profile.booleanProperty(ProfileProperty.USE_FIPS_ENDPOINT).ifPresent(d -> { 293 assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isEqualTo(d); 294 }); 295 }); 296 if (!defaultProfile.isPresent()) { 297 assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isIn(null, false); 298 } 299 }), 300 new TestCase<ProfileFile>("override.defaultProfileFile") 301 .defaultValue(ProfileFile.defaultProfileFile()) 302 .nonDefaultValue(nonDefaultProfileFile) 303 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.defaultProfileFile(v))) 304 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.defaultProfileFile(v)))) 305 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().defaultProfileFile()).hasValue(v)) 306 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.PROFILE_FILE)).isEqualTo(v)), 307 new TestCase<String>("override.defaultProfileName") 308 .defaultValue("default") 309 .nonDefaultValue("some-profile") 310 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.defaultProfileName(v) 311 .defaultProfileFile(nonDefaultProfileFile))) 312 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.defaultProfileName(v) 313 .defaultProfileFile(nonDefaultProfileFile)))) 314 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().defaultProfileName().orElse(null)).isEqualTo(v)) 315 .clientConfigurationValidator((c, v) -> { 316 assertThat(c.option(SdkClientOption.PROFILE_NAME)).isEqualTo(v); 317 ProfileFile profileFile = c.option(SdkClientOption.PROFILE_FILE_SUPPLIER).get(); 318 319 Optional<Profile> configuredProfile = profileFile.profile(v); 320 configuredProfile.ifPresent(profile -> { 321 profile.booleanProperty(ProfileProperty.USE_FIPS_ENDPOINT).ifPresent(d -> { 322 assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isEqualTo(d); 323 }); 324 }); 325 if (!configuredProfile.isPresent()) { 326 assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isIn(null, false); 327 } 328 }), 329 new TestCase<List<MetricPublisher>>("override.metricPublishers") 330 .defaultValue(emptyList()) 331 .nonDefaultValue(singletonList(mockMetricPublisher)) 332 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.metricPublishers(v))) 333 .requestSetter(AwsRequestOverrideConfiguration.Builder::metricPublishers) 334 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.metricPublishers(v)))) 335 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().metricPublishers()).isEqualTo(v)) 336 .clientConfigurationValidator((c, v) -> { 337 assertThat(c.option(SdkClientOption.METRIC_PUBLISHERS)).containsAll(v); 338 }), 339 new TestCase<ExecutionAttributes>("override.executionAttributes") 340 .defaultValue(new ExecutionAttributes()) 341 .nonDefaultValue(new ExecutionAttributes().putAttribute(FlagSettingInterceptor.FLAG, true)) 342 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.executionAttributes(v))) 343 .requestSetter(AwsRequestOverrideConfiguration.Builder::executionAttributes) 344 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.executionAttributes(v)))) 345 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().executionAttributes()).isEqualTo(v)) 346 .beforeTransmissionValidator((r, a, v) -> { 347 assertThat(a.getAttribute(FlagSettingInterceptor.FLAG)).isTrue(); 348 }), 349 new TestCase<CompressionConfiguration>("override.compressionConfiguration") 350 .defaultValue(CompressionConfiguration.builder() 351 .requestCompressionEnabled(true) 352 .minimumCompressionThresholdInBytes(10_240) 353 .build()) 354 .nonDefaultValue(CompressionConfiguration.builder() 355 .requestCompressionEnabled(true) 356 .minimumCompressionThresholdInBytes(1) 357 .build()) 358 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.compressionConfiguration(v))) 359 .requestSetter(AwsRequestOverrideConfiguration.Builder::compressionConfiguration) 360 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.compressionConfiguration(v)))) 361 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().compressionConfiguration().orElse(null)).isEqualTo(v)) 362 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.COMPRESSION_CONFIGURATION)).isEqualTo(v)) 363 ); 364 } 365 unsafePutOption(ClientOverrideConfiguration.Builder config, SdkAdvancedClientOption<T> option, Object value)366 private static <T> void unsafePutOption(ClientOverrideConfiguration.Builder config, 367 SdkAdvancedClientOption<T> option, 368 Object value) { 369 config.putAdvancedOption(option, option.convertValue(value)); 370 371 } 372 373 @ParameterizedTest 374 @MethodSource("testCases") validateTestCaseData(TestCase<T> testCase)375 public <T> void validateTestCaseData(TestCase<T> testCase) { 376 assertThat(testCase.defaultValue).isNotEqualTo(testCase.nonDefaultValue); 377 } 378 379 @ParameterizedTest 380 @MethodSource("testCases") clientPluginSeesDefaultValue(TestCase<T> testCase)381 public <T> void clientPluginSeesDefaultValue(TestCase<T> testCase) { 382 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 383 384 AtomicInteger timesCalled = new AtomicInteger(0); 385 SdkPlugin plugin = config -> { 386 ProtocolRestJsonServiceClientConfiguration.Builder conf = 387 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 388 testCase.pluginValidator.accept(conf, testCase.defaultValue); 389 timesCalled.incrementAndGet(); 390 }; 391 392 ProtocolRestJsonClient client = clientBuilder.addPlugin(plugin).build(); 393 if (testCase.clientConfigurationValidator != null) { 394 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue); 395 } 396 assertThat(timesCalled).hasValue(1); 397 } 398 399 @ParameterizedTest 400 @MethodSource("testCases") requestPluginSeesDefaultValue(TestCase<T> testCase)401 public <T> void requestPluginSeesDefaultValue(TestCase<T> testCase) { 402 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 403 404 AtomicInteger timesCalled = new AtomicInteger(0); 405 SdkPlugin plugin = config -> { 406 ProtocolRestJsonServiceClientConfiguration.Builder conf = 407 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 408 testCase.pluginValidator.accept(conf, testCase.defaultValue); 409 timesCalled.incrementAndGet(); 410 }; 411 412 ProtocolRestJsonClient client = clientBuilder.httpClient(succeedingHttpClient()).build(); 413 if (testCase.clientConfigurationValidator != null) { 414 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue); 415 } 416 client.allTypes(r -> r.overrideConfiguration(c -> c.addPlugin(plugin))); 417 assertThat(timesCalled).hasValue(1); 418 } 419 420 @ParameterizedTest 421 @MethodSource("testCases") clientPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase)422 public <T> void clientPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase) { 423 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 424 testCase.clientSetter.accept(clientBuilder, testCase.nonDefaultValue); 425 426 AtomicInteger timesCalled = new AtomicInteger(0); 427 SdkPlugin plugin = config -> { 428 ProtocolRestJsonServiceClientConfiguration.Builder conf = 429 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 430 testCase.pluginValidator.accept(conf, testCase.nonDefaultValue); 431 timesCalled.incrementAndGet(); 432 }; 433 434 ProtocolRestJsonClient client = clientBuilder.addPlugin(plugin).build(); 435 436 if (testCase.clientConfigurationValidator != null) { 437 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.nonDefaultValue); 438 } 439 440 assertThat(timesCalled).hasValue(1); 441 } 442 443 @ParameterizedTest 444 @MethodSource("testCases") requestPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase)445 public <T> void requestPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase) { 446 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 447 testCase.clientSetter.accept(clientBuilder, testCase.nonDefaultValue); 448 449 AtomicInteger timesCalled = new AtomicInteger(0); 450 SdkPlugin plugin = config -> { 451 ProtocolRestJsonServiceClientConfiguration.Builder conf = 452 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 453 testCase.pluginValidator.accept(conf, testCase.nonDefaultValue); 454 timesCalled.incrementAndGet(); 455 }; 456 457 ProtocolRestJsonClient client = clientBuilder.httpClient(succeedingHttpClient()).build(); 458 if (testCase.clientConfigurationValidator != null) { 459 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.nonDefaultValue); 460 } 461 client.allTypes(r -> r.overrideConfiguration(c -> c.addPlugin(plugin))); 462 assertThat(timesCalled).hasValue(1); 463 } 464 465 @ParameterizedTest 466 @MethodSource("testCases") 467 @Disabled("Request-level values are currently higher-priority than plugin settings.") // TODO(sra-identity-auth) requestPluginSeesCustomerRequestConfiguredValue(TestCase<T> testCase)468 public <T> void requestPluginSeesCustomerRequestConfiguredValue(TestCase<T> testCase) { 469 if (testCase.requestSetter == null) { 470 System.out.println("No request setting available."); 471 return; 472 } 473 474 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 475 476 AtomicInteger timesCalled = new AtomicInteger(0); 477 SdkPlugin plugin = config -> { 478 ProtocolRestJsonServiceClientConfiguration.Builder conf = 479 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 480 testCase.pluginValidator.accept(conf, testCase.nonDefaultValue); 481 timesCalled.incrementAndGet(); 482 }; 483 484 AwsRequestOverrideConfiguration overrideConfig = 485 AwsRequestOverrideConfiguration.builder() 486 .addPlugin(plugin) 487 .applyMutation(c -> testCase.requestSetter.accept(c, testCase.nonDefaultValue)) 488 .build(); 489 490 ProtocolRestJsonClient client = clientBuilder.httpClient(succeedingHttpClient()).build(); 491 492 if (testCase.clientConfigurationValidator != null) { 493 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue); 494 } 495 496 client.allTypes(r -> r.overrideConfiguration(overrideConfig)); 497 assertThat(timesCalled).hasValue(1); 498 } 499 500 @ParameterizedTest 501 @MethodSource("testCases") clientPluginSetValueIsUsed(TestCase<T> testCase)502 public <T> void clientPluginSetValueIsUsed(TestCase<T> testCase) { 503 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 504 testCase.clientSetter.accept(clientBuilder, testCase.defaultValue); 505 506 AtomicInteger timesPluginCalled = new AtomicInteger(0); 507 SdkPlugin plugin = config -> { 508 timesPluginCalled.incrementAndGet(); 509 ProtocolRestJsonServiceClientConfiguration.Builder conf = 510 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 511 testCase.pluginSetter.accept(conf, testCase.nonDefaultValue); 512 }; 513 514 AtomicInteger timesInterceptorCalled = new AtomicInteger(0); 515 ExecutionInterceptor validatingInterceptor = new ExecutionInterceptor() { 516 @Override 517 public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { 518 timesInterceptorCalled.getAndIncrement(); 519 if (testCase.beforeTransmissionValidator != null) { 520 testCase.beforeTransmissionValidator.accept(context, executionAttributes, testCase.nonDefaultValue); 521 } 522 } 523 }; 524 525 ProtocolRestJsonClient client = 526 clientBuilder.httpClient(succeedingHttpClient()) 527 .addPlugin(plugin) 528 .overrideConfiguration(c -> c.addExecutionInterceptor(validatingInterceptor)) 529 .build(); 530 531 if (testCase.clientConfigurationValidator != null) { 532 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.nonDefaultValue); 533 } 534 535 client.allTypes(); 536 537 assertThat(timesPluginCalled).hasValue(1); 538 assertThat(timesInterceptorCalled).hasValueGreaterThanOrEqualTo(1); 539 } 540 541 @ParameterizedTest 542 @MethodSource("testCases") requestPluginSetValueIsUsed(TestCase<T> testCase)543 public <T> void requestPluginSetValueIsUsed(TestCase<T> testCase) { 544 ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder(); 545 testCase.clientSetter.accept(clientBuilder, testCase.defaultValue); 546 547 AtomicInteger timesPluginCalled = new AtomicInteger(0); 548 SdkPlugin plugin = config -> { 549 timesPluginCalled.incrementAndGet(); 550 ProtocolRestJsonServiceClientConfiguration.Builder conf = 551 (ProtocolRestJsonServiceClientConfiguration.Builder) config; 552 testCase.pluginSetter.accept(conf, testCase.nonDefaultValue); 553 }; 554 555 AtomicInteger timesInterceptorCalled = new AtomicInteger(0); 556 ExecutionInterceptor validatingInterceptor = new ExecutionInterceptor() { 557 @Override 558 public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { 559 timesInterceptorCalled.incrementAndGet(); 560 if (testCase.beforeTransmissionValidator != null) { 561 testCase.beforeTransmissionValidator.accept(context, executionAttributes, testCase.nonDefaultValue); 562 } 563 } 564 }; 565 566 AwsRequestOverrideConfiguration requestConfig = 567 AwsRequestOverrideConfiguration.builder() 568 .addPlugin(plugin) 569 .applyMutation(c -> { 570 // TODO(sra-identity-auth): request-level plugins should override request-level 571 // configuration 572 // if (testCase.requestSetter != null) { 573 // testCase.requestSetter.accept(c, testCase.defaultValue); 574 // } 575 }) 576 .build(); 577 578 ProtocolRestJsonClient client = 579 clientBuilder.httpClient(succeedingHttpClient()) 580 .overrideConfiguration(c -> c.addExecutionInterceptor(validatingInterceptor)) 581 .build(); 582 583 if (testCase.clientConfigurationValidator != null) { 584 testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue); 585 } 586 587 client.allTypes(r -> r.overrideConfiguration(requestConfig)); 588 589 assertThat(timesPluginCalled).hasValue(1); 590 assertThat(timesInterceptorCalled).hasValueGreaterThanOrEqualTo(1); 591 } 592 defaultClientBuilder()593 private static ProtocolRestJsonClientBuilder defaultClientBuilder() { 594 return ProtocolRestJsonClient.builder().region(Region.US_WEST_2).credentialsProvider(DEFAULT_CREDENTIALS); 595 } 596 extractClientConfiguration(ProtocolRestJsonClient client)597 private SdkClientConfiguration extractClientConfiguration(ProtocolRestJsonClient client) { 598 try { 599 // Naughty, but we need to be able to verify some things that can't be easily observed with unprotected means. 600 Class<? extends ProtocolRestJsonClient> clientClass = client.getClass(); 601 Field configField = clientClass.getDeclaredField("clientConfiguration"); 602 configField.setAccessible(true); 603 return (SdkClientConfiguration) configField.get(client); 604 } catch (Exception e) { 605 throw new RuntimeException(e); 606 } 607 } 608 609 static class TestCase<T> { 610 private final String configName; 611 T defaultValue; 612 T nonDefaultValue; 613 BiConsumer<ProtocolRestJsonClientBuilder, T> clientSetter; 614 BiConsumer<AwsRequestOverrideConfiguration.Builder, T> requestSetter; 615 BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginSetter; 616 617 BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginValidator; 618 BiConsumer<SdkClientConfiguration, T> clientConfigurationValidator; 619 TriConsumer<Context.BeforeTransmission, ExecutionAttributes, T> beforeTransmissionValidator; 620 TestCase(String configName)621 TestCase(String configName) { 622 this.configName = configName; 623 } 624 defaultValue(T defaultValue)625 public TestCase<T> defaultValue(T defaultValue) { 626 this.defaultValue = defaultValue; 627 return this; 628 } 629 nonDefaultValue(T nonDefaultValue)630 public TestCase<T> nonDefaultValue(T nonDefaultValue) { 631 this.nonDefaultValue = nonDefaultValue; 632 return this; 633 } 634 clientSetter(BiConsumer<ProtocolRestJsonClientBuilder, T> clientSetter)635 public TestCase<T> clientSetter(BiConsumer<ProtocolRestJsonClientBuilder, T> clientSetter) { 636 this.clientSetter = clientSetter; 637 return this; 638 } 639 requestSetter(BiConsumer<AwsRequestOverrideConfiguration.Builder, T> requestSetter)640 public TestCase<T> requestSetter(BiConsumer<AwsRequestOverrideConfiguration.Builder, T> requestSetter) { 641 this.requestSetter = requestSetter; 642 return this; 643 } 644 pluginSetter(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginSetter)645 public TestCase<T> pluginSetter(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginSetter) { 646 this.pluginSetter = pluginSetter; 647 return this; 648 } 649 pluginValidator(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginValidator)650 public TestCase<T> pluginValidator(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginValidator) { 651 this.pluginValidator = pluginValidator; 652 return this; 653 } 654 clientConfigurationValidator(BiConsumer<SdkClientConfiguration, T> clientConfigurationValidator)655 public TestCase<T> clientConfigurationValidator(BiConsumer<SdkClientConfiguration, T> clientConfigurationValidator) { 656 this.clientConfigurationValidator = clientConfigurationValidator; 657 return this; 658 } 659 beforeTransmissionValidator(TriConsumer<Context.BeforeTransmission, ExecutionAttributes, T> beforeTransmissionValidator)660 public TestCase<T> beforeTransmissionValidator(TriConsumer<Context.BeforeTransmission, ExecutionAttributes, T> beforeTransmissionValidator) { 661 this.beforeTransmissionValidator = beforeTransmissionValidator; 662 return this; 663 } 664 665 @Override toString()666 public String toString() { 667 return configName; 668 } 669 } 670 succeedingHttpClient()671 private static SdkHttpClient succeedingHttpClient() { 672 MockSyncHttpClient client = new MockSyncHttpClient(); 673 client.stubNextResponse200(); 674 return client; 675 } 676 removePathAndQueryString(URI uri)677 private static URI removePathAndQueryString(URI uri) { 678 String uriString = uri.toString(); 679 return URI.create(uriString.substring(0, uriString.indexOf('/', "https://".length()))); 680 } 681 682 public interface TriConsumer<T, U, V> { accept(T t, U u, V v)683 void accept(T t, U u, V v); 684 } 685 686 private static class CustomAuthScheme implements AuthScheme<NoAuthAuthScheme.AnonymousIdentity> { 687 private static final String SCHEME_ID = "foo"; 688 private static final AuthScheme<NoAuthAuthScheme.AnonymousIdentity> DELEGATE = NoAuthAuthScheme.create(); 689 690 @Override schemeId()691 public String schemeId() { 692 return SCHEME_ID; 693 } 694 695 @Override identityProvider(IdentityProviders providers)696 public IdentityProvider<NoAuthAuthScheme.AnonymousIdentity> identityProvider(IdentityProviders providers) { 697 return DELEGATE.identityProvider(providers); 698 } 699 700 @Override signer()701 public HttpSigner<NoAuthAuthScheme.AnonymousIdentity> signer() { 702 return DELEGATE.signer(); 703 } 704 } 705 706 private static class FlagSettingInterceptor implements ExecutionInterceptor { 707 private static final ExecutionAttribute<Boolean> FLAG = new ExecutionAttribute<>("InterceptorAdded"); 708 709 @Override beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes)710 public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { 711 executionAttributes.putAttribute(FLAG, true); 712 } 713 } 714 } 715