1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/ssl/default_channel_id_store.h"
6
7 #include "base/bind.h"
8 #include "base/message_loop/message_loop.h"
9 #include "base/metrics/histogram.h"
10 #include "net/base/net_errors.h"
11
12 namespace net {
13
14 // --------------------------------------------------------------------------
15 // Task
16 class DefaultChannelIDStore::Task {
17 public:
18 virtual ~Task();
19
20 // Runs the task and invokes the client callback on the thread that
21 // originally constructed the task.
22 virtual void Run(DefaultChannelIDStore* store) = 0;
23
24 protected:
25 void InvokeCallback(base::Closure callback) const;
26 };
27
~Task()28 DefaultChannelIDStore::Task::~Task() {
29 }
30
InvokeCallback(base::Closure callback) const31 void DefaultChannelIDStore::Task::InvokeCallback(
32 base::Closure callback) const {
33 if (!callback.is_null())
34 callback.Run();
35 }
36
37 // --------------------------------------------------------------------------
38 // GetChannelIDTask
39 class DefaultChannelIDStore::GetChannelIDTask
40 : public DefaultChannelIDStore::Task {
41 public:
42 GetChannelIDTask(const std::string& server_identifier,
43 const GetChannelIDCallback& callback);
44 virtual ~GetChannelIDTask();
45 virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
46
47 private:
48 std::string server_identifier_;
49 GetChannelIDCallback callback_;
50 };
51
GetChannelIDTask(const std::string & server_identifier,const GetChannelIDCallback & callback)52 DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
53 const std::string& server_identifier,
54 const GetChannelIDCallback& callback)
55 : server_identifier_(server_identifier),
56 callback_(callback) {
57 }
58
~GetChannelIDTask()59 DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
60 }
61
Run(DefaultChannelIDStore * store)62 void DefaultChannelIDStore::GetChannelIDTask::Run(
63 DefaultChannelIDStore* store) {
64 base::Time expiration_time;
65 std::string private_key_result;
66 std::string cert_result;
67 int err = store->GetChannelID(
68 server_identifier_, &expiration_time, &private_key_result,
69 &cert_result, GetChannelIDCallback());
70 DCHECK(err != ERR_IO_PENDING);
71
72 InvokeCallback(base::Bind(callback_, err, server_identifier_,
73 expiration_time, private_key_result, cert_result));
74 }
75
76 // --------------------------------------------------------------------------
77 // SetChannelIDTask
78 class DefaultChannelIDStore::SetChannelIDTask
79 : public DefaultChannelIDStore::Task {
80 public:
81 SetChannelIDTask(const std::string& server_identifier,
82 base::Time creation_time,
83 base::Time expiration_time,
84 const std::string& private_key,
85 const std::string& cert);
86 virtual ~SetChannelIDTask();
87 virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
88
89 private:
90 std::string server_identifier_;
91 base::Time creation_time_;
92 base::Time expiration_time_;
93 std::string private_key_;
94 std::string cert_;
95 };
96
SetChannelIDTask(const std::string & server_identifier,base::Time creation_time,base::Time expiration_time,const std::string & private_key,const std::string & cert)97 DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
98 const std::string& server_identifier,
99 base::Time creation_time,
100 base::Time expiration_time,
101 const std::string& private_key,
102 const std::string& cert)
103 : server_identifier_(server_identifier),
104 creation_time_(creation_time),
105 expiration_time_(expiration_time),
106 private_key_(private_key),
107 cert_(cert) {
108 }
109
~SetChannelIDTask()110 DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
111 }
112
Run(DefaultChannelIDStore * store)113 void DefaultChannelIDStore::SetChannelIDTask::Run(
114 DefaultChannelIDStore* store) {
115 store->SyncSetChannelID(server_identifier_, creation_time_,
116 expiration_time_, private_key_, cert_);
117 }
118
119 // --------------------------------------------------------------------------
120 // DeleteChannelIDTask
121 class DefaultChannelIDStore::DeleteChannelIDTask
122 : public DefaultChannelIDStore::Task {
123 public:
124 DeleteChannelIDTask(const std::string& server_identifier,
125 const base::Closure& callback);
126 virtual ~DeleteChannelIDTask();
127 virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
128
129 private:
130 std::string server_identifier_;
131 base::Closure callback_;
132 };
133
134 DefaultChannelIDStore::DeleteChannelIDTask::
DeleteChannelIDTask(const std::string & server_identifier,const base::Closure & callback)135 DeleteChannelIDTask(
136 const std::string& server_identifier,
137 const base::Closure& callback)
138 : server_identifier_(server_identifier),
139 callback_(callback) {
140 }
141
142 DefaultChannelIDStore::DeleteChannelIDTask::
~DeleteChannelIDTask()143 ~DeleteChannelIDTask() {
144 }
145
Run(DefaultChannelIDStore * store)146 void DefaultChannelIDStore::DeleteChannelIDTask::Run(
147 DefaultChannelIDStore* store) {
148 store->SyncDeleteChannelID(server_identifier_);
149
150 InvokeCallback(callback_);
151 }
152
153 // --------------------------------------------------------------------------
154 // DeleteAllCreatedBetweenTask
155 class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
156 : public DefaultChannelIDStore::Task {
157 public:
158 DeleteAllCreatedBetweenTask(base::Time delete_begin,
159 base::Time delete_end,
160 const base::Closure& callback);
161 virtual ~DeleteAllCreatedBetweenTask();
162 virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
163
164 private:
165 base::Time delete_begin_;
166 base::Time delete_end_;
167 base::Closure callback_;
168 };
169
170 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
DeleteAllCreatedBetweenTask(base::Time delete_begin,base::Time delete_end,const base::Closure & callback)171 DeleteAllCreatedBetweenTask(
172 base::Time delete_begin,
173 base::Time delete_end,
174 const base::Closure& callback)
175 : delete_begin_(delete_begin),
176 delete_end_(delete_end),
177 callback_(callback) {
178 }
179
180 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
~DeleteAllCreatedBetweenTask()181 ~DeleteAllCreatedBetweenTask() {
182 }
183
Run(DefaultChannelIDStore * store)184 void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
185 DefaultChannelIDStore* store) {
186 store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
187
188 InvokeCallback(callback_);
189 }
190
191 // --------------------------------------------------------------------------
192 // GetAllChannelIDsTask
193 class DefaultChannelIDStore::GetAllChannelIDsTask
194 : public DefaultChannelIDStore::Task {
195 public:
196 explicit GetAllChannelIDsTask(const GetChannelIDListCallback& callback);
197 virtual ~GetAllChannelIDsTask();
198 virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
199
200 private:
201 std::string server_identifier_;
202 GetChannelIDListCallback callback_;
203 };
204
205 DefaultChannelIDStore::GetAllChannelIDsTask::
GetAllChannelIDsTask(const GetChannelIDListCallback & callback)206 GetAllChannelIDsTask(const GetChannelIDListCallback& callback)
207 : callback_(callback) {
208 }
209
210 DefaultChannelIDStore::GetAllChannelIDsTask::
~GetAllChannelIDsTask()211 ~GetAllChannelIDsTask() {
212 }
213
Run(DefaultChannelIDStore * store)214 void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
215 DefaultChannelIDStore* store) {
216 ChannelIDList cert_list;
217 store->SyncGetAllChannelIDs(&cert_list);
218
219 InvokeCallback(base::Bind(callback_, cert_list));
220 }
221
222 // --------------------------------------------------------------------------
223 // DefaultChannelIDStore
224
DefaultChannelIDStore(PersistentStore * store)225 DefaultChannelIDStore::DefaultChannelIDStore(
226 PersistentStore* store)
227 : initialized_(false),
228 loaded_(false),
229 store_(store),
230 weak_ptr_factory_(this) {}
231
GetChannelID(const std::string & server_identifier,base::Time * expiration_time,std::string * private_key_result,std::string * cert_result,const GetChannelIDCallback & callback)232 int DefaultChannelIDStore::GetChannelID(
233 const std::string& server_identifier,
234 base::Time* expiration_time,
235 std::string* private_key_result,
236 std::string* cert_result,
237 const GetChannelIDCallback& callback) {
238 DCHECK(CalledOnValidThread());
239 InitIfNecessary();
240
241 if (!loaded_) {
242 EnqueueTask(scoped_ptr<Task>(
243 new GetChannelIDTask(server_identifier, callback)));
244 return ERR_IO_PENDING;
245 }
246
247 ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
248
249 if (it == channel_ids_.end())
250 return ERR_FILE_NOT_FOUND;
251
252 ChannelID* channel_id = it->second;
253 *expiration_time = channel_id->expiration_time();
254 *private_key_result = channel_id->private_key();
255 *cert_result = channel_id->cert();
256
257 return OK;
258 }
259
SetChannelID(const std::string & server_identifier,base::Time creation_time,base::Time expiration_time,const std::string & private_key,const std::string & cert)260 void DefaultChannelIDStore::SetChannelID(
261 const std::string& server_identifier,
262 base::Time creation_time,
263 base::Time expiration_time,
264 const std::string& private_key,
265 const std::string& cert) {
266 RunOrEnqueueTask(scoped_ptr<Task>(new SetChannelIDTask(
267 server_identifier, creation_time, expiration_time, private_key,
268 cert)));
269 }
270
DeleteChannelID(const std::string & server_identifier,const base::Closure & callback)271 void DefaultChannelIDStore::DeleteChannelID(
272 const std::string& server_identifier,
273 const base::Closure& callback) {
274 RunOrEnqueueTask(scoped_ptr<Task>(
275 new DeleteChannelIDTask(server_identifier, callback)));
276 }
277
DeleteAllCreatedBetween(base::Time delete_begin,base::Time delete_end,const base::Closure & callback)278 void DefaultChannelIDStore::DeleteAllCreatedBetween(
279 base::Time delete_begin,
280 base::Time delete_end,
281 const base::Closure& callback) {
282 RunOrEnqueueTask(scoped_ptr<Task>(
283 new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
284 }
285
DeleteAll(const base::Closure & callback)286 void DefaultChannelIDStore::DeleteAll(
287 const base::Closure& callback) {
288 DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
289 }
290
GetAllChannelIDs(const GetChannelIDListCallback & callback)291 void DefaultChannelIDStore::GetAllChannelIDs(
292 const GetChannelIDListCallback& callback) {
293 RunOrEnqueueTask(scoped_ptr<Task>(new GetAllChannelIDsTask(callback)));
294 }
295
GetChannelIDCount()296 int DefaultChannelIDStore::GetChannelIDCount() {
297 DCHECK(CalledOnValidThread());
298
299 return channel_ids_.size();
300 }
301
SetForceKeepSessionState()302 void DefaultChannelIDStore::SetForceKeepSessionState() {
303 DCHECK(CalledOnValidThread());
304 InitIfNecessary();
305
306 if (store_.get())
307 store_->SetForceKeepSessionState();
308 }
309
~DefaultChannelIDStore()310 DefaultChannelIDStore::~DefaultChannelIDStore() {
311 DeleteAllInMemory();
312 }
313
DeleteAllInMemory()314 void DefaultChannelIDStore::DeleteAllInMemory() {
315 DCHECK(CalledOnValidThread());
316
317 for (ChannelIDMap::iterator it = channel_ids_.begin();
318 it != channel_ids_.end(); ++it) {
319 delete it->second;
320 }
321 channel_ids_.clear();
322 }
323
InitStore()324 void DefaultChannelIDStore::InitStore() {
325 DCHECK(CalledOnValidThread());
326 DCHECK(store_.get()) << "Store must exist to initialize";
327 DCHECK(!loaded_);
328
329 store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
330 weak_ptr_factory_.GetWeakPtr()));
331 }
332
OnLoaded(scoped_ptr<ScopedVector<ChannelID>> channel_ids)333 void DefaultChannelIDStore::OnLoaded(
334 scoped_ptr<ScopedVector<ChannelID> > channel_ids) {
335 DCHECK(CalledOnValidThread());
336
337 for (std::vector<ChannelID*>::const_iterator it = channel_ids->begin();
338 it != channel_ids->end(); ++it) {
339 DCHECK(channel_ids_.find((*it)->server_identifier()) ==
340 channel_ids_.end());
341 channel_ids_[(*it)->server_identifier()] = *it;
342 }
343 channel_ids->weak_clear();
344
345 loaded_ = true;
346
347 base::TimeDelta wait_time;
348 if (!waiting_tasks_.empty())
349 wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
350 DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
351 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
352 wait_time,
353 base::TimeDelta::FromMilliseconds(1),
354 base::TimeDelta::FromMinutes(1),
355 50);
356 UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
357 waiting_tasks_.size());
358
359
360 for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
361 i != waiting_tasks_.end(); ++i)
362 (*i)->Run(this);
363 waiting_tasks_.clear();
364 }
365
SyncSetChannelID(const std::string & server_identifier,base::Time creation_time,base::Time expiration_time,const std::string & private_key,const std::string & cert)366 void DefaultChannelIDStore::SyncSetChannelID(
367 const std::string& server_identifier,
368 base::Time creation_time,
369 base::Time expiration_time,
370 const std::string& private_key,
371 const std::string& cert) {
372 DCHECK(CalledOnValidThread());
373 DCHECK(loaded_);
374
375 InternalDeleteChannelID(server_identifier);
376 InternalInsertChannelID(
377 server_identifier,
378 new ChannelID(
379 server_identifier, creation_time, expiration_time, private_key,
380 cert));
381 }
382
SyncDeleteChannelID(const std::string & server_identifier)383 void DefaultChannelIDStore::SyncDeleteChannelID(
384 const std::string& server_identifier) {
385 DCHECK(CalledOnValidThread());
386 DCHECK(loaded_);
387 InternalDeleteChannelID(server_identifier);
388 }
389
SyncDeleteAllCreatedBetween(base::Time delete_begin,base::Time delete_end)390 void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
391 base::Time delete_begin,
392 base::Time delete_end) {
393 DCHECK(CalledOnValidThread());
394 DCHECK(loaded_);
395 for (ChannelIDMap::iterator it = channel_ids_.begin();
396 it != channel_ids_.end();) {
397 ChannelIDMap::iterator cur = it;
398 ++it;
399 ChannelID* channel_id = cur->second;
400 if ((delete_begin.is_null() ||
401 channel_id->creation_time() >= delete_begin) &&
402 (delete_end.is_null() || channel_id->creation_time() < delete_end)) {
403 if (store_.get())
404 store_->DeleteChannelID(*channel_id);
405 delete channel_id;
406 channel_ids_.erase(cur);
407 }
408 }
409 }
410
SyncGetAllChannelIDs(ChannelIDList * channel_id_list)411 void DefaultChannelIDStore::SyncGetAllChannelIDs(
412 ChannelIDList* channel_id_list) {
413 DCHECK(CalledOnValidThread());
414 DCHECK(loaded_);
415 for (ChannelIDMap::iterator it = channel_ids_.begin();
416 it != channel_ids_.end(); ++it)
417 channel_id_list->push_back(*it->second);
418 }
419
EnqueueTask(scoped_ptr<Task> task)420 void DefaultChannelIDStore::EnqueueTask(scoped_ptr<Task> task) {
421 DCHECK(CalledOnValidThread());
422 DCHECK(!loaded_);
423 if (waiting_tasks_.empty())
424 waiting_tasks_start_time_ = base::TimeTicks::Now();
425 waiting_tasks_.push_back(task.release());
426 }
427
RunOrEnqueueTask(scoped_ptr<Task> task)428 void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
429 DCHECK(CalledOnValidThread());
430 InitIfNecessary();
431
432 if (!loaded_) {
433 EnqueueTask(task.Pass());
434 return;
435 }
436
437 task->Run(this);
438 }
439
InternalDeleteChannelID(const std::string & server_identifier)440 void DefaultChannelIDStore::InternalDeleteChannelID(
441 const std::string& server_identifier) {
442 DCHECK(CalledOnValidThread());
443 DCHECK(loaded_);
444
445 ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
446 if (it == channel_ids_.end())
447 return; // There is nothing to delete.
448
449 ChannelID* channel_id = it->second;
450 if (store_.get())
451 store_->DeleteChannelID(*channel_id);
452 channel_ids_.erase(it);
453 delete channel_id;
454 }
455
InternalInsertChannelID(const std::string & server_identifier,ChannelID * channel_id)456 void DefaultChannelIDStore::InternalInsertChannelID(
457 const std::string& server_identifier,
458 ChannelID* channel_id) {
459 DCHECK(CalledOnValidThread());
460 DCHECK(loaded_);
461
462 if (store_.get())
463 store_->AddChannelID(*channel_id);
464 channel_ids_[server_identifier] = channel_id;
465 }
466
PersistentStore()467 DefaultChannelIDStore::PersistentStore::PersistentStore() {}
468
~PersistentStore()469 DefaultChannelIDStore::PersistentStore::~PersistentStore() {}
470
471 } // namespace net
472