• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/oauth_client.h"
17 
18 #include <fstream>
19 
20 #include <openssl/bio.h>
21 #include <openssl/evp.h>
22 #include <openssl/pem.h>
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/base64.h"
25 #include "tensorflow/core/platform/cloud/http_request_fake.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/path.h"
28 #include "tensorflow/core/platform/resource_loader.h"
29 #include "tensorflow/core/platform/scanner.h"
30 #include "tensorflow/core/platform/test.h"
31 
32 namespace tensorflow {
33 namespace {
34 
TestData()35 string TestData() {
36   return io::JoinPath("tensorflow", "core", "platform", "cloud", "testdata");
37 }
38 
39 constexpr char kTokenJson[] = R"(
40     {
41       "access_token":"WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY",
42       "expires_in":3920,
43       "token_type":"Bearer"
44     })";
45 
46 class FakeEnv : public EnvWrapper {
47  public:
FakeEnv()48   FakeEnv() : EnvWrapper(Env::Default()) {}
49 
NowSeconds() const50   uint64 NowSeconds() const override { return now; }
51   uint64 now = 10000;
52 };
53 
54 }  // namespace
55 
TEST(OAuthClientTest,ParseOAuthResponse)56 TEST(OAuthClientTest, ParseOAuthResponse) {
57   const uint64 request_timestamp = 100;
58   string token;
59   uint64 expiration_timestamp;
60   TF_EXPECT_OK(OAuthClient().ParseOAuthResponse(kTokenJson, request_timestamp,
61                                                 &token, &expiration_timestamp));
62   EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
63   EXPECT_EQ(4020, expiration_timestamp);
64 }
65 
TEST(OAuthClientTest,GetTokenFromRefreshTokenJson)66 TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) {
67   const string credentials_json = R"(
68       {
69         "client_id": "test_client_id",
70         "client_secret": "@@@test_client_secret@@@",
71         "refresh_token": "test_refresh_token",
72         "type": "authorized_user"
73       })";
74   Json::Value json;
75   Json::Reader reader;
76   ASSERT_TRUE(reader.parse(credentials_json, json));
77 
78   std::vector<HttpRequest*> requests({new FakeHttpRequest(
79       "Uri: https://www.googleapis.com/oauth2/v3/token\n"
80       "Post body: client_id=test_client_id&"
81       "client_secret=@@@test_client_secret@@@&"
82       "refresh_token=test_refresh_token&grant_type=refresh_token\n",
83       kTokenJson)});
84   FakeEnv env;
85   OAuthClient client(std::unique_ptr<HttpRequest::Factory>(
86                          new FakeHttpRequestFactory(&requests)),
87                      &env);
88   string token;
89   uint64 expiration_timestamp;
90   TF_EXPECT_OK(client.GetTokenFromRefreshTokenJson(
91       json, "https://www.googleapis.com/oauth2/v3/token", &token,
92       &expiration_timestamp));
93   EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
94   EXPECT_EQ(13920, expiration_timestamp);
95 }
96 
TEST(OAuthClientTest,GetTokenFromServiceAccountJson)97 TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
98   std::ifstream credentials(GetDataDependencyFilepath(
99       io::JoinPath(TestData(), "service_account_credentials.json")));
100   ASSERT_TRUE(credentials.is_open());
101   Json::Value json;
102   Json::Reader reader;
103   ASSERT_TRUE(reader.parse(credentials, json));
104 
105   string post_body;
106   std::vector<HttpRequest*> requests(
107       {new FakeHttpRequest("Uri: https://www.googleapis.com/oauth2/v3/token\n",
108                            kTokenJson, &post_body)});
109   FakeEnv env;
110   OAuthClient client(std::unique_ptr<HttpRequest::Factory>(
111                          new FakeHttpRequestFactory(&requests)),
112                      &env);
113   string token;
114   uint64 expiration_timestamp;
115   TF_EXPECT_OK(client.GetTokenFromServiceAccountJson(
116       json, "https://www.googleapis.com/oauth2/v3/token",
117       "https://test-token-scope.com", &token, &expiration_timestamp));
118   EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
119   EXPECT_EQ(13920, expiration_timestamp);
120 
121   // Now look at the JWT claim that was sent to the OAuth server.
122   StringPiece grant_type, assertion;
123   ASSERT_TRUE(strings::Scanner(post_body)
124                   .OneLiteral("grant_type=")
125                   .RestartCapture()
126                   .ScanEscapedUntil('&')
127                   .StopCapture()
128                   .OneLiteral("&assertion=")
129                   .GetResult(&assertion, &grant_type));
130   EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
131             grant_type);
132 
133   int last_dot = assertion.rfind('.');
134   string header_dot_claim(assertion.substr(0, last_dot));
135   string signature_encoded(assertion.substr(last_dot + 1));
136 
137   // Check that 'signature' signs 'header_dot_claim'.
138 
139   // Read the serialized public key.
140   std::ifstream public_key_stream(GetDataDependencyFilepath(
141       io::JoinPath(TestData(), "service_account_public_key.txt")));
142   string public_key_serialized(
143       (std::istreambuf_iterator<char>(public_key_stream)),
144       (std::istreambuf_iterator<char>()));
145 
146   // Deserialize the public key.
147   auto bio = BIO_new(BIO_s_mem());
148   RSA* public_key = nullptr;
149   EXPECT_EQ(public_key_serialized.size(),
150             BIO_puts(bio, public_key_serialized.c_str()));
151   public_key = PEM_read_bio_RSA_PUBKEY(bio, nullptr, nullptr, nullptr);
152   EXPECT_TRUE(public_key) << "Could not load the public key from testdata.";
153 
154   // Deserialize the signature.
155   string signature;
156   TF_EXPECT_OK(Base64Decode(signature_encoded, &signature));
157 
158   // Actually cryptographically verify the signature.
159   const auto md = EVP_sha256();
160   auto md_ctx = EVP_MD_CTX_create();
161   auto key = EVP_PKEY_new();
162   EVP_PKEY_set1_RSA(key, public_key);
163   ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key));
164   ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(),
165                                       header_dot_claim.size()));
166   ASSERT_EQ(1,
167             EVP_DigestVerifyFinal(
168                 md_ctx,
169                 const_cast<unsigned char*>(
170                     reinterpret_cast<const unsigned char*>(signature.data())),
171                 signature.size()));
172 
173   // Free all the crypto-related resources.
174   EVP_PKEY_free(key);
175   EVP_MD_CTX_destroy(md_ctx);
176   RSA_free(public_key);
177   BIO_free_all(bio);
178 
179   // Now check the content of the header and the claim.
180   int dot = header_dot_claim.find_last_of('.');
181   string header_encoded = header_dot_claim.substr(0, dot);
182   string claim_encoded = header_dot_claim.substr(dot + 1);
183 
184   string header, claim;
185   TF_EXPECT_OK(Base64Decode(header_encoded, &header));
186   TF_EXPECT_OK(Base64Decode(claim_encoded, &claim));
187 
188   Json::Value header_json, claim_json;
189   EXPECT_TRUE(reader.parse(header, header_json));
190   EXPECT_EQ("RS256", header_json.get("alg", Json::Value::null).asString());
191   EXPECT_EQ("JWT", header_json.get("typ", Json::Value::null).asString());
192   EXPECT_EQ("fake_key_id",
193             header_json.get("kid", Json::Value::null).asString());
194 
195   EXPECT_TRUE(reader.parse(claim, claim_json));
196   EXPECT_EQ("fake-test-project.iam.gserviceaccount.com",
197             claim_json.get("iss", Json::Value::null).asString());
198   EXPECT_EQ("https://test-token-scope.com",
199             claim_json.get("scope", Json::Value::null).asString());
200   EXPECT_EQ("https://www.googleapis.com/oauth2/v3/token",
201             claim_json.get("aud", Json::Value::null).asString());
202   EXPECT_EQ(10000, claim_json.get("iat", Json::Value::null).asInt64());
203   EXPECT_EQ(13600, claim_json.get("exp", Json::Value::null).asInt64());
204 }
205 }  // namespace tensorflow
206