• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <webrtc/DTLS.h>
18 
19 #include <webrtc/RTPSocketHandler.h>
20 
21 #include <https/SafeCallbackable.h>
22 #include <https/SSLSocket.h>
23 #include <https/Support.h>
24 
25 #include <android-base/logging.h>
26 
27 #include <sys/socket.h>
28 #include <unistd.h>
29 
30 #include <sstream>
31 
32 static int gDTLSInstanceIndex;
33 
34 // static
Init()35 void DTLS::Init() {
36     SSL_library_init();
37     SSL_load_error_strings();
38     OpenSSL_add_ssl_algorithms();
39 
40     auto err = srtp_init();
41     CHECK_EQ(err, srtp_err_status_ok);
42 
43     gDTLSInstanceIndex = SSL_get_ex_new_index(
44             0, const_cast<char *>("DTLSInstance index"), NULL, NULL, NULL);
45 
46 }
47 
useCertificate(std::shared_ptr<X509> cert)48 bool DTLS::useCertificate(std::shared_ptr<X509> cert) {
49     // I'm assuming that ownership of the certificate is transferred, so I'm
50     // adding an extra reference...
51     CHECK_EQ(1, X509_up_ref(cert.get()));
52 
53     return cert != nullptr && 1 == SSL_CTX_use_certificate(mCtx, cert.get());
54 }
55 
usePrivateKey(std::shared_ptr<EVP_PKEY> key)56 bool DTLS::usePrivateKey(std::shared_ptr<EVP_PKEY> key) {
57     // I'm assuming that ownership of the key in SSL_CTX_use_PrivateKey is
58     // transferred, so I'm adding an extra reference...
59     CHECK_EQ(1, EVP_PKEY_up_ref(key.get()));
60 
61     return key != nullptr
62         && 1 == SSL_CTX_use_PrivateKey(mCtx, key.get())
63         && 1 == SSL_CTX_check_private_key(mCtx);
64 }
65 
DTLS(std::shared_ptr<RTPSocketHandler> handler,DTLS::Mode mode,std::shared_ptr<X509> cert,std::shared_ptr<EVP_PKEY> key,const std::string & remoteFingerprint,bool useSRTP)66 DTLS::DTLS(
67         std::shared_ptr<RTPSocketHandler> handler,
68         DTLS::Mode mode,
69         std::shared_ptr<X509> cert,
70         std::shared_ptr<EVP_PKEY> key,
71         const std::string &remoteFingerprint,
72         bool useSRTP)
73     : mState(State::UNINITIALIZED),
74       mHandler(handler),
75       mMode(mode),
76       mRemoteFingerprint(remoteFingerprint),
77       mUseSRTP(useSRTP),
78       mCtx(nullptr),
79       mSSL(nullptr),
80       mBioR(nullptr),
81       mBioW(nullptr),
82       mSRTPInbound(nullptr),
83       mSRTPOutbound(nullptr) {
84     mCtx = SSL_CTX_new(DTLSv1_2_method());
85     CHECK(mCtx);
86 
87     int result = SSL_CTX_set_cipher_list(
88             mCtx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
89 
90     CHECK_EQ(result, 1);
91 
92     SSL_CTX_set_verify(
93             mCtx,
94             SSL_VERIFY_PEER
95                 | SSL_VERIFY_CLIENT_ONCE
96                 | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
97             &DTLS::OnVerifyPeerCertificate);
98 
99     CHECK(useCertificate(cert));
100     CHECK(usePrivateKey(key));
101 
102     if (mUseSRTP) {
103         result = SSL_CTX_set_tlsext_use_srtp(mCtx, "SRTP_AES128_CM_SHA1_80");
104         CHECK_EQ(result, 0);
105     }
106 
107     mSSL = SSL_new(mCtx);
108     CHECK(mSSL);
109 
110     SSL_set_ex_data(mSSL, gDTLSInstanceIndex, this);
111 
112     mBioR = BIO_new(BIO_s_mem());
113     CHECK(mBioR);
114 
115     mBioW = BIO_new(BIO_s_mem());
116     CHECK(mBioW);
117 
118     SSL_set_bio(mSSL, mBioR, mBioW);
119 
120     if (mode == Mode::CONNECT) {
121         SSL_set_connect_state(mSSL);
122     } else {
123         SSL_set_accept_state(mSSL);
124     }
125 }
126 
~DTLS()127 DTLS::~DTLS() {
128     if (mSRTPOutbound) {
129         srtp_dealloc(mSRTPOutbound);
130         mSRTPOutbound = nullptr;
131     }
132 
133     if (mSRTPInbound) {
134         srtp_dealloc(mSRTPInbound);
135         mSRTPInbound = nullptr;
136     }
137 
138     if (mSSL) {
139         SSL_shutdown(mSSL);
140     }
141 
142     SSL_free(mSSL);
143     mSSL = nullptr;
144 
145     mBioW = mBioR = nullptr;
146 
147     SSL_CTX_free(mCtx);
148     mCtx = nullptr;
149 }
150 
151 // static
OnVerifyPeerCertificate(int,X509_STORE_CTX * ctx)152 int DTLS::OnVerifyPeerCertificate(int /* ok */, X509_STORE_CTX *ctx) {
153     LOG(VERBOSE) << "OnVerifyPeerCertificate";
154 
155     SSL *ssl = static_cast<SSL *>(X509_STORE_CTX_get_ex_data(
156             ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
157 
158     DTLS *me = static_cast<DTLS *>(SSL_get_ex_data(ssl, gDTLSInstanceIndex));
159 
160     std::unique_ptr<X509, std::function<void(X509 *)>> cert(
161             SSL_get_peer_certificate(ssl), X509_free);
162 
163     if (!cert) {
164         LOG(ERROR) << "SSLSocket::isPeerCertificateValid no certificate.";
165 
166         return 0;
167     }
168 
169     auto spacePos = me->mRemoteFingerprint.find(' ');
170     CHECK(spacePos != std::string::npos);
171     auto digestName = me->mRemoteFingerprint.substr(0, spacePos);
172     CHECK(!strcasecmp(digestName.c_str(), "sha-256"));
173 
174     const EVP_MD *digest = EVP_get_digestbyname("sha256");
175 
176     unsigned char md[EVP_MAX_MD_SIZE];
177     unsigned int n;
178     int res = X509_digest(cert.get(), digest, md, &n);
179     CHECK_EQ(res, 1);
180 
181     std::stringstream ss;
182     for (unsigned int i = 0; i < n; ++i) {
183         if (i > 0) {
184             ss << ":";
185         }
186 
187         auto byte = md[i];
188 
189         auto nibble = byte >> 4;
190         ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
191 
192         nibble = byte & 0x0f;
193         ss << (char)((nibble < 10) ? ('0' + nibble) : ('A' + nibble - 10));
194     }
195 
196     LOG(VERBOSE)
197         << "Client offered certificate w/ fingerprint "
198         << ss.str();
199 
200     LOG(VERBOSE) << "should be: " << me->mRemoteFingerprint;
201 
202     auto remoteFingerprintHash = me->mRemoteFingerprint.substr(spacePos + 1);
203     bool match = !strcasecmp(remoteFingerprintHash.c_str(), ss.str().c_str());
204 
205     if (!match) {
206         LOG(ERROR)
207             << "The peer's certificate's fingerprint does not match that "
208             << "published in the SDP!";
209     }
210 
211     return match;
212 }
213 
connect(const sockaddr_storage & remoteAddr)214 void DTLS::connect(const sockaddr_storage &remoteAddr) {
215     CHECK_EQ(static_cast<int>(mState), static_cast<int>(State::UNINITIALIZED));
216 
217     mRemoteAddr = remoteAddr;
218     mState = State::CONNECTING;
219 
220     tryConnecting();
221 }
222 
doTheThing(int res)223 void DTLS::doTheThing(int res) {
224     LOG(VERBOSE) << "doTheThing(" << res << ")";
225 
226     int err = SSL_get_error(mSSL, res);
227 
228     switch (err) {
229         case SSL_ERROR_WANT_READ:
230         {
231             LOG(VERBOSE) << "SSL_ERROR_WANT_READ";
232 
233             queueOutputDataFromDTLS();
234             break;
235         }
236 
237         case SSL_ERROR_WANT_WRITE:
238         {
239             LOG(VERBOSE) << "SSL_ERROR_WANT_WRITE";
240             break;
241         }
242 
243         case SSL_ERROR_NONE:
244         {
245             LOG(VERBOSE) << "SSL_ERROR_NONE";
246             break;
247         }
248 
249         case SSL_ERROR_SYSCALL:
250         default:
251         {
252             LOG(ERROR)
253                 << "DTLS stack returned error "
254                 << err
255                 << " ("
256                 << SSL_state_string_long(mSSL)
257                 << ")";
258         }
259     }
260 }
261 
queueOutputDataFromDTLS()262 void DTLS::queueOutputDataFromDTLS() {
263     auto handler = mHandler.lock();
264 
265     if (!handler) {
266         return;
267     }
268 
269     int n;
270 
271     do {
272         char buf[RTPSocketHandler::kMaxUDPPayloadSize];
273         n = BIO_read(mBioW, buf, sizeof(buf));
274 
275         if (n > 0) {
276             LOG(VERBOSE) << "queueing " << n << " bytes of output data from DTLS.";
277 
278             handler->queueDatagram(
279                     mRemoteAddr, buf, static_cast<size_t>(n));
280         } else if (BIO_should_retry(mBioW)) {
281             continue;
282         } else {
283             CHECK(!"Should not be here");
284         }
285     } while (n > 0);
286 }
287 
tryConnecting()288 void DTLS::tryConnecting() {
289     CHECK_EQ(static_cast<int>(mState), static_cast<int>(State::CONNECTING));
290 
291     int res =
292         (mMode == Mode::CONNECT)
293             ? SSL_connect(mSSL) : SSL_accept(mSSL);
294 
295     if (res != 1) {
296         doTheThing(res);
297     } else {
298         queueOutputDataFromDTLS();
299 
300         LOG(INFO) << "DTLS connection established.";
301         mState = State::CONNECTED;
302 
303         auto handler = mHandler.lock();
304         if (handler) {
305             if (mUseSRTP) {
306                 getKeyingMaterial();
307             }
308 
309             handler->notifyDTLSConnected();
310         }
311     }
312 }
313 
inject(const uint8_t * data,size_t size)314 void DTLS::inject(const uint8_t *data, size_t size) {
315     LOG(VERBOSE) << "injecting " << size << " bytes into DTLS stack.";
316 
317     auto n = BIO_write(mBioR, data, size);
318     CHECK_EQ(n, static_cast<int>(size));
319 
320     if (mState == State::CONNECTING) {
321         if (!SSL_is_init_finished(mSSL)) {
322             tryConnecting();
323         }
324     }
325 }
326 
getKeyingMaterial()327 void DTLS::getKeyingMaterial() {
328     static constexpr char kLabel[] = "EXTRACTOR-dtls_srtp";
329 
330     // These correspond to the chosen option SRTP_AES128_CM_SHA1_80, passed
331     // to SSL_CTX_set_tlsext_use_srtp before. c/f RFC 5764 4.1.2
332 
333     uint8_t material[(SRTP_AES_128_KEY_LEN + SRTP_SALT_LEN) * 2];
334 
335     auto res = SSL_export_keying_material(
336             mSSL,
337             material,
338             sizeof(material),
339             kLabel,
340             strlen(kLabel),
341             nullptr /* context */,
342             0 /* contextlen */,
343             0 /* use_context */);
344 
345     CHECK_EQ(res, 1);
346 
347     // LOG(INFO) << "keying material:";
348     // hexdump(material, sizeof(material));
349 
350     size_t offset = 0;
351     const uint8_t *clientKey = &material[offset];
352     offset += SRTP_AES_128_KEY_LEN;
353     const uint8_t *serverKey = &material[offset];
354     offset += SRTP_AES_128_KEY_LEN;
355     const uint8_t *clientSalt = &material[offset];
356     offset += SRTP_SALT_LEN;
357     const uint8_t *serverSalt = &material[offset];
358     offset += SRTP_SALT_LEN;
359 
360     CHECK_EQ(offset, sizeof(material));
361 
362     std::string sendKey(
363             reinterpret_cast<const char *>(clientKey), SRTP_AES_128_KEY_LEN);
364 
365     sendKey.append(
366             reinterpret_cast<const char *>(clientSalt), SRTP_SALT_LEN);
367 
368     std::string receiveKey(
369             reinterpret_cast<const char *>(serverKey), SRTP_AES_128_KEY_LEN);
370 
371     receiveKey.append(
372             reinterpret_cast<const char *>(serverSalt), SRTP_SALT_LEN);
373 
374     if (mMode == Mode::CONNECT) {
375         CreateSRTPSession(&mSRTPInbound, receiveKey, ssrc_any_inbound);
376         CreateSRTPSession(&mSRTPOutbound, sendKey, ssrc_any_outbound);
377     } else {
378         CreateSRTPSession(&mSRTPInbound, sendKey, ssrc_any_inbound);
379         CreateSRTPSession(&mSRTPOutbound, receiveKey, ssrc_any_outbound);
380     }
381 }
382 
383 // static
CreateSRTPSession(srtp_t * session,const std::string & keyAndSalt,srtp_ssrc_type_t direction)384 void DTLS::CreateSRTPSession(
385         srtp_t *session,
386         const std::string &keyAndSalt,
387         srtp_ssrc_type_t direction) {
388     srtp_policy_t policy;
389     memset(&policy, 0, sizeof(policy));
390 
391     srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp);
392     srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp);
393 
394     policy.ssrc.type = direction;
395     policy.ssrc.value = 0;
396 
397     policy.key =
398         const_cast<unsigned char *>(
399                 reinterpret_cast<const unsigned char *>(keyAndSalt.c_str()));
400 
401     policy.allow_repeat_tx = 1;
402     policy.next = nullptr;
403 
404     auto ret = srtp_create(session, &policy);
405     CHECK_EQ(ret, srtp_err_status_ok);
406 }
407 
protect(void * data,size_t size,bool isRTP)408 size_t DTLS::protect(void *data, size_t size, bool isRTP) {
409     int len = static_cast<int>(size);
410 
411     auto ret =
412         isRTP
413             ? srtp_protect(mSRTPOutbound, data, &len)
414             : srtp_protect_rtcp(mSRTPOutbound, data, &len);
415 
416     CHECK_EQ(ret, srtp_err_status_ok);
417 
418     return static_cast<size_t>(len);
419 }
420 
unprotect(void * data,size_t size,bool isRTP)421 size_t DTLS::unprotect(void *data, size_t size, bool isRTP) {
422     int len = static_cast<int>(size);
423 
424     auto ret =
425         isRTP
426             ? srtp_unprotect(mSRTPInbound, data, &len)
427             : srtp_unprotect_rtcp(mSRTPInbound, data, &len);
428 
429     if (ret == srtp_err_status_replay_fail) {
430         LOG(WARNING)
431             << "srtp_unprotect"
432             << (isRTP ? "" : "_rtcp")
433             << " returned srtp_err_status_replay_fail, ignoring packet.";
434 
435         return 0;
436     }
437 
438     CHECK_EQ(ret, srtp_err_status_ok);
439 
440     return static_cast<size_t>(len);
441 }
442 
readApplicationData(void * data,size_t size)443 ssize_t DTLS::readApplicationData(void *data, size_t size) {
444     auto res = SSL_read(mSSL, data, size);
445 
446     if (res < 0) {
447         doTheThing(res);
448         return -1;
449     }
450 
451     return res;
452 }
453 
writeApplicationData(const void * data,size_t size)454 ssize_t DTLS::writeApplicationData(const void *data, size_t size) {
455     auto res = SSL_write(mSSL, data, size);
456 
457     queueOutputDataFromDTLS();
458 
459     // May have to queue the data and "doTheThing" on failure...
460     CHECK_EQ(res, static_cast<int>(size));
461 
462     return res;
463 }
464