• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * nghttp2 - HTTP/2 C Library
3  *
4  * Copyright (c) 2021 Tatsuhiro Tsujikawa
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining
7  * a copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sublicense, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be
15  * included in all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24  */
25 #include "shrpx_quic.h"
26 
27 #include <sys/types.h>
28 #include <sys/socket.h>
29 #include <netdb.h>
30 #include <netinet/udp.h>
31 
32 #include <array>
33 #include <chrono>
34 
35 #include <ngtcp2/ngtcp2_crypto.h>
36 
37 #include <nghttp3/nghttp3.h>
38 
39 #include <openssl/rand.h>
40 
41 #include "shrpx_config.h"
42 #include "shrpx_log.h"
43 #include "util.h"
44 #include "xsi_strerror.h"
45 
operator ==(const ngtcp2_cid & lhs,const ngtcp2_cid & rhs)46 bool operator==(const ngtcp2_cid &lhs, const ngtcp2_cid &rhs) {
47   return ngtcp2_cid_eq(&lhs, &rhs);
48 }
49 
50 namespace shrpx {
51 
quic_timestamp()52 ngtcp2_tstamp quic_timestamp() {
53   return std::chrono::duration_cast<std::chrono::nanoseconds>(
54              std::chrono::steady_clock::now().time_since_epoch())
55       .count();
56 }
57 
quic_send_packet(const UpstreamAddr * faddr,const sockaddr * remote_sa,size_t remote_salen,const sockaddr * local_sa,size_t local_salen,const ngtcp2_pkt_info & pi,std::span<const uint8_t> data,size_t gso_size)58 int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa,
59                      size_t remote_salen, const sockaddr *local_sa,
60                      size_t local_salen, const ngtcp2_pkt_info &pi,
61                      std::span<const uint8_t> data, size_t gso_size) {
62   assert(gso_size);
63 
64   iovec msg_iov = {const_cast<uint8_t *>(data.data()), data.size()};
65   msghdr msg{};
66   msg.msg_name = const_cast<sockaddr *>(remote_sa);
67   msg.msg_namelen = remote_salen;
68   msg.msg_iov = &msg_iov;
69   msg.msg_iovlen = 1;
70 
71   uint8_t msg_ctrl[CMSG_SPACE(sizeof(int)) +
72 #ifdef UDP_SEGMENT
73                    CMSG_SPACE(sizeof(uint16_t)) +
74 #endif // UDP_SEGMENT
75                    CMSG_SPACE(sizeof(in6_pktinfo))];
76 
77   memset(msg_ctrl, 0, sizeof(msg_ctrl));
78 
79   msg.msg_control = msg_ctrl;
80   msg.msg_controllen = sizeof(msg_ctrl);
81 
82   size_t controllen = 0;
83 
84   auto cm = CMSG_FIRSTHDR(&msg);
85 
86   switch (local_sa->sa_family) {
87   case AF_INET: {
88     controllen += CMSG_SPACE(sizeof(in_pktinfo));
89     cm->cmsg_level = IPPROTO_IP;
90     cm->cmsg_type = IP_PKTINFO;
91     cm->cmsg_len = CMSG_LEN(sizeof(in_pktinfo));
92     in_pktinfo pktinfo{};
93     auto addrin =
94         reinterpret_cast<sockaddr_in *>(const_cast<sockaddr *>(local_sa));
95     pktinfo.ipi_spec_dst = addrin->sin_addr;
96     memcpy(CMSG_DATA(cm), &pktinfo, sizeof(pktinfo));
97 
98     break;
99   }
100   case AF_INET6: {
101     controllen += CMSG_SPACE(sizeof(in6_pktinfo));
102     cm->cmsg_level = IPPROTO_IPV6;
103     cm->cmsg_type = IPV6_PKTINFO;
104     cm->cmsg_len = CMSG_LEN(sizeof(in6_pktinfo));
105     in6_pktinfo pktinfo{};
106     auto addrin =
107         reinterpret_cast<sockaddr_in6 *>(const_cast<sockaddr *>(local_sa));
108     pktinfo.ipi6_addr = addrin->sin6_addr;
109     memcpy(CMSG_DATA(cm), &pktinfo, sizeof(pktinfo));
110 
111     break;
112   }
113   default:
114     assert(0);
115   }
116 
117 #ifdef UDP_SEGMENT
118   if (data.size() > gso_size) {
119     controllen += CMSG_SPACE(sizeof(uint16_t));
120     cm = CMSG_NXTHDR(&msg, cm);
121     cm->cmsg_level = SOL_UDP;
122     cm->cmsg_type = UDP_SEGMENT;
123     cm->cmsg_len = CMSG_LEN(sizeof(uint16_t));
124     uint16_t n = gso_size;
125     memcpy(CMSG_DATA(cm), &n, sizeof(n));
126   }
127 #endif // UDP_SEGMENT
128 
129   controllen += CMSG_SPACE(sizeof(int));
130   cm = CMSG_NXTHDR(&msg, cm);
131   cm->cmsg_len = CMSG_LEN(sizeof(int));
132   unsigned int tos = pi.ecn;
133   memcpy(CMSG_DATA(cm), &tos, sizeof(tos));
134 
135   switch (local_sa->sa_family) {
136   case AF_INET:
137     cm->cmsg_level = IPPROTO_IP;
138     cm->cmsg_type = IP_TOS;
139 
140     break;
141   case AF_INET6:
142     cm->cmsg_level = IPPROTO_IPV6;
143     cm->cmsg_type = IPV6_TCLASS;
144 
145     break;
146   default:
147     assert(0);
148   }
149 
150   msg.msg_controllen = controllen;
151 
152   ssize_t nwrite;
153 
154   do {
155     nwrite = sendmsg(faddr->fd, &msg, 0);
156   } while (nwrite == -1 && errno == EINTR);
157 
158   if (nwrite == -1) {
159     if (LOG_ENABLED(INFO)) {
160       auto error = errno;
161       LOG(INFO) << "sendmsg failed: errno=" << error;
162     }
163 
164     return -errno;
165   }
166 
167   if (LOG_ENABLED(INFO)) {
168     LOG(INFO) << "QUIC sent packet: local="
169               << util::to_numeric_addr(local_sa, local_salen)
170               << " remote=" << util::to_numeric_addr(remote_sa, remote_salen)
171               << " ecn=" << log::hex << pi.ecn << log::dec << " " << nwrite
172               << " bytes";
173   }
174 
175   assert(static_cast<size_t>(nwrite) == data.size());
176 
177   return 0;
178 }
179 
generate_quic_retry_connection_id(ngtcp2_cid & cid,uint32_t server_id,uint8_t km_id,EVP_CIPHER_CTX * ctx)180 int generate_quic_retry_connection_id(ngtcp2_cid &cid, uint32_t server_id,
181                                       uint8_t km_id, EVP_CIPHER_CTX *ctx) {
182   if (RAND_bytes(cid.data, SHRPX_QUIC_SCIDLEN) != 1) {
183     return -1;
184   }
185 
186   cid.datalen = SHRPX_QUIC_SCIDLEN;
187   cid.data[0] = (cid.data[0] & (~SHRPX_QUIC_DCID_KM_ID_MASK)) | km_id;
188 
189   auto p = cid.data + SHRPX_QUIC_CID_WORKER_ID_OFFSET;
190 
191   std::copy_n(reinterpret_cast<uint8_t *>(&server_id), sizeof(server_id), p);
192 
193   return encrypt_quic_connection_id(p, p, ctx);
194 }
195 
generate_quic_connection_id(ngtcp2_cid & cid,const WorkerID & wid,uint8_t km_id,EVP_CIPHER_CTX * ctx)196 int generate_quic_connection_id(ngtcp2_cid &cid, const WorkerID &wid,
197                                 uint8_t km_id, EVP_CIPHER_CTX *ctx) {
198   if (RAND_bytes(cid.data, SHRPX_QUIC_SCIDLEN) != 1) {
199     return -1;
200   }
201 
202   cid.datalen = SHRPX_QUIC_SCIDLEN;
203   cid.data[0] = (cid.data[0] & (~SHRPX_QUIC_DCID_KM_ID_MASK)) | km_id;
204 
205   auto p = cid.data + SHRPX_QUIC_CID_WORKER_ID_OFFSET;
206 
207   std::copy_n(reinterpret_cast<const uint8_t *>(&wid), sizeof(wid), p);
208 
209   return encrypt_quic_connection_id(p, p, ctx);
210 }
211 
encrypt_quic_connection_id(uint8_t * dest,const uint8_t * src,EVP_CIPHER_CTX * ctx)212 int encrypt_quic_connection_id(uint8_t *dest, const uint8_t *src,
213                                EVP_CIPHER_CTX *ctx) {
214   int len;
215 
216   if (!EVP_EncryptUpdate(ctx, dest, &len, src, SHRPX_QUIC_DECRYPTED_DCIDLEN) ||
217       !EVP_EncryptFinal_ex(ctx, dest + len, &len)) {
218     return -1;
219   }
220 
221   return 0;
222 }
223 
decrypt_quic_connection_id(ConnectionID & dest,const uint8_t * src,EVP_CIPHER_CTX * ctx)224 int decrypt_quic_connection_id(ConnectionID &dest, const uint8_t *src,
225                                EVP_CIPHER_CTX *ctx) {
226   int len;
227   auto p = reinterpret_cast<uint8_t *>(&dest);
228 
229   if (!EVP_DecryptUpdate(ctx, p, &len, src, SHRPX_QUIC_DECRYPTED_DCIDLEN) ||
230       !EVP_DecryptFinal_ex(ctx, p + len, &len)) {
231     return -1;
232   }
233 
234   return 0;
235 }
236 
generate_quic_hashed_connection_id(ngtcp2_cid & dest,const Address & remote_addr,const Address & local_addr,const ngtcp2_cid & cid)237 int generate_quic_hashed_connection_id(ngtcp2_cid &dest,
238                                        const Address &remote_addr,
239                                        const Address &local_addr,
240                                        const ngtcp2_cid &cid) {
241   auto ctx = EVP_MD_CTX_new();
242   auto d = defer(EVP_MD_CTX_free, ctx);
243 
244   std::array<uint8_t, 32> h;
245   unsigned int hlen = EVP_MD_size(EVP_sha256());
246 
247   if (!EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr) ||
248       !EVP_DigestUpdate(ctx, &remote_addr.su.sa, remote_addr.len) ||
249       !EVP_DigestUpdate(ctx, &local_addr.su.sa, local_addr.len) ||
250       !EVP_DigestUpdate(ctx, cid.data, cid.datalen) ||
251       !EVP_DigestFinal_ex(ctx, h.data(), &hlen)) {
252     return -1;
253   }
254 
255   assert(hlen == h.size());
256 
257   std::copy_n(std::begin(h), sizeof(dest.data), std::begin(dest.data));
258   dest.datalen = sizeof(dest.data);
259 
260   return 0;
261 }
262 
generate_quic_stateless_reset_token(uint8_t * token,const ngtcp2_cid & cid,const uint8_t * secret,size_t secretlen)263 int generate_quic_stateless_reset_token(uint8_t *token, const ngtcp2_cid &cid,
264                                         const uint8_t *secret,
265                                         size_t secretlen) {
266   if (ngtcp2_crypto_generate_stateless_reset_token(token, secret, secretlen,
267                                                    &cid) != 0) {
268     return -1;
269   }
270 
271   return 0;
272 }
273 
274 std::optional<std::span<const uint8_t>>
generate_retry_token(std::span<uint8_t> token,uint32_t version,const sockaddr * sa,socklen_t salen,const ngtcp2_cid & retry_scid,const ngtcp2_cid & odcid,std::span<const uint8_t> secret)275 generate_retry_token(std::span<uint8_t> token, uint32_t version,
276                      const sockaddr *sa, socklen_t salen,
277                      const ngtcp2_cid &retry_scid, const ngtcp2_cid &odcid,
278                      std::span<const uint8_t> secret) {
279   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
280                std::chrono::system_clock::now().time_since_epoch())
281                .count();
282 
283   auto tokenlen = ngtcp2_crypto_generate_retry_token(
284       token.data(), secret.data(), secret.size(), version, sa, salen,
285       &retry_scid, &odcid, t);
286   if (tokenlen < 0) {
287     return {};
288   }
289 
290   return {{std::begin(token), static_cast<size_t>(tokenlen)}};
291 }
292 
verify_retry_token(ngtcp2_cid & odcid,std::span<const uint8_t> token,uint32_t version,const ngtcp2_cid & dcid,const sockaddr * sa,socklen_t salen,std::span<const uint8_t> secret)293 int verify_retry_token(ngtcp2_cid &odcid, std::span<const uint8_t> token,
294                        uint32_t version, const ngtcp2_cid &dcid,
295                        const sockaddr *sa, socklen_t salen,
296                        std::span<const uint8_t> secret) {
297   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
298                std::chrono::system_clock::now().time_since_epoch())
299                .count();
300 
301   if (ngtcp2_crypto_verify_retry_token(
302           &odcid, token.data(), token.size(), secret.data(), secret.size(),
303           version, sa, salen, &dcid, 10 * NGTCP2_SECONDS, t) != 0) {
304     return -1;
305   }
306 
307   return 0;
308 }
309 
310 std::optional<std::span<const uint8_t>>
generate_token(std::span<uint8_t> token,const sockaddr * sa,size_t salen,std::span<const uint8_t> secret,uint8_t km_id)311 generate_token(std::span<uint8_t> token, const sockaddr *sa, size_t salen,
312                std::span<const uint8_t> secret, uint8_t km_id) {
313   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
314                std::chrono::system_clock::now().time_since_epoch())
315                .count();
316 
317   auto tokenlen = ngtcp2_crypto_generate_regular_token(
318       token.data(), secret.data(), secret.size(), sa, salen, t);
319   if (tokenlen < 0) {
320     return {};
321   }
322 
323   token[tokenlen++] = km_id;
324 
325   return {{std::begin(token), static_cast<size_t>(tokenlen)}};
326 }
327 
verify_token(std::span<const uint8_t> token,const sockaddr * sa,socklen_t salen,std::span<const uint8_t> secret)328 int verify_token(std::span<const uint8_t> token, const sockaddr *sa,
329                  socklen_t salen, std::span<const uint8_t> secret) {
330   if (token.empty()) {
331     return -1;
332   }
333 
334   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
335                std::chrono::system_clock::now().time_since_epoch())
336                .count();
337 
338   if (ngtcp2_crypto_verify_regular_token(
339           token.data(), token.size() - 1, secret.data(), secret.size(), sa,
340           salen, 3600 * NGTCP2_SECONDS, t) != 0) {
341     return -1;
342   }
343 
344   return 0;
345 }
346 
generate_quic_connection_id_encryption_key(std::span<uint8_t> key,std::span<const uint8_t> secret,std::span<const uint8_t> salt)347 int generate_quic_connection_id_encryption_key(std::span<uint8_t> key,
348                                                std::span<const uint8_t> secret,
349                                                std::span<const uint8_t> salt) {
350   constexpr uint8_t info[] = "connection id encryption key";
351   ngtcp2_crypto_md sha256;
352   ngtcp2_crypto_md_init(
353       &sha256, reinterpret_cast<void *>(const_cast<EVP_MD *>(EVP_sha256())));
354 
355   if (ngtcp2_crypto_hkdf(key.data(), key.size(), &sha256, secret.data(),
356                          secret.size(), salt.data(), salt.size(), info,
357                          str_size(info)) != 0) {
358     return -1;
359   }
360 
361   return 0;
362 }
363 
364 const QUICKeyingMaterial *
select_quic_keying_material(const QUICKeyingMaterials & qkms,uint8_t km_id)365 select_quic_keying_material(const QUICKeyingMaterials &qkms, uint8_t km_id) {
366   for (auto &qkm : qkms.keying_materials) {
367     if (km_id == qkm.id) {
368       return &qkm;
369     }
370   }
371 
372   return &qkms.keying_materials.front();
373 }
374 
375 } // namespace shrpx
376