• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/device_bound_sessions/session_service_impl.h"
6 
7 #include "base/containers/to_vector.h"
8 #include "base/functional/bind.h"
9 #include "base/task/sequenced_task_runner.h"
10 #include "components/unexportable_keys/unexportable_key_service.h"
11 #include "net/base/schemeful_site.h"
12 #include "net/device_bound_sessions/registration_request_param.h"
13 #include "net/device_bound_sessions/session_store.h"
14 #include "net/url_request/url_request.h"
15 #include "net/url_request/url_request_context.h"
16 
17 namespace net::device_bound_sessions {
18 
19 namespace {
20 
NotifySessionAccess(SessionService::OnAccessCallback callback,const SchemefulSite & site,const Session & session)21 void NotifySessionAccess(SessionService::OnAccessCallback callback,
22                          const SchemefulSite& site,
23                          const Session& session) {
24   if (callback.is_null()) {
25     return;
26   }
27 
28   callback.Run({site, session.id()});
29 }
30 
31 }  // namespace
32 
SessionServiceImpl(unexportable_keys::UnexportableKeyService & key_service,const URLRequestContext * request_context,SessionStore * store)33 SessionServiceImpl::SessionServiceImpl(
34     unexportable_keys::UnexportableKeyService& key_service,
35     const URLRequestContext* request_context,
36     SessionStore* store)
37     : key_service_(key_service),
38       context_(request_context),
39       session_store_(store) {
40   CHECK(context_);
41 }
42 
43 SessionServiceImpl::~SessionServiceImpl() = default;
44 
LoadSessionsAsync()45 void SessionServiceImpl::LoadSessionsAsync() {
46   if (!session_store_) {
47     return;
48   }
49   pending_initialization_ = true;
50   session_store_->LoadSessions(base::BindOnce(
51       &SessionServiceImpl::OnLoadSessionsComplete, weak_factory_.GetWeakPtr()));
52 }
53 
RegisterBoundSession(OnAccessCallback on_access_callback,RegistrationFetcherParam registration_params,const IsolationInfo & isolation_info)54 void SessionServiceImpl::RegisterBoundSession(
55     OnAccessCallback on_access_callback,
56     RegistrationFetcherParam registration_params,
57     const IsolationInfo& isolation_info) {
58   RegistrationFetcher::StartCreateTokenAndFetch(
59       std::move(registration_params), key_service_.get(), context_.get(),
60       isolation_info,
61       base::BindOnce(&SessionServiceImpl::OnRegistrationComplete,
62                      weak_factory_.GetWeakPtr(),
63                      std::move(on_access_callback)));
64 }
65 
OnLoadSessionsComplete(SessionStore::SessionsMap sessions)66 void SessionServiceImpl::OnLoadSessionsComplete(
67     SessionStore::SessionsMap sessions) {
68   unpartitioned_sessions_.merge(sessions);
69   pending_initialization_ = false;
70 
71   std::vector<base::OnceClosure> queued_operations =
72       std::move(queued_operations_);
73   for (base::OnceClosure& closure : queued_operations) {
74     std::move(closure).Run();
75   }
76 }
77 
OnRegistrationComplete(OnAccessCallback on_access_callback,std::optional<RegistrationFetcher::RegistrationCompleteParams> params)78 void SessionServiceImpl::OnRegistrationComplete(
79     OnAccessCallback on_access_callback,
80     std::optional<RegistrationFetcher::RegistrationCompleteParams> params) {
81   if (!params) {
82     return;
83   }
84 
85   auto session = Session::CreateIfValid(std::move(params->params), params->url);
86   if (!session) {
87     return;
88   }
89   session->set_unexportable_key_id(std::move(params->key_id));
90 
91   const SchemefulSite site(url::Origin::Create(params->url));
92   NotifySessionAccess(on_access_callback, site, *session);
93 
94   // Clear the existing session which initiated the registration.
95   if (params->referral_session_identifier) {
96     DeleteSession(site,
97                   Session::Id(std::move(*params->referral_session_identifier)));
98   }
99   AddSession(site, std::move(session));
100 }
101 
102 std::pair<SessionServiceImpl::SessionsMap::iterator,
103           SessionServiceImpl::SessionsMap::iterator>
GetSessionsForSite(const SchemefulSite & site)104 SessionServiceImpl::GetSessionsForSite(const SchemefulSite& site) {
105   const auto now = base::Time::Now();
106   auto [begin, end] = unpartitioned_sessions_.equal_range(site);
107   for (auto it = begin; it != end;) {
108     if (now >= it->second->expiry_date()) {
109       it = DeleteSessionInternal(site, it);
110     } else {
111       it->second->RecordAccess();
112       it++;
113     }
114   }
115 
116   return unpartitioned_sessions_.equal_range(site);
117 }
118 
GetAnySessionRequiringDeferral(URLRequest * request)119 std::optional<Session::Id> SessionServiceImpl::GetAnySessionRequiringDeferral(
120     URLRequest* request) {
121   SchemefulSite site(request->url());
122   auto range = GetSessionsForSite(site);
123   for (auto it = range.first; it != range.second; ++it) {
124     if (it->second->ShouldDeferRequest(request)) {
125       NotifySessionAccess(request->device_bound_session_access_callback(), site,
126                           *it->second);
127       return it->second->id();
128     }
129   }
130 
131   return std::nullopt;
132 }
133 
134 // TODO(kristianm): Actually send the refresh request, for now continue
135 // with sending the deferred request right away.
DeferRequestForRefresh(URLRequest * request,Session::Id session_id,RefreshCompleteCallback restart_callback,RefreshCompleteCallback continue_callback)136 void SessionServiceImpl::DeferRequestForRefresh(
137     URLRequest* request,
138     Session::Id session_id,
139     RefreshCompleteCallback restart_callback,
140     RefreshCompleteCallback continue_callback) {
141   CHECK(restart_callback);
142   CHECK(continue_callback);
143   std::move(continue_callback).Run();
144 }
145 
SetChallengeForBoundSession(OnAccessCallback on_access_callback,const GURL & request_url,const SessionChallengeParam & param)146 void SessionServiceImpl::SetChallengeForBoundSession(
147     OnAccessCallback on_access_callback,
148     const GURL& request_url,
149     const SessionChallengeParam& param) {
150   if (!param.session_id()) {
151     return;
152   }
153 
154   SchemefulSite site(request_url);
155   auto range = GetSessionsForSite(site);
156   for (auto it = range.first; it != range.second; ++it) {
157     if (it->second->id().value() == param.session_id()) {
158       NotifySessionAccess(on_access_callback, site, *it->second);
159       it->second->set_cached_challenge(param.challenge());
160       return;
161     }
162   }
163 }
164 
GetAllSessionsAsync(base::OnceCallback<void (const std::vector<SessionKey> &)> callback)165 void SessionServiceImpl::GetAllSessionsAsync(
166     base::OnceCallback<void(const std::vector<SessionKey>&)> callback) {
167   if (pending_initialization_) {
168     queued_operations_.push_back(base::BindOnce(
169         &SessionServiceImpl::GetAllSessionsAsync,
170         // `base::Unretained` is safe because the callback is stored in
171         // `queued_operations_`, which is owned by `this`.
172         base::Unretained(this), std::move(callback)));
173   } else {
174     std::vector<SessionKey> sessions =
175         base::ToVector(unpartitioned_sessions_, [](const auto& pair) {
176           const auto& [site, session] = pair;
177           return SessionKey(site, session->id());
178         });
179     base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
180         FROM_HERE, base::BindOnce(std::move(callback), std::move(sessions)));
181   }
182 }
183 
GetSessionForTesting(const SchemefulSite & site,const std::string & session_id) const184 Session* SessionServiceImpl::GetSessionForTesting(
185     const SchemefulSite& site,
186     const std::string& session_id) const {
187   // Intentionally do not use `GetSessionsForSite` here so we do not
188   // modify the session during testing.
189   auto range = unpartitioned_sessions_.equal_range(site);
190   for (auto it = range.first; it != range.second; ++it) {
191     if (it->second->id().value() == session_id) {
192       return it->second.get();
193     }
194   }
195 
196   return nullptr;
197 }
198 
AddSession(const SchemefulSite & site,std::unique_ptr<Session> session)199 void SessionServiceImpl::AddSession(const SchemefulSite& site,
200                                     std::unique_ptr<Session> session) {
201   if (session_store_) {
202     session_store_->SaveSession(site, *session);
203   }
204   // TODO(crbug.com/353774923): Enforce unique session ids per site.
205   unpartitioned_sessions_.emplace(site, std::move(session));
206 }
207 
DeleteSession(const SchemefulSite & site,const Session::Id & id)208 void SessionServiceImpl::DeleteSession(const SchemefulSite& site,
209                                        const Session::Id& id) {
210   auto range = unpartitioned_sessions_.equal_range(site);
211   for (auto it = range.first; it != range.second; ++it) {
212     if (it->second->id() == id) {
213       std::ignore = DeleteSessionInternal(site, it);
214       return;
215     }
216   }
217 }
218 
219 SessionServiceImpl::SessionsMap::iterator
DeleteSessionInternal(const SchemefulSite & site,SessionServiceImpl::SessionsMap::iterator it)220 SessionServiceImpl::DeleteSessionInternal(
221     const SchemefulSite& site,
222     SessionServiceImpl::SessionsMap::iterator it) {
223   if (session_store_) {
224     session_store_->DeleteSession(site, it->second->id());
225   }
226 
227   // TODO(crbug.com/353774923): Clear BFCache entries for this session.
228   return unpartitioned_sessions_.erase(it);
229 }
230 
StartSessionRefresh(const Session & session,const IsolationInfo & isolation_info,OnAccessCallback on_access_callback)231 void SessionServiceImpl::StartSessionRefresh(
232     const Session& session,
233     const IsolationInfo& isolation_info,
234     OnAccessCallback on_access_callback) {
235   const Session::KeyIdOrError& key_id = session.unexportable_key_id();
236   if (!key_id.has_value()) {
237     return;
238   }
239 
240   auto request_params = RegistrationRequestParam::Create(session);
241   RegistrationFetcher::StartFetchWithExistingKey(
242       std::move(request_params), key_service_.get(), context_.get(),
243       isolation_info,
244       base::BindOnce(&SessionServiceImpl::OnRegistrationComplete,
245                      weak_factory_.GetWeakPtr(), std::move(on_access_callback)),
246       *key_id);
247 }
248 
249 }  // namespace net::device_bound_sessions
250