1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/ssl/default_channel_id_store.h"
6
7 #include <map>
8 #include <string>
9 #include <vector>
10
11 #include "base/bind.h"
12 #include "base/compiler_specific.h"
13 #include "base/logging.h"
14 #include "base/memory/scoped_ptr.h"
15 #include "base/message_loop/message_loop.h"
16 #include "net/base/net_errors.h"
17 #include "testing/gtest/include/gtest/gtest.h"
18
19 namespace net {
20
21 namespace {
22
CallCounter(int * counter)23 void CallCounter(int* counter) {
24 (*counter)++;
25 }
26
GetChannelIDCallbackNotCalled(int err,const std::string & server_identifier,base::Time expiration_time,const std::string & private_key_result,const std::string & cert_result)27 void GetChannelIDCallbackNotCalled(int err,
28 const std::string& server_identifier,
29 base::Time expiration_time,
30 const std::string& private_key_result,
31 const std::string& cert_result) {
32 ADD_FAILURE() << "Unexpected callback execution.";
33 }
34
35 class AsyncGetChannelIDHelper {
36 public:
AsyncGetChannelIDHelper()37 AsyncGetChannelIDHelper() : called_(false) {}
38
Callback(int err,const std::string & server_identifier,base::Time expiration_time,const std::string & private_key_result,const std::string & cert_result)39 void Callback(int err,
40 const std::string& server_identifier,
41 base::Time expiration_time,
42 const std::string& private_key_result,
43 const std::string& cert_result) {
44 err_ = err;
45 server_identifier_ = server_identifier;
46 expiration_time_ = expiration_time;
47 private_key_ = private_key_result;
48 cert_ = cert_result;
49 called_ = true;
50 }
51
52 int err_;
53 std::string server_identifier_;
54 base::Time expiration_time_;
55 std::string private_key_;
56 std::string cert_;
57 bool called_;
58 };
59
GetAllCallback(ChannelIDStore::ChannelIDList * dest,const ChannelIDStore::ChannelIDList & result)60 void GetAllCallback(
61 ChannelIDStore::ChannelIDList* dest,
62 const ChannelIDStore::ChannelIDList& result) {
63 *dest = result;
64 }
65
66 class MockPersistentStore
67 : public DefaultChannelIDStore::PersistentStore {
68 public:
69 MockPersistentStore();
70
71 // DefaultChannelIDStore::PersistentStore implementation.
72 virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE;
73 virtual void AddChannelID(
74 const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
75 virtual void DeleteChannelID(
76 const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
77 virtual void SetForceKeepSessionState() OVERRIDE;
78
79 protected:
80 virtual ~MockPersistentStore();
81
82 private:
83 typedef std::map<std::string, DefaultChannelIDStore::ChannelID>
84 ChannelIDMap;
85
86 ChannelIDMap channel_ids_;
87 };
88
MockPersistentStore()89 MockPersistentStore::MockPersistentStore() {}
90
Load(const LoadedCallback & loaded_callback)91 void MockPersistentStore::Load(const LoadedCallback& loaded_callback) {
92 scoped_ptr<ScopedVector<DefaultChannelIDStore::ChannelID> >
93 channel_ids(new ScopedVector<DefaultChannelIDStore::ChannelID>());
94 ChannelIDMap::iterator it;
95
96 for (it = channel_ids_.begin(); it != channel_ids_.end(); ++it) {
97 channel_ids->push_back(
98 new DefaultChannelIDStore::ChannelID(it->second));
99 }
100
101 base::MessageLoop::current()->PostTask(
102 FROM_HERE, base::Bind(loaded_callback, base::Passed(&channel_ids)));
103 }
104
AddChannelID(const DefaultChannelIDStore::ChannelID & channel_id)105 void MockPersistentStore::AddChannelID(
106 const DefaultChannelIDStore::ChannelID& channel_id) {
107 channel_ids_[channel_id.server_identifier()] = channel_id;
108 }
109
DeleteChannelID(const DefaultChannelIDStore::ChannelID & channel_id)110 void MockPersistentStore::DeleteChannelID(
111 const DefaultChannelIDStore::ChannelID& channel_id) {
112 channel_ids_.erase(channel_id.server_identifier());
113 }
114
SetForceKeepSessionState()115 void MockPersistentStore::SetForceKeepSessionState() {}
116
~MockPersistentStore()117 MockPersistentStore::~MockPersistentStore() {}
118
119 } // namespace
120
TEST(DefaultChannelIDStoreTest,TestLoading)121 TEST(DefaultChannelIDStoreTest, TestLoading) {
122 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
123
124 persistent_store->AddChannelID(
125 DefaultChannelIDStore::ChannelID(
126 "google.com",
127 base::Time(),
128 base::Time(),
129 "a", "b"));
130 persistent_store->AddChannelID(
131 DefaultChannelIDStore::ChannelID(
132 "verisign.com",
133 base::Time(),
134 base::Time(),
135 "c", "d"));
136
137 // Make sure channel_ids load properly.
138 DefaultChannelIDStore store(persistent_store.get());
139 // Load has not occurred yet.
140 EXPECT_EQ(0, store.GetChannelIDCount());
141 store.SetChannelID(
142 "verisign.com",
143 base::Time(),
144 base::Time(),
145 "e", "f");
146 // Wait for load & queued set task.
147 base::MessageLoop::current()->RunUntilIdle();
148 EXPECT_EQ(2, store.GetChannelIDCount());
149 store.SetChannelID(
150 "twitter.com",
151 base::Time(),
152 base::Time(),
153 "g", "h");
154 // Set should be synchronous now that load is done.
155 EXPECT_EQ(3, store.GetChannelIDCount());
156 }
157
158 //TODO(mattm): add more tests of without a persistent store?
TEST(DefaultChannelIDStoreTest,TestSettingAndGetting)159 TEST(DefaultChannelIDStoreTest, TestSettingAndGetting) {
160 // No persistent store, all calls will be synchronous.
161 DefaultChannelIDStore store(NULL);
162 base::Time expiration_time;
163 std::string private_key, cert;
164 EXPECT_EQ(0, store.GetChannelIDCount());
165 EXPECT_EQ(ERR_FILE_NOT_FOUND,
166 store.GetChannelID("verisign.com",
167 &expiration_time,
168 &private_key,
169 &cert,
170 base::Bind(&GetChannelIDCallbackNotCalled)));
171 EXPECT_TRUE(private_key.empty());
172 EXPECT_TRUE(cert.empty());
173 store.SetChannelID(
174 "verisign.com",
175 base::Time::FromInternalValue(123),
176 base::Time::FromInternalValue(456),
177 "i", "j");
178 EXPECT_EQ(OK,
179 store.GetChannelID("verisign.com",
180 &expiration_time,
181 &private_key,
182 &cert,
183 base::Bind(&GetChannelIDCallbackNotCalled)));
184 EXPECT_EQ(456, expiration_time.ToInternalValue());
185 EXPECT_EQ("i", private_key);
186 EXPECT_EQ("j", cert);
187 }
188
TEST(DefaultChannelIDStoreTest,TestDuplicateChannelIds)189 TEST(DefaultChannelIDStoreTest, TestDuplicateChannelIds) {
190 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
191 DefaultChannelIDStore store(persistent_store.get());
192
193 base::Time expiration_time;
194 std::string private_key, cert;
195 EXPECT_EQ(0, store.GetChannelIDCount());
196 store.SetChannelID(
197 "verisign.com",
198 base::Time::FromInternalValue(123),
199 base::Time::FromInternalValue(1234),
200 "a", "b");
201 store.SetChannelID(
202 "verisign.com",
203 base::Time::FromInternalValue(456),
204 base::Time::FromInternalValue(4567),
205 "c", "d");
206
207 // Wait for load & queued set tasks.
208 base::MessageLoop::current()->RunUntilIdle();
209 EXPECT_EQ(1, store.GetChannelIDCount());
210 EXPECT_EQ(OK,
211 store.GetChannelID("verisign.com",
212 &expiration_time,
213 &private_key,
214 &cert,
215 base::Bind(&GetChannelIDCallbackNotCalled)));
216 EXPECT_EQ(4567, expiration_time.ToInternalValue());
217 EXPECT_EQ("c", private_key);
218 EXPECT_EQ("d", cert);
219 }
220
TEST(DefaultChannelIDStoreTest,TestAsyncGet)221 TEST(DefaultChannelIDStoreTest, TestAsyncGet) {
222 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
223 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
224 "verisign.com",
225 base::Time::FromInternalValue(123),
226 base::Time::FromInternalValue(1234),
227 "a", "b"));
228
229 DefaultChannelIDStore store(persistent_store.get());
230 AsyncGetChannelIDHelper helper;
231 base::Time expiration_time;
232 std::string private_key;
233 std::string cert = "not set";
234 EXPECT_EQ(0, store.GetChannelIDCount());
235 EXPECT_EQ(ERR_IO_PENDING,
236 store.GetChannelID("verisign.com",
237 &expiration_time,
238 &private_key,
239 &cert,
240 base::Bind(&AsyncGetChannelIDHelper::Callback,
241 base::Unretained(&helper))));
242
243 // Wait for load & queued get tasks.
244 base::MessageLoop::current()->RunUntilIdle();
245 EXPECT_EQ(1, store.GetChannelIDCount());
246 EXPECT_EQ("not set", cert);
247 EXPECT_TRUE(helper.called_);
248 EXPECT_EQ(OK, helper.err_);
249 EXPECT_EQ("verisign.com", helper.server_identifier_);
250 EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue());
251 EXPECT_EQ("a", helper.private_key_);
252 EXPECT_EQ("b", helper.cert_);
253 }
254
TEST(DefaultChannelIDStoreTest,TestDeleteAll)255 TEST(DefaultChannelIDStoreTest, TestDeleteAll) {
256 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
257 DefaultChannelIDStore store(persistent_store.get());
258
259 store.SetChannelID(
260 "verisign.com",
261 base::Time(),
262 base::Time(),
263 "a", "b");
264 store.SetChannelID(
265 "google.com",
266 base::Time(),
267 base::Time(),
268 "c", "d");
269 store.SetChannelID(
270 "harvard.com",
271 base::Time(),
272 base::Time(),
273 "e", "f");
274 // Wait for load & queued set tasks.
275 base::MessageLoop::current()->RunUntilIdle();
276
277 EXPECT_EQ(3, store.GetChannelIDCount());
278 int delete_finished = 0;
279 store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
280 ASSERT_EQ(1, delete_finished);
281 EXPECT_EQ(0, store.GetChannelIDCount());
282 }
283
TEST(DefaultChannelIDStoreTest,TestAsyncGetAndDeleteAll)284 TEST(DefaultChannelIDStoreTest, TestAsyncGetAndDeleteAll) {
285 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
286 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
287 "verisign.com",
288 base::Time(),
289 base::Time(),
290 "a", "b"));
291 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
292 "google.com",
293 base::Time(),
294 base::Time(),
295 "c", "d"));
296
297 ChannelIDStore::ChannelIDList pre_channel_ids;
298 ChannelIDStore::ChannelIDList post_channel_ids;
299 int delete_finished = 0;
300 DefaultChannelIDStore store(persistent_store.get());
301
302 store.GetAllChannelIDs(base::Bind(GetAllCallback, &pre_channel_ids));
303 store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
304 store.GetAllChannelIDs(base::Bind(GetAllCallback, &post_channel_ids));
305 // Tasks have not run yet.
306 EXPECT_EQ(0u, pre_channel_ids.size());
307 // Wait for load & queued tasks.
308 base::MessageLoop::current()->RunUntilIdle();
309 EXPECT_EQ(0, store.GetChannelIDCount());
310 EXPECT_EQ(2u, pre_channel_ids.size());
311 EXPECT_EQ(0u, post_channel_ids.size());
312 }
313
TEST(DefaultChannelIDStoreTest,TestDelete)314 TEST(DefaultChannelIDStoreTest, TestDelete) {
315 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
316 DefaultChannelIDStore store(persistent_store.get());
317
318 base::Time expiration_time;
319 std::string private_key, cert;
320 EXPECT_EQ(0, store.GetChannelIDCount());
321 store.SetChannelID(
322 "verisign.com",
323 base::Time(),
324 base::Time(),
325 "a", "b");
326 // Wait for load & queued set task.
327 base::MessageLoop::current()->RunUntilIdle();
328
329 store.SetChannelID(
330 "google.com",
331 base::Time(),
332 base::Time(),
333 "c", "d");
334
335 EXPECT_EQ(2, store.GetChannelIDCount());
336 int delete_finished = 0;
337 store.DeleteChannelID("verisign.com",
338 base::Bind(&CallCounter, &delete_finished));
339 ASSERT_EQ(1, delete_finished);
340 EXPECT_EQ(1, store.GetChannelIDCount());
341 EXPECT_EQ(ERR_FILE_NOT_FOUND,
342 store.GetChannelID("verisign.com",
343 &expiration_time,
344 &private_key,
345 &cert,
346 base::Bind(&GetChannelIDCallbackNotCalled)));
347 EXPECT_EQ(OK,
348 store.GetChannelID("google.com",
349 &expiration_time,
350 &private_key,
351 &cert,
352 base::Bind(&GetChannelIDCallbackNotCalled)));
353 int delete2_finished = 0;
354 store.DeleteChannelID("google.com",
355 base::Bind(&CallCounter, &delete2_finished));
356 ASSERT_EQ(1, delete2_finished);
357 EXPECT_EQ(0, store.GetChannelIDCount());
358 EXPECT_EQ(ERR_FILE_NOT_FOUND,
359 store.GetChannelID("google.com",
360 &expiration_time,
361 &private_key,
362 &cert,
363 base::Bind(&GetChannelIDCallbackNotCalled)));
364 }
365
TEST(DefaultChannelIDStoreTest,TestAsyncDelete)366 TEST(DefaultChannelIDStoreTest, TestAsyncDelete) {
367 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
368 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
369 "a.com",
370 base::Time::FromInternalValue(1),
371 base::Time::FromInternalValue(2),
372 "a", "b"));
373 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
374 "b.com",
375 base::Time::FromInternalValue(3),
376 base::Time::FromInternalValue(4),
377 "c", "d"));
378 DefaultChannelIDStore store(persistent_store.get());
379 int delete_finished = 0;
380 store.DeleteChannelID("a.com",
381 base::Bind(&CallCounter, &delete_finished));
382
383 AsyncGetChannelIDHelper a_helper;
384 AsyncGetChannelIDHelper b_helper;
385 base::Time expiration_time;
386 std::string private_key;
387 std::string cert = "not set";
388 EXPECT_EQ(0, store.GetChannelIDCount());
389 EXPECT_EQ(ERR_IO_PENDING,
390 store.GetChannelID(
391 "a.com", &expiration_time, &private_key, &cert,
392 base::Bind(&AsyncGetChannelIDHelper::Callback,
393 base::Unretained(&a_helper))));
394 EXPECT_EQ(ERR_IO_PENDING,
395 store.GetChannelID(
396 "b.com", &expiration_time, &private_key, &cert,
397 base::Bind(&AsyncGetChannelIDHelper::Callback,
398 base::Unretained(&b_helper))));
399
400 EXPECT_EQ(0, delete_finished);
401 EXPECT_FALSE(a_helper.called_);
402 EXPECT_FALSE(b_helper.called_);
403 // Wait for load & queued tasks.
404 base::MessageLoop::current()->RunUntilIdle();
405 EXPECT_EQ(1, delete_finished);
406 EXPECT_EQ(1, store.GetChannelIDCount());
407 EXPECT_EQ("not set", cert);
408 EXPECT_TRUE(a_helper.called_);
409 EXPECT_EQ(ERR_FILE_NOT_FOUND, a_helper.err_);
410 EXPECT_EQ("a.com", a_helper.server_identifier_);
411 EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue());
412 EXPECT_EQ("", a_helper.private_key_);
413 EXPECT_EQ("", a_helper.cert_);
414 EXPECT_TRUE(b_helper.called_);
415 EXPECT_EQ(OK, b_helper.err_);
416 EXPECT_EQ("b.com", b_helper.server_identifier_);
417 EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue());
418 EXPECT_EQ("c", b_helper.private_key_);
419 EXPECT_EQ("d", b_helper.cert_);
420 }
421
TEST(DefaultChannelIDStoreTest,TestGetAll)422 TEST(DefaultChannelIDStoreTest, TestGetAll) {
423 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
424 DefaultChannelIDStore store(persistent_store.get());
425
426 EXPECT_EQ(0, store.GetChannelIDCount());
427 store.SetChannelID(
428 "verisign.com",
429 base::Time(),
430 base::Time(),
431 "a", "b");
432 store.SetChannelID(
433 "google.com",
434 base::Time(),
435 base::Time(),
436 "c", "d");
437 store.SetChannelID(
438 "harvard.com",
439 base::Time(),
440 base::Time(),
441 "e", "f");
442 store.SetChannelID(
443 "mit.com",
444 base::Time(),
445 base::Time(),
446 "g", "h");
447 // Wait for load & queued set tasks.
448 base::MessageLoop::current()->RunUntilIdle();
449
450 EXPECT_EQ(4, store.GetChannelIDCount());
451 ChannelIDStore::ChannelIDList channel_ids;
452 store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
453 EXPECT_EQ(4u, channel_ids.size());
454 }
455
TEST(DefaultChannelIDStoreTest,TestInitializeFrom)456 TEST(DefaultChannelIDStoreTest, TestInitializeFrom) {
457 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
458 DefaultChannelIDStore store(persistent_store.get());
459
460 store.SetChannelID(
461 "preexisting.com",
462 base::Time(),
463 base::Time(),
464 "a", "b");
465 store.SetChannelID(
466 "both.com",
467 base::Time(),
468 base::Time(),
469 "c", "d");
470 // Wait for load & queued set tasks.
471 base::MessageLoop::current()->RunUntilIdle();
472 EXPECT_EQ(2, store.GetChannelIDCount());
473
474 ChannelIDStore::ChannelIDList source_channel_ids;
475 source_channel_ids.push_back(ChannelIDStore::ChannelID(
476 "both.com",
477 base::Time(),
478 base::Time(),
479 // Key differs from above to test that existing entries are overwritten.
480 "e", "f"));
481 source_channel_ids.push_back(ChannelIDStore::ChannelID(
482 "copied.com",
483 base::Time(),
484 base::Time(),
485 "g", "h"));
486 store.InitializeFrom(source_channel_ids);
487 EXPECT_EQ(3, store.GetChannelIDCount());
488
489 ChannelIDStore::ChannelIDList channel_ids;
490 store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
491 ASSERT_EQ(3u, channel_ids.size());
492
493 ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
494 EXPECT_EQ("both.com", channel_id->server_identifier());
495 EXPECT_EQ("e", channel_id->private_key());
496
497 ++channel_id;
498 EXPECT_EQ("copied.com", channel_id->server_identifier());
499 EXPECT_EQ("g", channel_id->private_key());
500
501 ++channel_id;
502 EXPECT_EQ("preexisting.com", channel_id->server_identifier());
503 EXPECT_EQ("a", channel_id->private_key());
504 }
505
TEST(DefaultChannelIDStoreTest,TestAsyncInitializeFrom)506 TEST(DefaultChannelIDStoreTest, TestAsyncInitializeFrom) {
507 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
508 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
509 "preexisting.com",
510 base::Time(),
511 base::Time(),
512 "a", "b"));
513 persistent_store->AddChannelID(ChannelIDStore::ChannelID(
514 "both.com",
515 base::Time(),
516 base::Time(),
517 "c", "d"));
518
519 DefaultChannelIDStore store(persistent_store.get());
520 ChannelIDStore::ChannelIDList source_channel_ids;
521 source_channel_ids.push_back(ChannelIDStore::ChannelID(
522 "both.com",
523 base::Time(),
524 base::Time(),
525 // Key differs from above to test that existing entries are overwritten.
526 "e", "f"));
527 source_channel_ids.push_back(ChannelIDStore::ChannelID(
528 "copied.com",
529 base::Time(),
530 base::Time(),
531 "g", "h"));
532 store.InitializeFrom(source_channel_ids);
533 EXPECT_EQ(0, store.GetChannelIDCount());
534 // Wait for load & queued tasks.
535 base::MessageLoop::current()->RunUntilIdle();
536 EXPECT_EQ(3, store.GetChannelIDCount());
537
538 ChannelIDStore::ChannelIDList channel_ids;
539 store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
540 ASSERT_EQ(3u, channel_ids.size());
541
542 ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
543 EXPECT_EQ("both.com", channel_id->server_identifier());
544 EXPECT_EQ("e", channel_id->private_key());
545
546 ++channel_id;
547 EXPECT_EQ("copied.com", channel_id->server_identifier());
548 EXPECT_EQ("g", channel_id->private_key());
549
550 ++channel_id;
551 EXPECT_EQ("preexisting.com", channel_id->server_identifier());
552 EXPECT_EQ("a", channel_id->private_key());
553 }
554
555 } // namespace net
556