1 /* 2 * Copyright 2022, Google Inc. All rights reserved. 3 * 4 * Redistribution and use in source and binary forms, with or without 5 * modification, are permitted provided that the following conditions are 6 * met: 7 * 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above 11 * copyright notice, this list of conditions and the following disclaimer 12 * in the documentation and/or other materials provided with the 13 * distribution. 14 * 15 * * Neither the name of Google Inc. nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 */ 31 32 package com.google.auth.oauth2; 33 34 import com.google.api.client.http.GenericUrl; 35 import com.google.api.client.http.HttpRequest; 36 import com.google.api.client.http.HttpRequestFactory; 37 import com.google.api.client.http.HttpResponse; 38 import com.google.api.client.http.HttpResponseException; 39 import com.google.api.client.http.HttpTransport; 40 import com.google.api.client.http.UrlEncodedContent; 41 import com.google.api.client.http.javanet.NetHttpTransport; 42 import com.google.api.client.json.JsonFactory; 43 import com.google.api.client.json.JsonObjectParser; 44 import com.google.api.client.json.webtoken.JsonWebSignature; 45 import com.google.api.client.json.webtoken.JsonWebToken; 46 import com.google.api.client.util.GenericData; 47 import com.google.auth.http.HttpTransportFactory; 48 import com.google.common.annotations.VisibleForTesting; 49 import com.google.common.base.MoreObjects; 50 import com.google.common.base.Preconditions; 51 import com.google.errorprone.annotations.CanIgnoreReturnValue; 52 import java.io.File; 53 import java.io.FileInputStream; 54 import java.io.FileNotFoundException; 55 import java.io.IOException; 56 import java.io.InputStream; 57 import java.io.ObjectInputStream; 58 import java.net.URI; 59 import java.net.URISyntaxException; 60 import java.security.GeneralSecurityException; 61 import java.security.PrivateKey; 62 import java.util.Date; 63 import java.util.Map; 64 import java.util.Objects; 65 66 public class GdchCredentials extends GoogleCredentials { 67 static final String SUPPORTED_FORMAT_VERSION = "1"; 68 private static final String PARSE_ERROR_PREFIX = "Error parsing token refresh response. "; 69 private static final int DEFAULT_LIFETIME_IN_SECONDS = 3600; 70 71 private final PrivateKey privateKey; 72 private final String privateKeyId; 73 private final String projectId; 74 private final String serviceIdentityName; 75 private final URI tokenServerUri; 76 private final URI apiAudience; 77 private final int lifetime; 78 private final String transportFactoryClassName; 79 private final String caCertPath; 80 private transient HttpTransportFactory transportFactory; 81 82 /** 83 * Internal constructor. 84 * 85 * @param builder A builder for {@link GdchCredentials} See {@link GdchCredentials.Builder}. 86 */ 87 @VisibleForTesting GdchCredentials(GdchCredentials.Builder builder)88 GdchCredentials(GdchCredentials.Builder builder) { 89 this.projectId = Preconditions.checkNotNull(builder.projectId); 90 this.privateKeyId = Preconditions.checkNotNull(builder.privateKeyId); 91 this.privateKey = Preconditions.checkNotNull(builder.privateKey); 92 this.serviceIdentityName = Preconditions.checkNotNull(builder.serviceIdentityName); 93 this.tokenServerUri = Preconditions.checkNotNull(builder.tokenServerUri); 94 this.transportFactory = Preconditions.checkNotNull(builder.transportFactory); 95 this.transportFactoryClassName = this.transportFactory.getClass().getName(); 96 this.caCertPath = builder.caCertPath; 97 this.apiAudience = builder.apiAudience; 98 this.lifetime = builder.lifetime; 99 } 100 101 /** 102 * Create GDCH service account credentials defined by JSON. 103 * 104 * @param json a map from the JSON representing the credentials. 105 * @return the GDCH service account credentials defined by the JSON. 106 * @throws IOException if the credential cannot be created from the JSON. 107 */ fromJson(Map<String, Object> json)108 static GdchCredentials fromJson(Map<String, Object> json) throws IOException { 109 String caCertPath = (String) json.get("ca_cert_path"); 110 return fromJson(json, new TransportFactoryForGdch(caCertPath)); 111 } 112 113 /** 114 * Create GDCH service account credentials defined by JSON. 115 * 116 * @param json a map from the JSON representing the credentials. 117 * @param transportFactory HTTP transport factory, creates the transport used to get access 118 * tokens. 119 * @return the GDCH service account credentials defined by the JSON. 120 * @throws IOException if the credential cannot be created from the JSON. 121 */ 122 @VisibleForTesting fromJson(Map<String, Object> json, HttpTransportFactory transportFactory)123 static GdchCredentials fromJson(Map<String, Object> json, HttpTransportFactory transportFactory) 124 throws IOException { 125 String formatVersion = validateField((String) json.get("format_version"), "format_version"); 126 String projectId = validateField((String) json.get("project"), "project"); 127 String privateKeyId = validateField((String) json.get("private_key_id"), "private_key_id"); 128 String privateKeyPkcs8 = validateField((String) json.get("private_key"), "private_key"); 129 String serviceIdentityName = validateField((String) json.get("name"), "name"); 130 String tokenServerUriStringFromCreds = 131 validateField((String) json.get("token_uri"), "token_uri"); 132 String caCertPath = (String) json.get("ca_cert_path"); 133 134 if (!SUPPORTED_FORMAT_VERSION.equals(formatVersion)) { 135 throw new IOException( 136 String.format("Only format version %s is supported.", SUPPORTED_FORMAT_VERSION)); 137 } 138 139 URI tokenServerUriFromCreds = null; 140 try { 141 tokenServerUriFromCreds = new URI(tokenServerUriStringFromCreds); 142 } catch (URISyntaxException e) { 143 throw new IOException("Token server URI specified in 'token_uri' could not be parsed."); 144 } 145 146 GdchCredentials.Builder builder = 147 GdchCredentials.newBuilder() 148 .setProjectId(projectId) 149 .setPrivateKeyId(privateKeyId) 150 .setTokenServerUri(tokenServerUriFromCreds) 151 .setServiceIdentityName(serviceIdentityName) 152 .setCaCertPath(caCertPath) 153 .setHttpTransportFactory(transportFactory); 154 155 return fromPkcs8(privateKeyPkcs8, builder); 156 } 157 158 /** 159 * Internal constructor. 160 * 161 * @param privateKeyPkcs8 RSA private key object for the service account in PKCS#8 format. 162 * @param builder A builder for GdchCredentials. 163 * @return an instance of GdchCredentials. 164 */ fromPkcs8(String privateKeyPkcs8, GdchCredentials.Builder builder)165 static GdchCredentials fromPkcs8(String privateKeyPkcs8, GdchCredentials.Builder builder) 166 throws IOException { 167 PrivateKey privateKey = OAuth2Utils.privateKeyFromPkcs8(privateKeyPkcs8); 168 builder.setPrivateKey(privateKey); 169 170 return new GdchCredentials(builder); 171 } 172 173 /** 174 * Create a copy of GDCH credentials with the specified audience. 175 * 176 * @param apiAudience The intended audience for GDCH credentials. 177 */ createWithGdchAudience(URI apiAudience)178 public GdchCredentials createWithGdchAudience(URI apiAudience) throws IOException { 179 Preconditions.checkNotNull( 180 apiAudience, "Audience are not configured for GDCH service account credentials."); 181 return this.toBuilder().setGdchAudience(apiAudience).build(); 182 } 183 184 /** 185 * Refresh the OAuth2 access token by getting a new access token using a JSON Web Token (JWT). 186 * 187 * <p>For GDCH credentials, this class creates a self-signed JWT, and sends to the GDCH 188 * authentication endpoint (tokenServerUri) to exchange an access token for the intended api 189 * audience (apiAudience). 190 */ 191 @Override refreshAccessToken()192 public AccessToken refreshAccessToken() throws IOException { 193 Preconditions.checkNotNull( 194 this.apiAudience, 195 "Audience are not configured for GDCH service account. Specify the " 196 + "audience by calling createWithGDCHAudience."); 197 198 JsonFactory jsonFactory = OAuth2Utils.JSON_FACTORY; 199 long currentTime = clock.currentTimeMillis(); 200 String assertion = createAssertion(jsonFactory, currentTime, getApiAudience()); 201 202 GenericData tokenRequest = new GenericData(); 203 tokenRequest.set("grant_type", OAuth2Utils.TOKEN_TYPE_TOKEN_EXCHANGE); 204 tokenRequest.set("assertion", assertion); 205 UrlEncodedContent content = new UrlEncodedContent(tokenRequest); 206 207 HttpRequestFactory requestFactory = transportFactory.create().createRequestFactory(); 208 HttpRequest request = requestFactory.buildPostRequest(new GenericUrl(tokenServerUri), content); 209 210 request.setParser(new JsonObjectParser(jsonFactory)); 211 212 HttpResponse response; 213 String errorTemplate = "Error getting access token for GDCH service account: %s, iss: %s"; 214 215 try { 216 response = request.execute(); 217 } catch (HttpResponseException re) { 218 String message = String.format(errorTemplate, re.getMessage(), getServiceIdentityName()); 219 throw GoogleAuthException.createWithTokenEndpointResponseException(re, message); 220 } catch (IOException e) { 221 throw GoogleAuthException.createWithTokenEndpointIOException( 222 e, String.format(errorTemplate, e.getMessage(), getServiceIdentityName())); 223 } 224 225 GenericData responseData = response.parseAs(GenericData.class); 226 String accessToken = 227 OAuth2Utils.validateString(responseData, "access_token", PARSE_ERROR_PREFIX); 228 int expiresInSeconds = 229 OAuth2Utils.validateInt32(responseData, "expires_in", PARSE_ERROR_PREFIX); 230 long expiresAtMilliseconds = clock.currentTimeMillis() + expiresInSeconds * 1000L; 231 return new AccessToken(accessToken, new Date(expiresAtMilliseconds)); 232 } 233 234 /** 235 * Create a self-signed JWT for GDCH authentication flow. 236 * 237 * <p>The self-signed JWT is used to exchange access token from GDCH authentication 238 * (tokenServerUri), not for API call. It uses the serviceIdentityName as the `iss` and `sub` 239 * claim, and the tokenServerUri as the `aud` claim. The JWT is signed with the privateKey. 240 */ createAssertion(JsonFactory jsonFactory, long currentTime, URI apiAudience)241 String createAssertion(JsonFactory jsonFactory, long currentTime, URI apiAudience) 242 throws IOException { 243 JsonWebSignature.Header header = new JsonWebSignature.Header(); 244 header.setAlgorithm("RS256"); 245 header.setType("JWT"); 246 header.setKeyId(privateKeyId); 247 248 JsonWebToken.Payload payload = new JsonWebToken.Payload(); 249 payload.setIssuer(getIssuerSubjectValue(projectId, serviceIdentityName)); 250 payload.setSubject(getIssuerSubjectValue(projectId, serviceIdentityName)); 251 payload.setIssuedAtTimeSeconds(currentTime / 1000); 252 payload.setExpirationTimeSeconds(currentTime / 1000 + this.lifetime); 253 payload.setAudience(getTokenServerUri().toString()); 254 255 String assertion; 256 try { 257 payload.set("api_audience", apiAudience.toString()); 258 assertion = JsonWebSignature.signUsingRsaSha256(privateKey, jsonFactory, header, payload); 259 } catch (GeneralSecurityException e) { 260 throw new IOException( 261 "Error signing service account access token request with private key.", e); 262 } 263 264 return assertion; 265 } 266 267 /** 268 * Get the issuer and subject value in the format GDCH token server required. 269 * 270 * <p>This value is specific to GDCH and combined parameter used for both `iss` and `sub` fields 271 * in JWT claim. 272 */ 273 @VisibleForTesting getIssuerSubjectValue(String projectId, String serviceIdentityName)274 static String getIssuerSubjectValue(String projectId, String serviceIdentityName) { 275 return String.format("system:serviceaccount:%s:%s", projectId, serviceIdentityName); 276 } 277 getProjectId()278 public final String getProjectId() { 279 return projectId; 280 } 281 getPrivateKeyId()282 public final String getPrivateKeyId() { 283 return privateKeyId; 284 } 285 getPrivateKey()286 public final PrivateKey getPrivateKey() { 287 return privateKey; 288 } 289 getServiceIdentityName()290 public final String getServiceIdentityName() { 291 return serviceIdentityName; 292 } 293 getTokenServerUri()294 public final URI getTokenServerUri() { 295 return tokenServerUri; 296 } 297 getApiAudience()298 public final URI getApiAudience() { 299 return apiAudience; 300 } 301 getTransportFactory()302 public final HttpTransportFactory getTransportFactory() { 303 return transportFactory; 304 } 305 getCaCertPath()306 public final String getCaCertPath() { 307 return caCertPath; 308 } 309 newBuilder()310 public static Builder newBuilder() { 311 return new Builder(); 312 } 313 314 @Override toBuilder()315 public Builder toBuilder() { 316 return new Builder(this); 317 } 318 319 @SuppressWarnings("unused") readObject(ObjectInputStream input)320 private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { 321 // properly deserialize the transient transportFactory. 322 input.defaultReadObject(); 323 transportFactory = newInstance(transportFactoryClassName); 324 } 325 326 @Override hashCode()327 public int hashCode() { 328 return Objects.hash( 329 projectId, 330 privateKeyId, 331 privateKey, 332 serviceIdentityName, 333 tokenServerUri, 334 transportFactoryClassName, 335 apiAudience, 336 caCertPath, 337 lifetime); 338 } 339 340 @Override toString()341 public String toString() { 342 return MoreObjects.toStringHelper(this) 343 .add("projectId", projectId) 344 .add("privateKeyId", privateKeyId) 345 .add("serviceIdentityName", serviceIdentityName) 346 .add("tokenServerUri", tokenServerUri) 347 .add("transportFactoryClassName", transportFactoryClassName) 348 .add("caCertPath", caCertPath) 349 .add("apiAudience", apiAudience) 350 .add("lifetime", lifetime) 351 .toString(); 352 } 353 354 @Override equals(Object obj)355 public boolean equals(Object obj) { 356 if (!(obj instanceof GdchCredentials)) { 357 return false; 358 } 359 GdchCredentials other = (GdchCredentials) obj; 360 return Objects.equals(this.projectId, other.projectId) 361 && Objects.equals(this.privateKeyId, other.privateKeyId) 362 && Objects.equals(this.privateKey, other.privateKey) 363 && Objects.equals(this.serviceIdentityName, other.serviceIdentityName) 364 && Objects.equals(this.tokenServerUri, other.tokenServerUri) 365 && Objects.equals(this.transportFactoryClassName, other.transportFactoryClassName) 366 && Objects.equals(this.apiAudience, other.apiAudience) 367 && Objects.equals(this.caCertPath, other.caCertPath) 368 && Objects.equals(this.lifetime, other.lifetime); 369 } 370 readStream(File file)371 static InputStream readStream(File file) throws FileNotFoundException { 372 return new FileInputStream(file); 373 } 374 375 public static class Builder extends GoogleCredentials.Builder { 376 private String projectId; 377 private String privateKeyId; 378 private PrivateKey privateKey; 379 private String serviceIdentityName; 380 private URI tokenServerUri; 381 private URI apiAudience; 382 private HttpTransportFactory transportFactory; 383 private String caCertPath; 384 private int lifetime = DEFAULT_LIFETIME_IN_SECONDS; 385 Builder()386 protected Builder() {} 387 Builder(GdchCredentials credentials)388 protected Builder(GdchCredentials credentials) { 389 this.projectId = credentials.projectId; 390 this.privateKeyId = credentials.privateKeyId; 391 this.privateKey = credentials.privateKey; 392 this.serviceIdentityName = credentials.serviceIdentityName; 393 this.tokenServerUri = credentials.tokenServerUri; 394 this.transportFactory = credentials.transportFactory; 395 this.caCertPath = credentials.caCertPath; 396 this.lifetime = credentials.lifetime; 397 } 398 399 @CanIgnoreReturnValue setProjectId(String projectId)400 public Builder setProjectId(String projectId) { 401 this.projectId = projectId; 402 return this; 403 } 404 405 @CanIgnoreReturnValue setPrivateKeyId(String privateKeyId)406 public Builder setPrivateKeyId(String privateKeyId) { 407 this.privateKeyId = privateKeyId; 408 return this; 409 } 410 411 @CanIgnoreReturnValue setPrivateKey(PrivateKey privateKey)412 public Builder setPrivateKey(PrivateKey privateKey) { 413 this.privateKey = privateKey; 414 return this; 415 } 416 417 @CanIgnoreReturnValue setServiceIdentityName(String name)418 public Builder setServiceIdentityName(String name) { 419 this.serviceIdentityName = name; 420 return this; 421 } 422 423 @CanIgnoreReturnValue setTokenServerUri(URI tokenServerUri)424 public Builder setTokenServerUri(URI tokenServerUri) { 425 this.tokenServerUri = tokenServerUri; 426 return this; 427 } 428 429 @CanIgnoreReturnValue setHttpTransportFactory(HttpTransportFactory transportFactory)430 public Builder setHttpTransportFactory(HttpTransportFactory transportFactory) { 431 this.transportFactory = transportFactory; 432 return this; 433 } 434 435 @CanIgnoreReturnValue setCaCertPath(String caCertPath)436 public Builder setCaCertPath(String caCertPath) { 437 this.caCertPath = caCertPath; 438 return this; 439 } 440 441 @CanIgnoreReturnValue setGdchAudience(URI apiAudience)442 public Builder setGdchAudience(URI apiAudience) { 443 this.apiAudience = apiAudience; 444 return this; 445 } 446 getProjectId()447 public String getProjectId() { 448 return projectId; 449 } 450 getPrivateKeyId()451 public String getPrivateKeyId() { 452 return privateKeyId; 453 } 454 getPrivateKey()455 public PrivateKey getPrivateKey() { 456 return privateKey; 457 } 458 getServiceIdentityName()459 public String getServiceIdentityName() { 460 return serviceIdentityName; 461 } 462 getTokenServerUri()463 public URI getTokenServerUri() { 464 return tokenServerUri; 465 } 466 getHttpTransportFactory()467 public HttpTransportFactory getHttpTransportFactory() { 468 return transportFactory; 469 } 470 getCaCertPath()471 public String getCaCertPath() { 472 return caCertPath; 473 } 474 getLifetime()475 public int getLifetime() { 476 return lifetime; 477 } 478 479 @Override build()480 public GdchCredentials build() { 481 return new GdchCredentials(this); 482 } 483 } 484 validateField(String field, String fieldName)485 private static String validateField(String field, String fieldName) throws IOException { 486 if (field == null || field.isEmpty()) { 487 throw new IOException( 488 String.format( 489 "Error reading GDCH service account credential from JSON, %s is misconfigured.", 490 fieldName)); 491 } 492 return field; 493 } 494 495 /* 496 * Internal HttpTransportFactory for GDCH credentials. 497 * 498 * <p> GDCH authentication server could use a self-signed certificate, thus the client could 499 * provide the CA certificate path through the `ca_cert_path` in GDCH JSON file. 500 * 501 * <p> The TransportFactoryForGdch subclass would read the certificate and create a trust store, 502 * then use the trust store to create a transport. 503 * 504 * <p> If the GDCH authentication server uses well known CA certificate, then a regular transport 505 * would be set. 506 */ 507 static class TransportFactoryForGdch implements HttpTransportFactory { 508 HttpTransport transport; 509 TransportFactoryForGdch(String caCertPath)510 public TransportFactoryForGdch(String caCertPath) throws IOException { 511 setTransport(caCertPath); 512 } 513 514 @Override create()515 public HttpTransport create() { 516 return transport; 517 } 518 setTransport(String caCertPath)519 private void setTransport(String caCertPath) throws IOException { 520 if (caCertPath == null || caCertPath.isEmpty()) { 521 this.transport = new NetHttpTransport(); 522 return; 523 } 524 try { 525 InputStream certificateStream = readStream(new File(caCertPath)); 526 this.transport = 527 new NetHttpTransport.Builder().trustCertificatesFromStream(certificateStream).build(); 528 } catch (IOException e) { 529 throw new IOException( 530 String.format( 531 "Error reading certificate file from CA cert path, value '%s': %s", 532 caCertPath, e.getMessage()), 533 e); 534 } catch (GeneralSecurityException e) { 535 throw new IOException("Error initiating transport with certificate stream.", e); 536 } 537 } 538 } 539 } 540