1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/protocol/pairing_registry.h"
6
7 #include <stdlib.h>
8
9 #include <algorithm>
10
11 #include "base/bind.h"
12 #include "base/compiler_specific.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/message_loop/message_loop.h"
15 #include "base/run_loop.h"
16 #include "base/thread_task_runner_handle.h"
17 #include "base/values.h"
18 #include "remoting/protocol/protocol_mock_objects.h"
19 #include "testing/gmock/include/gmock/gmock.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21
22 using testing::Sequence;
23
24 namespace {
25
26 using remoting::protocol::PairingRegistry;
27
28 class MockPairingRegistryCallbacks {
29 public:
MockPairingRegistryCallbacks()30 MockPairingRegistryCallbacks() {}
~MockPairingRegistryCallbacks()31 virtual ~MockPairingRegistryCallbacks() {}
32
33 MOCK_METHOD1(DoneCallback, void(bool));
34 MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*));
35 MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing));
36
GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings)37 void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) {
38 GetAllPairingsCallbackPtr(pairings.get());
39 }
40
41 private:
42 DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks);
43 };
44
45 // Verify that a pairing Dictionary has correct entries, but doesn't include
46 // any shared secret.
VerifyPairing(PairingRegistry::Pairing expected,const base::DictionaryValue & actual)47 void VerifyPairing(PairingRegistry::Pairing expected,
48 const base::DictionaryValue& actual) {
49 std::string value;
50 EXPECT_TRUE(actual.GetString(PairingRegistry::kClientNameKey, &value));
51 EXPECT_EQ(expected.client_name(), value);
52 EXPECT_TRUE(actual.GetString(PairingRegistry::kClientIdKey, &value));
53 EXPECT_EQ(expected.client_id(), value);
54
55 EXPECT_FALSE(actual.HasKey(PairingRegistry::kSharedSecretKey));
56 }
57
58 } // namespace
59
60 namespace remoting {
61 namespace protocol {
62
63 class PairingRegistryTest : public testing::Test {
64 public:
SetUp()65 virtual void SetUp() OVERRIDE {
66 callback_count_ = 0;
67 }
68
set_pairings(scoped_ptr<base::ListValue> pairings)69 void set_pairings(scoped_ptr<base::ListValue> pairings) {
70 pairings_ = pairings.Pass();
71 }
72
ExpectSecret(const std::string & expected,PairingRegistry::Pairing actual)73 void ExpectSecret(const std::string& expected,
74 PairingRegistry::Pairing actual) {
75 EXPECT_EQ(expected, actual.shared_secret());
76 ++callback_count_;
77 }
78
ExpectSaveSuccess(bool success)79 void ExpectSaveSuccess(bool success) {
80 EXPECT_TRUE(success);
81 ++callback_count_;
82 }
83
84 protected:
85 base::MessageLoop message_loop_;
86 base::RunLoop run_loop_;
87
88 int callback_count_;
89 scoped_ptr<base::ListValue> pairings_;
90 };
91
TEST_F(PairingRegistryTest,CreateAndGetPairings)92 TEST_F(PairingRegistryTest, CreateAndGetPairings) {
93 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
94 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
95 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client");
96 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client");
97
98 EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret());
99
100 registry->GetPairing(pairing_1.client_id(),
101 base::Bind(&PairingRegistryTest::ExpectSecret,
102 base::Unretained(this),
103 pairing_1.shared_secret()));
104 EXPECT_EQ(1, callback_count_);
105
106 // Check that the second client is paired with a different shared secret.
107 registry->GetPairing(pairing_2.client_id(),
108 base::Bind(&PairingRegistryTest::ExpectSecret,
109 base::Unretained(this),
110 pairing_2.shared_secret()));
111 EXPECT_EQ(2, callback_count_);
112 }
113
TEST_F(PairingRegistryTest,GetAllPairings)114 TEST_F(PairingRegistryTest, GetAllPairings) {
115 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
116 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
117 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
118 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
119
120 registry->GetAllPairings(
121 base::Bind(&PairingRegistryTest::set_pairings,
122 base::Unretained(this)));
123
124 ASSERT_EQ(2u, pairings_->GetSize());
125 const base::DictionaryValue* actual_pairing_1;
126 const base::DictionaryValue* actual_pairing_2;
127 ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_1));
128 ASSERT_TRUE(pairings_->GetDictionary(1, &actual_pairing_2));
129
130 // Ordering is not guaranteed, so swap if necessary.
131 std::string actual_client_id;
132 ASSERT_TRUE(actual_pairing_1->GetString(PairingRegistry::kClientIdKey,
133 &actual_client_id));
134 if (actual_client_id != pairing_1.client_id()) {
135 std::swap(actual_pairing_1, actual_pairing_2);
136 }
137
138 VerifyPairing(pairing_1, *actual_pairing_1);
139 VerifyPairing(pairing_2, *actual_pairing_2);
140 }
141
TEST_F(PairingRegistryTest,DeletePairing)142 TEST_F(PairingRegistryTest, DeletePairing) {
143 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
144 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
145 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
146 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
147
148 registry->DeletePairing(
149 pairing_1.client_id(),
150 base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
151 base::Unretained(this)));
152
153 // Re-read the list, and verify it only has the pairing_2 client.
154 registry->GetAllPairings(
155 base::Bind(&PairingRegistryTest::set_pairings,
156 base::Unretained(this)));
157
158 ASSERT_EQ(1u, pairings_->GetSize());
159 const base::DictionaryValue* actual_pairing_2;
160 ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_2));
161 std::string actual_client_id;
162 ASSERT_TRUE(actual_pairing_2->GetString(PairingRegistry::kClientIdKey,
163 &actual_client_id));
164 EXPECT_EQ(pairing_2.client_id(), actual_client_id);
165 }
166
TEST_F(PairingRegistryTest,ClearAllPairings)167 TEST_F(PairingRegistryTest, ClearAllPairings) {
168 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
169 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
170 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
171 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
172
173 registry->ClearAllPairings(
174 base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
175 base::Unretained(this)));
176
177 // Re-read the list, and verify it is empty.
178 registry->GetAllPairings(
179 base::Bind(&PairingRegistryTest::set_pairings,
180 base::Unretained(this)));
181
182 EXPECT_TRUE(pairings_->empty());
183 }
184
ACTION_P(QuitMessageLoop,callback)185 ACTION_P(QuitMessageLoop, callback) {
186 callback.Run();
187 }
188
189 MATCHER_P(EqualsClientName, client_name, "") {
190 return arg.client_name() == client_name;
191 }
192
193 MATCHER(NoPairings, "") {
194 return arg->empty();
195 }
196
TEST_F(PairingRegistryTest,SerializedRequests)197 TEST_F(PairingRegistryTest, SerializedRequests) {
198 MockPairingRegistryCallbacks callbacks;
199 Sequence s;
200 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
201 .InSequence(s);
202 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2")))
203 .InSequence(s);
204 EXPECT_CALL(callbacks, DoneCallback(true))
205 .InSequence(s);
206 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
207 .InSequence(s);
208 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("")))
209 .InSequence(s);
210 EXPECT_CALL(callbacks, DoneCallback(true))
211 .InSequence(s);
212 EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings()))
213 .InSequence(s);
214 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3")))
215 .InSequence(s)
216 .WillOnce(QuitMessageLoop(run_loop_.QuitClosure()));
217
218 scoped_refptr<PairingRegistry> registry = new PairingRegistry(
219 base::ThreadTaskRunnerHandle::Get(),
220 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
221 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
222 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
223 registry->GetPairing(
224 pairing_1.client_id(),
225 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
226 base::Unretained(&callbacks)));
227 registry->GetPairing(
228 pairing_2.client_id(),
229 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
230 base::Unretained(&callbacks)));
231 registry->DeletePairing(
232 pairing_2.client_id(),
233 base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
234 base::Unretained(&callbacks)));
235 registry->GetPairing(
236 pairing_1.client_id(),
237 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
238 base::Unretained(&callbacks)));
239 registry->GetPairing(
240 pairing_2.client_id(),
241 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
242 base::Unretained(&callbacks)));
243 registry->ClearAllPairings(
244 base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
245 base::Unretained(&callbacks)));
246 registry->GetAllPairings(
247 base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback,
248 base::Unretained(&callbacks)));
249 PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3");
250 registry->GetPairing(
251 pairing_3.client_id(),
252 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
253 base::Unretained(&callbacks)));
254
255 run_loop_.Run();
256 }
257
258 } // namespace protocol
259 } // namespace remoting
260