1 // Copyright 2015 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/ssl/ssl_client_session_cache.h"
11
12 #include <tuple>
13 #include <utility>
14
15 #include "base/containers/flat_set.h"
16 #include "base/time/clock.h"
17 #include "base/time/default_clock.h"
18 #include "third_party/boringssl/src/include/openssl/ssl.h"
19
20 namespace net {
21
22 namespace {
23
24 // Returns a tuple of references to fields of |key|, for comparison purposes.
TieKeyFields(const SSLClientSessionCache::Key & key)25 auto TieKeyFields(const SSLClientSessionCache::Key& key) {
26 return std::tie(key.server, key.dest_ip_addr, key.network_anonymization_key,
27 key.privacy_mode);
28 }
29
30 } // namespace
31
32 SSLClientSessionCache::Key::Key() = default;
33 SSLClientSessionCache::Key::Key(const Key& other) = default;
34 SSLClientSessionCache::Key::Key(Key&& other) = default;
35 SSLClientSessionCache::Key::~Key() = default;
36 SSLClientSessionCache::Key& SSLClientSessionCache::Key::operator=(
37 const Key& other) = default;
38 SSLClientSessionCache::Key& SSLClientSessionCache::Key::operator=(Key&& other) =
39 default;
40
operator ==(const Key & other) const41 bool SSLClientSessionCache::Key::operator==(const Key& other) const {
42 return TieKeyFields(*this) == TieKeyFields(other);
43 }
44
operator <(const Key & other) const45 bool SSLClientSessionCache::Key::operator<(const Key& other) const {
46 return TieKeyFields(*this) < TieKeyFields(other);
47 }
48
SSLClientSessionCache(const Config & config)49 SSLClientSessionCache::SSLClientSessionCache(const Config& config)
50 : clock_(base::DefaultClock::GetInstance()),
51 config_(config),
52 cache_(config.max_entries) {
53 memory_pressure_listener_ = std::make_unique<base::MemoryPressureListener>(
54 FROM_HERE, base::BindRepeating(&SSLClientSessionCache::OnMemoryPressure,
55 base::Unretained(this)));
56 }
57
~SSLClientSessionCache()58 SSLClientSessionCache::~SSLClientSessionCache() {
59 Flush();
60 }
61
size() const62 size_t SSLClientSessionCache::size() const {
63 return cache_.size();
64 }
65
Lookup(const Key & cache_key)66 bssl::UniquePtr<SSL_SESSION> SSLClientSessionCache::Lookup(
67 const Key& cache_key) {
68 // Expire stale sessions.
69 lookups_since_flush_++;
70 if (lookups_since_flush_ >= config_.expiration_check_count) {
71 lookups_since_flush_ = 0;
72 FlushExpiredSessions();
73 }
74
75 auto iter = cache_.Get(cache_key);
76 if (iter == cache_.end())
77 return nullptr;
78
79 time_t now = clock_->Now().ToTimeT();
80 bssl::UniquePtr<SSL_SESSION> session = iter->second.Pop();
81 if (iter->second.ExpireSessions(now))
82 cache_.Erase(iter);
83
84 if (IsExpired(session.get(), now))
85 session = nullptr;
86
87 return session;
88 }
89
Insert(const Key & cache_key,bssl::UniquePtr<SSL_SESSION> session)90 void SSLClientSessionCache::Insert(const Key& cache_key,
91 bssl::UniquePtr<SSL_SESSION> session) {
92 auto iter = cache_.Get(cache_key);
93 if (iter == cache_.end())
94 iter = cache_.Put(cache_key, Entry());
95 iter->second.Push(std::move(session));
96 }
97
ClearEarlyData(const Key & cache_key)98 void SSLClientSessionCache::ClearEarlyData(const Key& cache_key) {
99 auto iter = cache_.Get(cache_key);
100 if (iter != cache_.end()) {
101 for (auto& session : iter->second.sessions) {
102 if (session) {
103 session.reset(SSL_SESSION_copy_without_early_data(session.get()));
104 }
105 }
106 }
107 }
108
FlushForServers(const base::flat_set<HostPortPair> & servers)109 void SSLClientSessionCache::FlushForServers(
110 const base::flat_set<HostPortPair>& servers) {
111 auto iter = cache_.begin();
112 while (iter != cache_.end()) {
113 if (servers.contains(iter->first.server)) {
114 iter = cache_.Erase(iter);
115 } else {
116 ++iter;
117 }
118 }
119 }
120
Flush()121 void SSLClientSessionCache::Flush() {
122 cache_.Clear();
123 }
124
SetClockForTesting(base::Clock * clock)125 void SSLClientSessionCache::SetClockForTesting(base::Clock* clock) {
126 clock_ = clock;
127 }
128
IsExpired(SSL_SESSION * session,time_t now)129 bool SSLClientSessionCache::IsExpired(SSL_SESSION* session, time_t now) {
130 if (now < 0)
131 return true;
132 uint64_t now_u64 = static_cast<uint64_t>(now);
133
134 // now_u64 may be slightly behind because of differences in how
135 // time is calculated at this layer versus BoringSSL.
136 // Add a second of wiggle room to account for this.
137 return now_u64 < SSL_SESSION_get_time(session) - 1 ||
138 now_u64 >=
139 SSL_SESSION_get_time(session) + SSL_SESSION_get_timeout(session);
140 }
141
142 SSLClientSessionCache::Entry::Entry() = default;
143 SSLClientSessionCache::Entry::Entry(Entry&&) = default;
144 SSLClientSessionCache::Entry::~Entry() = default;
145
Push(bssl::UniquePtr<SSL_SESSION> session)146 void SSLClientSessionCache::Entry::Push(bssl::UniquePtr<SSL_SESSION> session) {
147 if (sessions[0] != nullptr &&
148 SSL_SESSION_should_be_single_use(sessions[0].get())) {
149 sessions[1] = std::move(sessions[0]);
150 }
151 sessions[0] = std::move(session);
152 }
153
Pop()154 bssl::UniquePtr<SSL_SESSION> SSLClientSessionCache::Entry::Pop() {
155 if (sessions[0] == nullptr)
156 return nullptr;
157 bssl::UniquePtr<SSL_SESSION> session = bssl::UpRef(sessions[0]);
158 if (SSL_SESSION_should_be_single_use(session.get())) {
159 sessions[0] = std::move(sessions[1]);
160 sessions[1] = nullptr;
161 }
162 return session;
163 }
164
ExpireSessions(time_t now)165 bool SSLClientSessionCache::Entry::ExpireSessions(time_t now) {
166 if (sessions[0] == nullptr)
167 return true;
168
169 if (SSLClientSessionCache::IsExpired(sessions[0].get(), now)) {
170 return true;
171 }
172
173 if (sessions[1] != nullptr &&
174 SSLClientSessionCache::IsExpired(sessions[1].get(), now)) {
175 sessions[1] = nullptr;
176 }
177
178 return false;
179 }
180
FlushExpiredSessions()181 void SSLClientSessionCache::FlushExpiredSessions() {
182 time_t now = clock_->Now().ToTimeT();
183 auto iter = cache_.begin();
184 while (iter != cache_.end()) {
185 if (iter->second.ExpireSessions(now)) {
186 iter = cache_.Erase(iter);
187 } else {
188 ++iter;
189 }
190 }
191 }
192
OnMemoryPressure(base::MemoryPressureListener::MemoryPressureLevel memory_pressure_level)193 void SSLClientSessionCache::OnMemoryPressure(
194 base::MemoryPressureListener::MemoryPressureLevel memory_pressure_level) {
195 switch (memory_pressure_level) {
196 case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_NONE:
197 break;
198 case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_MODERATE:
199 FlushExpiredSessions();
200 break;
201 case base::MemoryPressureListener::MEMORY_PRESSURE_LEVEL_CRITICAL:
202 Flush();
203 break;
204 }
205 }
206
207 } // namespace net
208