1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/http/transport_security_persister.h"
6
7 #include <map>
8 #include <string>
9 #include <vector>
10
11 #include "base/file_util.h"
12 #include "base/files/file_path.h"
13 #include "base/files/scoped_temp_dir.h"
14 #include "base/message_loop/message_loop.h"
15 #include "net/http/transport_security_state.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17
18 using net::TransportSecurityPersister;
19 using net::TransportSecurityState;
20
21 class TransportSecurityPersisterTest : public testing::Test {
22 public:
TransportSecurityPersisterTest()23 TransportSecurityPersisterTest() {
24 }
25
~TransportSecurityPersisterTest()26 virtual ~TransportSecurityPersisterTest() {
27 base::MessageLoopForIO::current()->RunUntilIdle();
28 }
29
SetUp()30 virtual void SetUp() OVERRIDE {
31 ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
32 persister_.reset(new TransportSecurityPersister(
33 &state_,
34 temp_dir_.path(),
35 base::MessageLoopForIO::current()->message_loop_proxy(),
36 false));
37 }
38
39 protected:
40 base::ScopedTempDir temp_dir_;
41 TransportSecurityState state_;
42 scoped_ptr<TransportSecurityPersister> persister_;
43 };
44
TEST_F(TransportSecurityPersisterTest,SerializeData1)45 TEST_F(TransportSecurityPersisterTest, SerializeData1) {
46 std::string output;
47 bool dirty;
48
49 EXPECT_TRUE(persister_->SerializeData(&output));
50 EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
51 EXPECT_FALSE(dirty);
52 }
53
TEST_F(TransportSecurityPersisterTest,SerializeData2)54 TEST_F(TransportSecurityPersisterTest, SerializeData2) {
55 TransportSecurityState::DomainState domain_state;
56 const base::Time current_time(base::Time::Now());
57 const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
58 static const char kYahooDomain[] = "yahoo.com";
59
60 EXPECT_FALSE(state_.GetDomainState(kYahooDomain, true, &domain_state));
61
62 bool include_subdomains = true;
63 state_.AddHSTS(kYahooDomain, expiry, include_subdomains);
64
65 std::string output;
66 bool dirty;
67 EXPECT_TRUE(persister_->SerializeData(&output));
68 EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
69
70 EXPECT_TRUE(state_.GetDomainState(kYahooDomain, true, &domain_state));
71 EXPECT_EQ(domain_state.upgrade_mode,
72 TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
73 EXPECT_TRUE(state_.GetDomainState("foo.yahoo.com", true, &domain_state));
74 EXPECT_EQ(domain_state.upgrade_mode,
75 TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
76 EXPECT_TRUE(state_.GetDomainState("foo.bar.yahoo.com", true, &domain_state));
77 EXPECT_EQ(domain_state.upgrade_mode,
78 TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
79 EXPECT_TRUE(state_.GetDomainState("foo.bar.baz.yahoo.com", true,
80 &domain_state));
81 EXPECT_EQ(domain_state.upgrade_mode,
82 TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
83 EXPECT_FALSE(state_.GetDomainState("com", true, &domain_state));
84 }
85
TEST_F(TransportSecurityPersisterTest,SerializeData3)86 TEST_F(TransportSecurityPersisterTest, SerializeData3) {
87 // Add an entry.
88 net::HashValue fp1(net::HASH_VALUE_SHA1);
89 memset(fp1.data(), 0, fp1.size());
90 net::HashValue fp2(net::HASH_VALUE_SHA1);
91 memset(fp2.data(), 1, fp2.size());
92 base::Time expiry =
93 base::Time::Now() + base::TimeDelta::FromSeconds(1000);
94 net::HashValueVector dynamic_spki_hashes;
95 dynamic_spki_hashes.push_back(fp1);
96 dynamic_spki_hashes.push_back(fp2);
97 bool include_subdomains = false;
98 state_.AddHSTS("www.example.com", expiry, include_subdomains);
99 state_.AddHPKP("www.example.com", expiry, include_subdomains,
100 dynamic_spki_hashes);
101
102 // Add another entry.
103 memset(fp1.data(), 2, fp1.size());
104 memset(fp2.data(), 3, fp2.size());
105 expiry =
106 base::Time::Now() + base::TimeDelta::FromSeconds(3000);
107 dynamic_spki_hashes.push_back(fp1);
108 dynamic_spki_hashes.push_back(fp2);
109 state_.AddHSTS("www.example.net", expiry, include_subdomains);
110 state_.AddHPKP("www.example.net", expiry, include_subdomains,
111 dynamic_spki_hashes);
112
113 // Save a copy of everything.
114 std::map<std::string, TransportSecurityState::DomainState> saved;
115 TransportSecurityState::Iterator i(state_);
116 while (i.HasNext()) {
117 saved[i.hostname()] = i.domain_state();
118 i.Advance();
119 }
120
121 std::string serialized;
122 EXPECT_TRUE(persister_->SerializeData(&serialized));
123
124 // Persist the data to the file. For the test to be fast and not flaky, we
125 // just do it directly rather than call persister_->StateIsDirty. (That uses
126 // ImportantFileWriter, which has an asynchronous commit interval rather
127 // than block.) Use a different basename just for cleanliness.
128 base::FilePath path =
129 temp_dir_.path().AppendASCII("TransportSecurityPersisterTest");
130 EXPECT_TRUE(file_util::WriteFile(path, serialized.c_str(),
131 serialized.size()));
132
133 // Read the data back.
134 std::string persisted;
135 EXPECT_TRUE(base::ReadFileToString(path, &persisted));
136 EXPECT_EQ(persisted, serialized);
137 bool dirty;
138 EXPECT_TRUE(persister_->LoadEntries(persisted, &dirty));
139 EXPECT_FALSE(dirty);
140
141 // Check that states are the same as saved.
142 size_t count = 0;
143 TransportSecurityState::Iterator j(state_);
144 while (j.HasNext()) {
145 count++;
146 j.Advance();
147 }
148 EXPECT_EQ(count, saved.size());
149 }
150
TEST_F(TransportSecurityPersisterTest,SerializeDataOld)151 TEST_F(TransportSecurityPersisterTest, SerializeDataOld) {
152 // This is an old-style piece of transport state JSON, which has no creation
153 // date.
154 std::string output =
155 "{ "
156 "\"NiyD+3J1r6z1wjl2n1ALBu94Zj9OsEAMo0kCN8js0Uk=\": {"
157 "\"expiry\": 1266815027.983453, "
158 "\"include_subdomains\": false, "
159 "\"mode\": \"strict\" "
160 "}"
161 "}";
162 bool dirty;
163 EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
164 EXPECT_TRUE(dirty);
165 }
166
TEST_F(TransportSecurityPersisterTest,PublicKeyHashes)167 TEST_F(TransportSecurityPersisterTest, PublicKeyHashes) {
168 TransportSecurityState::DomainState domain_state;
169 static const char kTestDomain[] = "example.com";
170 EXPECT_FALSE(state_.GetDomainState(kTestDomain, false, &domain_state));
171 net::HashValueVector hashes;
172 EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
173
174 net::HashValue sha1(net::HASH_VALUE_SHA1);
175 memset(sha1.data(), '1', sha1.size());
176 domain_state.dynamic_spki_hashes.push_back(sha1);
177
178 EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
179
180 hashes.push_back(sha1);
181 EXPECT_TRUE(domain_state.CheckPublicKeyPins(hashes));
182
183 hashes[0].data()[0] = '2';
184 EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
185
186 const base::Time current_time(base::Time::Now());
187 const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
188 bool include_subdomains = false;
189 state_.AddHSTS(kTestDomain, expiry, include_subdomains);
190 state_.AddHPKP(kTestDomain, expiry, include_subdomains,
191 domain_state.dynamic_spki_hashes);
192 std::string ser;
193 EXPECT_TRUE(persister_->SerializeData(&ser));
194 bool dirty;
195 EXPECT_TRUE(persister_->LoadEntries(ser, &dirty));
196 EXPECT_TRUE(state_.GetDomainState(kTestDomain, false, &domain_state));
197 EXPECT_EQ(1u, domain_state.dynamic_spki_hashes.size());
198 EXPECT_EQ(sha1.tag, domain_state.dynamic_spki_hashes[0].tag);
199 EXPECT_EQ(0, memcmp(domain_state.dynamic_spki_hashes[0].data(), sha1.data(),
200 sha1.size()));
201 }
202