• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/http/transport_security_persister.h"
6 
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <utility>
11 #include <vector>
12 
13 #include "base/base64.h"
14 #include "base/feature_list.h"
15 #include "base/files/file_path.h"
16 #include "base/files/file_util.h"
17 #include "base/functional/bind.h"
18 #include "base/functional/callback.h"
19 #include "base/json/json_reader.h"
20 #include "base/json/json_writer.h"
21 #include "base/location.h"
22 #include "base/task/sequenced_task_runner.h"
23 #include "base/task/single_thread_task_runner.h"
24 #include "base/values.h"
25 #include "net/base/features.h"
26 #include "net/base/network_anonymization_key.h"
27 #include "net/cert/x509_certificate.h"
28 #include "net/http/transport_security_state.h"
29 #include "third_party/abseil-cpp/absl/types/optional.h"
30 
31 namespace net {
32 
33 namespace {
34 
35 // This function converts the binary hashes to a base64 string which we can
36 // include in a JSON file.
HashedDomainToExternalString(const TransportSecurityState::HashedHost & hashed)37 std::string HashedDomainToExternalString(
38     const TransportSecurityState::HashedHost& hashed) {
39   return base::Base64Encode(hashed);
40 }
41 
42 // This inverts |HashedDomainToExternalString|, above. It turns an external
43 // string (from a JSON file) into an internal (binary) array.
ExternalStringToHashedDomain(const std::string & external)44 absl::optional<TransportSecurityState::HashedHost> ExternalStringToHashedDomain(
45     const std::string& external) {
46   TransportSecurityState::HashedHost out;
47   absl::optional<std::vector<uint8_t>> hashed = base::Base64Decode(external);
48   if (!hashed.has_value() || hashed.value().size() != out.size()) {
49     return absl::nullopt;
50   }
51 
52   std::copy_n(hashed.value().begin(), out.size(), out.begin());
53   return out;
54 }
55 
56 // Version 2 of the on-disk format consists of a single JSON object. The
57 // top-level dictionary has "version", "sts", and "expect_ct" entries. The first
58 // is an integer, the latter two are unordered lists of dictionaries, each
59 // representing cached data for a single host.
60 
61 // Stored in serialized dictionary values to distinguish incompatible versions.
62 // Version 1 is distinguished by the lack of an integer version value.
63 const char kVersionKey[] = "version";
64 const int kCurrentVersionValue = 2;
65 
66 // Keys in top level serialized dictionary, for lists of STS and Expect-CT
67 // entries, respectively. The Expect-CT key is legacy and deleted when read.
68 const char kSTSKey[] = "sts";
69 const char kExpectCTKey[] = "expect_ct";
70 
71 // Hostname entry, used in serialized STS dictionaries. Value is produced by
72 // passing hashed hostname strings to HashedDomainToExternalString().
73 const char kHostname[] = "host";
74 
75 // Key values in serialized STS entries.
76 const char kStsIncludeSubdomains[] = "sts_include_subdomains";
77 const char kStsObserved[] = "sts_observed";
78 const char kExpiry[] = "expiry";
79 const char kMode[] = "mode";
80 
81 // Values for "mode" used in serialized STS entries.
82 const char kForceHTTPS[] = "force-https";
83 const char kDefault[] = "default";
84 
LoadState(const base::FilePath & path)85 std::string LoadState(const base::FilePath& path) {
86   std::string result;
87   if (!base::ReadFileToString(path, &result)) {
88     return "";
89   }
90   return result;
91 }
92 
93 // Serializes STS data from |state| to a Value.
SerializeSTSData(const TransportSecurityState * state)94 base::Value::List SerializeSTSData(const TransportSecurityState* state) {
95   base::Value::List sts_list;
96 
97   TransportSecurityState::STSStateIterator sts_iterator(*state);
98   for (; sts_iterator.HasNext(); sts_iterator.Advance()) {
99     const TransportSecurityState::STSState& sts_state =
100         sts_iterator.domain_state();
101 
102     base::Value::Dict serialized;
103     serialized.Set(kHostname,
104                    HashedDomainToExternalString(sts_iterator.hostname()));
105     serialized.Set(kStsIncludeSubdomains, sts_state.include_subdomains);
106     serialized.Set(kStsObserved,
107                    sts_state.last_observed.InSecondsFSinceUnixEpoch());
108     serialized.Set(kExpiry, sts_state.expiry.InSecondsFSinceUnixEpoch());
109 
110     switch (sts_state.upgrade_mode) {
111       case TransportSecurityState::STSState::MODE_FORCE_HTTPS:
112         serialized.Set(kMode, kForceHTTPS);
113         break;
114       case TransportSecurityState::STSState::MODE_DEFAULT:
115         serialized.Set(kMode, kDefault);
116         break;
117     }
118 
119     sts_list.Append(std::move(serialized));
120   }
121   return sts_list;
122 }
123 
124 // Deserializes STS data from a Value created by the above method.
DeserializeSTSData(const base::Value & sts_list,TransportSecurityState * state)125 void DeserializeSTSData(const base::Value& sts_list,
126                         TransportSecurityState* state) {
127   if (!sts_list.is_list())
128     return;
129 
130   base::Time current_time(base::Time::Now());
131 
132   for (const base::Value& sts_entry : sts_list.GetList()) {
133     const base::Value::Dict* sts_dict = sts_entry.GetIfDict();
134     if (!sts_dict)
135       continue;
136 
137     const std::string* hostname = sts_dict->FindString(kHostname);
138     absl::optional<bool> sts_include_subdomains =
139         sts_dict->FindBool(kStsIncludeSubdomains);
140     absl::optional<double> sts_observed = sts_dict->FindDouble(kStsObserved);
141     absl::optional<double> expiry = sts_dict->FindDouble(kExpiry);
142     const std::string* mode = sts_dict->FindString(kMode);
143 
144     if (!hostname || !sts_include_subdomains.has_value() ||
145         !sts_observed.has_value() || !expiry.has_value() || !mode) {
146       continue;
147     }
148 
149     TransportSecurityState::STSState sts_state;
150     sts_state.include_subdomains = *sts_include_subdomains;
151     sts_state.last_observed =
152         base::Time::FromSecondsSinceUnixEpoch(*sts_observed);
153     sts_state.expiry = base::Time::FromSecondsSinceUnixEpoch(*expiry);
154 
155     if (*mode == kForceHTTPS) {
156       sts_state.upgrade_mode =
157           TransportSecurityState::STSState::MODE_FORCE_HTTPS;
158     } else if (*mode == kDefault) {
159       sts_state.upgrade_mode = TransportSecurityState::STSState::MODE_DEFAULT;
160     } else {
161       continue;
162     }
163 
164     if (sts_state.expiry < current_time || !sts_state.ShouldUpgradeToSSL())
165       continue;
166 
167     absl::optional<TransportSecurityState::HashedHost> hashed =
168         ExternalStringToHashedDomain(*hostname);
169     if (!hashed.has_value())
170       continue;
171 
172     state->AddOrUpdateEnabledSTSHosts(hashed.value(), sts_state);
173   }
174 }
175 
OnWriteFinishedTask(scoped_refptr<base::SequencedTaskRunner> task_runner,base::OnceClosure callback,bool result)176 void OnWriteFinishedTask(scoped_refptr<base::SequencedTaskRunner> task_runner,
177                          base::OnceClosure callback,
178                          bool result) {
179   task_runner->PostTask(FROM_HERE, std::move(callback));
180 }
181 
182 }  // namespace
183 
TransportSecurityPersister(TransportSecurityState * state,const scoped_refptr<base::SequencedTaskRunner> & background_runner,const base::FilePath & data_path)184 TransportSecurityPersister::TransportSecurityPersister(
185     TransportSecurityState* state,
186     const scoped_refptr<base::SequencedTaskRunner>& background_runner,
187     const base::FilePath& data_path)
188     : transport_security_state_(state),
189       writer_(data_path, background_runner),
190       foreground_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
191       background_runner_(background_runner) {
192   transport_security_state_->SetDelegate(this);
193 
194   background_runner_->PostTaskAndReplyWithResult(
195       FROM_HERE, base::BindOnce(&LoadState, writer_.path()),
196       base::BindOnce(&TransportSecurityPersister::CompleteLoad,
197                      weak_ptr_factory_.GetWeakPtr()));
198 }
199 
~TransportSecurityPersister()200 TransportSecurityPersister::~TransportSecurityPersister() {
201   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
202 
203   if (writer_.HasPendingWrite())
204     writer_.DoScheduledWrite();
205 
206   transport_security_state_->SetDelegate(nullptr);
207 }
208 
StateIsDirty(TransportSecurityState * state)209 void TransportSecurityPersister::StateIsDirty(TransportSecurityState* state) {
210   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
211   DCHECK_EQ(transport_security_state_, state);
212 
213   writer_.ScheduleWrite(this);
214 }
215 
WriteNow(TransportSecurityState * state,base::OnceClosure callback)216 void TransportSecurityPersister::WriteNow(TransportSecurityState* state,
217                                           base::OnceClosure callback) {
218   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
219   DCHECK_EQ(transport_security_state_, state);
220 
221   writer_.RegisterOnNextWriteCallbacks(
222       base::OnceClosure(),
223       base::BindOnce(
224           &OnWriteFinishedTask, foreground_runner_,
225           base::BindOnce(&TransportSecurityPersister::OnWriteFinished,
226                          weak_ptr_factory_.GetWeakPtr(), std::move(callback))));
227   absl::optional<std::string> data = SerializeData();
228   if (data) {
229     writer_.WriteNow(std::move(data).value());
230   } else {
231     writer_.WriteNow(std::string());
232   }
233 }
234 
OnWriteFinished(base::OnceClosure callback)235 void TransportSecurityPersister::OnWriteFinished(base::OnceClosure callback) {
236   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
237   std::move(callback).Run();
238 }
239 
SerializeData()240 absl::optional<std::string> TransportSecurityPersister::SerializeData() {
241   CHECK(foreground_runner_->RunsTasksInCurrentSequence());
242 
243   base::Value::Dict toplevel;
244   toplevel.Set(kVersionKey, kCurrentVersionValue);
245   toplevel.Set(kSTSKey, SerializeSTSData(transport_security_state_));
246 
247   std::string output;
248   if (!base::JSONWriter::Write(toplevel, &output)) {
249     return absl::nullopt;
250   }
251   return output;
252 }
253 
LoadEntries(const std::string & serialized)254 void TransportSecurityPersister::LoadEntries(const std::string& serialized) {
255   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
256 
257   transport_security_state_->ClearDynamicData();
258   bool contains_legacy_expect_ct_data = false;
259   Deserialize(serialized, transport_security_state_,
260               contains_legacy_expect_ct_data);
261   if (contains_legacy_expect_ct_data) {
262     StateIsDirty(transport_security_state_);
263   }
264 }
265 
Deserialize(const std::string & serialized,TransportSecurityState * state,bool & contains_legacy_expect_ct_data)266 void TransportSecurityPersister::Deserialize(
267     const std::string& serialized,
268     TransportSecurityState* state,
269     bool& contains_legacy_expect_ct_data) {
270   absl::optional<base::Value> value = base::JSONReader::Read(serialized);
271   if (!value || !value->is_dict())
272     return;
273 
274   base::Value::Dict& dict = value->GetDict();
275   absl::optional<int> version = dict.FindInt(kVersionKey);
276 
277   // Stop if the data is out of date (or in the previous format that didn't have
278   // a version number).
279   if (!version || *version != kCurrentVersionValue)
280     return;
281 
282   base::Value* sts_value = dict.Find(kSTSKey);
283   if (sts_value)
284     DeserializeSTSData(*sts_value, state);
285 
286   // If an Expect-CT key is found on deserialization, record this so that a
287   // write can be scheduled to clear it from disk.
288   contains_legacy_expect_ct_data = !!dict.Find(kExpectCTKey);
289 }
290 
CompleteLoad(const std::string & state)291 void TransportSecurityPersister::CompleteLoad(const std::string& state) {
292   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
293 
294   if (state.empty())
295     return;
296 
297   LoadEntries(state);
298 }
299 
300 }  // namespace net
301