1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/cloud/oauth_client.h"
17
18 #include <fstream>
19
20 #include <openssl/bio.h>
21 #include <openssl/evp.h>
22 #include <openssl/pem.h>
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/base64.h"
25 #include "tensorflow/core/platform/cloud/http_request_fake.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/path.h"
28 #include "tensorflow/core/platform/resource_loader.h"
29 #include "tensorflow/core/platform/scanner.h"
30 #include "tensorflow/core/platform/test.h"
31
32 namespace tensorflow {
33 namespace {
34
TestData()35 string TestData() {
36 return io::JoinPath("tensorflow", "core", "platform", "cloud", "testdata");
37 }
38
39 constexpr char kTokenJson[] = R"(
40 {
41 "access_token":"WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY",
42 "expires_in":3920,
43 "token_type":"Bearer"
44 })";
45
46 class FakeEnv : public EnvWrapper {
47 public:
FakeEnv()48 FakeEnv() : EnvWrapper(Env::Default()) {}
49
NowSeconds() const50 uint64 NowSeconds() const override { return now; }
51 uint64 now = 10000;
52 };
53
54 } // namespace
55
TEST(OAuthClientTest,ParseOAuthResponse)56 TEST(OAuthClientTest, ParseOAuthResponse) {
57 const uint64 request_timestamp = 100;
58 string token;
59 uint64 expiration_timestamp;
60 TF_EXPECT_OK(OAuthClient().ParseOAuthResponse(kTokenJson, request_timestamp,
61 &token, &expiration_timestamp));
62 EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
63 EXPECT_EQ(4020, expiration_timestamp);
64 }
65
TEST(OAuthClientTest,GetTokenFromRefreshTokenJson)66 TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) {
67 const string credentials_json = R"(
68 {
69 "client_id": "test_client_id",
70 "client_secret": "@@@test_client_secret@@@",
71 "refresh_token": "test_refresh_token",
72 "type": "authorized_user"
73 })";
74 Json::Value json;
75 Json::Reader reader;
76 ASSERT_TRUE(reader.parse(credentials_json, json));
77
78 std::vector<HttpRequest*> requests({new FakeHttpRequest(
79 "Uri: https://www.googleapis.com/oauth2/v3/token\n"
80 "Post body: client_id=test_client_id&"
81 "client_secret=@@@test_client_secret@@@&"
82 "refresh_token=test_refresh_token&grant_type=refresh_token\n",
83 kTokenJson)});
84 FakeEnv env;
85 OAuthClient client(std::unique_ptr<HttpRequest::Factory>(
86 new FakeHttpRequestFactory(&requests)),
87 &env);
88 string token;
89 uint64 expiration_timestamp;
90 TF_EXPECT_OK(client.GetTokenFromRefreshTokenJson(
91 json, "https://www.googleapis.com/oauth2/v3/token", &token,
92 &expiration_timestamp));
93 EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
94 EXPECT_EQ(13920, expiration_timestamp);
95 }
96
TEST(OAuthClientTest,GetTokenFromServiceAccountJson)97 TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
98 std::ifstream credentials(GetDataDependencyFilepath(
99 io::JoinPath(TestData(), "service_account_credentials.json")));
100 ASSERT_TRUE(credentials.is_open());
101 Json::Value json;
102 Json::Reader reader;
103 ASSERT_TRUE(reader.parse(credentials, json));
104
105 string post_body;
106 std::vector<HttpRequest*> requests(
107 {new FakeHttpRequest("Uri: https://www.googleapis.com/oauth2/v3/token\n",
108 kTokenJson, &post_body)});
109 FakeEnv env;
110 OAuthClient client(std::unique_ptr<HttpRequest::Factory>(
111 new FakeHttpRequestFactory(&requests)),
112 &env);
113 string token;
114 uint64 expiration_timestamp;
115 TF_EXPECT_OK(client.GetTokenFromServiceAccountJson(
116 json, "https://www.googleapis.com/oauth2/v3/token",
117 "https://test-token-scope.com", &token, &expiration_timestamp));
118 EXPECT_EQ("WITH_FAKE_ACCESS_TOKEN_TEST_SHOULD_BE_HAPPY", token);
119 EXPECT_EQ(13920, expiration_timestamp);
120
121 // Now look at the JWT claim that was sent to the OAuth server.
122 StringPiece grant_type, assertion;
123 ASSERT_TRUE(strings::Scanner(post_body)
124 .OneLiteral("grant_type=")
125 .RestartCapture()
126 .ScanEscapedUntil('&')
127 .StopCapture()
128 .OneLiteral("&assertion=")
129 .GetResult(&assertion, &grant_type));
130 EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
131 grant_type);
132
133 int last_dot = assertion.rfind('.');
134 string header_dot_claim(assertion.substr(0, last_dot));
135 string signature_encoded(assertion.substr(last_dot + 1));
136
137 // Check that 'signature' signs 'header_dot_claim'.
138
139 // Read the serialized public key.
140 std::ifstream public_key_stream(GetDataDependencyFilepath(
141 io::JoinPath(TestData(), "service_account_public_key.txt")));
142 string public_key_serialized(
143 (std::istreambuf_iterator<char>(public_key_stream)),
144 (std::istreambuf_iterator<char>()));
145
146 // Deserialize the public key.
147 auto bio = BIO_new(BIO_s_mem());
148 RSA* public_key = nullptr;
149 EXPECT_EQ(public_key_serialized.size(),
150 BIO_puts(bio, public_key_serialized.c_str()));
151 public_key = PEM_read_bio_RSA_PUBKEY(bio, nullptr, nullptr, nullptr);
152 EXPECT_TRUE(public_key) << "Could not load the public key from testdata.";
153
154 // Deserialize the signature.
155 string signature;
156 TF_EXPECT_OK(Base64Decode(signature_encoded, &signature));
157
158 // Actually cryptographically verify the signature.
159 const auto md = EVP_sha256();
160 auto md_ctx = EVP_MD_CTX_create();
161 auto key = EVP_PKEY_new();
162 EVP_PKEY_set1_RSA(key, public_key);
163 ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key));
164 ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(),
165 header_dot_claim.size()));
166 ASSERT_EQ(1,
167 EVP_DigestVerifyFinal(
168 md_ctx,
169 const_cast<unsigned char*>(
170 reinterpret_cast<const unsigned char*>(signature.data())),
171 signature.size()));
172
173 // Free all the crypto-related resources.
174 EVP_PKEY_free(key);
175 EVP_MD_CTX_destroy(md_ctx);
176 RSA_free(public_key);
177 BIO_free_all(bio);
178
179 // Now check the content of the header and the claim.
180 int dot = header_dot_claim.find_last_of('.');
181 string header_encoded = header_dot_claim.substr(0, dot);
182 string claim_encoded = header_dot_claim.substr(dot + 1);
183
184 string header, claim;
185 TF_EXPECT_OK(Base64Decode(header_encoded, &header));
186 TF_EXPECT_OK(Base64Decode(claim_encoded, &claim));
187
188 Json::Value header_json, claim_json;
189 EXPECT_TRUE(reader.parse(header, header_json));
190 EXPECT_EQ("RS256", header_json.get("alg", Json::Value::null).asString());
191 EXPECT_EQ("JWT", header_json.get("typ", Json::Value::null).asString());
192 EXPECT_EQ("fake_key_id",
193 header_json.get("kid", Json::Value::null).asString());
194
195 EXPECT_TRUE(reader.parse(claim, claim_json));
196 EXPECT_EQ("fake-test-project.iam.gserviceaccount.com",
197 claim_json.get("iss", Json::Value::null).asString());
198 EXPECT_EQ("https://test-token-scope.com",
199 claim_json.get("scope", Json::Value::null).asString());
200 EXPECT_EQ("https://www.googleapis.com/oauth2/v3/token",
201 claim_json.get("aud", Json::Value::null).asString());
202 EXPECT_EQ(10000, claim_json.get("iat", Json::Value::null).asInt64());
203 EXPECT_EQ(13600, claim_json.get("exp", Json::Value::null).asInt64());
204 }
205 } // namespace tensorflow
206