• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/http/transport_security_persister.h"
6 
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/base64.h"
15 #include "base/feature_list.h"
16 #include "base/files/file_path.h"
17 #include "base/files/file_util.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback.h"
20 #include "base/json/json_reader.h"
21 #include "base/json/json_writer.h"
22 #include "base/location.h"
23 #include "base/metrics/field_trial_params.h"
24 #include "base/task/sequenced_task_runner.h"
25 #include "base/task/single_thread_task_runner.h"
26 #include "base/time/time.h"
27 #include "base/values.h"
28 #include "net/base/features.h"
29 #include "net/base/network_anonymization_key.h"
30 #include "net/cert/x509_certificate.h"
31 #include "net/http/transport_security_state.h"
32 
33 namespace net {
34 
35 BASE_FEATURE(kTransportSecurityFileWriterSchedule,
36              "TransportSecurityFileWriterSchedule",
37              base::FEATURE_ENABLED_BY_DEFAULT);
38 
39 namespace {
40 
41 // From kDefaultCommitInterval in base/files/important_file_writer.cc.
42 // kTransportSecurityFileWriterScheduleCommitInterval won't set the commit
43 // interval to less than this, for performance.
44 constexpr base::TimeDelta kMinCommitInterval = base::Seconds(10);
45 
46 // Max safe commit interval for the ImportantFileWriter.
47 constexpr base::TimeDelta kMaxCommitInterval = base::Minutes(10);
48 
49 // Overrides the default commit interval for the ImportantFileWriter.
50 const base::FeatureParam<base::TimeDelta> kCommitIntervalParam(
51     &kTransportSecurityFileWriterSchedule,
52     "commit_interval",
53     kMinCommitInterval);
54 
55 constexpr const char* kHistogramSuffix = "TransportSecurityPersister";
56 
57 // This function converts the binary hashes to a base64 string which we can
58 // include in a JSON file.
HashedDomainToExternalString(const TransportSecurityState::HashedHost & hashed)59 std::string HashedDomainToExternalString(
60     const TransportSecurityState::HashedHost& hashed) {
61   return base::Base64Encode(hashed);
62 }
63 
64 // This inverts |HashedDomainToExternalString|, above. It turns an external
65 // string (from a JSON file) into an internal (binary) array.
ExternalStringToHashedDomain(const std::string & external)66 std::optional<TransportSecurityState::HashedHost> ExternalStringToHashedDomain(
67     const std::string& external) {
68   TransportSecurityState::HashedHost out;
69   std::optional<std::vector<uint8_t>> hashed = base::Base64Decode(external);
70   if (!hashed.has_value() || hashed.value().size() != out.size()) {
71     return std::nullopt;
72   }
73 
74   std::copy_n(hashed.value().begin(), out.size(), out.begin());
75   return out;
76 }
77 
78 // Version 2 of the on-disk format consists of a single JSON object. The
79 // top-level dictionary has "version", "sts", and "expect_ct" entries. The first
80 // is an integer, the latter two are unordered lists of dictionaries, each
81 // representing cached data for a single host.
82 
83 // Stored in serialized dictionary values to distinguish incompatible versions.
84 // Version 1 is distinguished by the lack of an integer version value.
85 const char kVersionKey[] = "version";
86 const int kCurrentVersionValue = 2;
87 
88 // Keys in top level serialized dictionary, for lists of STS and Expect-CT
89 // entries, respectively. The Expect-CT key is legacy and deleted when read.
90 const char kSTSKey[] = "sts";
91 const char kExpectCTKey[] = "expect_ct";
92 
93 // Hostname entry, used in serialized STS dictionaries. Value is produced by
94 // passing hashed hostname strings to HashedDomainToExternalString().
95 const char kHostname[] = "host";
96 
97 // Key values in serialized STS entries.
98 const char kStsIncludeSubdomains[] = "sts_include_subdomains";
99 const char kStsObserved[] = "sts_observed";
100 const char kExpiry[] = "expiry";
101 const char kMode[] = "mode";
102 
103 // Values for "mode" used in serialized STS entries.
104 const char kForceHTTPS[] = "force-https";
105 const char kDefault[] = "default";
106 
LoadState(const base::FilePath & path)107 std::string LoadState(const base::FilePath& path) {
108   std::string result;
109   if (!base::ReadFileToString(path, &result)) {
110     return "";
111   }
112   return result;
113 }
114 
115 // Serializes STS data from |state| to a Value.
SerializeSTSData(const TransportSecurityState * state)116 base::Value::List SerializeSTSData(const TransportSecurityState* state) {
117   base::Value::List sts_list;
118 
119   TransportSecurityState::STSStateIterator sts_iterator(*state);
120   for (; sts_iterator.HasNext(); sts_iterator.Advance()) {
121     const TransportSecurityState::STSState& sts_state =
122         sts_iterator.domain_state();
123 
124     base::Value::Dict serialized;
125     serialized.Set(kHostname,
126                    HashedDomainToExternalString(sts_iterator.hostname()));
127     serialized.Set(kStsIncludeSubdomains, sts_state.include_subdomains);
128     serialized.Set(kStsObserved,
129                    sts_state.last_observed.InSecondsFSinceUnixEpoch());
130     serialized.Set(kExpiry, sts_state.expiry.InSecondsFSinceUnixEpoch());
131 
132     switch (sts_state.upgrade_mode) {
133       case TransportSecurityState::STSState::MODE_FORCE_HTTPS:
134         serialized.Set(kMode, kForceHTTPS);
135         break;
136       case TransportSecurityState::STSState::MODE_DEFAULT:
137         serialized.Set(kMode, kDefault);
138         break;
139     }
140 
141     sts_list.Append(std::move(serialized));
142   }
143   return sts_list;
144 }
145 
146 // Deserializes STS data from a Value created by the above method.
DeserializeSTSData(const base::Value & sts_list,TransportSecurityState * state)147 void DeserializeSTSData(const base::Value& sts_list,
148                         TransportSecurityState* state) {
149   if (!sts_list.is_list())
150     return;
151 
152   base::Time current_time(base::Time::Now());
153 
154   for (const base::Value& sts_entry : sts_list.GetList()) {
155     const base::Value::Dict* sts_dict = sts_entry.GetIfDict();
156     if (!sts_dict)
157       continue;
158 
159     const std::string* hostname = sts_dict->FindString(kHostname);
160     std::optional<bool> sts_include_subdomains =
161         sts_dict->FindBool(kStsIncludeSubdomains);
162     std::optional<double> sts_observed = sts_dict->FindDouble(kStsObserved);
163     std::optional<double> expiry = sts_dict->FindDouble(kExpiry);
164     const std::string* mode = sts_dict->FindString(kMode);
165 
166     if (!hostname || !sts_include_subdomains.has_value() ||
167         !sts_observed.has_value() || !expiry.has_value() || !mode) {
168       continue;
169     }
170 
171     TransportSecurityState::STSState sts_state;
172     sts_state.include_subdomains = *sts_include_subdomains;
173     sts_state.last_observed =
174         base::Time::FromSecondsSinceUnixEpoch(*sts_observed);
175     sts_state.expiry = base::Time::FromSecondsSinceUnixEpoch(*expiry);
176 
177     if (*mode == kForceHTTPS) {
178       sts_state.upgrade_mode =
179           TransportSecurityState::STSState::MODE_FORCE_HTTPS;
180     } else if (*mode == kDefault) {
181       sts_state.upgrade_mode = TransportSecurityState::STSState::MODE_DEFAULT;
182     } else {
183       continue;
184     }
185 
186     if (sts_state.expiry < current_time || !sts_state.ShouldUpgradeToSSL())
187       continue;
188 
189     std::optional<TransportSecurityState::HashedHost> hashed =
190         ExternalStringToHashedDomain(*hostname);
191     if (!hashed.has_value())
192       continue;
193 
194     state->AddOrUpdateEnabledSTSHosts(hashed.value(), sts_state);
195   }
196 }
197 
OnWriteFinishedTask(scoped_refptr<base::SequencedTaskRunner> task_runner,base::OnceClosure callback,bool result)198 void OnWriteFinishedTask(scoped_refptr<base::SequencedTaskRunner> task_runner,
199                          base::OnceClosure callback,
200                          bool result) {
201   task_runner->PostTask(FROM_HERE, std::move(callback));
202 }
203 
204 }  // namespace
205 
TransportSecurityPersister(TransportSecurityState * state,const scoped_refptr<base::SequencedTaskRunner> & background_runner,const base::FilePath & data_path)206 TransportSecurityPersister::TransportSecurityPersister(
207     TransportSecurityState* state,
208     const scoped_refptr<base::SequencedTaskRunner>& background_runner,
209     const base::FilePath& data_path)
210     : transport_security_state_(state),
211       writer_(data_path,
212               background_runner,
213               GetCommitInterval(),
214               kHistogramSuffix),
215       foreground_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
216       background_runner_(background_runner) {
217   transport_security_state_->SetDelegate(this);
218 
219   background_runner_->PostTaskAndReplyWithResult(
220       FROM_HERE, base::BindOnce(&LoadState, writer_.path()),
221       base::BindOnce(&TransportSecurityPersister::CompleteLoad,
222                      weak_ptr_factory_.GetWeakPtr()));
223 }
224 
~TransportSecurityPersister()225 TransportSecurityPersister::~TransportSecurityPersister() {
226   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
227 
228   if (writer_.HasPendingWrite())
229     writer_.DoScheduledWrite();
230 
231   transport_security_state_->SetDelegate(nullptr);
232 }
233 
StateIsDirty(TransportSecurityState * state)234 void TransportSecurityPersister::StateIsDirty(TransportSecurityState* state) {
235   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
236   DCHECK_EQ(transport_security_state_, state);
237 
238   writer_.ScheduleWrite(this);
239 }
240 
WriteNow(TransportSecurityState * state,base::OnceClosure callback)241 void TransportSecurityPersister::WriteNow(TransportSecurityState* state,
242                                           base::OnceClosure callback) {
243   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
244   DCHECK_EQ(transport_security_state_, state);
245 
246   writer_.RegisterOnNextWriteCallbacks(
247       base::OnceClosure(),
248       base::BindOnce(
249           &OnWriteFinishedTask, foreground_runner_,
250           base::BindOnce(&TransportSecurityPersister::OnWriteFinished,
251                          weak_ptr_factory_.GetWeakPtr(), std::move(callback))));
252   std::optional<std::string> data = SerializeData();
253   if (data) {
254     writer_.WriteNow(std::move(data).value());
255   } else {
256     writer_.WriteNow(std::string());
257   }
258 }
259 
OnWriteFinished(base::OnceClosure callback)260 void TransportSecurityPersister::OnWriteFinished(base::OnceClosure callback) {
261   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
262   std::move(callback).Run();
263 }
264 
SerializeData()265 std::optional<std::string> TransportSecurityPersister::SerializeData() {
266   CHECK(foreground_runner_->RunsTasksInCurrentSequence());
267 
268   base::Value::Dict toplevel;
269   toplevel.Set(kVersionKey, kCurrentVersionValue);
270   toplevel.Set(kSTSKey, SerializeSTSData(transport_security_state_));
271 
272   std::string output;
273   if (!base::JSONWriter::Write(toplevel, &output)) {
274     return std::nullopt;
275   }
276   return output;
277 }
278 
LoadEntries(const std::string & serialized)279 void TransportSecurityPersister::LoadEntries(const std::string& serialized) {
280   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
281 
282   transport_security_state_->ClearDynamicData();
283   bool contains_legacy_expect_ct_data = false;
284   Deserialize(serialized, transport_security_state_,
285               contains_legacy_expect_ct_data);
286   if (contains_legacy_expect_ct_data) {
287     StateIsDirty(transport_security_state_);
288   }
289 }
290 
291 // static
GetCommitInterval()292 base::TimeDelta TransportSecurityPersister::GetCommitInterval() {
293   return std::clamp(kCommitIntervalParam.Get(), kMinCommitInterval,
294                     kMaxCommitInterval);
295 }
296 
Deserialize(const std::string & serialized,TransportSecurityState * state,bool & contains_legacy_expect_ct_data)297 void TransportSecurityPersister::Deserialize(
298     const std::string& serialized,
299     TransportSecurityState* state,
300     bool& contains_legacy_expect_ct_data) {
301   std::optional<base::Value> value = base::JSONReader::Read(serialized);
302   if (!value || !value->is_dict())
303     return;
304 
305   base::Value::Dict& dict = value->GetDict();
306   std::optional<int> version = dict.FindInt(kVersionKey);
307 
308   // Stop if the data is out of date (or in the previous format that didn't have
309   // a version number).
310   if (!version || *version != kCurrentVersionValue)
311     return;
312 
313   base::Value* sts_value = dict.Find(kSTSKey);
314   if (sts_value)
315     DeserializeSTSData(*sts_value, state);
316 
317   // If an Expect-CT key is found on deserialization, record this so that a
318   // write can be scheduled to clear it from disk.
319   contains_legacy_expect_ct_data = !!dict.Find(kExpectCTKey);
320 }
321 
CompleteLoad(const std::string & state)322 void TransportSecurityPersister::CompleteLoad(const std::string& state) {
323   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
324 
325   if (state.empty())
326     return;
327 
328   LoadEntries(state);
329 }
330 
331 }  // namespace net
332