• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/http/transport_security_persister.h"
6 
7 #include "base/base64.h"
8 #include "base/bind.h"
9 #include "base/files/file_path.h"
10 #include "base/files/file_util.h"
11 #include "base/json/json_reader.h"
12 #include "base/json/json_writer.h"
13 #include "base/message_loop/message_loop.h"
14 #include "base/message_loop/message_loop_proxy.h"
15 #include "base/sequenced_task_runner.h"
16 #include "base/task_runner_util.h"
17 #include "base/values.h"
18 #include "crypto/sha2.h"
19 #include "net/cert/x509_certificate.h"
20 #include "net/http/transport_security_state.h"
21 
22 using net::HashValue;
23 using net::HashValueTag;
24 using net::HashValueVector;
25 using net::TransportSecurityState;
26 
27 namespace {
28 
SPKIHashesToListValue(const HashValueVector & hashes)29 base::ListValue* SPKIHashesToListValue(const HashValueVector& hashes) {
30   base::ListValue* pins = new base::ListValue;
31   for (size_t i = 0; i != hashes.size(); i++)
32     pins->Append(new base::StringValue(hashes[i].ToString()));
33   return pins;
34 }
35 
SPKIHashesFromListValue(const base::ListValue & pins,HashValueVector * hashes)36 void SPKIHashesFromListValue(const base::ListValue& pins,
37                              HashValueVector* hashes) {
38   size_t num_pins = pins.GetSize();
39   for (size_t i = 0; i < num_pins; ++i) {
40     std::string type_and_base64;
41     HashValue fingerprint;
42     if (pins.GetString(i, &type_and_base64) &&
43         fingerprint.FromString(type_and_base64)) {
44       hashes->push_back(fingerprint);
45     }
46   }
47 }
48 
49 // This function converts the binary hashes to a base64 string which we can
50 // include in a JSON file.
HashedDomainToExternalString(const std::string & hashed)51 std::string HashedDomainToExternalString(const std::string& hashed) {
52   std::string out;
53   base::Base64Encode(hashed, &out);
54   return out;
55 }
56 
57 // This inverts |HashedDomainToExternalString|, above. It turns an external
58 // string (from a JSON file) into an internal (binary) string.
ExternalStringToHashedDomain(const std::string & external)59 std::string ExternalStringToHashedDomain(const std::string& external) {
60   std::string out;
61   if (!base::Base64Decode(external, &out) ||
62       out.size() != crypto::kSHA256Length) {
63     return std::string();
64   }
65 
66   return out;
67 }
68 
69 const char kIncludeSubdomains[] = "include_subdomains";
70 const char kStsIncludeSubdomains[] = "sts_include_subdomains";
71 const char kPkpIncludeSubdomains[] = "pkp_include_subdomains";
72 const char kMode[] = "mode";
73 const char kExpiry[] = "expiry";
74 const char kDynamicSPKIHashesExpiry[] = "dynamic_spki_hashes_expiry";
75 const char kDynamicSPKIHashes[] = "dynamic_spki_hashes";
76 const char kForceHTTPS[] = "force-https";
77 const char kStrict[] = "strict";
78 const char kDefault[] = "default";
79 const char kPinningOnly[] = "pinning-only";
80 const char kCreated[] = "created";
81 const char kStsObserved[] = "sts_observed";
82 const char kPkpObserved[] = "pkp_observed";
83 
LoadState(const base::FilePath & path)84 std::string LoadState(const base::FilePath& path) {
85   std::string result;
86   if (!base::ReadFileToString(path, &result)) {
87     return "";
88   }
89   return result;
90 }
91 
92 }  // namespace
93 
94 
95 namespace net {
96 
TransportSecurityPersister(TransportSecurityState * state,const base::FilePath & profile_path,const scoped_refptr<base::SequencedTaskRunner> & background_runner,bool readonly)97 TransportSecurityPersister::TransportSecurityPersister(
98     TransportSecurityState* state,
99     const base::FilePath& profile_path,
100     const scoped_refptr<base::SequencedTaskRunner>& background_runner,
101     bool readonly)
102     : transport_security_state_(state),
103       writer_(profile_path.AppendASCII("TransportSecurity"), background_runner),
104       foreground_runner_(base::MessageLoop::current()->message_loop_proxy()),
105       background_runner_(background_runner),
106       readonly_(readonly),
107       weak_ptr_factory_(this) {
108   transport_security_state_->SetDelegate(this);
109 
110   base::PostTaskAndReplyWithResult(
111       background_runner_.get(),
112       FROM_HERE,
113       base::Bind(&::LoadState, writer_.path()),
114       base::Bind(&TransportSecurityPersister::CompleteLoad,
115                  weak_ptr_factory_.GetWeakPtr()));
116 }
117 
~TransportSecurityPersister()118 TransportSecurityPersister::~TransportSecurityPersister() {
119   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
120 
121   if (writer_.HasPendingWrite())
122     writer_.DoScheduledWrite();
123 
124   transport_security_state_->SetDelegate(NULL);
125 }
126 
StateIsDirty(TransportSecurityState * state)127 void TransportSecurityPersister::StateIsDirty(
128     TransportSecurityState* state) {
129   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
130   DCHECK_EQ(transport_security_state_, state);
131 
132   if (!readonly_)
133     writer_.ScheduleWrite(this);
134 }
135 
SerializeData(std::string * output)136 bool TransportSecurityPersister::SerializeData(std::string* output) {
137   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
138 
139   base::DictionaryValue toplevel;
140   base::Time now = base::Time::Now();
141   TransportSecurityState::Iterator state(*transport_security_state_);
142   for (; state.HasNext(); state.Advance()) {
143     const std::string& hostname = state.hostname();
144     const TransportSecurityState::DomainState& domain_state =
145         state.domain_state();
146 
147     base::DictionaryValue* serialized = new base::DictionaryValue;
148     serialized->SetBoolean(kStsIncludeSubdomains,
149                            domain_state.sts.include_subdomains);
150     serialized->SetBoolean(kPkpIncludeSubdomains,
151                            domain_state.pkp.include_subdomains);
152     serialized->SetDouble(kStsObserved,
153                           domain_state.sts.last_observed.ToDoubleT());
154     serialized->SetDouble(kPkpObserved,
155                           domain_state.pkp.last_observed.ToDoubleT());
156     serialized->SetDouble(kExpiry, domain_state.sts.expiry.ToDoubleT());
157     serialized->SetDouble(kDynamicSPKIHashesExpiry,
158                           domain_state.pkp.expiry.ToDoubleT());
159 
160     switch (domain_state.sts.upgrade_mode) {
161       case TransportSecurityState::DomainState::MODE_FORCE_HTTPS:
162         serialized->SetString(kMode, kForceHTTPS);
163         break;
164       case TransportSecurityState::DomainState::MODE_DEFAULT:
165         serialized->SetString(kMode, kDefault);
166         break;
167       default:
168         NOTREACHED() << "DomainState with unknown mode";
169         delete serialized;
170         continue;
171     }
172 
173     if (now < domain_state.pkp.expiry) {
174       serialized->Set(kDynamicSPKIHashes,
175                       SPKIHashesToListValue(domain_state.pkp.spki_hashes));
176     }
177 
178     toplevel.Set(HashedDomainToExternalString(hostname), serialized);
179   }
180 
181   base::JSONWriter::WriteWithOptions(&toplevel,
182                                      base::JSONWriter::OPTIONS_PRETTY_PRINT,
183                                      output);
184   return true;
185 }
186 
LoadEntries(const std::string & serialized,bool * dirty)187 bool TransportSecurityPersister::LoadEntries(const std::string& serialized,
188                                              bool* dirty) {
189   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
190 
191   transport_security_state_->ClearDynamicData();
192   return Deserialize(serialized, dirty, transport_security_state_);
193 }
194 
195 // static
Deserialize(const std::string & serialized,bool * dirty,TransportSecurityState * state)196 bool TransportSecurityPersister::Deserialize(const std::string& serialized,
197                                              bool* dirty,
198                                              TransportSecurityState* state) {
199   scoped_ptr<base::Value> value(base::JSONReader::Read(serialized));
200   base::DictionaryValue* dict_value = NULL;
201   if (!value.get() || !value->GetAsDictionary(&dict_value))
202     return false;
203 
204   const base::Time current_time(base::Time::Now());
205   bool dirtied = false;
206 
207   for (base::DictionaryValue::Iterator i(*dict_value);
208        !i.IsAtEnd(); i.Advance()) {
209     const base::DictionaryValue* parsed = NULL;
210     if (!i.value().GetAsDictionary(&parsed)) {
211       LOG(WARNING) << "Could not parse entry " << i.key() << "; skipping entry";
212       continue;
213     }
214 
215     TransportSecurityState::DomainState domain_state;
216 
217     // kIncludeSubdomains is a legacy synonym for kStsIncludeSubdomains and
218     // kPkpIncludeSubdomains. Parse at least one of these properties,
219     // preferably the new ones.
220     bool include_subdomains = false;
221     bool parsed_include_subdomains = parsed->GetBoolean(kIncludeSubdomains,
222                                                         &include_subdomains);
223     domain_state.sts.include_subdomains = include_subdomains;
224     domain_state.pkp.include_subdomains = include_subdomains;
225     if (parsed->GetBoolean(kStsIncludeSubdomains, &include_subdomains)) {
226       domain_state.sts.include_subdomains = include_subdomains;
227       parsed_include_subdomains = true;
228     }
229     if (parsed->GetBoolean(kPkpIncludeSubdomains, &include_subdomains)) {
230       domain_state.pkp.include_subdomains = include_subdomains;
231       parsed_include_subdomains = true;
232     }
233 
234     std::string mode_string;
235     double expiry = 0;
236     if (!parsed_include_subdomains ||
237         !parsed->GetString(kMode, &mode_string) ||
238         !parsed->GetDouble(kExpiry, &expiry)) {
239       LOG(WARNING) << "Could not parse some elements of entry " << i.key()
240                    << "; skipping entry";
241       continue;
242     }
243 
244     // Don't fail if this key is not present.
245     double dynamic_spki_hashes_expiry = 0;
246     parsed->GetDouble(kDynamicSPKIHashesExpiry,
247                       &dynamic_spki_hashes_expiry);
248 
249     const base::ListValue* pins_list = NULL;
250     if (parsed->GetList(kDynamicSPKIHashes, &pins_list)) {
251       SPKIHashesFromListValue(*pins_list, &domain_state.pkp.spki_hashes);
252     }
253 
254     if (mode_string == kForceHTTPS || mode_string == kStrict) {
255       domain_state.sts.upgrade_mode =
256           TransportSecurityState::DomainState::MODE_FORCE_HTTPS;
257     } else if (mode_string == kDefault || mode_string == kPinningOnly) {
258       domain_state.sts.upgrade_mode =
259           TransportSecurityState::DomainState::MODE_DEFAULT;
260     } else {
261       LOG(WARNING) << "Unknown TransportSecurityState mode string "
262                    << mode_string << " found for entry " << i.key()
263                    << "; skipping entry";
264       continue;
265     }
266 
267     domain_state.sts.expiry = base::Time::FromDoubleT(expiry);
268     domain_state.pkp.expiry =
269         base::Time::FromDoubleT(dynamic_spki_hashes_expiry);
270 
271     double sts_observed;
272     double pkp_observed;
273     if (parsed->GetDouble(kStsObserved, &sts_observed)) {
274       domain_state.sts.last_observed = base::Time::FromDoubleT(sts_observed);
275     } else if (parsed->GetDouble(kCreated, &sts_observed)) {
276       // kCreated is a legacy synonym for both kStsObserved and kPkpObserved.
277       domain_state.sts.last_observed = base::Time::FromDoubleT(sts_observed);
278     } else {
279       // We're migrating an old entry with no observation date. Make sure we
280       // write the new date back in a reasonable time frame.
281       dirtied = true;
282       domain_state.sts.last_observed = base::Time::Now();
283     }
284     if (parsed->GetDouble(kPkpObserved, &pkp_observed)) {
285       domain_state.pkp.last_observed = base::Time::FromDoubleT(pkp_observed);
286     } else if (parsed->GetDouble(kCreated, &pkp_observed)) {
287       domain_state.pkp.last_observed = base::Time::FromDoubleT(pkp_observed);
288     } else {
289       dirtied = true;
290       domain_state.pkp.last_observed = base::Time::Now();
291     }
292 
293     if (domain_state.sts.expiry <= current_time &&
294         domain_state.pkp.expiry <= current_time) {
295       // Make sure we dirty the state if we drop an entry.
296       dirtied = true;
297       continue;
298     }
299 
300     std::string hashed = ExternalStringToHashedDomain(i.key());
301     if (hashed.empty()) {
302       dirtied = true;
303       continue;
304     }
305 
306     state->AddOrUpdateEnabledHosts(hashed, domain_state);
307   }
308 
309   *dirty = dirtied;
310   return true;
311 }
312 
CompleteLoad(const std::string & state)313 void TransportSecurityPersister::CompleteLoad(const std::string& state) {
314   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
315 
316   if (state.empty())
317     return;
318 
319   bool dirty = false;
320   if (!LoadEntries(state, &dirty)) {
321     LOG(ERROR) << "Failed to deserialize state: " << state;
322     return;
323   }
324   if (dirty)
325     StateIsDirty(transport_security_state_);
326 }
327 
328 }  // namespace net
329