• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/ssl/default_channel_id_store.h"
6 
7 #include <map>
8 #include <string>
9 #include <vector>
10 
11 #include "base/bind.h"
12 #include "base/compiler_specific.h"
13 #include "base/logging.h"
14 #include "base/memory/scoped_ptr.h"
15 #include "base/message_loop/message_loop.h"
16 #include "net/base/net_errors.h"
17 #include "testing/gtest/include/gtest/gtest.h"
18 
19 namespace net {
20 
21 namespace {
22 
CallCounter(int * counter)23 void CallCounter(int* counter) {
24   (*counter)++;
25 }
26 
GetChannelIDCallbackNotCalled(int err,const std::string & server_identifier,base::Time expiration_time,const std::string & private_key_result,const std::string & cert_result)27 void GetChannelIDCallbackNotCalled(int err,
28                                    const std::string& server_identifier,
29                                    base::Time expiration_time,
30                                    const std::string& private_key_result,
31                                    const std::string& cert_result) {
32   ADD_FAILURE() << "Unexpected callback execution.";
33 }
34 
35 class AsyncGetChannelIDHelper {
36  public:
AsyncGetChannelIDHelper()37   AsyncGetChannelIDHelper() : called_(false) {}
38 
Callback(int err,const std::string & server_identifier,base::Time expiration_time,const std::string & private_key_result,const std::string & cert_result)39   void Callback(int err,
40                 const std::string& server_identifier,
41                 base::Time expiration_time,
42                 const std::string& private_key_result,
43                 const std::string& cert_result) {
44     err_ = err;
45     server_identifier_ = server_identifier;
46     expiration_time_ = expiration_time;
47     private_key_ = private_key_result;
48     cert_ = cert_result;
49     called_ = true;
50   }
51 
52   int err_;
53   std::string server_identifier_;
54   base::Time expiration_time_;
55   std::string private_key_;
56   std::string cert_;
57   bool called_;
58 };
59 
GetAllCallback(ChannelIDStore::ChannelIDList * dest,const ChannelIDStore::ChannelIDList & result)60 void GetAllCallback(
61     ChannelIDStore::ChannelIDList* dest,
62     const ChannelIDStore::ChannelIDList& result) {
63   *dest = result;
64 }
65 
66 class MockPersistentStore
67     : public DefaultChannelIDStore::PersistentStore {
68  public:
69   MockPersistentStore();
70 
71   // DefaultChannelIDStore::PersistentStore implementation.
72   virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE;
73   virtual void AddChannelID(
74       const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
75   virtual void DeleteChannelID(
76       const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
77   virtual void SetForceKeepSessionState() OVERRIDE;
78 
79  protected:
80   virtual ~MockPersistentStore();
81 
82  private:
83   typedef std::map<std::string, DefaultChannelIDStore::ChannelID>
84       ChannelIDMap;
85 
86   ChannelIDMap channel_ids_;
87 };
88 
MockPersistentStore()89 MockPersistentStore::MockPersistentStore() {}
90 
Load(const LoadedCallback & loaded_callback)91 void MockPersistentStore::Load(const LoadedCallback& loaded_callback) {
92   scoped_ptr<ScopedVector<DefaultChannelIDStore::ChannelID> >
93       channel_ids(new ScopedVector<DefaultChannelIDStore::ChannelID>());
94   ChannelIDMap::iterator it;
95 
96   for (it = channel_ids_.begin(); it != channel_ids_.end(); ++it) {
97     channel_ids->push_back(
98         new DefaultChannelIDStore::ChannelID(it->second));
99   }
100 
101   base::MessageLoop::current()->PostTask(
102       FROM_HERE, base::Bind(loaded_callback, base::Passed(&channel_ids)));
103 }
104 
AddChannelID(const DefaultChannelIDStore::ChannelID & channel_id)105 void MockPersistentStore::AddChannelID(
106     const DefaultChannelIDStore::ChannelID& channel_id) {
107   channel_ids_[channel_id.server_identifier()] = channel_id;
108 }
109 
DeleteChannelID(const DefaultChannelIDStore::ChannelID & channel_id)110 void MockPersistentStore::DeleteChannelID(
111     const DefaultChannelIDStore::ChannelID& channel_id) {
112   channel_ids_.erase(channel_id.server_identifier());
113 }
114 
SetForceKeepSessionState()115 void MockPersistentStore::SetForceKeepSessionState() {}
116 
~MockPersistentStore()117 MockPersistentStore::~MockPersistentStore() {}
118 
119 }  // namespace
120 
TEST(DefaultChannelIDStoreTest,TestLoading)121 TEST(DefaultChannelIDStoreTest, TestLoading) {
122   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
123 
124   persistent_store->AddChannelID(
125       DefaultChannelIDStore::ChannelID(
126           "google.com",
127           base::Time(),
128           base::Time(),
129           "a", "b"));
130   persistent_store->AddChannelID(
131       DefaultChannelIDStore::ChannelID(
132           "verisign.com",
133           base::Time(),
134           base::Time(),
135           "c", "d"));
136 
137   // Make sure channel_ids load properly.
138   DefaultChannelIDStore store(persistent_store.get());
139   // Load has not occurred yet.
140   EXPECT_EQ(0, store.GetChannelIDCount());
141   store.SetChannelID(
142       "verisign.com",
143       base::Time(),
144       base::Time(),
145       "e", "f");
146   // Wait for load & queued set task.
147   base::MessageLoop::current()->RunUntilIdle();
148   EXPECT_EQ(2, store.GetChannelIDCount());
149   store.SetChannelID(
150       "twitter.com",
151       base::Time(),
152       base::Time(),
153       "g", "h");
154   // Set should be synchronous now that load is done.
155   EXPECT_EQ(3, store.GetChannelIDCount());
156 }
157 
158 //TODO(mattm): add more tests of without a persistent store?
TEST(DefaultChannelIDStoreTest,TestSettingAndGetting)159 TEST(DefaultChannelIDStoreTest, TestSettingAndGetting) {
160   // No persistent store, all calls will be synchronous.
161   DefaultChannelIDStore store(NULL);
162   base::Time expiration_time;
163   std::string private_key, cert;
164   EXPECT_EQ(0, store.GetChannelIDCount());
165   EXPECT_EQ(ERR_FILE_NOT_FOUND,
166             store.GetChannelID("verisign.com",
167                                &expiration_time,
168                                &private_key,
169                                &cert,
170                                base::Bind(&GetChannelIDCallbackNotCalled)));
171   EXPECT_TRUE(private_key.empty());
172   EXPECT_TRUE(cert.empty());
173   store.SetChannelID(
174       "verisign.com",
175       base::Time::FromInternalValue(123),
176       base::Time::FromInternalValue(456),
177       "i", "j");
178   EXPECT_EQ(OK,
179             store.GetChannelID("verisign.com",
180                                &expiration_time,
181                                &private_key,
182                                &cert,
183                                base::Bind(&GetChannelIDCallbackNotCalled)));
184   EXPECT_EQ(456, expiration_time.ToInternalValue());
185   EXPECT_EQ("i", private_key);
186   EXPECT_EQ("j", cert);
187 }
188 
TEST(DefaultChannelIDStoreTest,TestDuplicateChannelIds)189 TEST(DefaultChannelIDStoreTest, TestDuplicateChannelIds) {
190   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
191   DefaultChannelIDStore store(persistent_store.get());
192 
193   base::Time expiration_time;
194   std::string private_key, cert;
195   EXPECT_EQ(0, store.GetChannelIDCount());
196   store.SetChannelID(
197       "verisign.com",
198       base::Time::FromInternalValue(123),
199       base::Time::FromInternalValue(1234),
200       "a", "b");
201   store.SetChannelID(
202       "verisign.com",
203       base::Time::FromInternalValue(456),
204       base::Time::FromInternalValue(4567),
205       "c", "d");
206 
207   // Wait for load & queued set tasks.
208   base::MessageLoop::current()->RunUntilIdle();
209   EXPECT_EQ(1, store.GetChannelIDCount());
210   EXPECT_EQ(OK,
211             store.GetChannelID("verisign.com",
212                                &expiration_time,
213                                &private_key,
214                                &cert,
215                                base::Bind(&GetChannelIDCallbackNotCalled)));
216   EXPECT_EQ(4567, expiration_time.ToInternalValue());
217   EXPECT_EQ("c", private_key);
218   EXPECT_EQ("d", cert);
219 }
220 
TEST(DefaultChannelIDStoreTest,TestAsyncGet)221 TEST(DefaultChannelIDStoreTest, TestAsyncGet) {
222   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
223   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
224       "verisign.com",
225       base::Time::FromInternalValue(123),
226       base::Time::FromInternalValue(1234),
227       "a", "b"));
228 
229   DefaultChannelIDStore store(persistent_store.get());
230   AsyncGetChannelIDHelper helper;
231   base::Time expiration_time;
232   std::string private_key;
233   std::string cert = "not set";
234   EXPECT_EQ(0, store.GetChannelIDCount());
235   EXPECT_EQ(ERR_IO_PENDING,
236             store.GetChannelID("verisign.com",
237                                &expiration_time,
238                                &private_key,
239                                &cert,
240                                base::Bind(&AsyncGetChannelIDHelper::Callback,
241                                           base::Unretained(&helper))));
242 
243   // Wait for load & queued get tasks.
244   base::MessageLoop::current()->RunUntilIdle();
245   EXPECT_EQ(1, store.GetChannelIDCount());
246   EXPECT_EQ("not set", cert);
247   EXPECT_TRUE(helper.called_);
248   EXPECT_EQ(OK, helper.err_);
249   EXPECT_EQ("verisign.com", helper.server_identifier_);
250   EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue());
251   EXPECT_EQ("a", helper.private_key_);
252   EXPECT_EQ("b", helper.cert_);
253 }
254 
TEST(DefaultChannelIDStoreTest,TestDeleteAll)255 TEST(DefaultChannelIDStoreTest, TestDeleteAll) {
256   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
257   DefaultChannelIDStore store(persistent_store.get());
258 
259   store.SetChannelID(
260       "verisign.com",
261       base::Time(),
262       base::Time(),
263       "a", "b");
264   store.SetChannelID(
265       "google.com",
266       base::Time(),
267       base::Time(),
268       "c", "d");
269   store.SetChannelID(
270       "harvard.com",
271       base::Time(),
272       base::Time(),
273       "e", "f");
274   // Wait for load & queued set tasks.
275   base::MessageLoop::current()->RunUntilIdle();
276 
277   EXPECT_EQ(3, store.GetChannelIDCount());
278   int delete_finished = 0;
279   store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
280   ASSERT_EQ(1, delete_finished);
281   EXPECT_EQ(0, store.GetChannelIDCount());
282 }
283 
TEST(DefaultChannelIDStoreTest,TestAsyncGetAndDeleteAll)284 TEST(DefaultChannelIDStoreTest, TestAsyncGetAndDeleteAll) {
285   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
286   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
287       "verisign.com",
288       base::Time(),
289       base::Time(),
290       "a", "b"));
291   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
292       "google.com",
293       base::Time(),
294       base::Time(),
295       "c", "d"));
296 
297   ChannelIDStore::ChannelIDList pre_channel_ids;
298   ChannelIDStore::ChannelIDList post_channel_ids;
299   int delete_finished = 0;
300   DefaultChannelIDStore store(persistent_store.get());
301 
302   store.GetAllChannelIDs(base::Bind(GetAllCallback, &pre_channel_ids));
303   store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
304   store.GetAllChannelIDs(base::Bind(GetAllCallback, &post_channel_ids));
305   // Tasks have not run yet.
306   EXPECT_EQ(0u, pre_channel_ids.size());
307   // Wait for load & queued tasks.
308   base::MessageLoop::current()->RunUntilIdle();
309   EXPECT_EQ(0, store.GetChannelIDCount());
310   EXPECT_EQ(2u, pre_channel_ids.size());
311   EXPECT_EQ(0u, post_channel_ids.size());
312 }
313 
TEST(DefaultChannelIDStoreTest,TestDelete)314 TEST(DefaultChannelIDStoreTest, TestDelete) {
315   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
316   DefaultChannelIDStore store(persistent_store.get());
317 
318   base::Time expiration_time;
319   std::string private_key, cert;
320   EXPECT_EQ(0, store.GetChannelIDCount());
321   store.SetChannelID(
322       "verisign.com",
323       base::Time(),
324       base::Time(),
325       "a", "b");
326   // Wait for load & queued set task.
327   base::MessageLoop::current()->RunUntilIdle();
328 
329   store.SetChannelID(
330       "google.com",
331       base::Time(),
332       base::Time(),
333       "c", "d");
334 
335   EXPECT_EQ(2, store.GetChannelIDCount());
336   int delete_finished = 0;
337   store.DeleteChannelID("verisign.com",
338                               base::Bind(&CallCounter, &delete_finished));
339   ASSERT_EQ(1, delete_finished);
340   EXPECT_EQ(1, store.GetChannelIDCount());
341   EXPECT_EQ(ERR_FILE_NOT_FOUND,
342             store.GetChannelID("verisign.com",
343                                &expiration_time,
344                                &private_key,
345                                &cert,
346                                base::Bind(&GetChannelIDCallbackNotCalled)));
347   EXPECT_EQ(OK,
348             store.GetChannelID("google.com",
349                                &expiration_time,
350                                &private_key,
351                                &cert,
352                                base::Bind(&GetChannelIDCallbackNotCalled)));
353   int delete2_finished = 0;
354   store.DeleteChannelID("google.com",
355                         base::Bind(&CallCounter, &delete2_finished));
356   ASSERT_EQ(1, delete2_finished);
357   EXPECT_EQ(0, store.GetChannelIDCount());
358   EXPECT_EQ(ERR_FILE_NOT_FOUND,
359             store.GetChannelID("google.com",
360                                &expiration_time,
361                                &private_key,
362                                &cert,
363                                base::Bind(&GetChannelIDCallbackNotCalled)));
364 }
365 
TEST(DefaultChannelIDStoreTest,TestAsyncDelete)366 TEST(DefaultChannelIDStoreTest, TestAsyncDelete) {
367   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
368   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
369       "a.com",
370       base::Time::FromInternalValue(1),
371       base::Time::FromInternalValue(2),
372       "a", "b"));
373   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
374       "b.com",
375       base::Time::FromInternalValue(3),
376       base::Time::FromInternalValue(4),
377       "c", "d"));
378   DefaultChannelIDStore store(persistent_store.get());
379   int delete_finished = 0;
380   store.DeleteChannelID("a.com",
381                         base::Bind(&CallCounter, &delete_finished));
382 
383   AsyncGetChannelIDHelper a_helper;
384   AsyncGetChannelIDHelper b_helper;
385   base::Time expiration_time;
386   std::string private_key;
387   std::string cert = "not set";
388   EXPECT_EQ(0, store.GetChannelIDCount());
389   EXPECT_EQ(ERR_IO_PENDING,
390       store.GetChannelID(
391           "a.com", &expiration_time, &private_key, &cert,
392           base::Bind(&AsyncGetChannelIDHelper::Callback,
393                      base::Unretained(&a_helper))));
394   EXPECT_EQ(ERR_IO_PENDING,
395       store.GetChannelID(
396           "b.com", &expiration_time, &private_key, &cert,
397           base::Bind(&AsyncGetChannelIDHelper::Callback,
398                      base::Unretained(&b_helper))));
399 
400   EXPECT_EQ(0, delete_finished);
401   EXPECT_FALSE(a_helper.called_);
402   EXPECT_FALSE(b_helper.called_);
403   // Wait for load & queued tasks.
404   base::MessageLoop::current()->RunUntilIdle();
405   EXPECT_EQ(1, delete_finished);
406   EXPECT_EQ(1, store.GetChannelIDCount());
407   EXPECT_EQ("not set", cert);
408   EXPECT_TRUE(a_helper.called_);
409   EXPECT_EQ(ERR_FILE_NOT_FOUND, a_helper.err_);
410   EXPECT_EQ("a.com", a_helper.server_identifier_);
411   EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue());
412   EXPECT_EQ("", a_helper.private_key_);
413   EXPECT_EQ("", a_helper.cert_);
414   EXPECT_TRUE(b_helper.called_);
415   EXPECT_EQ(OK, b_helper.err_);
416   EXPECT_EQ("b.com", b_helper.server_identifier_);
417   EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue());
418   EXPECT_EQ("c", b_helper.private_key_);
419   EXPECT_EQ("d", b_helper.cert_);
420 }
421 
TEST(DefaultChannelIDStoreTest,TestGetAll)422 TEST(DefaultChannelIDStoreTest, TestGetAll) {
423   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
424   DefaultChannelIDStore store(persistent_store.get());
425 
426   EXPECT_EQ(0, store.GetChannelIDCount());
427   store.SetChannelID(
428       "verisign.com",
429       base::Time(),
430       base::Time(),
431       "a", "b");
432   store.SetChannelID(
433       "google.com",
434       base::Time(),
435       base::Time(),
436       "c", "d");
437   store.SetChannelID(
438       "harvard.com",
439       base::Time(),
440       base::Time(),
441       "e", "f");
442   store.SetChannelID(
443       "mit.com",
444       base::Time(),
445       base::Time(),
446       "g", "h");
447   // Wait for load & queued set tasks.
448   base::MessageLoop::current()->RunUntilIdle();
449 
450   EXPECT_EQ(4, store.GetChannelIDCount());
451   ChannelIDStore::ChannelIDList channel_ids;
452   store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
453   EXPECT_EQ(4u, channel_ids.size());
454 }
455 
TEST(DefaultChannelIDStoreTest,TestInitializeFrom)456 TEST(DefaultChannelIDStoreTest, TestInitializeFrom) {
457   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
458   DefaultChannelIDStore store(persistent_store.get());
459 
460   store.SetChannelID(
461       "preexisting.com",
462       base::Time(),
463       base::Time(),
464       "a", "b");
465   store.SetChannelID(
466       "both.com",
467       base::Time(),
468       base::Time(),
469       "c", "d");
470   // Wait for load & queued set tasks.
471   base::MessageLoop::current()->RunUntilIdle();
472   EXPECT_EQ(2, store.GetChannelIDCount());
473 
474   ChannelIDStore::ChannelIDList source_channel_ids;
475   source_channel_ids.push_back(ChannelIDStore::ChannelID(
476       "both.com",
477       base::Time(),
478       base::Time(),
479       // Key differs from above to test that existing entries are overwritten.
480       "e", "f"));
481   source_channel_ids.push_back(ChannelIDStore::ChannelID(
482       "copied.com",
483       base::Time(),
484       base::Time(),
485       "g", "h"));
486   store.InitializeFrom(source_channel_ids);
487   EXPECT_EQ(3, store.GetChannelIDCount());
488 
489   ChannelIDStore::ChannelIDList channel_ids;
490   store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
491   ASSERT_EQ(3u, channel_ids.size());
492 
493   ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
494   EXPECT_EQ("both.com", channel_id->server_identifier());
495   EXPECT_EQ("e", channel_id->private_key());
496 
497   ++channel_id;
498   EXPECT_EQ("copied.com", channel_id->server_identifier());
499   EXPECT_EQ("g", channel_id->private_key());
500 
501   ++channel_id;
502   EXPECT_EQ("preexisting.com", channel_id->server_identifier());
503   EXPECT_EQ("a", channel_id->private_key());
504 }
505 
TEST(DefaultChannelIDStoreTest,TestAsyncInitializeFrom)506 TEST(DefaultChannelIDStoreTest, TestAsyncInitializeFrom) {
507   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
508   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
509       "preexisting.com",
510       base::Time(),
511       base::Time(),
512       "a", "b"));
513   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
514       "both.com",
515       base::Time(),
516       base::Time(),
517       "c", "d"));
518 
519   DefaultChannelIDStore store(persistent_store.get());
520   ChannelIDStore::ChannelIDList source_channel_ids;
521   source_channel_ids.push_back(ChannelIDStore::ChannelID(
522       "both.com",
523       base::Time(),
524       base::Time(),
525       // Key differs from above to test that existing entries are overwritten.
526       "e", "f"));
527   source_channel_ids.push_back(ChannelIDStore::ChannelID(
528       "copied.com",
529       base::Time(),
530       base::Time(),
531       "g", "h"));
532   store.InitializeFrom(source_channel_ids);
533   EXPECT_EQ(0, store.GetChannelIDCount());
534   // Wait for load & queued tasks.
535   base::MessageLoop::current()->RunUntilIdle();
536   EXPECT_EQ(3, store.GetChannelIDCount());
537 
538   ChannelIDStore::ChannelIDList channel_ids;
539   store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
540   ASSERT_EQ(3u, channel_ids.size());
541 
542   ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
543   EXPECT_EQ("both.com", channel_id->server_identifier());
544   EXPECT_EQ("e", channel_id->private_key());
545 
546   ++channel_id;
547   EXPECT_EQ("copied.com", channel_id->server_identifier());
548   EXPECT_EQ("g", channel_id->private_key());
549 
550   ++channel_id;
551   EXPECT_EQ("preexisting.com", channel_id->server_identifier());
552   EXPECT_EQ("a", channel_id->private_key());
553 }
554 
555 }  // namespace net
556