• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.services;
17 
18 import static java.util.Collections.emptyList;
19 import static java.util.Collections.emptyMap;
20 import static java.util.Collections.singletonList;
21 import static java.util.Collections.singletonMap;
22 import static org.assertj.core.api.Assertions.assertThat;
23 import static org.mockito.ArgumentMatchers.eq;
24 import static org.mockito.Mockito.mock;
25 import static org.mockito.Mockito.times;
26 import static software.amazon.awssdk.profiles.ProfileFile.Type.CONFIGURATION;
27 
28 import java.lang.reflect.Field;
29 import java.net.URI;
30 import java.time.Duration;
31 import java.util.HashMap;
32 import java.util.List;
33 import java.util.Map;
34 import java.util.Optional;
35 import java.util.concurrent.CompletableFuture;
36 import java.util.concurrent.ScheduledExecutorService;
37 import java.util.concurrent.atomic.AtomicInteger;
38 import java.util.function.BiConsumer;
39 import java.util.function.Supplier;
40 import java.util.stream.Stream;
41 import org.junit.jupiter.api.Disabled;
42 import org.junit.jupiter.params.ParameterizedTest;
43 import org.junit.jupiter.params.provider.MethodSource;
44 import org.mockito.Mockito;
45 import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
46 import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
47 import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
48 import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
49 import software.amazon.awssdk.awscore.client.config.AwsClientOption;
50 import software.amazon.awssdk.core.CompressionConfiguration;
51 import software.amazon.awssdk.core.RequestOverrideConfiguration;
52 import software.amazon.awssdk.core.SdkPlugin;
53 import software.amazon.awssdk.core.client.builder.SdkClientBuilder;
54 import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
55 import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption;
56 import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
57 import software.amazon.awssdk.core.client.config.SdkClientOption;
58 import software.amazon.awssdk.core.interceptor.Context;
59 import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
60 import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
61 import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
62 import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
63 import software.amazon.awssdk.core.retry.RetryMode;
64 import software.amazon.awssdk.core.retry.RetryPolicy;
65 import software.amazon.awssdk.endpoints.Endpoint;
66 import software.amazon.awssdk.http.SdkHttpClient;
67 import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme;
68 import software.amazon.awssdk.http.auth.scheme.NoAuthAuthScheme;
69 import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme;
70 import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption;
71 import software.amazon.awssdk.http.auth.spi.signer.HttpSigner;
72 import software.amazon.awssdk.identity.spi.IdentityProvider;
73 import software.amazon.awssdk.identity.spi.IdentityProviders;
74 import software.amazon.awssdk.metrics.MetricPublisher;
75 import software.amazon.awssdk.profiles.Profile;
76 import software.amazon.awssdk.profiles.ProfileFile;
77 import software.amazon.awssdk.profiles.ProfileProperty;
78 import software.amazon.awssdk.regions.Region;
79 import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
80 import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClientBuilder;
81 import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonServiceClientConfiguration;
82 import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeProvider;
83 import software.amazon.awssdk.services.protocolrestjson.endpoints.ProtocolRestJsonEndpointProvider;
84 import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
85 import software.amazon.awssdk.utils.ImmutableMap;
86 import software.amazon.awssdk.utils.Lazy;
87 import software.amazon.awssdk.utils.StringInputStream;
88 
89 /**
90  * Verify that configuration changes made by plugins are reflected in the SDK client configuration used by the request, and
91  * that the plugin can see all SDK configuration options.
92  */
93 public class SdkPluginTest {
94     private static final AwsCredentialsProvider DEFAULT_CREDENTIALS = () -> AwsBasicCredentials.create("akid", "skid");
95 
testCases()96     public static Stream<TestCase<?>> testCases() {
97         Map<String, AuthScheme<?>> defaultAuthSchemes =
98             ImmutableMap.of(AwsV4AuthScheme.SCHEME_ID, AwsV4AuthScheme.create(),
99                             NoAuthAuthScheme.SCHEME_ID, NoAuthAuthScheme.create());
100         Map<String, AuthScheme<?>> nonDefaultAuthSchemes = new HashMap<>(defaultAuthSchemes);
101         nonDefaultAuthSchemes.put(CustomAuthScheme.SCHEME_ID, new CustomAuthScheme());
102 
103         ScheduledExecutorService mockScheduledExecutor = mock(ScheduledExecutorService.class);
104         MetricPublisher mockMetricPublisher = mock(MetricPublisher.class);
105 
106         String profileFileContent =
107             "[default]\n"
108             + ProfileProperty.USE_FIPS_ENDPOINT + " = true"
109             + "[profile some-profile]\n"
110             + ProfileProperty.USE_FIPS_ENDPOINT + " = false";
111 
112         ProfileFile nonDefaultProfileFile =
113             ProfileFile.builder()
114                        .type(CONFIGURATION)
115                        .content(new StringInputStream(profileFileContent))
116                        .build();
117 
118         return Stream.of(
119             new TestCase<URI>("endpointOverride")
120                 .nonDefaultValue(URI.create("https://example.aws"))
121                 .clientSetter(SdkClientBuilder::endpointOverride)
122                 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::endpointOverride)
123                 .pluginValidator((c, v) -> assertThat(v).isEqualTo(c.endpointOverride()))
124                 .beforeTransmissionValidator((r, a, v) -> {
125                     assertThat(v).isEqualTo(removePathAndQueryString(r.httpRequest().getUri()));
126                 }),
127             new TestCase<ProtocolRestJsonEndpointProvider>("endpointProvider")
128                 .defaultValue(ProtocolRestJsonEndpointProvider.defaultProvider())
129                 .nonDefaultValue(a -> CompletableFuture.completedFuture(Endpoint.builder()
130                                                                                 .url(URI.create("https://example.aws"))
131                                                                                 .build()))
132                 .clientSetter(ProtocolRestJsonClientBuilder::endpointProvider)
133                 .requestSetter(RequestOverrideConfiguration.Builder::endpointProvider)
134                 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::endpointProvider)
135                 .pluginValidator((c, v) -> assertThat(c.endpointProvider()).isEqualTo(v))
136                 .beforeTransmissionValidator((r, a, v) -> {
137                     assertThat(removePathAndQueryString(r.httpRequest().getUri()))
138                         .isEqualTo(v.resolveEndpoint(x -> {}).join().url());
139                 }),
140             new TestCase<Map<String, AuthScheme<?>>>("authSchemes")
141                 .defaultValue(defaultAuthSchemes)
142                 .nonDefaultValue(nonDefaultAuthSchemes)
143                 .clientSetter((b, v) -> v.forEach((x, scheme) -> b.putAuthScheme(scheme)))
144                 .pluginSetter((b, v) -> v.forEach((x, scheme) -> b.putAuthScheme(scheme)))
145                 .pluginValidator((c, v) -> v.forEach((id, s) -> assertThat(c.authSchemes()).containsEntry(id, s)))
146                 .beforeTransmissionValidator((r, a, v) -> v.forEach((id, s) -> {
147                     assertThat(a.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES)).containsEntry(id, s);
148                 })),
149             new TestCase<Region>("region")
150                 .defaultValue(Region.US_WEST_2)
151                 .nonDefaultValue(Region.US_EAST_1)
152                 .clientSetter(AwsClientBuilder::region)
153                 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::region)
154                 .pluginValidator((c, v) -> assertThat(c.region()).isEqualTo(v))
155                 .beforeTransmissionValidator((r, a, v) -> {
156                     assertThat(r.httpRequest()
157                                 .firstMatchingHeader("Authorization")).get()
158                                                                       .asString()
159                                                                       .contains(v.id());
160                     assertThat(r.httpRequest().getUri().getHost()).contains(v.id());
161                 }),
162             new TestCase<AwsCredentialsProvider>("credentialsProvider")
163                 .defaultValue(DEFAULT_CREDENTIALS)
164                 .nonDefaultValue(DEFAULT_CREDENTIALS::resolveCredentials)
165                 .clientSetter(AwsClientBuilder::credentialsProvider)
166                 .requestSetter(AwsRequestOverrideConfiguration.Builder::credentialsProvider)
167                 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::credentialsProvider)
168                 .pluginValidator((c, v) -> assertThat(c.credentialsProvider()).isEqualTo(v))
169                 .beforeTransmissionValidator((r, a, v) -> {
170                     assertThat(r.httpRequest()
171                                 .firstMatchingHeader("Authorization")).get()
172                                                                       .asString()
173                                                                       .contains(v.resolveCredentials().accessKeyId());
174                 }),
175             new TestCase<ProtocolRestJsonAuthSchemeProvider>("authSchemeProvider")
176                 .defaultValue(ProtocolRestJsonAuthSchemeProvider.defaultProvider())
177                 .nonDefaultValue(p -> singletonList(AuthSchemeOption.builder().schemeId(NoAuthAuthScheme.SCHEME_ID).build()))
178                 .clientSetter(ProtocolRestJsonClientBuilder::authSchemeProvider)
179                 .pluginSetter(ProtocolRestJsonServiceClientConfiguration.Builder::authSchemeProvider)
180                 .pluginValidator((c, v) -> assertThat(c.authSchemeProvider()).isEqualTo(v))
181                 .beforeTransmissionValidator((r, a, v) -> {
182                     assertThat(a.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)
183                                 .authSchemeOption()
184                                 .schemeId()).isEqualTo(NoAuthAuthScheme.SCHEME_ID);
185                     assertThat(r.httpRequest().firstMatchingHeader("Authorization")).isNotPresent();
186                 }),
187             new TestCase<Map<String, List<String>>>("override.headers")
188                 .defaultValue(emptyMap())
189                 .nonDefaultValue(singletonMap("foo", singletonList("bar")))
190                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.headers(v)))
191                 .requestSetter(AwsRequestOverrideConfiguration.Builder::headers)
192                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.headers(v))))
193                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().headers()).isEqualTo(v))
194                 .beforeTransmissionValidator((r, a, v) -> {
195                     v.forEach((key, value) -> assertThat(r.httpRequest().headers().get(key)).isEqualTo(value));
196                 }),
197             new TestCase<RetryPolicy>("override.retryPolicy")
198                 .defaultValue(RetryPolicy.defaultRetryPolicy())
199                 .nonDefaultValue(RetryPolicy.builder(RetryMode.STANDARD).numRetries(1).build())
200                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.retryPolicy(v)))
201                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.retryPolicy(v))))
202                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().retryPolicy().get().numRetries())
203                     .isEqualTo(v.numRetries()))
204                 .beforeTransmissionValidator((r, a, v) -> {
205                     assertThat(r.httpRequest().firstMatchingHeader("amz-sdk-request"))
206                         .hasValue("attempt=1; max=" + (v.numRetries() + 1));
207                 }),
208             new TestCase<List<ExecutionInterceptor>>("override.executionInterceptors")
209                 .defaultValue(emptyList())
210                 .nonDefaultValue(singletonList(new FlagSettingInterceptor()))
211                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.executionInterceptors(v)))
212                 .pluginSetter((b, v) -> {
213                     b.overrideConfiguration(b.overrideConfiguration().copy(c -> v.forEach(c::addExecutionInterceptor)));
214                 })
215                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().executionInterceptors()).containsAll(v))
216                 .beforeTransmissionValidator((r, a, v) -> {
217                     if (v.stream().anyMatch(i -> i instanceof FlagSettingInterceptor)) {
218                         assertThat(a.getAttribute(FlagSettingInterceptor.FLAG)).isEqualTo(true);
219                     } else {
220                         assertThat(a.getAttribute(FlagSettingInterceptor.FLAG)).isNull();
221                     }
222                 }),
223             new TestCase<ScheduledExecutorService>("override.scheduledExecutorService")
224                 .defaultValue(null)
225                 .nonDefaultValue(mockScheduledExecutor)
226                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.scheduledExecutorService(v)))
227                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.scheduledExecutorService(v))))
228                 .pluginValidator((c, v) -> {
229                     Optional<ScheduledExecutorService> executor = c.overrideConfiguration().scheduledExecutorService();
230                     if (v != null) {
231                         // The SDK should decorate the non-default-value. Ensure that's what happened.
232                         Runnable runnable = () -> {};
233                         v.submit(runnable);
234                         assertThat(v).isEqualTo(mockScheduledExecutor);
235                         Mockito.verify(v, times(1)).submit(eq(runnable));
236                     } else {
237                         // Null means we're using the default, and the default should be specified by the runtime.
238                         assertThat(executor).isPresent();
239                     }
240                 })
241                 .clientConfigurationValidator((c, v) -> {
242                     ScheduledExecutorService configuredService = c.option(SdkClientOption.SCHEDULED_EXECUTOR_SERVICE);
243                     if (mockScheduledExecutor.equals(v)) {
244                         // The SDK should decorate the non-default-value. Ensure that's what happened.
245                         Runnable runnable = () -> {};
246                         configuredService.submit(runnable);
247                         assertThat(v).isEqualTo(mockScheduledExecutor);
248                         Mockito.verify(v, times(1)).submit(eq(runnable));
249                     } else {
250                         assertThat(configuredService).isNotNull();
251                     }
252                 }),
253             new TestCase<Map<SdkAdvancedClientOption<?>, ?>>("override.advancedOptions")
254                 .defaultValue(emptyMap())
255                 .nonDefaultValue(singletonMap(SdkAdvancedClientOption.USER_AGENT_PREFIX, "foo"))
256                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.advancedOptions(v)))
257                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> {
258                     v.forEach((option, value) -> unsafePutOption(c, option, value));
259                 })))
260                 .pluginValidator((c, v) -> {
261                     v.forEach((o, ov) -> assertThat(c.overrideConfiguration().advancedOption(o).orElse(null)).isEqualTo(ov));
262                 })
263                 .clientConfigurationValidator((c, v) -> v.forEach((o, ov) -> assertThat(c.option(o)).isEqualTo(ov))),
264             new TestCase<Duration>("override.apiCallTimeout")
265                 .defaultValue(null)
266                 .nonDefaultValue(Duration.ofSeconds(5))
267                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.apiCallTimeout(v)))
268                 .requestSetter(AwsRequestOverrideConfiguration.Builder::apiCallTimeout)
269                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.apiCallTimeout(v))))
270                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().apiCallTimeout().orElse(null)).isEqualTo(v))
271                 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.API_CALL_TIMEOUT)).isEqualTo(v)),
272             new TestCase<Duration>("override.apiCallAttemptTimeout")
273                 .defaultValue(null)
274                 .nonDefaultValue(Duration.ofSeconds(3))
275                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.apiCallAttemptTimeout(v)))
276                 .requestSetter(AwsRequestOverrideConfiguration.Builder::apiCallAttemptTimeout)
277                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.apiCallAttemptTimeout(v))))
278                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().apiCallAttemptTimeout().orElse(null)).isEqualTo(v))
279                 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.API_CALL_ATTEMPT_TIMEOUT)).isEqualTo(v)),
280             new TestCase<Supplier<ProfileFile>>("override.defaultProfileFileSupplier")
281                 .defaultValue(new Lazy<>(ProfileFile::defaultProfileFile)::getValue)
282                 .nonDefaultValue(() -> nonDefaultProfileFile)
283                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.defaultProfileFileSupplier(v)))
284                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.defaultProfileFileSupplier(v))))
285                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().defaultProfileFileSupplier().get().get()).isEqualTo(v.get()))
286                 .clientConfigurationValidator((c, v) -> {
287                     Supplier<ProfileFile> supplier = c.option(SdkClientOption.PROFILE_FILE_SUPPLIER);
288                     assertThat(supplier.get()).isEqualTo(v.get());
289 
290                     Optional<Profile> defaultProfile = v.get().profile("default");
291                     defaultProfile.ifPresent(profile -> {
292                         profile.booleanProperty(ProfileProperty.USE_FIPS_ENDPOINT).ifPresent(d -> {
293                             assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isEqualTo(d);
294                         });
295                     });
296                     if (!defaultProfile.isPresent()) {
297                         assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isIn(null, false);
298                     }
299                 }),
300             new TestCase<ProfileFile>("override.defaultProfileFile")
301                 .defaultValue(ProfileFile.defaultProfileFile())
302                 .nonDefaultValue(nonDefaultProfileFile)
303                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.defaultProfileFile(v)))
304                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.defaultProfileFile(v))))
305                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().defaultProfileFile()).hasValue(v))
306                 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.PROFILE_FILE)).isEqualTo(v)),
307             new TestCase<String>("override.defaultProfileName")
308                 .defaultValue("default")
309                 .nonDefaultValue("some-profile")
310                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.defaultProfileName(v)
311                                                                       .defaultProfileFile(nonDefaultProfileFile)))
312                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.defaultProfileName(v)
313                                                                                                      .defaultProfileFile(nonDefaultProfileFile))))
314                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().defaultProfileName().orElse(null)).isEqualTo(v))
315                 .clientConfigurationValidator((c, v) -> {
316                     assertThat(c.option(SdkClientOption.PROFILE_NAME)).isEqualTo(v);
317                     ProfileFile profileFile = c.option(SdkClientOption.PROFILE_FILE_SUPPLIER).get();
318 
319                     Optional<Profile> configuredProfile = profileFile.profile(v);
320                     configuredProfile.ifPresent(profile -> {
321                         profile.booleanProperty(ProfileProperty.USE_FIPS_ENDPOINT).ifPresent(d -> {
322                             assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isEqualTo(d);
323                         });
324                     });
325                     if (!configuredProfile.isPresent()) {
326                         assertThat(c.option(AwsClientOption.FIPS_ENDPOINT_ENABLED)).isIn(null, false);
327                     }
328                 }),
329             new TestCase<List<MetricPublisher>>("override.metricPublishers")
330                 .defaultValue(emptyList())
331                 .nonDefaultValue(singletonList(mockMetricPublisher))
332                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.metricPublishers(v)))
333                 .requestSetter(AwsRequestOverrideConfiguration.Builder::metricPublishers)
334                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.metricPublishers(v))))
335                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().metricPublishers()).isEqualTo(v))
336                 .clientConfigurationValidator((c, v) -> {
337                     assertThat(c.option(SdkClientOption.METRIC_PUBLISHERS)).containsAll(v);
338                 }),
339             new TestCase<ExecutionAttributes>("override.executionAttributes")
340                 .defaultValue(new ExecutionAttributes())
341                 .nonDefaultValue(new ExecutionAttributes().putAttribute(FlagSettingInterceptor.FLAG, true))
342                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.executionAttributes(v)))
343                 .requestSetter(AwsRequestOverrideConfiguration.Builder::executionAttributes)
344                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.executionAttributes(v))))
345                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().executionAttributes()).isEqualTo(v))
346                 .beforeTransmissionValidator((r, a, v) -> {
347                     assertThat(a.getAttribute(FlagSettingInterceptor.FLAG)).isTrue();
348                 }),
349             new TestCase<CompressionConfiguration>("override.compressionConfiguration")
350                 .defaultValue(CompressionConfiguration.builder()
351                                                       .requestCompressionEnabled(true)
352                                                       .minimumCompressionThresholdInBytes(10_240)
353                                                       .build())
354                 .nonDefaultValue(CompressionConfiguration.builder()
355                                                          .requestCompressionEnabled(true)
356                                                          .minimumCompressionThresholdInBytes(1)
357                                                          .build())
358                 .clientSetter((b, v) -> b.overrideConfiguration(c -> c.compressionConfiguration(v)))
359                 .requestSetter(AwsRequestOverrideConfiguration.Builder::compressionConfiguration)
360                 .pluginSetter((b, v) -> b.overrideConfiguration(b.overrideConfiguration().copy(c -> c.compressionConfiguration(v))))
361                 .pluginValidator((c, v) -> assertThat(c.overrideConfiguration().compressionConfiguration().orElse(null)).isEqualTo(v))
362                 .clientConfigurationValidator((c, v) -> assertThat(c.option(SdkClientOption.COMPRESSION_CONFIGURATION)).isEqualTo(v))
363         );
364     }
365 
unsafePutOption(ClientOverrideConfiguration.Builder config, SdkAdvancedClientOption<T> option, Object value)366     private static <T> void unsafePutOption(ClientOverrideConfiguration.Builder config,
367                                             SdkAdvancedClientOption<T> option,
368                                             Object value) {
369         config.putAdvancedOption(option, option.convertValue(value));
370 
371     }
372 
373     @ParameterizedTest
374     @MethodSource("testCases")
validateTestCaseData(TestCase<T> testCase)375     public <T> void validateTestCaseData(TestCase<T> testCase) {
376         assertThat(testCase.defaultValue).isNotEqualTo(testCase.nonDefaultValue);
377     }
378 
379     @ParameterizedTest
380     @MethodSource("testCases")
clientPluginSeesDefaultValue(TestCase<T> testCase)381     public <T> void clientPluginSeesDefaultValue(TestCase<T> testCase) {
382         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
383 
384         AtomicInteger timesCalled = new AtomicInteger(0);
385         SdkPlugin plugin = config -> {
386             ProtocolRestJsonServiceClientConfiguration.Builder conf =
387                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
388             testCase.pluginValidator.accept(conf, testCase.defaultValue);
389             timesCalled.incrementAndGet();
390         };
391 
392         ProtocolRestJsonClient client = clientBuilder.addPlugin(plugin).build();
393         if (testCase.clientConfigurationValidator != null) {
394             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue);
395         }
396         assertThat(timesCalled).hasValue(1);
397     }
398 
399     @ParameterizedTest
400     @MethodSource("testCases")
requestPluginSeesDefaultValue(TestCase<T> testCase)401     public <T> void requestPluginSeesDefaultValue(TestCase<T> testCase) {
402         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
403 
404         AtomicInteger timesCalled = new AtomicInteger(0);
405         SdkPlugin plugin = config -> {
406             ProtocolRestJsonServiceClientConfiguration.Builder conf =
407                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
408             testCase.pluginValidator.accept(conf, testCase.defaultValue);
409             timesCalled.incrementAndGet();
410         };
411 
412         ProtocolRestJsonClient client = clientBuilder.httpClient(succeedingHttpClient()).build();
413         if (testCase.clientConfigurationValidator != null) {
414             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue);
415         }
416         client.allTypes(r -> r.overrideConfiguration(c -> c.addPlugin(plugin)));
417         assertThat(timesCalled).hasValue(1);
418     }
419 
420     @ParameterizedTest
421     @MethodSource("testCases")
clientPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase)422     public <T> void clientPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase) {
423         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
424         testCase.clientSetter.accept(clientBuilder, testCase.nonDefaultValue);
425 
426         AtomicInteger timesCalled = new AtomicInteger(0);
427         SdkPlugin plugin = config -> {
428             ProtocolRestJsonServiceClientConfiguration.Builder conf =
429                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
430             testCase.pluginValidator.accept(conf, testCase.nonDefaultValue);
431             timesCalled.incrementAndGet();
432         };
433 
434         ProtocolRestJsonClient client = clientBuilder.addPlugin(plugin).build();
435 
436         if (testCase.clientConfigurationValidator != null) {
437             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.nonDefaultValue);
438         }
439 
440         assertThat(timesCalled).hasValue(1);
441     }
442 
443     @ParameterizedTest
444     @MethodSource("testCases")
requestPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase)445     public <T> void requestPluginSeesCustomerClientConfiguredValue(TestCase<T> testCase) {
446         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
447         testCase.clientSetter.accept(clientBuilder, testCase.nonDefaultValue);
448 
449         AtomicInteger timesCalled = new AtomicInteger(0);
450         SdkPlugin plugin = config -> {
451             ProtocolRestJsonServiceClientConfiguration.Builder conf =
452                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
453             testCase.pluginValidator.accept(conf, testCase.nonDefaultValue);
454             timesCalled.incrementAndGet();
455         };
456 
457         ProtocolRestJsonClient client = clientBuilder.httpClient(succeedingHttpClient()).build();
458         if (testCase.clientConfigurationValidator != null) {
459             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.nonDefaultValue);
460         }
461         client.allTypes(r -> r.overrideConfiguration(c -> c.addPlugin(plugin)));
462         assertThat(timesCalled).hasValue(1);
463     }
464 
465     @ParameterizedTest
466     @MethodSource("testCases")
467     @Disabled("Request-level values are currently higher-priority than plugin settings.") // TODO(sra-identity-auth)
requestPluginSeesCustomerRequestConfiguredValue(TestCase<T> testCase)468     public <T> void requestPluginSeesCustomerRequestConfiguredValue(TestCase<T> testCase) {
469         if (testCase.requestSetter == null) {
470             System.out.println("No request setting available.");
471             return;
472         }
473 
474         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
475 
476         AtomicInteger timesCalled = new AtomicInteger(0);
477         SdkPlugin plugin = config -> {
478             ProtocolRestJsonServiceClientConfiguration.Builder conf =
479                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
480             testCase.pluginValidator.accept(conf, testCase.nonDefaultValue);
481             timesCalled.incrementAndGet();
482         };
483 
484         AwsRequestOverrideConfiguration overrideConfig =
485             AwsRequestOverrideConfiguration.builder()
486                                            .addPlugin(plugin)
487                                            .applyMutation(c -> testCase.requestSetter.accept(c, testCase.nonDefaultValue))
488                                            .build();
489 
490         ProtocolRestJsonClient client = clientBuilder.httpClient(succeedingHttpClient()).build();
491 
492         if (testCase.clientConfigurationValidator != null) {
493             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue);
494         }
495 
496         client.allTypes(r -> r.overrideConfiguration(overrideConfig));
497         assertThat(timesCalled).hasValue(1);
498     }
499 
500     @ParameterizedTest
501     @MethodSource("testCases")
clientPluginSetValueIsUsed(TestCase<T> testCase)502     public <T> void clientPluginSetValueIsUsed(TestCase<T> testCase) {
503         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
504         testCase.clientSetter.accept(clientBuilder, testCase.defaultValue);
505 
506         AtomicInteger timesPluginCalled = new AtomicInteger(0);
507         SdkPlugin plugin = config -> {
508             timesPluginCalled.incrementAndGet();
509             ProtocolRestJsonServiceClientConfiguration.Builder conf =
510                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
511             testCase.pluginSetter.accept(conf, testCase.nonDefaultValue);
512         };
513 
514         AtomicInteger timesInterceptorCalled = new AtomicInteger(0);
515         ExecutionInterceptor validatingInterceptor = new ExecutionInterceptor() {
516             @Override
517             public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
518                 timesInterceptorCalled.getAndIncrement();
519                 if (testCase.beforeTransmissionValidator != null) {
520                     testCase.beforeTransmissionValidator.accept(context, executionAttributes, testCase.nonDefaultValue);
521                 }
522             }
523         };
524 
525         ProtocolRestJsonClient client =
526             clientBuilder.httpClient(succeedingHttpClient())
527                          .addPlugin(plugin)
528                          .overrideConfiguration(c -> c.addExecutionInterceptor(validatingInterceptor))
529                          .build();
530 
531         if (testCase.clientConfigurationValidator != null) {
532             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.nonDefaultValue);
533         }
534 
535         client.allTypes();
536 
537         assertThat(timesPluginCalled).hasValue(1);
538         assertThat(timesInterceptorCalled).hasValueGreaterThanOrEqualTo(1);
539     }
540 
541     @ParameterizedTest
542     @MethodSource("testCases")
requestPluginSetValueIsUsed(TestCase<T> testCase)543     public <T> void requestPluginSetValueIsUsed(TestCase<T> testCase) {
544         ProtocolRestJsonClientBuilder clientBuilder = defaultClientBuilder();
545         testCase.clientSetter.accept(clientBuilder, testCase.defaultValue);
546 
547         AtomicInteger timesPluginCalled = new AtomicInteger(0);
548         SdkPlugin plugin = config -> {
549             timesPluginCalled.incrementAndGet();
550             ProtocolRestJsonServiceClientConfiguration.Builder conf =
551                 (ProtocolRestJsonServiceClientConfiguration.Builder) config;
552             testCase.pluginSetter.accept(conf, testCase.nonDefaultValue);
553         };
554 
555         AtomicInteger timesInterceptorCalled = new AtomicInteger(0);
556         ExecutionInterceptor validatingInterceptor = new ExecutionInterceptor() {
557             @Override
558             public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
559                 timesInterceptorCalled.incrementAndGet();
560                 if (testCase.beforeTransmissionValidator != null) {
561                     testCase.beforeTransmissionValidator.accept(context, executionAttributes, testCase.nonDefaultValue);
562                 }
563             }
564         };
565 
566         AwsRequestOverrideConfiguration requestConfig =
567             AwsRequestOverrideConfiguration.builder()
568                                            .addPlugin(plugin)
569                                            .applyMutation(c -> {
570                                                // TODO(sra-identity-auth): request-level plugins should override request-level
571                                                // configuration
572                                                // if (testCase.requestSetter != null) {
573                                                //     testCase.requestSetter.accept(c, testCase.defaultValue);
574                                                // }
575                                            })
576                                            .build();
577 
578         ProtocolRestJsonClient client =
579             clientBuilder.httpClient(succeedingHttpClient())
580                          .overrideConfiguration(c -> c.addExecutionInterceptor(validatingInterceptor))
581                          .build();
582 
583         if (testCase.clientConfigurationValidator != null) {
584             testCase.clientConfigurationValidator.accept(extractClientConfiguration(client), testCase.defaultValue);
585         }
586 
587         client.allTypes(r -> r.overrideConfiguration(requestConfig));
588 
589         assertThat(timesPluginCalled).hasValue(1);
590         assertThat(timesInterceptorCalled).hasValueGreaterThanOrEqualTo(1);
591     }
592 
defaultClientBuilder()593     private static ProtocolRestJsonClientBuilder defaultClientBuilder() {
594         return ProtocolRestJsonClient.builder().region(Region.US_WEST_2).credentialsProvider(DEFAULT_CREDENTIALS);
595     }
596 
extractClientConfiguration(ProtocolRestJsonClient client)597     private SdkClientConfiguration extractClientConfiguration(ProtocolRestJsonClient client) {
598         try {
599             // Naughty, but we need to be able to verify some things that can't be easily observed with unprotected means.
600             Class<? extends ProtocolRestJsonClient> clientClass = client.getClass();
601             Field configField = clientClass.getDeclaredField("clientConfiguration");
602             configField.setAccessible(true);
603             return (SdkClientConfiguration) configField.get(client);
604         } catch (Exception e) {
605             throw new RuntimeException(e);
606         }
607     }
608 
609     static class TestCase<T> {
610         private final String configName;
611         T defaultValue;
612         T nonDefaultValue;
613         BiConsumer<ProtocolRestJsonClientBuilder, T> clientSetter;
614         BiConsumer<AwsRequestOverrideConfiguration.Builder, T> requestSetter;
615         BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginSetter;
616 
617         BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginValidator;
618         BiConsumer<SdkClientConfiguration, T> clientConfigurationValidator;
619         TriConsumer<Context.BeforeTransmission, ExecutionAttributes, T> beforeTransmissionValidator;
620 
TestCase(String configName)621         TestCase(String configName) {
622             this.configName = configName;
623         }
624 
defaultValue(T defaultValue)625         public TestCase<T> defaultValue(T defaultValue) {
626             this.defaultValue = defaultValue;
627             return this;
628         }
629 
nonDefaultValue(T nonDefaultValue)630         public TestCase<T> nonDefaultValue(T nonDefaultValue) {
631             this.nonDefaultValue = nonDefaultValue;
632             return this;
633         }
634 
clientSetter(BiConsumer<ProtocolRestJsonClientBuilder, T> clientSetter)635         public TestCase<T> clientSetter(BiConsumer<ProtocolRestJsonClientBuilder, T> clientSetter) {
636             this.clientSetter = clientSetter;
637             return this;
638         }
639 
requestSetter(BiConsumer<AwsRequestOverrideConfiguration.Builder, T> requestSetter)640         public TestCase<T> requestSetter(BiConsumer<AwsRequestOverrideConfiguration.Builder, T> requestSetter) {
641             this.requestSetter = requestSetter;
642             return this;
643         }
644 
pluginSetter(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginSetter)645         public TestCase<T> pluginSetter(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginSetter) {
646             this.pluginSetter = pluginSetter;
647             return this;
648         }
649 
pluginValidator(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginValidator)650         public TestCase<T> pluginValidator(BiConsumer<ProtocolRestJsonServiceClientConfiguration.Builder, T> pluginValidator) {
651             this.pluginValidator = pluginValidator;
652             return this;
653         }
654 
clientConfigurationValidator(BiConsumer<SdkClientConfiguration, T> clientConfigurationValidator)655         public TestCase<T> clientConfigurationValidator(BiConsumer<SdkClientConfiguration, T> clientConfigurationValidator) {
656             this.clientConfigurationValidator = clientConfigurationValidator;
657             return this;
658         }
659 
beforeTransmissionValidator(TriConsumer<Context.BeforeTransmission, ExecutionAttributes, T> beforeTransmissionValidator)660         public TestCase<T> beforeTransmissionValidator(TriConsumer<Context.BeforeTransmission, ExecutionAttributes, T> beforeTransmissionValidator) {
661             this.beforeTransmissionValidator = beforeTransmissionValidator;
662             return this;
663         }
664 
665         @Override
toString()666         public String toString() {
667             return configName;
668         }
669     }
670 
succeedingHttpClient()671     private static SdkHttpClient succeedingHttpClient() {
672         MockSyncHttpClient client = new MockSyncHttpClient();
673         client.stubNextResponse200();
674         return client;
675     }
676 
removePathAndQueryString(URI uri)677     private static URI removePathAndQueryString(URI uri) {
678         String uriString = uri.toString();
679         return URI.create(uriString.substring(0, uriString.indexOf('/', "https://".length())));
680     }
681 
682     public interface TriConsumer<T, U, V> {
accept(T t, U u, V v)683         void accept(T t, U u, V v);
684     }
685 
686     private static class CustomAuthScheme implements AuthScheme<NoAuthAuthScheme.AnonymousIdentity> {
687         private static final String SCHEME_ID = "foo";
688         private static final AuthScheme<NoAuthAuthScheme.AnonymousIdentity> DELEGATE = NoAuthAuthScheme.create();
689 
690         @Override
schemeId()691         public String schemeId() {
692             return SCHEME_ID;
693         }
694 
695         @Override
identityProvider(IdentityProviders providers)696         public IdentityProvider<NoAuthAuthScheme.AnonymousIdentity> identityProvider(IdentityProviders providers) {
697             return DELEGATE.identityProvider(providers);
698         }
699 
700         @Override
signer()701         public HttpSigner<NoAuthAuthScheme.AnonymousIdentity> signer() {
702             return DELEGATE.signer();
703         }
704     }
705 
706     private static class FlagSettingInterceptor implements ExecutionInterceptor {
707         private static final ExecutionAttribute<Boolean> FLAG = new ExecutionAttribute<>("InterceptorAdded");
708 
709         @Override
beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes)710         public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
711             executionAttributes.putAttribute(FLAG, true);
712         }
713     }
714 }
715