• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "remoting/protocol/pairing_registry.h"
6 
7 #include <stdlib.h>
8 
9 #include <algorithm>
10 
11 #include "base/bind.h"
12 #include "base/compiler_specific.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/message_loop/message_loop.h"
15 #include "base/run_loop.h"
16 #include "base/thread_task_runner_handle.h"
17 #include "base/values.h"
18 #include "remoting/protocol/protocol_mock_objects.h"
19 #include "testing/gmock/include/gmock/gmock.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 
22 using testing::Sequence;
23 
24 namespace {
25 
26 using remoting::protocol::PairingRegistry;
27 
28 class MockPairingRegistryCallbacks {
29  public:
MockPairingRegistryCallbacks()30   MockPairingRegistryCallbacks() {}
~MockPairingRegistryCallbacks()31   virtual ~MockPairingRegistryCallbacks() {}
32 
33   MOCK_METHOD1(DoneCallback, void(bool));
34   MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*));
35   MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing));
36 
GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings)37   void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) {
38     GetAllPairingsCallbackPtr(pairings.get());
39   }
40 
41  private:
42   DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks);
43 };
44 
45 // Verify that a pairing Dictionary has correct entries, but doesn't include
46 // any shared secret.
VerifyPairing(PairingRegistry::Pairing expected,const base::DictionaryValue & actual)47 void VerifyPairing(PairingRegistry::Pairing expected,
48                    const base::DictionaryValue& actual) {
49   std::string value;
50   EXPECT_TRUE(actual.GetString(PairingRegistry::kClientNameKey, &value));
51   EXPECT_EQ(expected.client_name(), value);
52   EXPECT_TRUE(actual.GetString(PairingRegistry::kClientIdKey, &value));
53   EXPECT_EQ(expected.client_id(), value);
54 
55   EXPECT_FALSE(actual.HasKey(PairingRegistry::kSharedSecretKey));
56 }
57 
58 }  // namespace
59 
60 namespace remoting {
61 namespace protocol {
62 
63 class PairingRegistryTest : public testing::Test {
64  public:
SetUp()65   virtual void SetUp() OVERRIDE {
66     callback_count_ = 0;
67   }
68 
set_pairings(scoped_ptr<base::ListValue> pairings)69   void set_pairings(scoped_ptr<base::ListValue> pairings) {
70     pairings_ = pairings.Pass();
71   }
72 
ExpectSecret(const std::string & expected,PairingRegistry::Pairing actual)73   void ExpectSecret(const std::string& expected,
74                     PairingRegistry::Pairing actual) {
75     EXPECT_EQ(expected, actual.shared_secret());
76     ++callback_count_;
77   }
78 
ExpectSaveSuccess(bool success)79   void ExpectSaveSuccess(bool success) {
80     EXPECT_TRUE(success);
81     ++callback_count_;
82   }
83 
84  protected:
85   base::MessageLoop message_loop_;
86   base::RunLoop run_loop_;
87 
88   int callback_count_;
89   scoped_ptr<base::ListValue> pairings_;
90 };
91 
TEST_F(PairingRegistryTest,CreateAndGetPairings)92 TEST_F(PairingRegistryTest, CreateAndGetPairings) {
93   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
94       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
95   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client");
96   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client");
97 
98   EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret());
99 
100   registry->GetPairing(pairing_1.client_id(),
101                        base::Bind(&PairingRegistryTest::ExpectSecret,
102                                   base::Unretained(this),
103                                   pairing_1.shared_secret()));
104   EXPECT_EQ(1, callback_count_);
105 
106   // Check that the second client is paired with a different shared secret.
107   registry->GetPairing(pairing_2.client_id(),
108                        base::Bind(&PairingRegistryTest::ExpectSecret,
109                                   base::Unretained(this),
110                                   pairing_2.shared_secret()));
111   EXPECT_EQ(2, callback_count_);
112 }
113 
TEST_F(PairingRegistryTest,GetAllPairings)114 TEST_F(PairingRegistryTest, GetAllPairings) {
115   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
116       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
117   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
118   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
119 
120   registry->GetAllPairings(
121       base::Bind(&PairingRegistryTest::set_pairings,
122                  base::Unretained(this)));
123 
124   ASSERT_EQ(2u, pairings_->GetSize());
125   const base::DictionaryValue* actual_pairing_1;
126   const base::DictionaryValue* actual_pairing_2;
127   ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_1));
128   ASSERT_TRUE(pairings_->GetDictionary(1, &actual_pairing_2));
129 
130   // Ordering is not guaranteed, so swap if necessary.
131   std::string actual_client_id;
132   ASSERT_TRUE(actual_pairing_1->GetString(PairingRegistry::kClientIdKey,
133                                           &actual_client_id));
134   if (actual_client_id != pairing_1.client_id()) {
135     std::swap(actual_pairing_1, actual_pairing_2);
136   }
137 
138   VerifyPairing(pairing_1, *actual_pairing_1);
139   VerifyPairing(pairing_2, *actual_pairing_2);
140 }
141 
TEST_F(PairingRegistryTest,DeletePairing)142 TEST_F(PairingRegistryTest, DeletePairing) {
143   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
144       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
145   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
146   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
147 
148   registry->DeletePairing(
149       pairing_1.client_id(),
150       base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
151                  base::Unretained(this)));
152 
153   // Re-read the list, and verify it only has the pairing_2 client.
154   registry->GetAllPairings(
155       base::Bind(&PairingRegistryTest::set_pairings,
156                  base::Unretained(this)));
157 
158   ASSERT_EQ(1u, pairings_->GetSize());
159   const base::DictionaryValue* actual_pairing_2;
160   ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_2));
161   std::string actual_client_id;
162   ASSERT_TRUE(actual_pairing_2->GetString(PairingRegistry::kClientIdKey,
163                                           &actual_client_id));
164   EXPECT_EQ(pairing_2.client_id(), actual_client_id);
165 }
166 
TEST_F(PairingRegistryTest,ClearAllPairings)167 TEST_F(PairingRegistryTest, ClearAllPairings) {
168   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
169       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
170   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
171   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
172 
173   registry->ClearAllPairings(
174       base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
175                  base::Unretained(this)));
176 
177   // Re-read the list, and verify it is empty.
178   registry->GetAllPairings(
179       base::Bind(&PairingRegistryTest::set_pairings,
180                  base::Unretained(this)));
181 
182   EXPECT_TRUE(pairings_->empty());
183 }
184 
ACTION_P(QuitMessageLoop,callback)185 ACTION_P(QuitMessageLoop, callback) {
186   callback.Run();
187 }
188 
189 MATCHER_P(EqualsClientName, client_name, "") {
190   return arg.client_name() == client_name;
191 }
192 
193 MATCHER(NoPairings, "") {
194   return arg->empty();
195 }
196 
TEST_F(PairingRegistryTest,SerializedRequests)197 TEST_F(PairingRegistryTest, SerializedRequests) {
198   MockPairingRegistryCallbacks callbacks;
199   Sequence s;
200   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
201       .InSequence(s);
202   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2")))
203       .InSequence(s);
204   EXPECT_CALL(callbacks, DoneCallback(true))
205       .InSequence(s);
206   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
207       .InSequence(s);
208   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("")))
209       .InSequence(s);
210   EXPECT_CALL(callbacks, DoneCallback(true))
211       .InSequence(s);
212   EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings()))
213       .InSequence(s);
214   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3")))
215       .InSequence(s)
216       .WillOnce(QuitMessageLoop(run_loop_.QuitClosure()));
217 
218   scoped_refptr<PairingRegistry> registry = new PairingRegistry(
219       base::ThreadTaskRunnerHandle::Get(),
220       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
221   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
222   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
223   registry->GetPairing(
224       pairing_1.client_id(),
225       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
226                  base::Unretained(&callbacks)));
227   registry->GetPairing(
228       pairing_2.client_id(),
229       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
230                  base::Unretained(&callbacks)));
231   registry->DeletePairing(
232       pairing_2.client_id(),
233       base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
234                  base::Unretained(&callbacks)));
235   registry->GetPairing(
236       pairing_1.client_id(),
237       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
238                  base::Unretained(&callbacks)));
239   registry->GetPairing(
240       pairing_2.client_id(),
241       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
242                  base::Unretained(&callbacks)));
243   registry->ClearAllPairings(
244       base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
245                  base::Unretained(&callbacks)));
246   registry->GetAllPairings(
247       base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback,
248                  base::Unretained(&callbacks)));
249   PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3");
250   registry->GetPairing(
251       pairing_3.client_id(),
252       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
253                  base::Unretained(&callbacks)));
254 
255   run_loop_.Run();
256 }
257 
258 }  // namespace protocol
259 }  // namespace remoting
260