• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h"
20 
21 #include "src/core/lib/debug/trace.h"
22 #include "src/core/lib/event_engine/default_event_engine.h"
23 #include "src/core/lib/iomgr/pollset_set.h"
24 #include "src/core/lib/promise/context.h"
25 #include "src/core/lib/promise/poll.h"
26 #include "src/core/lib/promise/promise.h"
27 
28 namespace grpc_core {
29 
30 namespace {
31 
32 // Amount of time before the token's expiration that we consider it
33 // invalid to account for server processing time and clock skew.
34 constexpr Duration kTokenExpirationAdjustmentDuration = Duration::Seconds(30);
35 
36 // Amount of time before the token's expiration that we pre-fetch a new
37 // token.  Also determines the timeout for the fetch request.
38 constexpr Duration kTokenRefreshDuration = Duration::Seconds(60);
39 
40 }  // namespace
41 
42 //
43 // TokenFetcherCredentials::Token
44 //
45 
Token(Slice token,Timestamp expiration)46 TokenFetcherCredentials::Token::Token(Slice token, Timestamp expiration)
47     : token_(std::move(token)),
48       expiration_(expiration - kTokenExpirationAdjustmentDuration) {}
49 
AddTokenToClientInitialMetadata(ClientMetadata & metadata) const50 void TokenFetcherCredentials::Token::AddTokenToClientInitialMetadata(
51     ClientMetadata& metadata) const {
52   metadata.Append(GRPC_AUTHORIZATION_METADATA_KEY, token_.Ref(),
53                   [](absl::string_view, const Slice&) { abort(); });
54 }
55 
56 //
57 // TokenFetcherCredentials::FetchState::BackoffTimer
58 //
59 
BackoffTimer(RefCountedPtr<FetchState> fetch_state,absl::Status status)60 TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer(
61     RefCountedPtr<FetchState> fetch_state, absl::Status status)
62     : fetch_state_(std::move(fetch_state)), status_(status) {
63   const Duration delay = fetch_state_->backoff_.NextAttemptDelay();
64   GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
65       << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
66       << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
67       << ": starting backoff timer for " << delay;
68   timer_handle_ = fetch_state_->creds_->event_engine().RunAfter(
69       delay, [self = Ref()]() mutable {
70         ApplicationCallbackExecCtx callback_exec_ctx;
71         ExecCtx exec_ctx;
72         self->OnTimer();
73         self.reset();
74       });
75 }
76 
Orphan()77 void TokenFetcherCredentials::FetchState::BackoffTimer::Orphan() {
78   GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
79       << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
80       << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
81       << ": backoff timer shut down";
82   if (timer_handle_.has_value()) {
83     GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
84         << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
85         << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
86         << ": cancelling timer";
87     fetch_state_->creds_->event_engine().Cancel(*timer_handle_);
88     timer_handle_.reset();
89     fetch_state_->ResumeQueuedCalls(
90         absl::CancelledError("credentials shutdown"));
91   }
92   Unref();
93 }
94 
OnTimer()95 void TokenFetcherCredentials::FetchState::BackoffTimer::OnTimer() {
96   MutexLock lock(&fetch_state_->creds_->mu_);
97   if (!timer_handle_.has_value()) return;
98   timer_handle_.reset();
99   GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
100       << "[TokenFetcherCredentials " << fetch_state_->creds_.get()
101       << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this
102       << ": backoff timer fired";
103   auto* self_ptr =
104       absl::get_if<OrphanablePtr<BackoffTimer>>(&fetch_state_->state_);
105   // This condition should always be true, but check to be defensive.
106   if (self_ptr != nullptr && self_ptr->get() == this) {
107     // Reset pointer in fetch_state_, so that subsequent RPCs know that
108     // we're no longer in backoff and they can trigger a new fetch.
109     self_ptr->reset();
110   }
111 }
112 
113 //
114 // TokenFetcherCredentials::FetchState
115 //
116 
FetchState(WeakRefCountedPtr<TokenFetcherCredentials> creds)117 TokenFetcherCredentials::FetchState::FetchState(
118     WeakRefCountedPtr<TokenFetcherCredentials> creds)
119     : creds_(std::move(creds)),
120       backoff_(BackOff::Options()
121                    .set_initial_backoff(Duration::Seconds(1))
122                    .set_multiplier(1.6)
123                    .set_jitter(creds_->test_only_use_backoff_jitter_ ? 0.2 : 0)
124                    .set_max_backoff(Duration::Seconds(120))) {
125   StartFetchAttempt();
126 }
127 
Orphan()128 void TokenFetcherCredentials::FetchState::Orphan() {
129   GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
130       << "[TokenFetcherCredentials " << creds_.get()
131       << "]: fetch_state=" << this << ": shutting down";
132   // Cancels fetch or backoff timer, if any.
133   state_ = Shutdown{};
134   Unref();
135 }
136 
status() const137 absl::Status TokenFetcherCredentials::FetchState::status() const {
138   auto* backoff_ptr = absl::get_if<OrphanablePtr<BackoffTimer>>(&state_);
139   if (backoff_ptr == nullptr || *backoff_ptr == nullptr) {
140     return absl::OkStatus();
141   }
142   return (*backoff_ptr)->status();
143 }
144 
StartFetchAttempt()145 void TokenFetcherCredentials::FetchState::StartFetchAttempt() {
146   GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
147       << "[TokenFetcherCredentials " << creds_.get()
148       << "]: fetch_state=" << this << ": starting fetch";
149   state_ = creds_->FetchToken(
150       /*deadline=*/Timestamp::Now() + kTokenRefreshDuration,
151       [self = Ref()](absl::StatusOr<RefCountedPtr<Token>> token) mutable {
152         self->TokenFetchComplete(std::move(token));
153       });
154 }
155 
TokenFetchComplete(absl::StatusOr<RefCountedPtr<Token>> token)156 void TokenFetcherCredentials::FetchState::TokenFetchComplete(
157     absl::StatusOr<RefCountedPtr<Token>> token) {
158   MutexLock lock(&creds_->mu_);
159   // If we were shut down, clean up.
160   if (absl::holds_alternative<Shutdown>(state_)) {
161     if (token.ok()) token = absl::CancelledError("credentials shutdown");
162     GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
163         << "[TokenFetcherCredentials " << creds_.get()
164         << "]: fetch_state=" << this
165         << ": shut down before fetch completed: " << token.status();
166     ResumeQueuedCalls(std::move(token));
167     return;
168   }
169   // If succeeded, update cache in creds object.
170   if (token.ok()) {
171     GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
172         << "[TokenFetcherCredentials " << creds_.get()
173         << "]: fetch_state=" << this << ": token fetch succeeded";
174     creds_->token_ = *token;
175     creds_->fetch_state_.reset();  // Orphan ourselves.
176   } else {
177     GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
178         << "[TokenFetcherCredentials " << creds_.get()
179         << "]: fetch_state=" << this
180         << ": token fetch failed: " << token.status();
181     // If failed, start backoff timer.
182     state_ =
183         OrphanablePtr<BackoffTimer>(new BackoffTimer(Ref(), token.status()));
184   }
185   ResumeQueuedCalls(std::move(token));
186 }
187 
ResumeQueuedCalls(absl::StatusOr<RefCountedPtr<Token>> token)188 void TokenFetcherCredentials::FetchState::ResumeQueuedCalls(
189     absl::StatusOr<RefCountedPtr<Token>> token) {
190   // Invoke callbacks for all pending requests.
191   for (auto& queued_call : queued_calls_) {
192     queued_call->result = token;
193     queued_call->done.store(true, std::memory_order_release);
194     queued_call->waker.Wakeup();
195     grpc_polling_entity_del_from_pollset_set(
196         queued_call->pollent,
197         grpc_polling_entity_pollset_set(&creds_->pollent_));
198   }
199   queued_calls_.clear();
200 }
201 
202 RefCountedPtr<TokenFetcherCredentials::QueuedCall>
QueueCall(ClientMetadataHandle initial_metadata)203 TokenFetcherCredentials::FetchState::QueueCall(
204     ClientMetadataHandle initial_metadata) {
205   auto queued_call = MakeRefCounted<QueuedCall>();
206   queued_call->waker = GetContext<Activity>()->MakeNonOwningWaker();
207   queued_call->pollent = GetContext<grpc_polling_entity>();
208   grpc_polling_entity_add_to_pollset_set(
209       queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_));
210   queued_call->md = std::move(initial_metadata);
211   queued_calls_.insert(queued_call);
212   // If backoff has expired since the last attempt, trigger a new one.
213   auto* backoff_ptr = absl::get_if<OrphanablePtr<BackoffTimer>>(&state_);
214   if (backoff_ptr != nullptr && backoff_ptr->get() == nullptr) {
215     StartFetchAttempt();
216   }
217   return queued_call;
218 }
219 
220 //
221 // TokenFetcherCredentials
222 //
223 
TokenFetcherCredentials(std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine,bool test_only_use_backoff_jitter)224 TokenFetcherCredentials::TokenFetcherCredentials(
225     std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine,
226     bool test_only_use_backoff_jitter)
227     : event_engine_(
228           event_engine == nullptr
229               ? grpc_event_engine::experimental::GetDefaultEventEngine()
230               : std::move(event_engine)),
231       test_only_use_backoff_jitter_(test_only_use_backoff_jitter),
232       pollent_(grpc_polling_entity_create_from_pollset_set(
233           grpc_pollset_set_create())) {}
234 
~TokenFetcherCredentials()235 TokenFetcherCredentials::~TokenFetcherCredentials() {
236   grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_));
237 }
238 
Orphaned()239 void TokenFetcherCredentials::Orphaned() {
240   MutexLock lock(&mu_);
241   fetch_state_.reset();
242 }
243 
244 ArenaPromise<absl::StatusOr<ClientMetadataHandle>>
GetRequestMetadata(ClientMetadataHandle initial_metadata,const GetRequestMetadataArgs *)245 TokenFetcherCredentials::GetRequestMetadata(
246     ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs*) {
247   RefCountedPtr<QueuedCall> queued_call;
248   {
249     MutexLock lock(&mu_);
250     // If we don't have a cached token or the token is within the
251     // refresh duration, start a new fetch if there isn't a pending one.
252     if ((token_ == nullptr || (token_->ExpirationTime() - Timestamp::Now()) <=
253                                   kTokenRefreshDuration) &&
254         fetch_state_ == nullptr) {
255       GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
256           << "[TokenFetcherCredentials " << this
257           << "]: " << GetContext<Activity>()->DebugTag()
258           << " triggering new token fetch";
259       fetch_state_ = OrphanablePtr<FetchState>(
260           new FetchState(WeakRefAsSubclass<TokenFetcherCredentials>()));
261     }
262     // If we have a cached non-expired token, use it.
263     if (token_ != nullptr &&
264         (token_->ExpirationTime() - Timestamp::Now()) > Duration::Zero()) {
265       GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
266           << "[TokenFetcherCredentials " << this
267           << "]: " << GetContext<Activity>()->DebugTag()
268           << " using cached token";
269       token_->AddTokenToClientInitialMetadata(*initial_metadata);
270       return Immediate(std::move(initial_metadata));
271     }
272     // If we're in backoff, fail the call.
273     if (fetch_state_ != nullptr) {
274       absl::Status status = fetch_state_->status();
275       if (!status.ok()) return Immediate(std::move(status));
276     }
277     // If we don't have a cached token, this call will need to be queued.
278     GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
279         << "[TokenFetcherCredentials " << this
280         << "]: " << GetContext<Activity>()->DebugTag()
281         << " no cached token; queuing call";
282     queued_call = fetch_state_->QueueCall(std::move(initial_metadata));
283   }
284   return [this, queued_call = std::move(queued_call)]()
285              -> Poll<absl::StatusOr<ClientMetadataHandle>> {
286     if (!queued_call->done.load(std::memory_order_acquire)) {
287       return Pending{};
288     }
289     if (!queued_call->result.ok()) {
290       GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
291           << "[TokenFetcherCredentials " << this
292           << "]: " << GetContext<Activity>()->DebugTag()
293           << " token fetch failed; failing call";
294       return queued_call->result.status();
295     }
296     GRPC_TRACE_LOG(token_fetcher_credentials, INFO)
297         << "[TokenFetcherCredentials " << this
298         << "]: " << GetContext<Activity>()->DebugTag()
299         << " token fetch complete; resuming call";
300     (*queued_call->result)->AddTokenToClientInitialMetadata(*queued_call->md);
301     return std::move(queued_call->md);
302   };
303 }
304 
305 }  // namespace grpc_core
306