• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <https/SSLSocket.h>
18 
19 #include <https/SafeCallbackable.h>
20 #include <https/Support.h>
21 #include <glog/logging.h>
22 #include <sstream>
23 #include <sys/socket.h>
24 
25 // static
Init()26 void SSLSocket::Init() {
27     SSL_library_init();
28     SSL_load_error_strings();
29 }
30 
31 // static
CreateSSLContext()32 SSL_CTX *SSLSocket::CreateSSLContext() {
33     SSL_CTX *ctx = SSL_CTX_new(SSLv23_method());
34 
35      /* Recommended to avoid SSLv2 & SSLv3 */
36      SSL_CTX_set_options(
37             ctx, SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
38 
39     return ctx;
40 }
41 
SSLSocket(std::shared_ptr<RunLoop> rl,Mode mode,int sock,uint32_t flags)42 SSLSocket::SSLSocket(
43         std::shared_ptr<RunLoop> rl, Mode mode, int sock, uint32_t flags)
44     : BufferedSocket(rl, sock),
45       mMode(mode),
46       mFlags(flags),
47       mCtx(CreateSSLContext(), SSL_CTX_free),
48       mSSL(SSL_new(mCtx.get()), SSL_free),
49       mBioR(BIO_new(BIO_s_mem())),
50       mBioW(BIO_new(BIO_s_mem())),
51       mEOS(false),
52       mFinalErrno(0),
53       mRecvPending(false),
54       mRecvCallback(nullptr),
55       mSendPending(false),
56       mFlushFn(nullptr) {
57     if (mMode == Mode::ACCEPT) {
58         SSL_set_accept_state(mSSL.get());
59     } else {
60         SSL_set_connect_state(mSSL.get());
61     }
62     SSL_set_bio(mSSL.get(), mBioR, mBioW);
63 }
64 
useCertificate(const std::string & path)65 bool SSLSocket::useCertificate(const std::string &path) {
66     return 1 == SSL_use_certificate_file(
67                 mSSL.get(), path.c_str(), SSL_FILETYPE_PEM);
68 }
69 
usePrivateKey(const std::string & path)70 bool SSLSocket::usePrivateKey(const std::string &path) {
71     return 1 == SSL_use_PrivateKey_file(
72             mSSL.get(), path.c_str(), SSL_FILETYPE_PEM)
73         && 1 == SSL_check_private_key(mSSL.get());
74 }
75 
useTrustedCertificates(const std::string & path)76 bool SSLSocket::useTrustedCertificates(const std::string &path) {
77     return 1 == SSL_CTX_load_verify_locations(
78             mCtx.get(),
79             path.c_str(),
80             nullptr /* CApath */);
81 }
82 
SSLSocket(std::shared_ptr<RunLoop> rl,int sock,const std::string & certificate_pem_path,const std::string & private_key_pem_path,uint32_t flags)83 SSLSocket::SSLSocket(
84         std::shared_ptr<RunLoop> rl,
85         int sock,
86         const std::string &certificate_pem_path,
87         const std::string &private_key_pem_path,
88         uint32_t flags)
89     : SSLSocket(rl, Mode::ACCEPT, sock, flags) {
90 
91     // This flag makes no sense for a server.
92     CHECK(!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE));
93 
94     CHECK(useCertificate(certificate_pem_path)
95             && usePrivateKey(private_key_pem_path));
96 }
97 
SSLSocket(std::shared_ptr<RunLoop> rl,int sock,uint32_t flags,const std::optional<std::string> & trusted_pem_path)98 SSLSocket::SSLSocket(
99         std::shared_ptr<RunLoop> rl,
100         int sock,
101         uint32_t flags,
102         const std::optional<std::string> &trusted_pem_path)
103     : SSLSocket(rl, Mode::CONNECT, sock, flags) {
104 
105     if (!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)) {
106         CHECK(trusted_pem_path.has_value());
107         CHECK(useTrustedCertificates(*trusted_pem_path));
108     }
109 }
110 
~SSLSocket()111 SSLSocket::~SSLSocket() {
112     SSL_shutdown(mSSL.get());
113 
114     mBioW = mBioR = nullptr;
115 }
116 
postRecv(RunLoop::AsyncFunction fn)117 void SSLSocket::postRecv(RunLoop::AsyncFunction fn) {
118     char tmp[128];
119     int n = SSL_peek(mSSL.get(), tmp, sizeof(tmp));
120 
121     if (n > 0) {
122         fn();
123         return;
124     }
125 
126     CHECK(mRecvCallback == nullptr);
127     mRecvCallback = fn;
128 
129     if (!mRecvPending) {
130         mRecvPending = true;
131         runLoop()->postSocketRecv(
132                 fd(),
133                 makeSafeCallback(this, &SSLSocket::handleIncomingData));
134     }
135 }
136 
handleIncomingData()137 void SSLSocket::handleIncomingData() {
138     mRecvPending = false;
139 
140     uint8_t buffer[1024];
141     ssize_t len;
142     do {
143         len = ::recv(fd(), buffer, sizeof(buffer), 0);
144     } while (len < 0 && errno == EINTR);
145 
146     if (len <= 0) {
147         mEOS = true;
148         mFinalErrno = (len < 0) ? errno : 0;
149 
150         sendRecvCallback();
151         return;
152     }
153 
154     size_t offset = 0;
155     while (len > 0) {
156         int n = BIO_write(mBioR, &buffer[offset], len);
157         CHECK_GT(n, 0);
158 
159         offset += n;
160         len -= n;
161 
162         if (!SSL_is_init_finished(mSSL.get())) {
163             if (mMode == Mode::ACCEPT) {
164                 n = SSL_accept(mSSL.get());
165             } else {
166                 n = SSL_connect(mSSL.get());
167             }
168 
169             auto err = SSL_get_error(mSSL.get(), n);
170 
171             switch (err) {
172                 case SSL_ERROR_WANT_READ:
173                 {
174                     CHECK_EQ(len, 0);
175                     queueOutputDataFromSSL();
176 
177                     mRecvPending = true;
178 
179                     runLoop()->postSocketRecv(
180                             fd(),
181                             makeSafeCallback(
182                                 this, &SSLSocket::handleIncomingData));
183 
184                     return;
185                 }
186 
187                 case SSL_ERROR_WANT_WRITE:
188                 {
189                     CHECK_EQ(len, 0);
190 
191                     mRecvPending = true;
192 
193                     runLoop()->postSocketRecv(
194                             fd(),
195                             makeSafeCallback(
196                                 this, &SSLSocket::handleIncomingData));
197 
198                     return;
199                 }
200 
201                 case SSL_ERROR_NONE:
202                     break;
203 
204                 case SSL_ERROR_SYSCALL:
205                 default:
206                 {
207                     // This is where we end up if the client doesn't trust us.
208                     mEOS = true;
209                     mFinalErrno = ECONNREFUSED;
210 
211                     sendRecvCallback();
212                     return;
213                 }
214             }
215 
216             CHECK(SSL_is_init_finished(mSSL.get()));
217 
218             drainOutputBufferPlain();
219 
220             if (!(mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)
221                     && !isPeerCertificateValid()) {
222                 mEOS = true;
223                 mFinalErrno = ECONNREFUSED;
224                 sendRecvCallback();
225             }
226         }
227     }
228 
229     int n = SSL_peek(mSSL.get(), buffer, sizeof(buffer));
230 
231     if (n > 0) {
232         sendRecvCallback();
233         return;
234     }
235 
236     auto err = SSL_get_error(mSSL.get(), n);
237 
238     switch (err) {
239         case SSL_ERROR_WANT_READ:
240         {
241             queueOutputDataFromSSL();
242 
243             mRecvPending = true;
244 
245             runLoop()->postSocketRecv(
246                     fd(),
247                     makeSafeCallback(this, &SSLSocket::handleIncomingData));
248 
249             break;
250         }
251 
252         case SSL_ERROR_WANT_WRITE:
253         {
254             mRecvPending = true;
255 
256             runLoop()->postSocketRecv(
257                     fd(),
258                     makeSafeCallback(this, &SSLSocket::handleIncomingData));
259 
260             break;
261         }
262 
263         case SSL_ERROR_ZERO_RETURN:
264         {
265             mEOS = true;
266             mFinalErrno = 0;
267 
268             sendRecvCallback();
269             break;
270         }
271 
272         case SSL_ERROR_NONE:
273             break;
274 
275         case SSL_ERROR_SYSCALL:
276         default:
277         {
278             // This is where we end up if the client doesn't trust us.
279             mEOS = true;
280             mFinalErrno = ECONNREFUSED;
281 
282             sendRecvCallback();
283             break;
284         }
285     }
286 }
287 
sendRecvCallback()288 void SSLSocket::sendRecvCallback() {
289     const auto cb = mRecvCallback;
290     mRecvCallback = nullptr;
291     if (cb != nullptr) {
292         cb();
293     }
294 }
295 
postSend(RunLoop::AsyncFunction fn)296 void SSLSocket::postSend(RunLoop::AsyncFunction fn) {
297     runLoop()->post(fn);
298 }
299 
recvfrom(void * data,size_t size,sockaddr * address,socklen_t * addressLen)300 ssize_t SSLSocket::recvfrom(
301         void *data,
302         size_t size,
303         sockaddr *address,
304         socklen_t *addressLen) {
305     if (address || addressLen) {
306         errno = EINVAL;
307         return -1;
308     }
309 
310     if (mEOS) {
311         errno = mFinalErrno;
312         return (mFinalErrno == 0) ? 0 : -1;
313     }
314 
315     int n = SSL_read(mSSL.get(), data, size);
316 
317     // We should only get here after SSL_peek signaled that there's data to
318     // be read.
319     CHECK_GT(n, 0);
320 
321     return n;
322 }
323 
queueOutputDataFromSSL()324 void SSLSocket::queueOutputDataFromSSL() {
325     int n;
326     do {
327         char buf[1024];
328         n = BIO_read(mBioW, buf, sizeof(buf));
329 
330         if (n > 0) {
331             queueOutputData(buf, n);
332         } else if (BIO_should_retry(mBioW)) {
333             continue;
334         } else {
335             LOG(FATAL) << "Should not be here.";
336         }
337     } while (n > 0);
338 }
339 
queueOutputData(const void * data,size_t size)340 void SSLSocket::queueOutputData(const void *data, size_t size) {
341     if (!size) {
342         return;
343     }
344 
345     const size_t pos = mOutBuffer.size();
346     mOutBuffer.resize(pos + size);
347     memcpy(mOutBuffer.data() + pos, data, size);
348 
349     if (!mSendPending) {
350         mSendPending = true;
351         runLoop()->postSocketSend(
352                 fd(),
353                 makeSafeCallback(this, &SSLSocket::sendOutputData));
354     }
355 }
356 
sendOutputData()357 void SSLSocket::sendOutputData() {
358     mSendPending = false;
359 
360     const size_t size = mOutBuffer.size();
361     size_t offset = 0;
362 
363     while (offset < size) {
364         ssize_t n = ::send(
365                 fd(), mOutBuffer.data() + offset, size - offset, 0);
366 
367         if (n < 0) {
368             if (errno == EINTR) {
369                 continue;
370             } else if (errno == EAGAIN || errno == EWOULDBLOCK) {
371                 break;
372             }
373 
374             LOG(FATAL) << "Should not be here.";
375         }
376 
377         offset += static_cast<size_t>(n);
378     }
379 
380     mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
381 
382     if (!mOutBufferPlain.empty()) {
383         drainOutputBufferPlain();
384     }
385 
386     if (!mOutBuffer.empty()) {
387         mSendPending = true;
388         runLoop()->postSocketSend(
389                 fd(),
390                 makeSafeCallback(this, &SSLSocket::sendOutputData));
391 
392         return;
393     }
394 
395     auto fn = mFlushFn;
396     mFlushFn = nullptr;
397     if (fn != nullptr) {
398         fn();
399     }
400 }
401 
sendto(const void * data,size_t size,const sockaddr * addr,socklen_t addrLen)402 ssize_t SSLSocket::sendto(
403         const void *data,
404         size_t size,
405         const sockaddr *addr,
406         socklen_t addrLen) {
407     if (addr || addrLen) {
408         errno = -EINVAL;
409         return -1;
410     }
411 
412     if (mEOS) {
413         errno = mFinalErrno;
414         return (mFinalErrno == 0) ? 0 : -1;
415     }
416 
417     const size_t pos = mOutBufferPlain.size();
418     mOutBufferPlain.resize(pos + size);
419     memcpy(&mOutBufferPlain[pos], data, size);
420 
421     drainOutputBufferPlain();
422 
423     return size;
424 }
425 
drainOutputBufferPlain()426 void SSLSocket::drainOutputBufferPlain() {
427     size_t offset = 0;
428     const size_t size = mOutBufferPlain.size();
429 
430     while (offset < size) {
431         int n = SSL_write(mSSL.get(), &mOutBufferPlain[offset], size - offset);
432 
433         if (!SSL_is_init_finished(mSSL.get())) {
434             if (mMode == Mode::ACCEPT) {
435                 n = SSL_accept(mSSL.get());
436             } else {
437                 n = SSL_connect(mSSL.get());
438             }
439 
440             auto err = SSL_get_error(mSSL.get(), n);
441 
442             switch (err) {
443                 case SSL_ERROR_WANT_WRITE:
444                 {
445                     mOutBufferPlain.erase(
446                             mOutBufferPlain.begin(),
447                             mOutBufferPlain.begin() + offset);
448 
449                     queueOutputDataFromSSL();
450                     return;
451                 }
452 
453                 case SSL_ERROR_WANT_READ:
454                 {
455                     mOutBufferPlain.erase(
456                             mOutBufferPlain.begin(),
457                             mOutBufferPlain.begin() + offset);
458 
459                     queueOutputDataFromSSL();
460 
461                     if (!mRecvPending) {
462                         mRecvPending = true;
463 
464                         runLoop()->postSocketRecv(
465                                 fd(),
466                                 makeSafeCallback(
467                                     this, &SSLSocket::handleIncomingData));
468                     }
469                     return;
470                 }
471 
472                 case SSL_ERROR_SYSCALL:
473                 {
474                     // This is where we end up if the client doesn't trust us.
475                     mEOS = true;
476                     mFinalErrno = ECONNREFUSED;
477 
478                     LOG(FATAL) << "Should not be here.";
479                     return;
480                 }
481 
482                 case SSL_ERROR_NONE:
483                     break;
484 
485                 default:
486                     LOG(FATAL) << "Should not be here.";
487             }
488 
489             CHECK(SSL_is_init_finished(mSSL.get()));
490 
491             if (!isPeerCertificateValid()) {
492                 mEOS = true;
493                 mFinalErrno = ECONNREFUSED;
494                 sendRecvCallback();
495             }
496         }
497 
498         offset += n;
499     }
500 
501     mOutBufferPlain.erase(
502             mOutBufferPlain.begin(), mOutBufferPlain.begin() + offset);
503 
504     queueOutputDataFromSSL();
505 }
506 
isPeerCertificateValid()507 bool SSLSocket::isPeerCertificateValid() {
508     if (mMode == Mode::ACCEPT || (mFlags & FLAG_DONT_CHECK_PEER_CERTIFICATE)) {
509         // For now we won't validate the client if we are the server.
510         return true;
511     }
512 
513     std::unique_ptr<X509, std::function<void(X509 *)>> cert(
514             SSL_get_peer_certificate(mSSL.get()), X509_free);
515 
516     if (!cert) {
517         LOG(ERROR) << "SSLSocket::isPeerCertificateValid no certificate.";
518 
519         return false;
520     }
521 
522     int res = SSL_get_verify_result(mSSL.get());
523 
524     bool valid = (res == X509_V_OK);
525 
526     if (!valid) {
527         LOG(ERROR) << "SSLSocket::isPeerCertificateValid invalid certificate.";
528 
529         const EVP_MD *digest = EVP_get_digestbyname("sha256");
530 
531         unsigned char md[EVP_MAX_MD_SIZE];
532         unsigned int n;
533         int res = X509_digest(cert.get(), digest, md, &n);
534         CHECK_EQ(res, 1);
535 
536         std::stringstream ss;
537         for (unsigned int i = 0; i < n; ++i) {
538             if (i > 0) {
539                 ss << ":";
540             }
541 
542             auto byte = md[i];
543 
544             auto nibble = byte >> 4;
545             ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
546 
547             nibble = byte & 0x0f;
548             ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
549         }
550 
551         LOG(ERROR)
552             << "Server offered certificate w/ fingerprint "
553             << ss.str();
554     }
555 
556     return valid;
557 }
558 
postFlush(RunLoop::AsyncFunction fn)559 void SSLSocket::postFlush(RunLoop::AsyncFunction fn) {
560     CHECK(mFlushFn == nullptr);
561 
562     if (!mSendPending) {
563         fn();
564         return;
565     }
566 
567     mFlushFn = fn;
568 }
569 
570