• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "platform/impl/tls_connection_factory_posix.h"
6 
7 #include <errno.h>
8 #include <fcntl.h>
9 #include <netinet/in.h>
10 #include <netinet/ip.h>
11 #include <openssl/ssl.h>
12 #include <sys/ioctl.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
16 
17 #include <cstring>
18 #include <utility>
19 #include <vector>
20 
21 #include "platform/api/task_runner.h"
22 #include "platform/api/tls_connection_factory.h"
23 #include "platform/base/tls_connect_options.h"
24 #include "platform/base/tls_credentials.h"
25 #include "platform/base/tls_listen_options.h"
26 #include "platform/impl/stream_socket.h"
27 #include "platform/impl/tls_connection_posix.h"
28 #include "util/crypto/certificate_utils.h"
29 #include "util/crypto/openssl_util.h"
30 #include "util/osp_logging.h"
31 #include "util/trace_logging.h"
32 
33 namespace openscreen {
34 
35 namespace {
36 
GetDEREncodedPeerCertificate(const SSL & ssl)37 ErrorOr<std::vector<uint8_t>> GetDEREncodedPeerCertificate(const SSL& ssl) {
38   X509* const peer_cert = SSL_get_peer_certificate(&ssl);
39   ErrorOr<std::vector<uint8_t>> der_peer_cert =
40       ExportX509CertificateToDer(*peer_cert);
41   X509_free(peer_cert);
42   return der_peer_cert;
43 }
44 
45 }  // namespace
46 
CreateFactory(Client * client,TaskRunner * task_runner)47 std::unique_ptr<TlsConnectionFactory> TlsConnectionFactory::CreateFactory(
48     Client* client,
49     TaskRunner* task_runner) {
50   return std::unique_ptr<TlsConnectionFactory>(
51       new TlsConnectionFactoryPosix(client, task_runner));
52 }
53 
TlsConnectionFactoryPosix(Client * client,TaskRunner * task_runner,PlatformClientPosix * platform_client)54 TlsConnectionFactoryPosix::TlsConnectionFactoryPosix(
55     Client* client,
56     TaskRunner* task_runner,
57     PlatformClientPosix* platform_client)
58     : client_(client),
59       task_runner_(task_runner),
60       platform_client_(platform_client) {
61   OSP_DCHECK(client_);
62   OSP_DCHECK(task_runner_);
63 }
64 
~TlsConnectionFactoryPosix()65 TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() {
66   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
67   if (platform_client_) {
68     platform_client_->tls_data_router()->DeregisterAcceptObserver(this);
69   }
70 }
71 
72 // TODO(rwkeane): Add support for resuming sessions.
73 // TODO(rwkeane): Integrate with Auth.
Connect(const IPEndpoint & remote_address,const TlsConnectOptions & options)74 void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
75                                         const TlsConnectOptions& options) {
76   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
77   TRACE_SCOPED(TraceCategory::kSsl, "TlsConnectionFactoryPosix::Connect");
78   IPAddress::Version version = remote_address.address.version();
79   std::unique_ptr<TlsConnectionPosix> connection(
80       new TlsConnectionPosix(version, task_runner_));
81   Error connect_error = connection->socket_->Connect(remote_address);
82   if (!connect_error.ok()) {
83     TRACE_SET_RESULT(connect_error);
84     DispatchConnectionFailed(remote_address);
85     return;
86   }
87 
88   if (!ConfigureSsl(connection.get())) {
89     return;
90   }
91 
92   if (options.unsafely_skip_certificate_validation) {
93     // Verifies the server certificate but does not make errors fatal.
94     SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_NONE, nullptr);
95   } else {
96     // Make server certificate errors fatal.
97     SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_PEER, nullptr);
98   }
99 
100   Connect(std::move(connection));
101 }
102 
SetListenCredentials(const TlsCredentials & credentials)103 void TlsConnectionFactoryPosix::SetListenCredentials(
104     const TlsCredentials& credentials) {
105   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
106   EnsureInitialized();
107 
108   ErrorOr<bssl::UniquePtr<X509>> cert = ImportCertificate(
109       credentials.der_x509_cert.data(), credentials.der_x509_cert.size());
110   ErrorOr<bssl::UniquePtr<EVP_PKEY>> pkey =
111       ImportRSAPrivateKey(credentials.der_rsa_private_key.data(),
112                           credentials.der_rsa_private_key.size());
113 
114   if (!cert || !pkey ||
115       SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1 ||
116       SSL_CTX_use_PrivateKey(ssl_context_.get(), pkey.value().get()) != 1) {
117     DispatchError(Error::Code::kSocketListenFailure);
118     TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
119     return;
120   }
121 
122   listen_credentials_set_ = true;
123 }
124 
Listen(const IPEndpoint & local_address,const TlsListenOptions & options)125 void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address,
126                                        const TlsListenOptions& options) {
127   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
128   // Credentials must be set before Listen() is called.
129   OSP_DCHECK(listen_credentials_set_);
130 
131   auto socket = std::make_unique<StreamSocketPosix>(local_address);
132   socket->Bind();
133   socket->Listen(options.backlog_size);
134   if (socket->state() == TcpSocketState::kClosed) {
135     DispatchError(Error::Code::kSocketListenFailure);
136     TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
137     return;
138   }
139   OSP_DCHECK(socket->state() == TcpSocketState::kListening);
140 
141   OSP_DCHECK(platform_client_);
142   if (platform_client_) {
143     platform_client_->tls_data_router()->RegisterAcceptObserver(
144         std::move(socket), this);
145   }
146 }
147 
OnConnectionPending(StreamSocketPosix * socket)148 void TlsConnectionFactoryPosix::OnConnectionPending(StreamSocketPosix* socket) {
149   task_runner_->PostTask([connection_factory_weak_ptr =
150                               weak_factory_.GetWeakPtr(),
151                           socket_weak_ptr = socket->GetWeakPtr()] {
152     if (!connection_factory_weak_ptr || !socket_weak_ptr) {
153       // Cancel the Accept() since either the factory or the listener socket
154       // went away before this task has run.
155       return;
156     }
157 
158     ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket_weak_ptr->Accept();
159     if (accepted.is_error()) {
160       // Check for special error code. Because this call doesn't get executed
161       // until it gets through the task runner, OnConnectionPending may get
162       // called multiple times. This check ensures only the first such call will
163       // create a new SSL connection.
164       if (accepted.error().code() != Error::Code::kAgain) {
165         connection_factory_weak_ptr->DispatchError(std::move(accepted.error()));
166       }
167       return;
168     }
169 
170     connection_factory_weak_ptr->OnSocketAccepted(std::move(accepted.value()));
171   });
172 }
173 
OnSocketAccepted(std::unique_ptr<StreamSocket> socket)174 void TlsConnectionFactoryPosix::OnSocketAccepted(
175     std::unique_ptr<StreamSocket> socket) {
176   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
177 
178   TRACE_SCOPED(TraceCategory::kSsl,
179                "TlsConnectionFactoryPosix::OnSocketAccepted");
180   std::unique_ptr<TlsConnectionPosix> connection(
181       new TlsConnectionPosix(std::move(socket), task_runner_));
182 
183   if (!ConfigureSsl(connection.get())) {
184     return;
185   }
186 
187   Accept(std::move(connection));
188 }
189 
ConfigureSsl(TlsConnectionPosix * connection)190 bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) {
191   ErrorOr<bssl::UniquePtr<SSL>> connection_result = GetSslConnection();
192   if (connection_result.is_error()) {
193     DispatchError(connection_result.error());
194     TRACE_SET_RESULT(connection_result.error());
195     return false;
196   }
197 
198   bssl::UniquePtr<SSL> ssl = std::move(connection_result.value());
199   if (!SSL_set_fd(ssl.get(), connection->socket_->socket_handle().fd)) {
200     DispatchConnectionFailed(connection->GetRemoteEndpoint());
201     TRACE_SET_RESULT(Error(Error::Code::kSocketBindFailure));
202     return false;
203   }
204 
205   connection->ssl_.swap(ssl);
206   return true;
207 }
208 
GetSslConnection()209 ErrorOr<bssl::UniquePtr<SSL>> TlsConnectionFactoryPosix::GetSslConnection() {
210   EnsureInitialized();
211   if (!ssl_context_.get()) {
212     return Error::Code::kFatalSSLError;
213   }
214 
215   SSL* ssl = SSL_new(ssl_context_.get());
216   if (ssl == nullptr) {
217     return Error::Code::kFatalSSLError;
218   }
219 
220   return bssl::UniquePtr<SSL>(ssl);
221 }
222 
EnsureInitialized()223 void TlsConnectionFactoryPosix::EnsureInitialized() {
224   std::call_once(init_instance_flag_, [this]() { this->Initialize(); });
225 }
226 
Initialize()227 void TlsConnectionFactoryPosix::Initialize() {
228   EnsureOpenSSLInit();
229   SSL_CTX* context = SSL_CTX_new(TLS_method());
230   if (context == nullptr) {
231     return;
232   }
233 
234   SSL_CTX_set_mode(context, SSL_MODE_ENABLE_PARTIAL_WRITE);
235 
236   ssl_context_.reset(context);
237 }
238 
Connect(std::unique_ptr<TlsConnectionPosix> connection)239 void TlsConnectionFactoryPosix::Connect(
240     std::unique_ptr<TlsConnectionPosix> connection) {
241   if (connection->socket_->state() == TcpSocketState::kClosed) {
242     return;
243   }
244   OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected);
245   ClearOpenSSLERRStack(CURRENT_LOCATION);
246   const int connection_status = SSL_connect(connection->ssl_.get());
247   if (connection_status != 1) {
248     Error error = GetSSLError(connection->ssl_.get(), connection_status);
249     if (error.code() == Error::Code::kAgain) {
250       task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
251                               conn = std::move(connection)]() mutable {
252         if (auto* self = weak_this.get()) {
253           self->Connect(std::move(conn));
254         }
255       });
256       return;
257     } else {
258       OSP_DVLOG << "SSL_connect failed with error: " << error;
259       DispatchConnectionFailed(connection->GetRemoteEndpoint());
260       TRACE_SET_RESULT(error);
261       return;
262     }
263   }
264 
265   ErrorOr<std::vector<uint8_t>> der_peer_cert =
266       GetDEREncodedPeerCertificate(*connection->ssl_);
267   if (!der_peer_cert) {
268     DispatchConnectionFailed(connection->GetRemoteEndpoint());
269     TRACE_SET_RESULT(der_peer_cert.error());
270     return;
271   }
272 
273   connection->RegisterConnectionWithDataRouter(platform_client_);
274   task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
275                           der = std::move(der_peer_cert.value()),
276                           moved_connection = std::move(connection)]() mutable {
277     if (auto* self = weak_this.get()) {
278       self->client_->OnConnected(self, std::move(der),
279                                  std::move(moved_connection));
280     }
281   });
282 }
283 
Accept(std::unique_ptr<TlsConnectionPosix> connection)284 void TlsConnectionFactoryPosix::Accept(
285     std::unique_ptr<TlsConnectionPosix> connection) {
286   if (connection->socket_->state() == TcpSocketState::kClosed) {
287     return;
288   }
289   OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected);
290 
291   ClearOpenSSLERRStack(CURRENT_LOCATION);
292   const int connection_status = SSL_accept(connection->ssl_.get());
293   if (connection_status != 1) {
294     Error error = GetSSLError(connection->ssl_.get(), connection_status);
295     if (error.code() == Error::Code::kAgain) {
296       task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
297                               conn = std::move(connection)]() mutable {
298         if (auto* self = weak_this.get()) {
299           self->Accept(std::move(conn));
300         }
301       });
302       return;
303     } else {
304       OSP_DVLOG << "SSL_accept failed with error: " << error;
305       DispatchConnectionFailed(connection->GetRemoteEndpoint());
306       TRACE_SET_RESULT(error);
307       return;
308     }
309   }
310 
311   ErrorOr<std::vector<uint8_t>> der_peer_cert =
312       GetDEREncodedPeerCertificate(*connection->ssl_);
313   std::vector<uint8_t> der;
314   if (der_peer_cert) {
315     der = std::move(der_peer_cert.value());
316   }
317   connection->RegisterConnectionWithDataRouter(platform_client_);
318   task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
319                           der = std::move(der),
320                           moved_connection = std::move(connection)]() mutable {
321     if (auto* self = weak_this.get()) {
322       self->client_->OnAccepted(self, std::move(der),
323                                 std::move(moved_connection));
324     }
325   });
326 }
327 
DispatchConnectionFailed(const IPEndpoint & remote_endpoint)328 void TlsConnectionFactoryPosix::DispatchConnectionFailed(
329     const IPEndpoint& remote_endpoint) {
330   task_runner_->PostTask(
331       [weak_this = weak_factory_.GetWeakPtr(), remote = remote_endpoint] {
332         if (auto* self = weak_this.get()) {
333           self->client_->OnConnectionFailed(self, remote);
334         }
335       });
336 }
337 
DispatchError(Error error)338 void TlsConnectionFactoryPosix::DispatchError(Error error) {
339   task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
340                           moved_error = std::move(error)]() mutable {
341     if (auto* self = weak_this.get()) {
342       self->client_->OnError(self, std::move(moved_error));
343     }
344   });
345 }
346 
347 }  // namespace openscreen
348