• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/ssl/default_channel_id_store.h"
6 
7 #include "base/bind.h"
8 #include "base/message_loop/message_loop.h"
9 #include "base/metrics/histogram.h"
10 #include "net/base/net_errors.h"
11 
12 namespace net {
13 
14 // --------------------------------------------------------------------------
15 // Task
16 class DefaultChannelIDStore::Task {
17  public:
18   virtual ~Task();
19 
20   // Runs the task and invokes the client callback on the thread that
21   // originally constructed the task.
22   virtual void Run(DefaultChannelIDStore* store) = 0;
23 
24  protected:
25   void InvokeCallback(base::Closure callback) const;
26 };
27 
~Task()28 DefaultChannelIDStore::Task::~Task() {
29 }
30 
InvokeCallback(base::Closure callback) const31 void DefaultChannelIDStore::Task::InvokeCallback(
32     base::Closure callback) const {
33   if (!callback.is_null())
34     callback.Run();
35 }
36 
37 // --------------------------------------------------------------------------
38 // GetChannelIDTask
39 class DefaultChannelIDStore::GetChannelIDTask
40     : public DefaultChannelIDStore::Task {
41  public:
42   GetChannelIDTask(const std::string& server_identifier,
43                    const GetChannelIDCallback& callback);
44   virtual ~GetChannelIDTask();
45   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
46 
47  private:
48   std::string server_identifier_;
49   GetChannelIDCallback callback_;
50 };
51 
GetChannelIDTask(const std::string & server_identifier,const GetChannelIDCallback & callback)52 DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
53     const std::string& server_identifier,
54     const GetChannelIDCallback& callback)
55     : server_identifier_(server_identifier),
56       callback_(callback) {
57 }
58 
~GetChannelIDTask()59 DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
60 }
61 
Run(DefaultChannelIDStore * store)62 void DefaultChannelIDStore::GetChannelIDTask::Run(
63     DefaultChannelIDStore* store) {
64   base::Time expiration_time;
65   std::string private_key_result;
66   std::string cert_result;
67   int err = store->GetChannelID(
68       server_identifier_, &expiration_time, &private_key_result,
69       &cert_result, GetChannelIDCallback());
70   DCHECK(err != ERR_IO_PENDING);
71 
72   InvokeCallback(base::Bind(callback_, err, server_identifier_,
73                             expiration_time, private_key_result, cert_result));
74 }
75 
76 // --------------------------------------------------------------------------
77 // SetChannelIDTask
78 class DefaultChannelIDStore::SetChannelIDTask
79     : public DefaultChannelIDStore::Task {
80  public:
81   SetChannelIDTask(const std::string& server_identifier,
82                    base::Time creation_time,
83                    base::Time expiration_time,
84                    const std::string& private_key,
85                    const std::string& cert);
86   virtual ~SetChannelIDTask();
87   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
88 
89  private:
90   std::string server_identifier_;
91   base::Time creation_time_;
92   base::Time expiration_time_;
93   std::string private_key_;
94   std::string cert_;
95 };
96 
SetChannelIDTask(const std::string & server_identifier,base::Time creation_time,base::Time expiration_time,const std::string & private_key,const std::string & cert)97 DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
98     const std::string& server_identifier,
99     base::Time creation_time,
100     base::Time expiration_time,
101     const std::string& private_key,
102     const std::string& cert)
103     : server_identifier_(server_identifier),
104       creation_time_(creation_time),
105       expiration_time_(expiration_time),
106       private_key_(private_key),
107       cert_(cert) {
108 }
109 
~SetChannelIDTask()110 DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
111 }
112 
Run(DefaultChannelIDStore * store)113 void DefaultChannelIDStore::SetChannelIDTask::Run(
114     DefaultChannelIDStore* store) {
115   store->SyncSetChannelID(server_identifier_, creation_time_,
116                           expiration_time_, private_key_, cert_);
117 }
118 
119 // --------------------------------------------------------------------------
120 // DeleteChannelIDTask
121 class DefaultChannelIDStore::DeleteChannelIDTask
122     : public DefaultChannelIDStore::Task {
123  public:
124   DeleteChannelIDTask(const std::string& server_identifier,
125                       const base::Closure& callback);
126   virtual ~DeleteChannelIDTask();
127   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
128 
129  private:
130   std::string server_identifier_;
131   base::Closure callback_;
132 };
133 
134 DefaultChannelIDStore::DeleteChannelIDTask::
DeleteChannelIDTask(const std::string & server_identifier,const base::Closure & callback)135     DeleteChannelIDTask(
136         const std::string& server_identifier,
137         const base::Closure& callback)
138         : server_identifier_(server_identifier),
139           callback_(callback) {
140 }
141 
142 DefaultChannelIDStore::DeleteChannelIDTask::
~DeleteChannelIDTask()143     ~DeleteChannelIDTask() {
144 }
145 
Run(DefaultChannelIDStore * store)146 void DefaultChannelIDStore::DeleteChannelIDTask::Run(
147     DefaultChannelIDStore* store) {
148   store->SyncDeleteChannelID(server_identifier_);
149 
150   InvokeCallback(callback_);
151 }
152 
153 // --------------------------------------------------------------------------
154 // DeleteAllCreatedBetweenTask
155 class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
156     : public DefaultChannelIDStore::Task {
157  public:
158   DeleteAllCreatedBetweenTask(base::Time delete_begin,
159                               base::Time delete_end,
160                               const base::Closure& callback);
161   virtual ~DeleteAllCreatedBetweenTask();
162   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
163 
164  private:
165   base::Time delete_begin_;
166   base::Time delete_end_;
167   base::Closure callback_;
168 };
169 
170 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
DeleteAllCreatedBetweenTask(base::Time delete_begin,base::Time delete_end,const base::Closure & callback)171     DeleteAllCreatedBetweenTask(
172         base::Time delete_begin,
173         base::Time delete_end,
174         const base::Closure& callback)
175         : delete_begin_(delete_begin),
176           delete_end_(delete_end),
177           callback_(callback) {
178 }
179 
180 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
~DeleteAllCreatedBetweenTask()181     ~DeleteAllCreatedBetweenTask() {
182 }
183 
Run(DefaultChannelIDStore * store)184 void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
185     DefaultChannelIDStore* store) {
186   store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
187 
188   InvokeCallback(callback_);
189 }
190 
191 // --------------------------------------------------------------------------
192 // GetAllChannelIDsTask
193 class DefaultChannelIDStore::GetAllChannelIDsTask
194     : public DefaultChannelIDStore::Task {
195  public:
196   explicit GetAllChannelIDsTask(const GetChannelIDListCallback& callback);
197   virtual ~GetAllChannelIDsTask();
198   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
199 
200  private:
201   std::string server_identifier_;
202   GetChannelIDListCallback callback_;
203 };
204 
205 DefaultChannelIDStore::GetAllChannelIDsTask::
GetAllChannelIDsTask(const GetChannelIDListCallback & callback)206     GetAllChannelIDsTask(const GetChannelIDListCallback& callback)
207         : callback_(callback) {
208 }
209 
210 DefaultChannelIDStore::GetAllChannelIDsTask::
~GetAllChannelIDsTask()211     ~GetAllChannelIDsTask() {
212 }
213 
Run(DefaultChannelIDStore * store)214 void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
215     DefaultChannelIDStore* store) {
216   ChannelIDList cert_list;
217   store->SyncGetAllChannelIDs(&cert_list);
218 
219   InvokeCallback(base::Bind(callback_, cert_list));
220 }
221 
222 // --------------------------------------------------------------------------
223 // DefaultChannelIDStore
224 
DefaultChannelIDStore(PersistentStore * store)225 DefaultChannelIDStore::DefaultChannelIDStore(
226     PersistentStore* store)
227     : initialized_(false),
228       loaded_(false),
229       store_(store),
230       weak_ptr_factory_(this) {}
231 
GetChannelID(const std::string & server_identifier,base::Time * expiration_time,std::string * private_key_result,std::string * cert_result,const GetChannelIDCallback & callback)232 int DefaultChannelIDStore::GetChannelID(
233     const std::string& server_identifier,
234     base::Time* expiration_time,
235     std::string* private_key_result,
236     std::string* cert_result,
237     const GetChannelIDCallback& callback) {
238   DCHECK(CalledOnValidThread());
239   InitIfNecessary();
240 
241   if (!loaded_) {
242     EnqueueTask(scoped_ptr<Task>(
243         new GetChannelIDTask(server_identifier, callback)));
244     return ERR_IO_PENDING;
245   }
246 
247   ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
248 
249   if (it == channel_ids_.end())
250     return ERR_FILE_NOT_FOUND;
251 
252   ChannelID* channel_id = it->second;
253   *expiration_time = channel_id->expiration_time();
254   *private_key_result = channel_id->private_key();
255   *cert_result = channel_id->cert();
256 
257   return OK;
258 }
259 
SetChannelID(const std::string & server_identifier,base::Time creation_time,base::Time expiration_time,const std::string & private_key,const std::string & cert)260 void DefaultChannelIDStore::SetChannelID(
261     const std::string& server_identifier,
262     base::Time creation_time,
263     base::Time expiration_time,
264     const std::string& private_key,
265     const std::string& cert) {
266   RunOrEnqueueTask(scoped_ptr<Task>(new SetChannelIDTask(
267       server_identifier, creation_time, expiration_time, private_key,
268       cert)));
269 }
270 
DeleteChannelID(const std::string & server_identifier,const base::Closure & callback)271 void DefaultChannelIDStore::DeleteChannelID(
272     const std::string& server_identifier,
273     const base::Closure& callback) {
274   RunOrEnqueueTask(scoped_ptr<Task>(
275       new DeleteChannelIDTask(server_identifier, callback)));
276 }
277 
DeleteAllCreatedBetween(base::Time delete_begin,base::Time delete_end,const base::Closure & callback)278 void DefaultChannelIDStore::DeleteAllCreatedBetween(
279     base::Time delete_begin,
280     base::Time delete_end,
281     const base::Closure& callback) {
282   RunOrEnqueueTask(scoped_ptr<Task>(
283       new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
284 }
285 
DeleteAll(const base::Closure & callback)286 void DefaultChannelIDStore::DeleteAll(
287     const base::Closure& callback) {
288   DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
289 }
290 
GetAllChannelIDs(const GetChannelIDListCallback & callback)291 void DefaultChannelIDStore::GetAllChannelIDs(
292     const GetChannelIDListCallback& callback) {
293   RunOrEnqueueTask(scoped_ptr<Task>(new GetAllChannelIDsTask(callback)));
294 }
295 
GetChannelIDCount()296 int DefaultChannelIDStore::GetChannelIDCount() {
297   DCHECK(CalledOnValidThread());
298 
299   return channel_ids_.size();
300 }
301 
SetForceKeepSessionState()302 void DefaultChannelIDStore::SetForceKeepSessionState() {
303   DCHECK(CalledOnValidThread());
304   InitIfNecessary();
305 
306   if (store_.get())
307     store_->SetForceKeepSessionState();
308 }
309 
~DefaultChannelIDStore()310 DefaultChannelIDStore::~DefaultChannelIDStore() {
311   DeleteAllInMemory();
312 }
313 
DeleteAllInMemory()314 void DefaultChannelIDStore::DeleteAllInMemory() {
315   DCHECK(CalledOnValidThread());
316 
317   for (ChannelIDMap::iterator it = channel_ids_.begin();
318        it != channel_ids_.end(); ++it) {
319     delete it->second;
320   }
321   channel_ids_.clear();
322 }
323 
InitStore()324 void DefaultChannelIDStore::InitStore() {
325   DCHECK(CalledOnValidThread());
326   DCHECK(store_.get()) << "Store must exist to initialize";
327   DCHECK(!loaded_);
328 
329   store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
330                           weak_ptr_factory_.GetWeakPtr()));
331 }
332 
OnLoaded(scoped_ptr<ScopedVector<ChannelID>> channel_ids)333 void DefaultChannelIDStore::OnLoaded(
334     scoped_ptr<ScopedVector<ChannelID> > channel_ids) {
335   DCHECK(CalledOnValidThread());
336 
337   for (std::vector<ChannelID*>::const_iterator it = channel_ids->begin();
338        it != channel_ids->end(); ++it) {
339     DCHECK(channel_ids_.find((*it)->server_identifier()) ==
340            channel_ids_.end());
341     channel_ids_[(*it)->server_identifier()] = *it;
342   }
343   channel_ids->weak_clear();
344 
345   loaded_ = true;
346 
347   base::TimeDelta wait_time;
348   if (!waiting_tasks_.empty())
349     wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
350   DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
351   UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
352                              wait_time,
353                              base::TimeDelta::FromMilliseconds(1),
354                              base::TimeDelta::FromMinutes(1),
355                              50);
356   UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
357                            waiting_tasks_.size());
358 
359 
360   for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
361        i != waiting_tasks_.end(); ++i)
362     (*i)->Run(this);
363   waiting_tasks_.clear();
364 }
365 
SyncSetChannelID(const std::string & server_identifier,base::Time creation_time,base::Time expiration_time,const std::string & private_key,const std::string & cert)366 void DefaultChannelIDStore::SyncSetChannelID(
367     const std::string& server_identifier,
368     base::Time creation_time,
369     base::Time expiration_time,
370     const std::string& private_key,
371     const std::string& cert) {
372   DCHECK(CalledOnValidThread());
373   DCHECK(loaded_);
374 
375   InternalDeleteChannelID(server_identifier);
376   InternalInsertChannelID(
377       server_identifier,
378       new ChannelID(
379           server_identifier, creation_time, expiration_time, private_key,
380           cert));
381 }
382 
SyncDeleteChannelID(const std::string & server_identifier)383 void DefaultChannelIDStore::SyncDeleteChannelID(
384     const std::string& server_identifier) {
385   DCHECK(CalledOnValidThread());
386   DCHECK(loaded_);
387   InternalDeleteChannelID(server_identifier);
388 }
389 
SyncDeleteAllCreatedBetween(base::Time delete_begin,base::Time delete_end)390 void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
391     base::Time delete_begin,
392     base::Time delete_end) {
393   DCHECK(CalledOnValidThread());
394   DCHECK(loaded_);
395   for (ChannelIDMap::iterator it = channel_ids_.begin();
396        it != channel_ids_.end();) {
397     ChannelIDMap::iterator cur = it;
398     ++it;
399     ChannelID* channel_id = cur->second;
400     if ((delete_begin.is_null() ||
401          channel_id->creation_time() >= delete_begin) &&
402         (delete_end.is_null() || channel_id->creation_time() < delete_end)) {
403       if (store_.get())
404         store_->DeleteChannelID(*channel_id);
405       delete channel_id;
406       channel_ids_.erase(cur);
407     }
408   }
409 }
410 
SyncGetAllChannelIDs(ChannelIDList * channel_id_list)411 void DefaultChannelIDStore::SyncGetAllChannelIDs(
412     ChannelIDList* channel_id_list) {
413   DCHECK(CalledOnValidThread());
414   DCHECK(loaded_);
415   for (ChannelIDMap::iterator it = channel_ids_.begin();
416        it != channel_ids_.end(); ++it)
417     channel_id_list->push_back(*it->second);
418 }
419 
EnqueueTask(scoped_ptr<Task> task)420 void DefaultChannelIDStore::EnqueueTask(scoped_ptr<Task> task) {
421   DCHECK(CalledOnValidThread());
422   DCHECK(!loaded_);
423   if (waiting_tasks_.empty())
424     waiting_tasks_start_time_ = base::TimeTicks::Now();
425   waiting_tasks_.push_back(task.release());
426 }
427 
RunOrEnqueueTask(scoped_ptr<Task> task)428 void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
429   DCHECK(CalledOnValidThread());
430   InitIfNecessary();
431 
432   if (!loaded_) {
433     EnqueueTask(task.Pass());
434     return;
435   }
436 
437   task->Run(this);
438 }
439 
InternalDeleteChannelID(const std::string & server_identifier)440 void DefaultChannelIDStore::InternalDeleteChannelID(
441     const std::string& server_identifier) {
442   DCHECK(CalledOnValidThread());
443   DCHECK(loaded_);
444 
445   ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
446   if (it == channel_ids_.end())
447     return;  // There is nothing to delete.
448 
449   ChannelID* channel_id = it->second;
450   if (store_.get())
451     store_->DeleteChannelID(*channel_id);
452   channel_ids_.erase(it);
453   delete channel_id;
454 }
455 
InternalInsertChannelID(const std::string & server_identifier,ChannelID * channel_id)456 void DefaultChannelIDStore::InternalInsertChannelID(
457     const std::string& server_identifier,
458     ChannelID* channel_id) {
459   DCHECK(CalledOnValidThread());
460   DCHECK(loaded_);
461 
462   if (store_.get())
463     store_->AddChannelID(*channel_id);
464   channel_ids_[server_identifier] = channel_id;
465 }
466 
PersistentStore()467 DefaultChannelIDStore::PersistentStore::PersistentStore() {}
468 
~PersistentStore()469 DefaultChannelIDStore::PersistentStore::~PersistentStore() {}
470 
471 }  // namespace net
472