1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/protocol/pairing_registry.h"
6
7 #include "base/base64.h"
8 #include "base/bind.h"
9 #include "base/guid.h"
10 #include "base/json/json_string_value_serializer.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/strings/string_number_conversions.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "base/values.h"
16 #include "crypto/random.h"
17
18 namespace remoting {
19 namespace protocol {
20
21 // How many bytes of random data to use for the shared secret.
22 const int kKeySize = 16;
23
24 const char PairingRegistry::kCreatedTimeKey[] = "createdTime";
25 const char PairingRegistry::kClientIdKey[] = "clientId";
26 const char PairingRegistry::kClientNameKey[] = "clientName";
27 const char PairingRegistry::kSharedSecretKey[] = "sharedSecret";
28
Pairing()29 PairingRegistry::Pairing::Pairing() {
30 }
31
Pairing(const base::Time & created_time,const std::string & client_name,const std::string & client_id,const std::string & shared_secret)32 PairingRegistry::Pairing::Pairing(const base::Time& created_time,
33 const std::string& client_name,
34 const std::string& client_id,
35 const std::string& shared_secret)
36 : created_time_(created_time),
37 client_name_(client_name),
38 client_id_(client_id),
39 shared_secret_(shared_secret) {
40 }
41
~Pairing()42 PairingRegistry::Pairing::~Pairing() {
43 }
44
Create(const std::string & client_name)45 PairingRegistry::Pairing PairingRegistry::Pairing::Create(
46 const std::string& client_name) {
47 base::Time created_time = base::Time::Now();
48 std::string client_id = base::GenerateGUID();
49 std::string shared_secret;
50 char buffer[kKeySize];
51 crypto::RandBytes(buffer, arraysize(buffer));
52 base::Base64Encode(base::StringPiece(buffer, arraysize(buffer)),
53 &shared_secret);
54 return Pairing(created_time, client_name, client_id, shared_secret);
55 }
56
CreateFromValue(const base::DictionaryValue & pairing)57 PairingRegistry::Pairing PairingRegistry::Pairing::CreateFromValue(
58 const base::DictionaryValue& pairing) {
59 std::string client_name, client_id;
60 double created_time_value;
61 if (pairing.GetDouble(kCreatedTimeKey, &created_time_value) &&
62 pairing.GetString(kClientNameKey, &client_name) &&
63 pairing.GetString(kClientIdKey, &client_id)) {
64 // The shared secret is optional.
65 std::string shared_secret;
66 pairing.GetString(kSharedSecretKey, &shared_secret);
67 base::Time created_time = base::Time::FromJsTime(created_time_value);
68 return Pairing(created_time, client_name, client_id, shared_secret);
69 }
70
71 LOG(ERROR) << "Failed to load pairing information: unexpected format.";
72 return Pairing();
73 }
74
ToValue() const75 scoped_ptr<base::DictionaryValue> PairingRegistry::Pairing::ToValue() const {
76 scoped_ptr<base::DictionaryValue> pairing(new base::DictionaryValue());
77 pairing->SetDouble(kCreatedTimeKey, created_time().ToJsTime());
78 pairing->SetString(kClientNameKey, client_name());
79 pairing->SetString(kClientIdKey, client_id());
80 if (!shared_secret().empty())
81 pairing->SetString(kSharedSecretKey, shared_secret());
82 return pairing.Pass();
83 }
84
operator ==(const Pairing & other) const85 bool PairingRegistry::Pairing::operator==(const Pairing& other) const {
86 return created_time_ == other.created_time_ &&
87 client_id_ == other.client_id_ &&
88 client_name_ == other.client_name_ &&
89 shared_secret_ == other.shared_secret_;
90 }
91
is_valid() const92 bool PairingRegistry::Pairing::is_valid() const {
93 return !client_id_.empty() && !shared_secret_.empty();
94 }
95
PairingRegistry(scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner,scoped_ptr<Delegate> delegate)96 PairingRegistry::PairingRegistry(
97 scoped_refptr<base::SingleThreadTaskRunner> delegate_task_runner,
98 scoped_ptr<Delegate> delegate)
99 : caller_task_runner_(base::ThreadTaskRunnerHandle::Get()),
100 delegate_task_runner_(delegate_task_runner),
101 delegate_(delegate.Pass()) {
102 DCHECK(delegate_);
103 }
104
CreatePairing(const std::string & client_name)105 PairingRegistry::Pairing PairingRegistry::CreatePairing(
106 const std::string& client_name) {
107 DCHECK(caller_task_runner_->BelongsToCurrentThread());
108
109 Pairing result = Pairing::Create(client_name);
110 AddPairing(result);
111 return result;
112 }
113
GetPairing(const std::string & client_id,const GetPairingCallback & callback)114 void PairingRegistry::GetPairing(const std::string& client_id,
115 const GetPairingCallback& callback) {
116 DCHECK(caller_task_runner_->BelongsToCurrentThread());
117
118 GetPairingCallback wrapped_callback = base::Bind(
119 &PairingRegistry::InvokeGetPairingCallbackAndScheduleNext,
120 this, callback);
121 base::Closure request = base::Bind(
122 &PairingRegistry::DoLoad, this, client_id, wrapped_callback);
123 ServiceOrQueueRequest(request);
124 }
125
GetAllPairings(const GetAllPairingsCallback & callback)126 void PairingRegistry::GetAllPairings(
127 const GetAllPairingsCallback& callback) {
128 DCHECK(caller_task_runner_->BelongsToCurrentThread());
129
130 GetAllPairingsCallback wrapped_callback = base::Bind(
131 &PairingRegistry::InvokeGetAllPairingsCallbackAndScheduleNext,
132 this, callback);
133 GetAllPairingsCallback sanitize_callback = base::Bind(
134 &PairingRegistry::SanitizePairings,
135 this, wrapped_callback);
136 base::Closure request = base::Bind(
137 &PairingRegistry::DoLoadAll, this, sanitize_callback);
138 ServiceOrQueueRequest(request);
139 }
140
DeletePairing(const std::string & client_id,const DoneCallback & callback)141 void PairingRegistry::DeletePairing(
142 const std::string& client_id, const DoneCallback& callback) {
143 DCHECK(caller_task_runner_->BelongsToCurrentThread());
144
145 DoneCallback wrapped_callback = base::Bind(
146 &PairingRegistry::InvokeDoneCallbackAndScheduleNext,
147 this, callback);
148 base::Closure request = base::Bind(
149 &PairingRegistry::DoDelete, this, client_id, wrapped_callback);
150 ServiceOrQueueRequest(request);
151 }
152
ClearAllPairings(const DoneCallback & callback)153 void PairingRegistry::ClearAllPairings(
154 const DoneCallback& callback) {
155 DCHECK(caller_task_runner_->BelongsToCurrentThread());
156
157 DoneCallback wrapped_callback = base::Bind(
158 &PairingRegistry::InvokeDoneCallbackAndScheduleNext,
159 this, callback);
160 base::Closure request = base::Bind(
161 &PairingRegistry::DoDeleteAll, this, wrapped_callback);
162 ServiceOrQueueRequest(request);
163 }
164
~PairingRegistry()165 PairingRegistry::~PairingRegistry() {
166 }
167
PostTask(const scoped_refptr<base::SingleThreadTaskRunner> & task_runner,const tracked_objects::Location & from_here,const base::Closure & task)168 void PairingRegistry::PostTask(
169 const scoped_refptr<base::SingleThreadTaskRunner>& task_runner,
170 const tracked_objects::Location& from_here,
171 const base::Closure& task) {
172 task_runner->PostTask(from_here, task);
173 }
174
AddPairing(const Pairing & pairing)175 void PairingRegistry::AddPairing(const Pairing& pairing) {
176 DoneCallback wrapped_callback = base::Bind(
177 &PairingRegistry::InvokeDoneCallbackAndScheduleNext,
178 this, DoneCallback());
179 base::Closure request = base::Bind(
180 &PairingRegistry::DoSave, this, pairing, wrapped_callback);
181 ServiceOrQueueRequest(request);
182 }
183
DoLoadAll(const protocol::PairingRegistry::GetAllPairingsCallback & callback)184 void PairingRegistry::DoLoadAll(
185 const protocol::PairingRegistry::GetAllPairingsCallback& callback) {
186 DCHECK(delegate_task_runner_->BelongsToCurrentThread());
187
188 scoped_ptr<base::ListValue> pairings = delegate_->LoadAll();
189 PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback,
190 base::Passed(&pairings)));
191 }
192
DoDeleteAll(const protocol::PairingRegistry::DoneCallback & callback)193 void PairingRegistry::DoDeleteAll(
194 const protocol::PairingRegistry::DoneCallback& callback) {
195 DCHECK(delegate_task_runner_->BelongsToCurrentThread());
196
197 bool success = delegate_->DeleteAll();
198 PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success));
199 }
200
DoLoad(const std::string & client_id,const protocol::PairingRegistry::GetPairingCallback & callback)201 void PairingRegistry::DoLoad(
202 const std::string& client_id,
203 const protocol::PairingRegistry::GetPairingCallback& callback) {
204 DCHECK(delegate_task_runner_->BelongsToCurrentThread());
205
206 Pairing pairing = delegate_->Load(client_id);
207 PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, pairing));
208 }
209
DoSave(const protocol::PairingRegistry::Pairing & pairing,const protocol::PairingRegistry::DoneCallback & callback)210 void PairingRegistry::DoSave(
211 const protocol::PairingRegistry::Pairing& pairing,
212 const protocol::PairingRegistry::DoneCallback& callback) {
213 DCHECK(delegate_task_runner_->BelongsToCurrentThread());
214
215 bool success = delegate_->Save(pairing);
216 PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success));
217 }
218
DoDelete(const std::string & client_id,const protocol::PairingRegistry::DoneCallback & callback)219 void PairingRegistry::DoDelete(
220 const std::string& client_id,
221 const protocol::PairingRegistry::DoneCallback& callback) {
222 DCHECK(delegate_task_runner_->BelongsToCurrentThread());
223
224 bool success = delegate_->Delete(client_id);
225 PostTask(caller_task_runner_, FROM_HERE, base::Bind(callback, success));
226 }
227
InvokeDoneCallbackAndScheduleNext(const DoneCallback & callback,bool success)228 void PairingRegistry::InvokeDoneCallbackAndScheduleNext(
229 const DoneCallback& callback, bool success) {
230 // CreatePairing doesn't have a callback, so the callback can be null.
231 if (!callback.is_null())
232 callback.Run(success);
233
234 pending_requests_.pop();
235 ServiceNextRequest();
236 }
237
InvokeGetPairingCallbackAndScheduleNext(const GetPairingCallback & callback,Pairing pairing)238 void PairingRegistry::InvokeGetPairingCallbackAndScheduleNext(
239 const GetPairingCallback& callback, Pairing pairing) {
240 callback.Run(pairing);
241 pending_requests_.pop();
242 ServiceNextRequest();
243 }
244
InvokeGetAllPairingsCallbackAndScheduleNext(const GetAllPairingsCallback & callback,scoped_ptr<base::ListValue> pairings)245 void PairingRegistry::InvokeGetAllPairingsCallbackAndScheduleNext(
246 const GetAllPairingsCallback& callback,
247 scoped_ptr<base::ListValue> pairings) {
248 callback.Run(pairings.Pass());
249 pending_requests_.pop();
250 ServiceNextRequest();
251 }
252
SanitizePairings(const GetAllPairingsCallback & callback,scoped_ptr<base::ListValue> pairings)253 void PairingRegistry::SanitizePairings(const GetAllPairingsCallback& callback,
254 scoped_ptr<base::ListValue> pairings) {
255 DCHECK(caller_task_runner_->BelongsToCurrentThread());
256
257 scoped_ptr<base::ListValue> sanitized_pairings(new base::ListValue());
258 for (size_t i = 0; i < pairings->GetSize(); ++i) {
259 DictionaryValue* pairing_json;
260 if (!pairings->GetDictionary(i, &pairing_json)) {
261 LOG(WARNING) << "A pairing entry is not a dictionary.";
262 continue;
263 }
264
265 // Parse the pairing data.
266 Pairing pairing = Pairing::CreateFromValue(*pairing_json);
267 if (!pairing.is_valid()) {
268 LOG(WARNING) << "Could not parse a pairing entry.";
269 continue;
270 }
271
272 // Clear the shared secrect and append the pairing data to the list.
273 Pairing sanitized_pairing(
274 pairing.created_time(),
275 pairing.client_name(),
276 pairing.client_id(),
277 "");
278 sanitized_pairings->Append(sanitized_pairing.ToValue().release());
279 }
280
281 callback.Run(sanitized_pairings.Pass());
282 }
283
ServiceOrQueueRequest(const base::Closure & request)284 void PairingRegistry::ServiceOrQueueRequest(const base::Closure& request) {
285 bool servicing_request = !pending_requests_.empty();
286 pending_requests_.push(request);
287 if (!servicing_request) {
288 ServiceNextRequest();
289 }
290 }
291
ServiceNextRequest()292 void PairingRegistry::ServiceNextRequest() {
293 if (pending_requests_.empty())
294 return;
295
296 PostTask(delegate_task_runner_, FROM_HERE, pending_requests_.front());
297 }
298
299 } // namespace protocol
300 } // namespace remoting
301