1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/protocol/authenticator_test_base.h"
6
7 #include "base/base64.h"
8 #include "base/file_util.h"
9 #include "base/files/file_path.h"
10 #include "base/path_service.h"
11 #include "base/test/test_timeouts.h"
12 #include "base/timer/timer.h"
13 #include "net/base/test_data_directory.h"
14 #include "remoting/base/rsa_key_pair.h"
15 #include "remoting/protocol/authenticator.h"
16 #include "remoting/protocol/channel_authenticator.h"
17 #include "remoting/protocol/fake_session.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
20
21 using testing::_;
22 using testing::SaveArg;
23
24 namespace remoting {
25 namespace protocol {
26
27 namespace {
28
ACTION_P(QuitThreadOnCounter,counter)29 ACTION_P(QuitThreadOnCounter, counter) {
30 --(*counter);
31 EXPECT_GE(*counter, 0);
32 if (*counter == 0)
33 base::MessageLoop::current()->Quit();
34 }
35
36 } // namespace
37
MockChannelDoneCallback()38 AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {}
39
~MockChannelDoneCallback()40 AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {}
41
AuthenticatorTestBase()42 AuthenticatorTestBase::AuthenticatorTestBase() {}
43
~AuthenticatorTestBase()44 AuthenticatorTestBase::~AuthenticatorTestBase() {}
45
SetUp()46 void AuthenticatorTestBase::SetUp() {
47 base::FilePath certs_dir(net::GetTestCertsDirectory());
48
49 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
50 ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_));
51
52 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
53 std::string key_string;
54 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
55 std::string key_base64;
56 base::Base64Encode(key_string, &key_base64);
57 key_pair_ = RsaKeyPair::FromString(key_base64);
58 ASSERT_TRUE(key_pair_.get());
59 host_public_key_ = key_pair_->GetPublicKey();
60 }
61
RunAuthExchange()62 void AuthenticatorTestBase::RunAuthExchange() {
63 ContinueAuthExchangeWith(client_.get(),
64 host_.get(),
65 client_->started(),
66 host_->started());
67 }
68
RunHostInitiatedAuthExchange()69 void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
70 ContinueAuthExchangeWith(host_.get(),
71 client_.get(),
72 host_->started(),
73 client_->started());
74 }
75
76 // static
77 // This function sends a message from the sender and receiver and recursively
78 // calls itself to the send the next message from the receiver to the sender
79 // untils the authentication completes.
ContinueAuthExchangeWith(Authenticator * sender,Authenticator * receiver,bool sender_started,bool receiver_started)80 void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
81 Authenticator* receiver,
82 bool sender_started,
83 bool receiver_started) {
84 scoped_ptr<buzz::XmlElement> message;
85 ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state());
86 if (sender->state() == Authenticator::ACCEPTED ||
87 sender->state() == Authenticator::REJECTED)
88 return;
89
90 // Verify that once the started flag for either party is set to true,
91 // it should always stay true.
92 if (receiver_started) {
93 ASSERT_TRUE(receiver->started());
94 }
95
96 if (sender_started) {
97 ASSERT_TRUE(sender->started());
98 }
99
100 ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state());
101 message = sender->GetNextMessage();
102 ASSERT_TRUE(message.get());
103 ASSERT_NE(Authenticator::MESSAGE_READY, sender->state());
104
105 ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state());
106 receiver->ProcessMessage(message.get(), base::Bind(
107 &AuthenticatorTestBase::ContinueAuthExchangeWith,
108 base::Unretained(receiver), base::Unretained(sender),
109 receiver->started(), sender->started()));
110 }
111
RunChannelAuth(bool expected_fail)112 void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
113 client_fake_socket_.reset(new FakeSocket());
114 host_fake_socket_.reset(new FakeSocket());
115 client_fake_socket_->PairWith(host_fake_socket_.get());
116
117 client_auth_->SecureAndAuthenticate(
118 client_fake_socket_.PassAs<net::StreamSocket>(),
119 base::Bind(&AuthenticatorTestBase::OnClientConnected,
120 base::Unretained(this)));
121
122 host_auth_->SecureAndAuthenticate(
123 host_fake_socket_.PassAs<net::StreamSocket>(),
124 base::Bind(&AuthenticatorTestBase::OnHostConnected,
125 base::Unretained(this)));
126
127 // Expect two callbacks to be called - the client callback and the host
128 // callback.
129 int callback_counter = 2;
130
131 EXPECT_CALL(client_callback_, OnDone(net::OK))
132 .WillOnce(QuitThreadOnCounter(&callback_counter));
133 if (expected_fail) {
134 EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED))
135 .WillOnce(QuitThreadOnCounter(&callback_counter));
136 } else {
137 EXPECT_CALL(host_callback_, OnDone(net::OK))
138 .WillOnce(QuitThreadOnCounter(&callback_counter));
139 }
140
141 // Ensure that .Run() does not run unbounded if the callbacks are never
142 // called.
143 base::Timer shutdown_timer(false, false);
144 shutdown_timer.Start(FROM_HERE,
145 TestTimeouts::action_timeout(),
146 base::MessageLoop::QuitClosure());
147 message_loop_.Run();
148 shutdown_timer.Stop();
149
150 testing::Mock::VerifyAndClearExpectations(&client_callback_);
151 testing::Mock::VerifyAndClearExpectations(&host_callback_);
152
153 if (!expected_fail) {
154 ASSERT_TRUE(client_socket_.get() != NULL);
155 ASSERT_TRUE(host_socket_.get() != NULL);
156 }
157 }
158
OnHostConnected(net::Error error,scoped_ptr<net::StreamSocket> socket)159 void AuthenticatorTestBase::OnHostConnected(
160 net::Error error,
161 scoped_ptr<net::StreamSocket> socket) {
162 host_callback_.OnDone(error);
163 host_socket_ = socket.Pass();
164 }
165
OnClientConnected(net::Error error,scoped_ptr<net::StreamSocket> socket)166 void AuthenticatorTestBase::OnClientConnected(
167 net::Error error,
168 scoped_ptr<net::StreamSocket> socket) {
169 client_callback_.OnDone(error);
170 client_socket_ = socket.Pass();
171 }
172
173 } // namespace protocol
174 } // namespace remoting
175