• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
6 
7 #include "base/bind.h"
8 #include "base/bind_helpers.h"
9 #include "crypto/secure_util.h"
10 #include "net/base/host_port_pair.h"
11 #include "net/base/io_buffer.h"
12 #include "net/base/net_errors.h"
13 #include "net/cert/cert_verifier.h"
14 #include "net/cert/x509_certificate.h"
15 #include "net/http/transport_security_state.h"
16 #include "net/socket/client_socket_factory.h"
17 #include "net/socket/client_socket_handle.h"
18 #include "net/socket/ssl_client_socket.h"
19 #include "net/socket/ssl_server_socket.h"
20 #include "net/ssl/ssl_config_service.h"
21 #include "remoting/base/rsa_key_pair.h"
22 #include "remoting/protocol/auth_util.h"
23 
24 namespace remoting {
25 namespace protocol {
26 
27 // static
28 scoped_ptr<SslHmacChannelAuthenticator>
CreateForClient(const std::string & remote_cert,const std::string & auth_key)29 SslHmacChannelAuthenticator::CreateForClient(
30       const std::string& remote_cert,
31       const std::string& auth_key) {
32   scoped_ptr<SslHmacChannelAuthenticator> result(
33       new SslHmacChannelAuthenticator(auth_key));
34   result->remote_cert_ = remote_cert;
35   return result.Pass();
36 }
37 
38 scoped_ptr<SslHmacChannelAuthenticator>
CreateForHost(const std::string & local_cert,scoped_refptr<RsaKeyPair> key_pair,const std::string & auth_key)39 SslHmacChannelAuthenticator::CreateForHost(
40     const std::string& local_cert,
41     scoped_refptr<RsaKeyPair> key_pair,
42     const std::string& auth_key) {
43   scoped_ptr<SslHmacChannelAuthenticator> result(
44       new SslHmacChannelAuthenticator(auth_key));
45   result->local_cert_ = local_cert;
46   result->local_key_pair_ = key_pair;
47   return result.Pass();
48 }
49 
SslHmacChannelAuthenticator(const std::string & auth_key)50 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
51     const std::string& auth_key)
52     : auth_key_(auth_key) {
53 }
54 
~SslHmacChannelAuthenticator()55 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
56 }
57 
SecureAndAuthenticate(scoped_ptr<net::StreamSocket> socket,const DoneCallback & done_callback)58 void SslHmacChannelAuthenticator::SecureAndAuthenticate(
59     scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) {
60   DCHECK(CalledOnValidThread());
61   DCHECK(socket->IsConnected());
62 
63   done_callback_ = done_callback;
64 
65   int result;
66   if (is_ssl_server()) {
67     scoped_refptr<net::X509Certificate> cert =
68         net::X509Certificate::CreateFromBytes(
69             local_cert_.data(), local_cert_.length());
70     if (!cert.get()) {
71       LOG(ERROR) << "Failed to parse X509Certificate";
72       NotifyError(net::ERR_FAILED);
73       return;
74     }
75 
76     net::SSLConfig ssl_config;
77     ssl_config.require_forward_secrecy = true;
78 
79     scoped_ptr<net::SSLServerSocket> server_socket =
80         net::CreateSSLServerSocket(socket.Pass(),
81                                    cert.get(),
82                                    local_key_pair_->private_key(),
83                                    ssl_config);
84     net::SSLServerSocket* raw_server_socket = server_socket.get();
85     socket_ = server_socket.Pass();
86     result = raw_server_socket->Handshake(
87         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
88                    base::Unretained(this)));
89   } else {
90     cert_verifier_.reset(net::CertVerifier::CreateDefault());
91     transport_security_state_.reset(new net::TransportSecurityState);
92 
93     net::SSLConfig::CertAndStatus cert_and_status;
94     cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
95     cert_and_status.der_cert = remote_cert_;
96 
97     net::SSLConfig ssl_config;
98     // Certificate verification and revocation checking are not needed
99     // because we use self-signed certs. Disable it so that the SSL
100     // layer doesn't try to initialize OCSP (OCSP works only on the IO
101     // thread).
102     ssl_config.cert_io_enabled = false;
103     ssl_config.rev_checking_enabled = false;
104     ssl_config.allowed_bad_certs.push_back(cert_and_status);
105 
106     net::HostPortPair host_and_port(kSslFakeHostName, 0);
107     net::SSLClientSocketContext context;
108     context.cert_verifier = cert_verifier_.get();
109     context.transport_security_state = transport_security_state_.get();
110     scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
111     connection->SetSocket(socket.Pass());
112     socket_ =
113         net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
114             connection.Pass(), host_and_port, ssl_config, context);
115 
116     result = socket_->Connect(
117         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
118                    base::Unretained(this)));
119   }
120 
121   if (result == net::ERR_IO_PENDING)
122     return;
123 
124   OnConnected(result);
125 }
126 
is_ssl_server()127 bool SslHmacChannelAuthenticator::is_ssl_server() {
128   return local_key_pair_.get() != NULL;
129 }
130 
OnConnected(int result)131 void SslHmacChannelAuthenticator::OnConnected(int result) {
132   if (result != net::OK) {
133     LOG(WARNING) << "Failed to establish SSL connection";
134     NotifyError(result);
135     return;
136   }
137 
138   // Generate authentication digest to write to the socket.
139   std::string auth_bytes = GetAuthBytes(
140       socket_.get(), is_ssl_server() ?
141       kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
142   if (auth_bytes.empty()) {
143     NotifyError(net::ERR_FAILED);
144     return;
145   }
146 
147   // Allocate a buffer to write the digest.
148   auth_write_buf_ = new net::DrainableIOBuffer(
149       new net::StringIOBuffer(auth_bytes), auth_bytes.size());
150 
151   // Read an incoming token.
152   auth_read_buf_ = new net::GrowableIOBuffer();
153   auth_read_buf_->SetCapacity(kAuthDigestLength);
154 
155   // If WriteAuthenticationBytes() results in |done_callback_| being
156   // called then we must not do anything else because this object may
157   // be destroyed at that point.
158   bool callback_called = false;
159   WriteAuthenticationBytes(&callback_called);
160   if (!callback_called)
161     ReadAuthenticationBytes();
162 }
163 
WriteAuthenticationBytes(bool * callback_called)164 void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
165     bool* callback_called) {
166   while (true) {
167     int result = socket_->Write(
168         auth_write_buf_.get(),
169         auth_write_buf_->BytesRemaining(),
170         base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
171                    base::Unretained(this)));
172     if (result == net::ERR_IO_PENDING)
173       break;
174     if (!HandleAuthBytesWritten(result, callback_called))
175       break;
176   }
177 }
178 
OnAuthBytesWritten(int result)179 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
180   DCHECK(CalledOnValidThread());
181 
182   if (HandleAuthBytesWritten(result, NULL))
183     WriteAuthenticationBytes(NULL);
184 }
185 
HandleAuthBytesWritten(int result,bool * callback_called)186 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
187     int result, bool* callback_called) {
188   if (result <= 0) {
189     LOG(ERROR) << "Error writing authentication: " << result;
190     if (callback_called)
191       *callback_called = false;
192     NotifyError(result);
193     return false;
194   }
195 
196   auth_write_buf_->DidConsume(result);
197   if (auth_write_buf_->BytesRemaining() > 0)
198     return true;
199 
200   auth_write_buf_ = NULL;
201   CheckDone(callback_called);
202   return false;
203 }
204 
ReadAuthenticationBytes()205 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
206   while (true) {
207     int result =
208         socket_->Read(auth_read_buf_.get(),
209                       auth_read_buf_->RemainingCapacity(),
210                       base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
211                                  base::Unretained(this)));
212     if (result == net::ERR_IO_PENDING)
213       break;
214     if (!HandleAuthBytesRead(result))
215       break;
216   }
217 }
218 
OnAuthBytesRead(int result)219 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
220   DCHECK(CalledOnValidThread());
221 
222   if (HandleAuthBytesRead(result))
223     ReadAuthenticationBytes();
224 }
225 
HandleAuthBytesRead(int read_result)226 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
227   if (read_result <= 0) {
228     NotifyError(read_result);
229     return false;
230   }
231 
232   auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
233   if (auth_read_buf_->RemainingCapacity() > 0)
234     return true;
235 
236   if (!VerifyAuthBytes(std::string(
237           auth_read_buf_->StartOfBuffer(),
238           auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
239     LOG(WARNING) << "Mismatched authentication";
240     NotifyError(net::ERR_FAILED);
241     return false;
242   }
243 
244   auth_read_buf_ = NULL;
245   CheckDone(NULL);
246   return false;
247 }
248 
VerifyAuthBytes(const std::string & received_auth_bytes)249 bool SslHmacChannelAuthenticator::VerifyAuthBytes(
250     const std::string& received_auth_bytes) {
251   DCHECK(received_auth_bytes.length() == kAuthDigestLength);
252 
253   // Compute expected auth bytes.
254   std::string auth_bytes = GetAuthBytes(
255       socket_.get(), is_ssl_server() ?
256       kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
257   if (auth_bytes.empty())
258     return false;
259 
260   return crypto::SecureMemEqual(received_auth_bytes.data(),
261                                 &(auth_bytes[0]), kAuthDigestLength);
262 }
263 
CheckDone(bool * callback_called)264 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
265   if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) {
266     DCHECK(socket_.get() != NULL);
267     if (callback_called)
268       *callback_called = true;
269     done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>());
270   }
271 }
272 
NotifyError(int error)273 void SslHmacChannelAuthenticator::NotifyError(int error) {
274   done_callback_.Run(static_cast<net::Error>(error),
275                      scoped_ptr<net::StreamSocket>());
276 }
277 
278 }  // namespace protocol
279 }  // namespace remoting
280