1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/device_bound_sessions/session_service_impl.h"
6
7 #include "base/containers/to_vector.h"
8 #include "base/functional/bind.h"
9 #include "base/task/sequenced_task_runner.h"
10 #include "components/unexportable_keys/unexportable_key_service.h"
11 #include "net/base/schemeful_site.h"
12 #include "net/device_bound_sessions/registration_request_param.h"
13 #include "net/device_bound_sessions/session_store.h"
14 #include "net/url_request/url_request.h"
15 #include "net/url_request/url_request_context.h"
16
17 namespace net::device_bound_sessions {
18
19 namespace {
20
NotifySessionAccess(SessionService::OnAccessCallback callback,const SchemefulSite & site,const Session & session)21 void NotifySessionAccess(SessionService::OnAccessCallback callback,
22 const SchemefulSite& site,
23 const Session& session) {
24 if (callback.is_null()) {
25 return;
26 }
27
28 callback.Run({site, session.id()});
29 }
30
31 } // namespace
32
SessionServiceImpl(unexportable_keys::UnexportableKeyService & key_service,const URLRequestContext * request_context,SessionStore * store)33 SessionServiceImpl::SessionServiceImpl(
34 unexportable_keys::UnexportableKeyService& key_service,
35 const URLRequestContext* request_context,
36 SessionStore* store)
37 : key_service_(key_service),
38 context_(request_context),
39 session_store_(store) {
40 CHECK(context_);
41 }
42
43 SessionServiceImpl::~SessionServiceImpl() = default;
44
LoadSessionsAsync()45 void SessionServiceImpl::LoadSessionsAsync() {
46 if (!session_store_) {
47 return;
48 }
49 pending_initialization_ = true;
50 session_store_->LoadSessions(base::BindOnce(
51 &SessionServiceImpl::OnLoadSessionsComplete, weak_factory_.GetWeakPtr()));
52 }
53
RegisterBoundSession(OnAccessCallback on_access_callback,RegistrationFetcherParam registration_params,const IsolationInfo & isolation_info)54 void SessionServiceImpl::RegisterBoundSession(
55 OnAccessCallback on_access_callback,
56 RegistrationFetcherParam registration_params,
57 const IsolationInfo& isolation_info) {
58 RegistrationFetcher::StartCreateTokenAndFetch(
59 std::move(registration_params), key_service_.get(), context_.get(),
60 isolation_info,
61 base::BindOnce(&SessionServiceImpl::OnRegistrationComplete,
62 weak_factory_.GetWeakPtr(),
63 std::move(on_access_callback)));
64 }
65
OnLoadSessionsComplete(SessionStore::SessionsMap sessions)66 void SessionServiceImpl::OnLoadSessionsComplete(
67 SessionStore::SessionsMap sessions) {
68 unpartitioned_sessions_.merge(sessions);
69 pending_initialization_ = false;
70
71 std::vector<base::OnceClosure> queued_operations =
72 std::move(queued_operations_);
73 for (base::OnceClosure& closure : queued_operations) {
74 std::move(closure).Run();
75 }
76 }
77
OnRegistrationComplete(OnAccessCallback on_access_callback,std::optional<RegistrationFetcher::RegistrationCompleteParams> params)78 void SessionServiceImpl::OnRegistrationComplete(
79 OnAccessCallback on_access_callback,
80 std::optional<RegistrationFetcher::RegistrationCompleteParams> params) {
81 if (!params) {
82 return;
83 }
84
85 auto session = Session::CreateIfValid(std::move(params->params), params->url);
86 if (!session) {
87 return;
88 }
89 session->set_unexportable_key_id(std::move(params->key_id));
90
91 const SchemefulSite site(url::Origin::Create(params->url));
92 NotifySessionAccess(on_access_callback, site, *session);
93
94 // Clear the existing session which initiated the registration.
95 if (params->referral_session_identifier) {
96 DeleteSession(site,
97 Session::Id(std::move(*params->referral_session_identifier)));
98 }
99 AddSession(site, std::move(session));
100 }
101
102 std::pair<SessionServiceImpl::SessionsMap::iterator,
103 SessionServiceImpl::SessionsMap::iterator>
GetSessionsForSite(const SchemefulSite & site)104 SessionServiceImpl::GetSessionsForSite(const SchemefulSite& site) {
105 const auto now = base::Time::Now();
106 auto [begin, end] = unpartitioned_sessions_.equal_range(site);
107 for (auto it = begin; it != end;) {
108 if (now >= it->second->expiry_date()) {
109 it = DeleteSessionInternal(site, it);
110 } else {
111 it->second->RecordAccess();
112 it++;
113 }
114 }
115
116 return unpartitioned_sessions_.equal_range(site);
117 }
118
GetAnySessionRequiringDeferral(URLRequest * request)119 std::optional<Session::Id> SessionServiceImpl::GetAnySessionRequiringDeferral(
120 URLRequest* request) {
121 SchemefulSite site(request->url());
122 auto range = GetSessionsForSite(site);
123 for (auto it = range.first; it != range.second; ++it) {
124 if (it->second->ShouldDeferRequest(request)) {
125 NotifySessionAccess(request->device_bound_session_access_callback(), site,
126 *it->second);
127 return it->second->id();
128 }
129 }
130
131 return std::nullopt;
132 }
133
134 // TODO(kristianm): Actually send the refresh request, for now continue
135 // with sending the deferred request right away.
DeferRequestForRefresh(URLRequest * request,Session::Id session_id,RefreshCompleteCallback restart_callback,RefreshCompleteCallback continue_callback)136 void SessionServiceImpl::DeferRequestForRefresh(
137 URLRequest* request,
138 Session::Id session_id,
139 RefreshCompleteCallback restart_callback,
140 RefreshCompleteCallback continue_callback) {
141 CHECK(restart_callback);
142 CHECK(continue_callback);
143 std::move(continue_callback).Run();
144 }
145
SetChallengeForBoundSession(OnAccessCallback on_access_callback,const GURL & request_url,const SessionChallengeParam & param)146 void SessionServiceImpl::SetChallengeForBoundSession(
147 OnAccessCallback on_access_callback,
148 const GURL& request_url,
149 const SessionChallengeParam& param) {
150 if (!param.session_id()) {
151 return;
152 }
153
154 SchemefulSite site(request_url);
155 auto range = GetSessionsForSite(site);
156 for (auto it = range.first; it != range.second; ++it) {
157 if (it->second->id().value() == param.session_id()) {
158 NotifySessionAccess(on_access_callback, site, *it->second);
159 it->second->set_cached_challenge(param.challenge());
160 return;
161 }
162 }
163 }
164
GetAllSessionsAsync(base::OnceCallback<void (const std::vector<SessionKey> &)> callback)165 void SessionServiceImpl::GetAllSessionsAsync(
166 base::OnceCallback<void(const std::vector<SessionKey>&)> callback) {
167 if (pending_initialization_) {
168 queued_operations_.push_back(base::BindOnce(
169 &SessionServiceImpl::GetAllSessionsAsync,
170 // `base::Unretained` is safe because the callback is stored in
171 // `queued_operations_`, which is owned by `this`.
172 base::Unretained(this), std::move(callback)));
173 } else {
174 std::vector<SessionKey> sessions =
175 base::ToVector(unpartitioned_sessions_, [](const auto& pair) {
176 const auto& [site, session] = pair;
177 return SessionKey(site, session->id());
178 });
179 base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
180 FROM_HERE, base::BindOnce(std::move(callback), std::move(sessions)));
181 }
182 }
183
GetSessionForTesting(const SchemefulSite & site,const std::string & session_id) const184 Session* SessionServiceImpl::GetSessionForTesting(
185 const SchemefulSite& site,
186 const std::string& session_id) const {
187 // Intentionally do not use `GetSessionsForSite` here so we do not
188 // modify the session during testing.
189 auto range = unpartitioned_sessions_.equal_range(site);
190 for (auto it = range.first; it != range.second; ++it) {
191 if (it->second->id().value() == session_id) {
192 return it->second.get();
193 }
194 }
195
196 return nullptr;
197 }
198
AddSession(const SchemefulSite & site,std::unique_ptr<Session> session)199 void SessionServiceImpl::AddSession(const SchemefulSite& site,
200 std::unique_ptr<Session> session) {
201 if (session_store_) {
202 session_store_->SaveSession(site, *session);
203 }
204 // TODO(crbug.com/353774923): Enforce unique session ids per site.
205 unpartitioned_sessions_.emplace(site, std::move(session));
206 }
207
DeleteSession(const SchemefulSite & site,const Session::Id & id)208 void SessionServiceImpl::DeleteSession(const SchemefulSite& site,
209 const Session::Id& id) {
210 auto range = unpartitioned_sessions_.equal_range(site);
211 for (auto it = range.first; it != range.second; ++it) {
212 if (it->second->id() == id) {
213 std::ignore = DeleteSessionInternal(site, it);
214 return;
215 }
216 }
217 }
218
219 SessionServiceImpl::SessionsMap::iterator
DeleteSessionInternal(const SchemefulSite & site,SessionServiceImpl::SessionsMap::iterator it)220 SessionServiceImpl::DeleteSessionInternal(
221 const SchemefulSite& site,
222 SessionServiceImpl::SessionsMap::iterator it) {
223 if (session_store_) {
224 session_store_->DeleteSession(site, it->second->id());
225 }
226
227 // TODO(crbug.com/353774923): Clear BFCache entries for this session.
228 return unpartitioned_sessions_.erase(it);
229 }
230
StartSessionRefresh(const Session & session,const IsolationInfo & isolation_info,OnAccessCallback on_access_callback)231 void SessionServiceImpl::StartSessionRefresh(
232 const Session& session,
233 const IsolationInfo& isolation_info,
234 OnAccessCallback on_access_callback) {
235 const Session::KeyIdOrError& key_id = session.unexportable_key_id();
236 if (!key_id.has_value()) {
237 return;
238 }
239
240 auto request_params = RegistrationRequestParam::Create(session);
241 RegistrationFetcher::StartFetchWithExistingKey(
242 std::move(request_params), key_service_.get(), context_.get(),
243 isolation_info,
244 base::BindOnce(&SessionServiceImpl::OnRegistrationComplete,
245 weak_factory_.GetWeakPtr(), std::move(on_access_callback)),
246 *key_id);
247 }
248
249 } // namespace net::device_bound_sessions
250