• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/http/transport_security_persister.h"
6 
7 #include <map>
8 #include <string>
9 #include <vector>
10 
11 #include "base/file_util.h"
12 #include "base/files/file_path.h"
13 #include "base/files/scoped_temp_dir.h"
14 #include "base/message_loop/message_loop.h"
15 #include "net/http/transport_security_state.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17 
18 using net::TransportSecurityPersister;
19 using net::TransportSecurityState;
20 
21 class TransportSecurityPersisterTest : public testing::Test {
22  public:
TransportSecurityPersisterTest()23   TransportSecurityPersisterTest() {
24   }
25 
~TransportSecurityPersisterTest()26   virtual ~TransportSecurityPersisterTest() {
27     base::MessageLoopForIO::current()->RunUntilIdle();
28   }
29 
SetUp()30   virtual void SetUp() OVERRIDE {
31     ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
32     persister_.reset(new TransportSecurityPersister(
33         &state_,
34         temp_dir_.path(),
35         base::MessageLoopForIO::current()->message_loop_proxy(),
36         false));
37   }
38 
39  protected:
40   base::ScopedTempDir temp_dir_;
41   TransportSecurityState state_;
42   scoped_ptr<TransportSecurityPersister> persister_;
43 };
44 
TEST_F(TransportSecurityPersisterTest,SerializeData1)45 TEST_F(TransportSecurityPersisterTest, SerializeData1) {
46   std::string output;
47   bool dirty;
48 
49   EXPECT_TRUE(persister_->SerializeData(&output));
50   EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
51   EXPECT_FALSE(dirty);
52 }
53 
TEST_F(TransportSecurityPersisterTest,SerializeData2)54 TEST_F(TransportSecurityPersisterTest, SerializeData2) {
55   TransportSecurityState::DomainState domain_state;
56   const base::Time current_time(base::Time::Now());
57   const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
58   static const char kYahooDomain[] = "yahoo.com";
59 
60   EXPECT_FALSE(state_.GetDomainState(kYahooDomain, true, &domain_state));
61 
62   bool include_subdomains = true;
63   state_.AddHSTS(kYahooDomain, expiry, include_subdomains);
64 
65   std::string output;
66   bool dirty;
67   EXPECT_TRUE(persister_->SerializeData(&output));
68   EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
69 
70   EXPECT_TRUE(state_.GetDomainState(kYahooDomain, true, &domain_state));
71   EXPECT_EQ(domain_state.upgrade_mode,
72             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
73   EXPECT_TRUE(state_.GetDomainState("foo.yahoo.com", true, &domain_state));
74   EXPECT_EQ(domain_state.upgrade_mode,
75             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
76   EXPECT_TRUE(state_.GetDomainState("foo.bar.yahoo.com", true, &domain_state));
77   EXPECT_EQ(domain_state.upgrade_mode,
78             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
79   EXPECT_TRUE(state_.GetDomainState("foo.bar.baz.yahoo.com", true,
80                                    &domain_state));
81   EXPECT_EQ(domain_state.upgrade_mode,
82             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
83   EXPECT_FALSE(state_.GetDomainState("com", true, &domain_state));
84 }
85 
TEST_F(TransportSecurityPersisterTest,SerializeData3)86 TEST_F(TransportSecurityPersisterTest, SerializeData3) {
87   // Add an entry.
88   net::HashValue fp1(net::HASH_VALUE_SHA1);
89   memset(fp1.data(), 0, fp1.size());
90   net::HashValue fp2(net::HASH_VALUE_SHA1);
91   memset(fp2.data(), 1, fp2.size());
92   base::Time expiry =
93       base::Time::Now() + base::TimeDelta::FromSeconds(1000);
94   net::HashValueVector dynamic_spki_hashes;
95   dynamic_spki_hashes.push_back(fp1);
96   dynamic_spki_hashes.push_back(fp2);
97   bool include_subdomains = false;
98   state_.AddHSTS("www.example.com", expiry, include_subdomains);
99   state_.AddHPKP("www.example.com", expiry, include_subdomains,
100                  dynamic_spki_hashes);
101 
102   // Add another entry.
103   memset(fp1.data(), 2, fp1.size());
104   memset(fp2.data(), 3, fp2.size());
105   expiry =
106       base::Time::Now() + base::TimeDelta::FromSeconds(3000);
107   dynamic_spki_hashes.push_back(fp1);
108   dynamic_spki_hashes.push_back(fp2);
109   state_.AddHSTS("www.example.net", expiry, include_subdomains);
110   state_.AddHPKP("www.example.net", expiry, include_subdomains,
111                  dynamic_spki_hashes);
112 
113   // Save a copy of everything.
114   std::map<std::string, TransportSecurityState::DomainState> saved;
115   TransportSecurityState::Iterator i(state_);
116   while (i.HasNext()) {
117     saved[i.hostname()] = i.domain_state();
118     i.Advance();
119   }
120 
121   std::string serialized;
122   EXPECT_TRUE(persister_->SerializeData(&serialized));
123 
124   // Persist the data to the file. For the test to be fast and not flaky, we
125   // just do it directly rather than call persister_->StateIsDirty. (That uses
126   // ImportantFileWriter, which has an asynchronous commit interval rather
127   // than block.) Use a different basename just for cleanliness.
128   base::FilePath path =
129       temp_dir_.path().AppendASCII("TransportSecurityPersisterTest");
130   EXPECT_TRUE(file_util::WriteFile(path, serialized.c_str(),
131                                    serialized.size()));
132 
133   // Read the data back.
134   std::string persisted;
135   EXPECT_TRUE(base::ReadFileToString(path, &persisted));
136   EXPECT_EQ(persisted, serialized);
137   bool dirty;
138   EXPECT_TRUE(persister_->LoadEntries(persisted, &dirty));
139   EXPECT_FALSE(dirty);
140 
141   // Check that states are the same as saved.
142   size_t count = 0;
143   TransportSecurityState::Iterator j(state_);
144   while (j.HasNext()) {
145     count++;
146     j.Advance();
147   }
148   EXPECT_EQ(count, saved.size());
149 }
150 
TEST_F(TransportSecurityPersisterTest,SerializeDataOld)151 TEST_F(TransportSecurityPersisterTest, SerializeDataOld) {
152   // This is an old-style piece of transport state JSON, which has no creation
153   // date.
154   std::string output =
155       "{ "
156       "\"NiyD+3J1r6z1wjl2n1ALBu94Zj9OsEAMo0kCN8js0Uk=\": {"
157       "\"expiry\": 1266815027.983453, "
158       "\"include_subdomains\": false, "
159       "\"mode\": \"strict\" "
160       "}"
161       "}";
162   bool dirty;
163   EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
164   EXPECT_TRUE(dirty);
165 }
166 
TEST_F(TransportSecurityPersisterTest,PublicKeyHashes)167 TEST_F(TransportSecurityPersisterTest, PublicKeyHashes) {
168   TransportSecurityState::DomainState domain_state;
169   static const char kTestDomain[] = "example.com";
170   EXPECT_FALSE(state_.GetDomainState(kTestDomain, false, &domain_state));
171   net::HashValueVector hashes;
172   EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
173 
174   net::HashValue sha1(net::HASH_VALUE_SHA1);
175   memset(sha1.data(), '1', sha1.size());
176   domain_state.dynamic_spki_hashes.push_back(sha1);
177 
178   EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
179 
180   hashes.push_back(sha1);
181   EXPECT_TRUE(domain_state.CheckPublicKeyPins(hashes));
182 
183   hashes[0].data()[0] = '2';
184   EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
185 
186   const base::Time current_time(base::Time::Now());
187   const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
188   bool include_subdomains = false;
189   state_.AddHSTS(kTestDomain, expiry, include_subdomains);
190   state_.AddHPKP(kTestDomain, expiry, include_subdomains,
191                  domain_state.dynamic_spki_hashes);
192   std::string ser;
193   EXPECT_TRUE(persister_->SerializeData(&ser));
194   bool dirty;
195   EXPECT_TRUE(persister_->LoadEntries(ser, &dirty));
196   EXPECT_TRUE(state_.GetDomainState(kTestDomain, false, &domain_state));
197   EXPECT_EQ(1u, domain_state.dynamic_spki_hashes.size());
198   EXPECT_EQ(sha1.tag, domain_state.dynamic_spki_hashes[0].tag);
199   EXPECT_EQ(0, memcmp(domain_state.dynamic_spki_hashes[0].data(), sha1.data(),
200                       sha1.size()));
201 }
202