1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <https/SSLSocket.h>
18
19 #include <https/SafeCallbackable.h>
20 #include <https/Support.h>
21 #include <glog/logging.h>
22 #include <sstream>
23 #include <sys/socket.h>
24
25 // static
Init()26 void SSLSocket::Init() {
27 SSL_library_init();
28 SSL_load_error_strings();
29 }
30
31 // static
CreateSSLContext()32 SSL_CTX *SSLSocket::CreateSSLContext() {
33 SSL_CTX *ctx = SSL_CTX_new(SSLv23_method());
34
35 /* Recommended to avoid SSLv2 & SSLv3 */
36 SSL_CTX_set_options(
37 ctx, SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
38
39 return ctx;
40 }
41
SSLSocket(std::shared_ptr<RunLoop> rl,Mode mode,int sock,uint32_t flags)42 SSLSocket::SSLSocket(
43 std::shared_ptr<RunLoop> rl, Mode mode, int sock, uint32_t flags)
44 : BufferedSocket(rl, sock),
45 mMode(mode),
46 mFlags(flags),
47 mCtx(CreateSSLContext(), SSL_CTX_free),
48 mSSL(SSL_new(mCtx.get()), SSL_free),
49 mBioR(BIO_new(BIO_s_mem())),
50 mBioW(BIO_new(BIO_s_mem())),
51 mEOS(false),
52 mFinalErrno(0),
53 mRecvPending(false),
54 mRecvCallback(nullptr),
55 mSendPending(false),
56 mFlushFn(nullptr) {
57 if (mMode == Mode::ACCEPT) {
58 SSL_set_accept_state(mSSL.get());
59 } else {
60 SSL_set_connect_state(mSSL.get());
61 }
62 SSL_set_bio(mSSL.get(), mBioR, mBioW);
63 }
64
useCertificate(const std::string & path)65 bool SSLSocket::useCertificate(const std::string &path) {
66 return 1 == SSL_use_certificate_file(
67 mSSL.get(), path.c_str(), SSL_FILETYPE_PEM);
68 }
69
usePrivateKey(const std::string & path)70 bool SSLSocket::usePrivateKey(const std::string &path) {
71 return 1 == SSL_use_PrivateKey_file(
72 mSSL.get(), path.c_str(), SSL_FILETYPE_PEM)
73 && 1 == SSL_check_private_key(mSSL.get());
74 }
75
useTrustedCertificates(const std::string & path)76 bool SSLSocket::useTrustedCertificates(const std::string &path) {
77 return 1 == SSL_CTX_load_verify_locations(
78 mCtx.get(),
79 path.c_str(),
80 nullptr /* CApath */);
81 }
82
SSLSocket(std::shared_ptr<RunLoop> rl,int sock,const std::string & certificate_pem_path,const std::string & private_key_pem_path,uint32_t flags)83 SSLSocket::SSLSocket(
84 std::shared_ptr<RunLoop> rl,
85 int sock,
86 const std::string &certificate_pem_path,
87 const std::string &private_key_pem_path,
88 uint32_t flags)
89 : SSLSocket(rl, Mode::ACCEPT, sock, flags) {
90
91 // This flag makes no sense for a server.
92 CHECK(!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE));
93
94 CHECK(useCertificate(certificate_pem_path)
95 && usePrivateKey(private_key_pem_path));
96 }
97
SSLSocket(std::shared_ptr<RunLoop> rl,int sock,uint32_t flags,const std::optional<std::string> & trusted_pem_path)98 SSLSocket::SSLSocket(
99 std::shared_ptr<RunLoop> rl,
100 int sock,
101 uint32_t flags,
102 const std::optional<std::string> &trusted_pem_path)
103 : SSLSocket(rl, Mode::CONNECT, sock, flags) {
104
105 if (!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)) {
106 CHECK(trusted_pem_path.has_value());
107 CHECK(useTrustedCertificates(*trusted_pem_path));
108 }
109 }
110
~SSLSocket()111 SSLSocket::~SSLSocket() {
112 SSL_shutdown(mSSL.get());
113
114 mBioW = mBioR = nullptr;
115 }
116
postRecv(RunLoop::AsyncFunction fn)117 void SSLSocket::postRecv(RunLoop::AsyncFunction fn) {
118 char tmp[128];
119 int n = SSL_peek(mSSL.get(), tmp, sizeof(tmp));
120
121 if (n > 0) {
122 fn();
123 return;
124 }
125
126 CHECK(mRecvCallback == nullptr);
127 mRecvCallback = fn;
128
129 if (!mRecvPending) {
130 mRecvPending = true;
131 runLoop()->postSocketRecv(
132 fd(),
133 makeSafeCallback(this, &SSLSocket::handleIncomingData));
134 }
135 }
136
handleIncomingData()137 void SSLSocket::handleIncomingData() {
138 mRecvPending = false;
139
140 uint8_t buffer[1024];
141 ssize_t len;
142 do {
143 len = ::recv(fd(), buffer, sizeof(buffer), 0);
144 } while (len < 0 && errno == EINTR);
145
146 if (len <= 0) {
147 mEOS = true;
148 mFinalErrno = (len < 0) ? errno : 0;
149
150 sendRecvCallback();
151 return;
152 }
153
154 size_t offset = 0;
155 while (len > 0) {
156 int n = BIO_write(mBioR, &buffer[offset], len);
157 CHECK_GT(n, 0);
158
159 offset += n;
160 len -= n;
161
162 if (!SSL_is_init_finished(mSSL.get())) {
163 if (mMode == Mode::ACCEPT) {
164 n = SSL_accept(mSSL.get());
165 } else {
166 n = SSL_connect(mSSL.get());
167 }
168
169 auto err = SSL_get_error(mSSL.get(), n);
170
171 switch (err) {
172 case SSL_ERROR_WANT_READ:
173 {
174 CHECK_EQ(len, 0);
175 queueOutputDataFromSSL();
176
177 mRecvPending = true;
178
179 runLoop()->postSocketRecv(
180 fd(),
181 makeSafeCallback(
182 this, &SSLSocket::handleIncomingData));
183
184 return;
185 }
186
187 case SSL_ERROR_WANT_WRITE:
188 {
189 CHECK_EQ(len, 0);
190
191 mRecvPending = true;
192
193 runLoop()->postSocketRecv(
194 fd(),
195 makeSafeCallback(
196 this, &SSLSocket::handleIncomingData));
197
198 return;
199 }
200
201 case SSL_ERROR_NONE:
202 break;
203
204 case SSL_ERROR_SYSCALL:
205 default:
206 {
207 // This is where we end up if the client doesn't trust us.
208 mEOS = true;
209 mFinalErrno = ECONNREFUSED;
210
211 sendRecvCallback();
212 return;
213 }
214 }
215
216 CHECK(SSL_is_init_finished(mSSL.get()));
217
218 drainOutputBufferPlain();
219
220 if (!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)
221 && !isPeerCertificateValid()) {
222 mEOS = true;
223 mFinalErrno = ECONNREFUSED;
224 sendRecvCallback();
225 }
226 }
227 }
228
229 int n = SSL_peek(mSSL.get(), buffer, sizeof(buffer));
230
231 if (n > 0) {
232 sendRecvCallback();
233 return;
234 }
235
236 auto err = SSL_get_error(mSSL.get(), n);
237
238 switch (err) {
239 case SSL_ERROR_WANT_READ:
240 {
241 queueOutputDataFromSSL();
242
243 mRecvPending = true;
244
245 runLoop()->postSocketRecv(
246 fd(),
247 makeSafeCallback(this, &SSLSocket::handleIncomingData));
248
249 break;
250 }
251
252 case SSL_ERROR_WANT_WRITE:
253 {
254 mRecvPending = true;
255
256 runLoop()->postSocketRecv(
257 fd(),
258 makeSafeCallback(this, &SSLSocket::handleIncomingData));
259
260 break;
261 }
262
263 case SSL_ERROR_ZERO_RETURN:
264 {
265 mEOS = true;
266 mFinalErrno = 0;
267
268 sendRecvCallback();
269 break;
270 }
271
272 case SSL_ERROR_NONE:
273 break;
274
275 case SSL_ERROR_SYSCALL:
276 default:
277 {
278 // This is where we end up if the client doesn't trust us.
279 mEOS = true;
280 mFinalErrno = ECONNREFUSED;
281
282 sendRecvCallback();
283 break;
284 }
285 }
286 }
287
sendRecvCallback()288 void SSLSocket::sendRecvCallback() {
289 const auto cb = mRecvCallback;
290 mRecvCallback = nullptr;
291 if (cb != nullptr) {
292 cb();
293 }
294 }
295
postSend(RunLoop::AsyncFunction fn)296 void SSLSocket::postSend(RunLoop::AsyncFunction fn) {
297 runLoop()->post(fn);
298 }
299
recvfrom(void * data,size_t size,sockaddr * address,socklen_t * addressLen)300 ssize_t SSLSocket::recvfrom(
301 void *data,
302 size_t size,
303 sockaddr *address,
304 socklen_t *addressLen) {
305 if (address || addressLen) {
306 errno = EINVAL;
307 return -1;
308 }
309
310 if (mEOS) {
311 errno = mFinalErrno;
312 return (mFinalErrno == 0) ? 0 : -1;
313 }
314
315 int n = SSL_read(mSSL.get(), data, size);
316
317 // We should only get here after SSL_peek signaled that there's data to
318 // be read.
319 CHECK_GT(n, 0);
320
321 return n;
322 }
323
queueOutputDataFromSSL()324 void SSLSocket::queueOutputDataFromSSL() {
325 int n;
326 do {
327 char buf[1024];
328 n = BIO_read(mBioW, buf, sizeof(buf));
329
330 if (n > 0) {
331 queueOutputData(buf, n);
332 } else if (BIO_should_retry(mBioW)) {
333 continue;
334 } else {
335 LOG(FATAL) << "Should not be here.";
336 }
337 } while (n > 0);
338 }
339
queueOutputData(const void * data,size_t size)340 void SSLSocket::queueOutputData(const void *data, size_t size) {
341 if (!size) {
342 return;
343 }
344
345 const size_t pos = mOutBuffer.size();
346 mOutBuffer.resize(pos + size);
347 memcpy(mOutBuffer.data() + pos, data, size);
348
349 if (!mSendPending) {
350 mSendPending = true;
351 runLoop()->postSocketSend(
352 fd(),
353 makeSafeCallback(this, &SSLSocket::sendOutputData));
354 }
355 }
356
sendOutputData()357 void SSLSocket::sendOutputData() {
358 mSendPending = false;
359
360 const size_t size = mOutBuffer.size();
361 size_t offset = 0;
362
363 while (offset < size) {
364 ssize_t n = ::send(
365 fd(), mOutBuffer.data() + offset, size - offset, 0);
366
367 if (n < 0) {
368 if (errno == EINTR) {
369 continue;
370 } else if (errno == EAGAIN || errno == EWOULDBLOCK) {
371 break;
372 }
373
374 LOG(FATAL) << "Should not be here.";
375 }
376
377 offset += static_cast<size_t>(n);
378 }
379
380 mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
381
382 if (!mOutBufferPlain.empty()) {
383 drainOutputBufferPlain();
384 }
385
386 if (!mOutBuffer.empty()) {
387 mSendPending = true;
388 runLoop()->postSocketSend(
389 fd(),
390 makeSafeCallback(this, &SSLSocket::sendOutputData));
391
392 return;
393 }
394
395 auto fn = mFlushFn;
396 mFlushFn = nullptr;
397 if (fn != nullptr) {
398 fn();
399 }
400 }
401
sendto(const void * data,size_t size,const sockaddr * addr,socklen_t addrLen)402 ssize_t SSLSocket::sendto(
403 const void *data,
404 size_t size,
405 const sockaddr *addr,
406 socklen_t addrLen) {
407 if (addr || addrLen) {
408 errno = -EINVAL;
409 return -1;
410 }
411
412 if (mEOS) {
413 errno = mFinalErrno;
414 return (mFinalErrno == 0) ? 0 : -1;
415 }
416
417 const size_t pos = mOutBufferPlain.size();
418 mOutBufferPlain.resize(pos + size);
419 memcpy(&mOutBufferPlain[pos], data, size);
420
421 drainOutputBufferPlain();
422
423 return size;
424 }
425
drainOutputBufferPlain()426 void SSLSocket::drainOutputBufferPlain() {
427 size_t offset = 0;
428 const size_t size = mOutBufferPlain.size();
429
430 while (offset < size) {
431 int n = SSL_write(mSSL.get(), &mOutBufferPlain[offset], size - offset);
432
433 if (!SSL_is_init_finished(mSSL.get())) {
434 if (mMode == Mode::ACCEPT) {
435 n = SSL_accept(mSSL.get());
436 } else {
437 n = SSL_connect(mSSL.get());
438 }
439
440 auto err = SSL_get_error(mSSL.get(), n);
441
442 switch (err) {
443 case SSL_ERROR_WANT_WRITE:
444 {
445 mOutBufferPlain.erase(
446 mOutBufferPlain.begin(),
447 mOutBufferPlain.begin() + offset);
448
449 queueOutputDataFromSSL();
450 return;
451 }
452
453 case SSL_ERROR_WANT_READ:
454 {
455 mOutBufferPlain.erase(
456 mOutBufferPlain.begin(),
457 mOutBufferPlain.begin() + offset);
458
459 queueOutputDataFromSSL();
460
461 if (!mRecvPending) {
462 mRecvPending = true;
463
464 runLoop()->postSocketRecv(
465 fd(),
466 makeSafeCallback(
467 this, &SSLSocket::handleIncomingData));
468 }
469 return;
470 }
471
472 case SSL_ERROR_SYSCALL:
473 {
474 // This is where we end up if the client doesn't trust us.
475 mEOS = true;
476 mFinalErrno = ECONNREFUSED;
477
478 LOG(FATAL) << "Should not be here.";
479 return;
480 }
481
482 case SSL_ERROR_NONE:
483 break;
484
485 default:
486 LOG(FATAL) << "Should not be here.";
487 }
488
489 CHECK(SSL_is_init_finished(mSSL.get()));
490
491 if (!isPeerCertificateValid()) {
492 mEOS = true;
493 mFinalErrno = ECONNREFUSED;
494 sendRecvCallback();
495 }
496 }
497
498 offset += n;
499 }
500
501 mOutBufferPlain.erase(
502 mOutBufferPlain.begin(), mOutBufferPlain.begin() + offset);
503
504 queueOutputDataFromSSL();
505 }
506
isPeerCertificateValid()507 bool SSLSocket::isPeerCertificateValid() {
508 if (mMode == Mode::ACCEPT || (mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)) {
509 // For now we won't validate the client if we are the server.
510 return true;
511 }
512
513 std::unique_ptr<X509, std::function<void(X509 *)>> cert(
514 SSL_get_peer_certificate(mSSL.get()), X509_free);
515
516 if (!cert) {
517 LOG(ERROR) << "SSLSocket::isPeerCertificateValid no certificate.";
518
519 return false;
520 }
521
522 int res = SSL_get_verify_result(mSSL.get());
523
524 bool valid = (res == X509_V_OK);
525
526 if (!valid) {
527 LOG(ERROR) << "SSLSocket::isPeerCertificateValid invalid certificate.";
528
529 const EVP_MD *digest = EVP_get_digestbyname("sha256");
530
531 unsigned char md[EVP_MAX_MD_SIZE];
532 unsigned int n;
533 int res = X509_digest(cert.get(), digest, md, &n);
534 CHECK_EQ(res, 1);
535
536 std::stringstream ss;
537 for (unsigned int i = 0; i < n; ++i) {
538 if (i > 0) {
539 ss << ":";
540 }
541
542 auto byte = md[i];
543
544 auto nibble = byte >> 4;
545 ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
546
547 nibble = byte & 0x0f;
548 ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
549 }
550
551 LOG(ERROR)
552 << "Server offered certificate w/ fingerprint "
553 << ss.str();
554 }
555
556 return valid;
557 }
558
postFlush(RunLoop::AsyncFunction fn)559 void SSLSocket::postFlush(RunLoop::AsyncFunction fn) {
560 CHECK(mFlushFn == nullptr);
561
562 if (!mSendPending) {
563 fn();
564 return;
565 }
566
567 mFlushFn = fn;
568 }
569
570