1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h"
20
21 #include "src/core/lib/debug/trace.h"
22 #include "src/core/lib/event_engine/default_event_engine.h"
23 #include "src/core/lib/iomgr/pollset_set.h"
24 #include "src/core/lib/promise/context.h"
25 #include "src/core/lib/promise/poll.h"
26 #include "src/core/lib/promise/promise.h"
27
28 namespace grpc_core {
29
30 namespace {
31
32 // Amount of time before the token's expiration that we consider it
33 // invalid to account for server processing time and clock skew.
34 constexpr Duration kTokenExpirationAdjustmentDuration = Duration::Seconds(30);
35
36 // Amount of time before the token's expiration that we pre-fetch a new
37 // token. Also determines the timeout for the fetch request.
38 constexpr Duration kTokenRefreshDuration = Duration::Seconds(60);
39
40 } // namespace
41
42 //
43 // TokenFetcherCredentials::Token
44 //
45
Token(Slice token,Timestamp expiration)46 TokenFetcherCredentials::Token::Token(Slice token, Timestamp expiration)
47 : token_(std::move(token)),
48 expiration_(expiration - kTokenExpirationAdjustmentDuration) {}
49
AddTokenToClientInitialMetadata(ClientMetadata & metadata) const50 void TokenFetcherCredentials::Token::AddTokenToClientInitialMetadata(
51 ClientMetadata& metadata) const {
52 metadata.Append(GRPC_AUTHORIZATION_METADATA_KEY, token_.Ref(),
53 [](absl::string_view, const Slice&) { abort(); });
54 }
55
56 //
57 // TokenFetcherCredentials::FetchState::BackoffTimer
58 //
59
BackoffTimer(RefCountedPtr<FetchState> fetch_state,absl::Status status)60 TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer(
61 RefCountedPtr<FetchState> fetch_state, absl::Status status)
62 : fetch_state_(std::move(fetch_state)), status_(status) {
63 const Duration delay = fetch_state_->backoff_.NextAttemptDelay();
64 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
65 << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
66 << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
67 << ": starting backoff timer for " << delay;
68 timer_handle_ = fetch_state_->creds_->event_engine().RunAfter(
69 delay, [self = Ref()]() mutable {
70 ApplicationCallbackExecCtx callback_exec_ctx;
71 ExecCtx exec_ctx;
72 self->OnTimer();
73 self.reset();
74 });
75 }
76
Orphan()77 void TokenFetcherCredentials::FetchState::BackoffTimer::Orphan() {
78 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
79 << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
80 << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
81 << ": backoff timer shut down";
82 if (timer_handle_.has_value()) {
83 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
84 << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
85 << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
86 << ": cancelling timer";
87 fetch_state_->creds_->event_engine().Cancel(*timer_handle_);
88 timer_handle_.reset();
89 fetch_state_->ResumeQueuedCalls(
90 absl::CancelledError("credentials shutdown"));
91 }
92 Unref();
93 }
94
OnTimer()95 void TokenFetcherCredentials::FetchState::BackoffTimer::OnTimer() {
96 MutexLock lock(&fetch_state_->creds_->mu_);
97 if (!timer_handle_.has_value()) return;
98 timer_handle_.reset();
99 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
100 << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
101 << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
102 << ": backoff timer fired";
103 auto* self_ptr =
104 absl::get_if<OrphanablePtr<BackoffTimer>>(&fetch_state_->state_);
105 // This condition should always be true, but check to be defensive.
106 if (self_ptr != nullptr && self_ptr->get() == this) {
107 // Reset pointer in fetch_state_, so that subsequent RPCs know that
108 // we're no longer in backoff and they can trigger a new fetch.
109 self_ptr->reset();
110 }
111 }
112
113 //
114 // TokenFetcherCredentials::FetchState
115 //
116
FetchState(WeakRefCountedPtr<TokenFetcherCredentials> creds)117 TokenFetcherCredentials::FetchState::FetchState(
118 WeakRefCountedPtr<TokenFetcherCredentials> creds)
119 : creds_(std::move(creds)),
120 backoff_(BackOff::Options()
121 .set_initial_backoff(Duration::Seconds(1))
122 .set_multiplier(1.6)
123 .set_jitter(creds_->test_only_use_backoff_jitter_ ? 0.2 : 0)
124 .set_max_backoff(Duration::Seconds(120))) {
125 StartFetchAttempt();
126 }
127
Orphan()128 void TokenFetcherCredentials::FetchState::Orphan() {
129 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
130 << "[TokenFetcherCredentials " << creds_.get()
131 << "]: fetch_state=" << this << ": shutting down";
132 // Cancels fetch or backoff timer, if any.
133 state_ = Shutdown{};
134 Unref();
135 }
136
status() const137 absl::Status TokenFetcherCredentials::FetchState::status() const {
138 auto* backoff_ptr = absl::get_if<OrphanablePtr<BackoffTimer>>(&state_);
139 if (backoff_ptr == nullptr || *backoff_ptr == nullptr) {
140 return absl::OkStatus();
141 }
142 return (*backoff_ptr)->status();
143 }
144
StartFetchAttempt()145 void TokenFetcherCredentials::FetchState::StartFetchAttempt() {
146 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
147 << "[TokenFetcherCredentials " << creds_.get()
148 << "]: fetch_state=" << this << ": starting fetch";
149 state_ = creds_->FetchToken(
150 /*deadline=*/Timestamp::Now() + kTokenRefreshDuration,
151 [self = Ref()](absl::StatusOr<RefCountedPtr<Token>> token) mutable {
152 self->TokenFetchComplete(std::move(token));
153 });
154 }
155
TokenFetchComplete(absl::StatusOr<RefCountedPtr<Token>> token)156 void TokenFetcherCredentials::FetchState::TokenFetchComplete(
157 absl::StatusOr<RefCountedPtr<Token>> token) {
158 MutexLock lock(&creds_->mu_);
159 // If we were shut down, clean up.
160 if (absl::holds_alternative<Shutdown>(state_)) {
161 if (token.ok()) token = absl::CancelledError("credentials shutdown");
162 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
163 << "[TokenFetcherCredentials " << creds_.get()
164 << "]: fetch_state=" << this
165 << ": shut down before fetch completed: " << token.status();
166 ResumeQueuedCalls(std::move(token));
167 return;
168 }
169 // If succeeded, update cache in creds object.
170 if (token.ok()) {
171 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
172 << "[TokenFetcherCredentials " << creds_.get()
173 << "]: fetch_state=" << this << ": token fetch succeeded";
174 creds_->token_ = *token;
175 creds_->fetch_state_.reset(); // Orphan ourselves.
176 } else {
177 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
178 << "[TokenFetcherCredentials " << creds_.get()
179 << "]: fetch_state=" << this
180 << ": token fetch failed: " << token.status();
181 // If failed, start backoff timer.
182 state_ =
183 OrphanablePtr<BackoffTimer>(new BackoffTimer(Ref(), token.status()));
184 }
185 ResumeQueuedCalls(std::move(token));
186 }
187
ResumeQueuedCalls(absl::StatusOr<RefCountedPtr<Token>> token)188 void TokenFetcherCredentials::FetchState::ResumeQueuedCalls(
189 absl::StatusOr<RefCountedPtr<Token>> token) {
190 // Invoke callbacks for all pending requests.
191 for (auto& queued_call : queued_calls_) {
192 queued_call->result = token;
193 queued_call->done.store(true, std::memory_order_release);
194 queued_call->waker.Wakeup();
195 grpc_polling_entity_del_from_pollset_set(
196 queued_call->pollent,
197 grpc_polling_entity_pollset_set(&creds_->pollent_));
198 }
199 queued_calls_.clear();
200 }
201
202 RefCountedPtr<TokenFetcherCredentials::QueuedCall>
QueueCall(ClientMetadataHandle initial_metadata)203 TokenFetcherCredentials::FetchState::QueueCall(
204 ClientMetadataHandle initial_metadata) {
205 auto queued_call = MakeRefCounted<QueuedCall>();
206 queued_call->waker = GetContext<Activity>()->MakeNonOwningWaker();
207 queued_call->pollent = GetContext<grpc_polling_entity>();
208 grpc_polling_entity_add_to_pollset_set(
209 queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_));
210 queued_call->md = std::move(initial_metadata);
211 queued_calls_.insert(queued_call);
212 // If backoff has expired since the last attempt, trigger a new one.
213 auto* backoff_ptr = absl::get_if<OrphanablePtr<BackoffTimer>>(&state_);
214 if (backoff_ptr != nullptr && backoff_ptr->get() == nullptr) {
215 StartFetchAttempt();
216 }
217 return queued_call;
218 }
219
220 //
221 // TokenFetcherCredentials
222 //
223
TokenFetcherCredentials(std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine,bool test_only_use_backoff_jitter)224 TokenFetcherCredentials::TokenFetcherCredentials(
225 std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine,
226 bool test_only_use_backoff_jitter)
227 : event_engine_(
228 event_engine == nullptr
229 ? grpc_event_engine::experimental::GetDefaultEventEngine()
230 : std::move(event_engine)),
231 test_only_use_backoff_jitter_(test_only_use_backoff_jitter),
232 pollent_(grpc_polling_entity_create_from_pollset_set(
233 grpc_pollset_set_create())) {}
234
~TokenFetcherCredentials()235 TokenFetcherCredentials::~TokenFetcherCredentials() {
236 grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_));
237 }
238
Orphaned()239 void TokenFetcherCredentials::Orphaned() {
240 MutexLock lock(&mu_);
241 fetch_state_.reset();
242 }
243
244 ArenaPromise<absl::StatusOr<ClientMetadataHandle>>
GetRequestMetadata(ClientMetadataHandle initial_metadata,const GetRequestMetadataArgs *)245 TokenFetcherCredentials::GetRequestMetadata(
246 ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs*) {
247 RefCountedPtr<QueuedCall> queued_call;
248 {
249 MutexLock lock(&mu_);
250 // If we don't have a cached token or the token is within the
251 // refresh duration, start a new fetch if there isn't a pending one.
252 if ((token_ == nullptr || (token_->ExpirationTime() - Timestamp::Now()) <=
253 kTokenRefreshDuration) &&
254 fetch_state_ == nullptr) {
255 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
256 << "[TokenFetcherCredentials " << this
257 << "]: " << GetContext<Activity>()->DebugTag()
258 << " triggering new token fetch";
259 fetch_state_ = OrphanablePtr<FetchState>(
260 new FetchState(WeakRefAsSubclass<TokenFetcherCredentials>()));
261 }
262 // If we have a cached non-expired token, use it.
263 if (token_ != nullptr &&
264 (token_->ExpirationTime() - Timestamp::Now()) > Duration::Zero()) {
265 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
266 << "[TokenFetcherCredentials " << this
267 << "]: " << GetContext<Activity>()->DebugTag()
268 << " using cached token";
269 token_->AddTokenToClientInitialMetadata(*initial_metadata);
270 return Immediate(std::move(initial_metadata));
271 }
272 // If we're in backoff, fail the call.
273 if (fetch_state_ != nullptr) {
274 absl::Status status = fetch_state_->status();
275 if (!status.ok()) return Immediate(std::move(status));
276 }
277 // If we don't have a cached token, this call will need to be queued.
278 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
279 << "[TokenFetcherCredentials " << this
280 << "]: " << GetContext<Activity>()->DebugTag()
281 << " no cached token; queuing call";
282 queued_call = fetch_state_->QueueCall(std::move(initial_metadata));
283 }
284 return [this, queued_call = std::move(queued_call)]()
285 -> Poll<absl::StatusOr<ClientMetadataHandle>> {
286 if (!queued_call->done.load(std::memory_order_acquire)) {
287 return Pending{};
288 }
289 if (!queued_call->result.ok()) {
290 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
291 << "[TokenFetcherCredentials " << this
292 << "]: " << GetContext<Activity>()->DebugTag()
293 << " token fetch failed; failing call";
294 return queued_call->result.status();
295 }
296 GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
297 << "[TokenFetcherCredentials " << this
298 << "]: " << GetContext<Activity>()->DebugTag()
299 << " token fetch complete; resuming call";
300 (*queued_call->result)->AddTokenToClientInitialMetadata(*queued_call->md);
301 return std::move(queued_call->md);
302 };
303 }
304
305 } // namespace grpc_core
306