1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "platform/impl/tls_connection_factory_posix.h"
6
7 #include <errno.h>
8 #include <fcntl.h>
9 #include <netinet/in.h>
10 #include <netinet/ip.h>
11 #include <openssl/ssl.h>
12 #include <sys/ioctl.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
15 #include <unistd.h>
16
17 #include <cstring>
18 #include <utility>
19 #include <vector>
20
21 #include "platform/api/task_runner.h"
22 #include "platform/api/tls_connection_factory.h"
23 #include "platform/base/tls_connect_options.h"
24 #include "platform/base/tls_credentials.h"
25 #include "platform/base/tls_listen_options.h"
26 #include "platform/impl/stream_socket.h"
27 #include "platform/impl/tls_connection_posix.h"
28 #include "util/crypto/certificate_utils.h"
29 #include "util/crypto/openssl_util.h"
30 #include "util/osp_logging.h"
31 #include "util/trace_logging.h"
32
33 namespace openscreen {
34
35 namespace {
36
GetDEREncodedPeerCertificate(const SSL & ssl)37 ErrorOr<std::vector<uint8_t>> GetDEREncodedPeerCertificate(const SSL& ssl) {
38 X509* const peer_cert = SSL_get_peer_certificate(&ssl);
39 ErrorOr<std::vector<uint8_t>> der_peer_cert =
40 ExportX509CertificateToDer(*peer_cert);
41 X509_free(peer_cert);
42 return der_peer_cert;
43 }
44
45 } // namespace
46
CreateFactory(Client * client,TaskRunner * task_runner)47 std::unique_ptr<TlsConnectionFactory> TlsConnectionFactory::CreateFactory(
48 Client* client,
49 TaskRunner* task_runner) {
50 return std::unique_ptr<TlsConnectionFactory>(
51 new TlsConnectionFactoryPosix(client, task_runner));
52 }
53
TlsConnectionFactoryPosix(Client * client,TaskRunner * task_runner,PlatformClientPosix * platform_client)54 TlsConnectionFactoryPosix::TlsConnectionFactoryPosix(
55 Client* client,
56 TaskRunner* task_runner,
57 PlatformClientPosix* platform_client)
58 : client_(client),
59 task_runner_(task_runner),
60 platform_client_(platform_client) {
61 OSP_DCHECK(client_);
62 OSP_DCHECK(task_runner_);
63 }
64
~TlsConnectionFactoryPosix()65 TlsConnectionFactoryPosix::~TlsConnectionFactoryPosix() {
66 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
67 if (platform_client_) {
68 platform_client_->tls_data_router()->DeregisterAcceptObserver(this);
69 }
70 }
71
72 // TODO(rwkeane): Add support for resuming sessions.
73 // TODO(rwkeane): Integrate with Auth.
Connect(const IPEndpoint & remote_address,const TlsConnectOptions & options)74 void TlsConnectionFactoryPosix::Connect(const IPEndpoint& remote_address,
75 const TlsConnectOptions& options) {
76 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
77 TRACE_SCOPED(TraceCategory::kSsl, "TlsConnectionFactoryPosix::Connect");
78 IPAddress::Version version = remote_address.address.version();
79 std::unique_ptr<TlsConnectionPosix> connection(
80 new TlsConnectionPosix(version, task_runner_));
81 Error connect_error = connection->socket_->Connect(remote_address);
82 if (!connect_error.ok()) {
83 TRACE_SET_RESULT(connect_error);
84 DispatchConnectionFailed(remote_address);
85 return;
86 }
87
88 if (!ConfigureSsl(connection.get())) {
89 return;
90 }
91
92 if (options.unsafely_skip_certificate_validation) {
93 // Verifies the server certificate but does not make errors fatal.
94 SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_NONE, nullptr);
95 } else {
96 // Make server certificate errors fatal.
97 SSL_set_verify(connection->ssl_.get(), SSL_VERIFY_PEER, nullptr);
98 }
99
100 Connect(std::move(connection));
101 }
102
SetListenCredentials(const TlsCredentials & credentials)103 void TlsConnectionFactoryPosix::SetListenCredentials(
104 const TlsCredentials& credentials) {
105 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
106 EnsureInitialized();
107
108 ErrorOr<bssl::UniquePtr<X509>> cert = ImportCertificate(
109 credentials.der_x509_cert.data(), credentials.der_x509_cert.size());
110 ErrorOr<bssl::UniquePtr<EVP_PKEY>> pkey =
111 ImportRSAPrivateKey(credentials.der_rsa_private_key.data(),
112 credentials.der_rsa_private_key.size());
113
114 if (!cert || !pkey ||
115 SSL_CTX_use_certificate(ssl_context_.get(), cert.value().get()) != 1 ||
116 SSL_CTX_use_PrivateKey(ssl_context_.get(), pkey.value().get()) != 1) {
117 DispatchError(Error::Code::kSocketListenFailure);
118 TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
119 return;
120 }
121
122 listen_credentials_set_ = true;
123 }
124
Listen(const IPEndpoint & local_address,const TlsListenOptions & options)125 void TlsConnectionFactoryPosix::Listen(const IPEndpoint& local_address,
126 const TlsListenOptions& options) {
127 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
128 // Credentials must be set before Listen() is called.
129 OSP_DCHECK(listen_credentials_set_);
130
131 auto socket = std::make_unique<StreamSocketPosix>(local_address);
132 socket->Bind();
133 socket->Listen(options.backlog_size);
134 if (socket->state() == TcpSocketState::kClosed) {
135 DispatchError(Error::Code::kSocketListenFailure);
136 TRACE_SET_RESULT(Error::Code::kSocketListenFailure);
137 return;
138 }
139 OSP_DCHECK(socket->state() == TcpSocketState::kListening);
140
141 OSP_DCHECK(platform_client_);
142 if (platform_client_) {
143 platform_client_->tls_data_router()->RegisterAcceptObserver(
144 std::move(socket), this);
145 }
146 }
147
OnConnectionPending(StreamSocketPosix * socket)148 void TlsConnectionFactoryPosix::OnConnectionPending(StreamSocketPosix* socket) {
149 task_runner_->PostTask([connection_factory_weak_ptr =
150 weak_factory_.GetWeakPtr(),
151 socket_weak_ptr = socket->GetWeakPtr()] {
152 if (!connection_factory_weak_ptr || !socket_weak_ptr) {
153 // Cancel the Accept() since either the factory or the listener socket
154 // went away before this task has run.
155 return;
156 }
157
158 ErrorOr<std::unique_ptr<StreamSocket>> accepted = socket_weak_ptr->Accept();
159 if (accepted.is_error()) {
160 // Check for special error code. Because this call doesn't get executed
161 // until it gets through the task runner, OnConnectionPending may get
162 // called multiple times. This check ensures only the first such call will
163 // create a new SSL connection.
164 if (accepted.error().code() != Error::Code::kAgain) {
165 connection_factory_weak_ptr->DispatchError(std::move(accepted.error()));
166 }
167 return;
168 }
169
170 connection_factory_weak_ptr->OnSocketAccepted(std::move(accepted.value()));
171 });
172 }
173
OnSocketAccepted(std::unique_ptr<StreamSocket> socket)174 void TlsConnectionFactoryPosix::OnSocketAccepted(
175 std::unique_ptr<StreamSocket> socket) {
176 OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
177
178 TRACE_SCOPED(TraceCategory::kSsl,
179 "TlsConnectionFactoryPosix::OnSocketAccepted");
180 std::unique_ptr<TlsConnectionPosix> connection(
181 new TlsConnectionPosix(std::move(socket), task_runner_));
182
183 if (!ConfigureSsl(connection.get())) {
184 return;
185 }
186
187 Accept(std::move(connection));
188 }
189
ConfigureSsl(TlsConnectionPosix * connection)190 bool TlsConnectionFactoryPosix::ConfigureSsl(TlsConnectionPosix* connection) {
191 ErrorOr<bssl::UniquePtr<SSL>> connection_result = GetSslConnection();
192 if (connection_result.is_error()) {
193 DispatchError(connection_result.error());
194 TRACE_SET_RESULT(connection_result.error());
195 return false;
196 }
197
198 bssl::UniquePtr<SSL> ssl = std::move(connection_result.value());
199 if (!SSL_set_fd(ssl.get(), connection->socket_->socket_handle().fd)) {
200 DispatchConnectionFailed(connection->GetRemoteEndpoint());
201 TRACE_SET_RESULT(Error(Error::Code::kSocketBindFailure));
202 return false;
203 }
204
205 connection->ssl_.swap(ssl);
206 return true;
207 }
208
GetSslConnection()209 ErrorOr<bssl::UniquePtr<SSL>> TlsConnectionFactoryPosix::GetSslConnection() {
210 EnsureInitialized();
211 if (!ssl_context_.get()) {
212 return Error::Code::kFatalSSLError;
213 }
214
215 SSL* ssl = SSL_new(ssl_context_.get());
216 if (ssl == nullptr) {
217 return Error::Code::kFatalSSLError;
218 }
219
220 return bssl::UniquePtr<SSL>(ssl);
221 }
222
EnsureInitialized()223 void TlsConnectionFactoryPosix::EnsureInitialized() {
224 std::call_once(init_instance_flag_, [this]() { this->Initialize(); });
225 }
226
Initialize()227 void TlsConnectionFactoryPosix::Initialize() {
228 EnsureOpenSSLInit();
229 SSL_CTX* context = SSL_CTX_new(TLS_method());
230 if (context == nullptr) {
231 return;
232 }
233
234 SSL_CTX_set_mode(context, SSL_MODE_ENABLE_PARTIAL_WRITE);
235
236 ssl_context_.reset(context);
237 }
238
Connect(std::unique_ptr<TlsConnectionPosix> connection)239 void TlsConnectionFactoryPosix::Connect(
240 std::unique_ptr<TlsConnectionPosix> connection) {
241 if (connection->socket_->state() == TcpSocketState::kClosed) {
242 return;
243 }
244 OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected);
245 ClearOpenSSLERRStack(CURRENT_LOCATION);
246 const int connection_status = SSL_connect(connection->ssl_.get());
247 if (connection_status != 1) {
248 Error error = GetSSLError(connection->ssl_.get(), connection_status);
249 if (error.code() == Error::Code::kAgain) {
250 task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
251 conn = std::move(connection)]() mutable {
252 if (auto* self = weak_this.get()) {
253 self->Connect(std::move(conn));
254 }
255 });
256 return;
257 } else {
258 OSP_DVLOG << "SSL_connect failed with error: " << error;
259 DispatchConnectionFailed(connection->GetRemoteEndpoint());
260 TRACE_SET_RESULT(error);
261 return;
262 }
263 }
264
265 ErrorOr<std::vector<uint8_t>> der_peer_cert =
266 GetDEREncodedPeerCertificate(*connection->ssl_);
267 if (!der_peer_cert) {
268 DispatchConnectionFailed(connection->GetRemoteEndpoint());
269 TRACE_SET_RESULT(der_peer_cert.error());
270 return;
271 }
272
273 connection->RegisterConnectionWithDataRouter(platform_client_);
274 task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
275 der = std::move(der_peer_cert.value()),
276 moved_connection = std::move(connection)]() mutable {
277 if (auto* self = weak_this.get()) {
278 self->client_->OnConnected(self, std::move(der),
279 std::move(moved_connection));
280 }
281 });
282 }
283
Accept(std::unique_ptr<TlsConnectionPosix> connection)284 void TlsConnectionFactoryPosix::Accept(
285 std::unique_ptr<TlsConnectionPosix> connection) {
286 if (connection->socket_->state() == TcpSocketState::kClosed) {
287 return;
288 }
289 OSP_DCHECK(connection->socket_->state() == TcpSocketState::kConnected);
290
291 ClearOpenSSLERRStack(CURRENT_LOCATION);
292 const int connection_status = SSL_accept(connection->ssl_.get());
293 if (connection_status != 1) {
294 Error error = GetSSLError(connection->ssl_.get(), connection_status);
295 if (error.code() == Error::Code::kAgain) {
296 task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
297 conn = std::move(connection)]() mutable {
298 if (auto* self = weak_this.get()) {
299 self->Accept(std::move(conn));
300 }
301 });
302 return;
303 } else {
304 OSP_DVLOG << "SSL_accept failed with error: " << error;
305 DispatchConnectionFailed(connection->GetRemoteEndpoint());
306 TRACE_SET_RESULT(error);
307 return;
308 }
309 }
310
311 ErrorOr<std::vector<uint8_t>> der_peer_cert =
312 GetDEREncodedPeerCertificate(*connection->ssl_);
313 std::vector<uint8_t> der;
314 if (der_peer_cert) {
315 der = std::move(der_peer_cert.value());
316 }
317 connection->RegisterConnectionWithDataRouter(platform_client_);
318 task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
319 der = std::move(der),
320 moved_connection = std::move(connection)]() mutable {
321 if (auto* self = weak_this.get()) {
322 self->client_->OnAccepted(self, std::move(der),
323 std::move(moved_connection));
324 }
325 });
326 }
327
DispatchConnectionFailed(const IPEndpoint & remote_endpoint)328 void TlsConnectionFactoryPosix::DispatchConnectionFailed(
329 const IPEndpoint& remote_endpoint) {
330 task_runner_->PostTask(
331 [weak_this = weak_factory_.GetWeakPtr(), remote = remote_endpoint] {
332 if (auto* self = weak_this.get()) {
333 self->client_->OnConnectionFailed(self, remote);
334 }
335 });
336 }
337
DispatchError(Error error)338 void TlsConnectionFactoryPosix::DispatchError(Error error) {
339 task_runner_->PostTask([weak_this = weak_factory_.GetWeakPtr(),
340 moved_error = std::move(error)]() mutable {
341 if (auto* self = weak_this.get()) {
342 self->client_->OnError(self, std::move(moved_error));
343 }
344 });
345 }
346
347 } // namespace openscreen
348