• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * nghttp2 - HTTP/2 C Library
3  *
4  * Copyright (c) 2012 Tatsuhiro Tsujikawa
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining
7  * a copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sublicense, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be
15  * included in all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24  */
25 #include "shrpx_worker.h"
26 
27 #ifdef HAVE_UNISTD_H
28 #  include <unistd.h>
29 #endif // HAVE_UNISTD_H
30 #include <netinet/udp.h>
31 
32 #include <cstdio>
33 #include <memory>
34 
35 #include <openssl/rand.h>
36 
37 #ifdef HAVE_LIBBPF
38 #  include <bpf/bpf.h>
39 #  include <bpf/libbpf.h>
40 #endif // HAVE_LIBBPF
41 
42 #include "shrpx_tls.h"
43 #include "shrpx_log.h"
44 #include "shrpx_client_handler.h"
45 #include "shrpx_http2_session.h"
46 #include "shrpx_log_config.h"
47 #include "shrpx_memcached_dispatcher.h"
48 #ifdef HAVE_MRUBY
49 #  include "shrpx_mruby.h"
50 #endif // HAVE_MRUBY
51 #ifdef ENABLE_HTTP3
52 #  include "shrpx_quic_listener.h"
53 #endif // ENABLE_HTTP3
54 #include "shrpx_connection_handler.h"
55 #include "util.h"
56 #include "template.h"
57 #include "xsi_strerror.h"
58 
59 namespace shrpx {
60 
61 namespace {
eventcb(struct ev_loop * loop,ev_async * w,int revents)62 void eventcb(struct ev_loop *loop, ev_async *w, int revents) {
63   auto worker = static_cast<Worker *>(w->data);
64   worker->process_events();
65 }
66 } // namespace
67 
68 namespace {
mcpool_clear_cb(struct ev_loop * loop,ev_timer * w,int revents)69 void mcpool_clear_cb(struct ev_loop *loop, ev_timer *w, int revents) {
70   auto worker = static_cast<Worker *>(w->data);
71   if (worker->get_worker_stat()->num_connections != 0) {
72     return;
73   }
74   auto mcpool = worker->get_mcpool();
75   if (mcpool->freelistsize == mcpool->poolsize) {
76     worker->get_mcpool()->clear();
77   }
78 }
79 } // namespace
80 
81 namespace {
proc_wev_cb(struct ev_loop * loop,ev_timer * w,int revents)82 void proc_wev_cb(struct ev_loop *loop, ev_timer *w, int revents) {
83   auto worker = static_cast<Worker *>(w->data);
84   worker->process_events();
85 }
86 } // namespace
87 
DownstreamAddrGroup()88 DownstreamAddrGroup::DownstreamAddrGroup() : retired{false} {}
89 
~DownstreamAddrGroup()90 DownstreamAddrGroup::~DownstreamAddrGroup() {}
91 
92 // DownstreamKey is used to index SharedDownstreamAddr in order to
93 // find the same configuration.
94 using DownstreamKey = std::tuple<
95     std::vector<
96         std::tuple<StringRef, StringRef, StringRef, size_t, size_t, Proto,
97                    uint32_t, uint32_t, uint32_t, bool, bool, bool, bool>>,
98     bool, SessionAffinity, StringRef, StringRef, SessionAffinityCookieSecure,
99     SessionAffinityCookieStickiness, int64_t, int64_t, StringRef, bool>;
100 
101 namespace {
102 DownstreamKey
create_downstream_key(const std::shared_ptr<SharedDownstreamAddr> & shared_addr,const StringRef & mruby_file)103 create_downstream_key(const std::shared_ptr<SharedDownstreamAddr> &shared_addr,
104                       const StringRef &mruby_file) {
105   DownstreamKey dkey;
106 
107   auto &addrs = std::get<0>(dkey);
108   addrs.resize(shared_addr->addrs.size());
109   auto p = std::begin(addrs);
110   for (auto &a : shared_addr->addrs) {
111     std::get<0>(*p) = a.host;
112     std::get<1>(*p) = a.sni;
113     std::get<2>(*p) = a.group;
114     std::get<3>(*p) = a.fall;
115     std::get<4>(*p) = a.rise;
116     std::get<5>(*p) = a.proto;
117     std::get<6>(*p) = a.port;
118     std::get<7>(*p) = a.weight;
119     std::get<8>(*p) = a.group_weight;
120     std::get<9>(*p) = a.host_unix;
121     std::get<10>(*p) = a.tls;
122     std::get<11>(*p) = a.dns;
123     std::get<12>(*p) = a.upgrade_scheme;
124     ++p;
125   }
126   std::sort(std::begin(addrs), std::end(addrs));
127 
128   std::get<1>(dkey) = shared_addr->redirect_if_not_tls;
129 
130   auto &affinity = shared_addr->affinity;
131   std::get<2>(dkey) = affinity.type;
132   std::get<3>(dkey) = affinity.cookie.name;
133   std::get<4>(dkey) = affinity.cookie.path;
134   std::get<5>(dkey) = affinity.cookie.secure;
135   std::get<6>(dkey) = affinity.cookie.stickiness;
136   auto &timeout = shared_addr->timeout;
137   std::get<7>(dkey) = timeout.read;
138   std::get<8>(dkey) = timeout.write;
139   std::get<9>(dkey) = mruby_file;
140   std::get<10>(dkey) = shared_addr->dnf;
141 
142   return dkey;
143 }
144 } // namespace
145 
Worker(struct ev_loop * loop,SSL_CTX * sv_ssl_ctx,SSL_CTX * cl_ssl_ctx,SSL_CTX * tls_session_cache_memcached_ssl_ctx,tls::CertLookupTree * cert_tree,SSL_CTX * quic_sv_ssl_ctx,tls::CertLookupTree * quic_cert_tree,WorkerID wid,size_t index,const std::shared_ptr<TicketKeys> & ticket_keys,ConnectionHandler * conn_handler,std::shared_ptr<DownstreamConfig> downstreamconf)146 Worker::Worker(struct ev_loop *loop, SSL_CTX *sv_ssl_ctx, SSL_CTX *cl_ssl_ctx,
147                SSL_CTX *tls_session_cache_memcached_ssl_ctx,
148                tls::CertLookupTree *cert_tree,
149 #ifdef ENABLE_HTTP3
150                SSL_CTX *quic_sv_ssl_ctx, tls::CertLookupTree *quic_cert_tree,
151                WorkerID wid,
152 #  ifdef HAVE_LIBBPF
153                size_t index,
154 #  endif // HAVE_LIBBPF
155 #endif   // ENABLE_HTTP3
156                const std::shared_ptr<TicketKeys> &ticket_keys,
157                ConnectionHandler *conn_handler,
158                std::shared_ptr<DownstreamConfig> downstreamconf)
159     :
160 #if defined(ENABLE_HTTP3) && defined(HAVE_LIBBPF)
161       index_{index},
162 #endif // ENABLE_HTTP3 && HAVE_LIBBPF
163       randgen_(util::make_mt19937()),
164       worker_stat_{},
165       dns_tracker_(loop, get_config()->conn.downstream->family),
166 #ifdef ENABLE_HTTP3
167       worker_id_{std::move(wid)},
168       quic_upstream_addrs_{get_config()->conn.quic_listener.addrs},
169 #endif // ENABLE_HTTP3
170       loop_(loop),
171       sv_ssl_ctx_(sv_ssl_ctx),
172       cl_ssl_ctx_(cl_ssl_ctx),
173       cert_tree_(cert_tree),
174       conn_handler_(conn_handler),
175 #ifdef ENABLE_HTTP3
176       quic_sv_ssl_ctx_{quic_sv_ssl_ctx},
177       quic_cert_tree_{quic_cert_tree},
178       quic_conn_handler_{this},
179 #endif // ENABLE_HTTP3
180       ticket_keys_(ticket_keys),
181       connect_blocker_(
182           std::make_unique<ConnectBlocker>(randgen_, loop_, nullptr, nullptr)),
183       graceful_shutdown_(false) {
184   ev_async_init(&w_, eventcb);
185   w_.data = this;
186   ev_async_start(loop_, &w_);
187 
188   ev_timer_init(&mcpool_clear_timer_, mcpool_clear_cb, 0., 0.);
189   mcpool_clear_timer_.data = this;
190 
191   ev_timer_init(&proc_wev_timer_, proc_wev_cb, 0., 0.);
192   proc_wev_timer_.data = this;
193 
194   auto &session_cacheconf = get_config()->tls.session_cache;
195 
196   if (!session_cacheconf.memcached.host.empty()) {
197     session_cache_memcached_dispatcher_ = std::make_unique<MemcachedDispatcher>(
198         &session_cacheconf.memcached.addr, loop,
199         tls_session_cache_memcached_ssl_ctx,
200         StringRef{session_cacheconf.memcached.host}, &mcpool_, randgen_);
201   }
202 
203   replace_downstream_config(std::move(downstreamconf));
204 }
205 
206 namespace {
ensure_enqueue_addr(std::priority_queue<WeightGroupEntry,std::vector<WeightGroupEntry>,WeightGroupEntryGreater> & wgpq,WeightGroup * wg,DownstreamAddr * addr)207 void ensure_enqueue_addr(
208     std::priority_queue<WeightGroupEntry, std::vector<WeightGroupEntry>,
209                         WeightGroupEntryGreater> &wgpq,
210     WeightGroup *wg, DownstreamAddr *addr) {
211   uint32_t cycle;
212   if (!wg->pq.empty()) {
213     auto &top = wg->pq.top();
214     cycle = top.cycle;
215   } else {
216     cycle = 0;
217   }
218 
219   addr->cycle = cycle;
220   addr->pending_penalty = 0;
221   wg->pq.push(DownstreamAddrEntry{addr, addr->seq, addr->cycle});
222   addr->queued = true;
223 
224   if (!wg->queued) {
225     if (!wgpq.empty()) {
226       auto &top = wgpq.top();
227       cycle = top.cycle;
228     } else {
229       cycle = 0;
230     }
231 
232     wg->cycle = cycle;
233     wg->pending_penalty = 0;
234     wgpq.push(WeightGroupEntry{wg, wg->seq, wg->cycle});
235     wg->queued = true;
236   }
237 }
238 } // namespace
239 
replace_downstream_config(std::shared_ptr<DownstreamConfig> downstreamconf)240 void Worker::replace_downstream_config(
241     std::shared_ptr<DownstreamConfig> downstreamconf) {
242   for (auto &g : downstream_addr_groups_) {
243     g->retired = true;
244 
245     auto &shared_addr = g->shared_addr;
246     for (auto &addr : shared_addr->addrs) {
247       addr.dconn_pool->remove_all();
248     }
249   }
250 
251   downstreamconf_ = downstreamconf;
252 
253   // Making a copy is much faster with multiple thread on
254   // backendconfig API call.
255   auto groups = downstreamconf->addr_groups;
256 
257   downstream_addr_groups_ =
258       std::vector<std::shared_ptr<DownstreamAddrGroup>>(groups.size());
259 
260   std::map<DownstreamKey, size_t> addr_groups_indexer;
261 #ifdef HAVE_MRUBY
262   // TODO It is a bit less efficient because
263   // mruby::create_mruby_context returns std::unique_ptr and we cannot
264   // use std::make_shared.
265   std::map<StringRef, std::shared_ptr<mruby::MRubyContext>> shared_mruby_ctxs;
266 #endif // HAVE_MRUBY
267 
268   for (size_t i = 0; i < groups.size(); ++i) {
269     auto &src = groups[i];
270     auto &dst = downstream_addr_groups_[i];
271 
272     dst = std::make_shared<DownstreamAddrGroup>();
273     dst->pattern =
274         ImmutableString{std::begin(src.pattern), std::end(src.pattern)};
275 
276     auto shared_addr = std::make_shared<SharedDownstreamAddr>();
277 
278     shared_addr->addrs.resize(src.addrs.size());
279     shared_addr->affinity.type = src.affinity.type;
280     if (src.affinity.type == SessionAffinity::COOKIE) {
281       shared_addr->affinity.cookie.name =
282           make_string_ref(shared_addr->balloc, src.affinity.cookie.name);
283       if (!src.affinity.cookie.path.empty()) {
284         shared_addr->affinity.cookie.path =
285             make_string_ref(shared_addr->balloc, src.affinity.cookie.path);
286       }
287       shared_addr->affinity.cookie.secure = src.affinity.cookie.secure;
288       shared_addr->affinity.cookie.stickiness = src.affinity.cookie.stickiness;
289     }
290     shared_addr->affinity_hash = src.affinity_hash;
291     shared_addr->affinity_hash_map = src.affinity_hash_map;
292     shared_addr->redirect_if_not_tls = src.redirect_if_not_tls;
293     shared_addr->dnf = src.dnf;
294     shared_addr->timeout.read = src.timeout.read;
295     shared_addr->timeout.write = src.timeout.write;
296 
297     for (size_t j = 0; j < src.addrs.size(); ++j) {
298       auto &src_addr = src.addrs[j];
299       auto &dst_addr = shared_addr->addrs[j];
300 
301       dst_addr.addr = src_addr.addr;
302       dst_addr.host = make_string_ref(shared_addr->balloc, src_addr.host);
303       dst_addr.hostport =
304           make_string_ref(shared_addr->balloc, src_addr.hostport);
305       dst_addr.port = src_addr.port;
306       dst_addr.host_unix = src_addr.host_unix;
307       dst_addr.weight = src_addr.weight;
308       dst_addr.group = make_string_ref(shared_addr->balloc, src_addr.group);
309       dst_addr.group_weight = src_addr.group_weight;
310       dst_addr.affinity_hash = src_addr.affinity_hash;
311       dst_addr.proto = src_addr.proto;
312       dst_addr.tls = src_addr.tls;
313       dst_addr.sni = make_string_ref(shared_addr->balloc, src_addr.sni);
314       dst_addr.fall = src_addr.fall;
315       dst_addr.rise = src_addr.rise;
316       dst_addr.dns = src_addr.dns;
317       dst_addr.upgrade_scheme = src_addr.upgrade_scheme;
318     }
319 
320 #ifdef HAVE_MRUBY
321     auto mruby_ctx_it = shared_mruby_ctxs.find(src.mruby_file);
322     if (mruby_ctx_it == std::end(shared_mruby_ctxs)) {
323       shared_addr->mruby_ctx = mruby::create_mruby_context(src.mruby_file);
324       assert(shared_addr->mruby_ctx);
325       shared_mruby_ctxs.emplace(src.mruby_file, shared_addr->mruby_ctx);
326     } else {
327       shared_addr->mruby_ctx = (*mruby_ctx_it).second;
328     }
329 #endif // HAVE_MRUBY
330 
331     // share the connection if patterns have the same set of backend
332     // addresses.
333 
334     auto dkey = create_downstream_key(shared_addr, src.mruby_file);
335     auto it = addr_groups_indexer.find(dkey);
336 
337     if (it == std::end(addr_groups_indexer)) {
338       auto shared_addr_ptr = shared_addr.get();
339 
340       for (auto &addr : shared_addr->addrs) {
341         addr.connect_blocker = std::make_unique<ConnectBlocker>(
342             randgen_, loop_, nullptr, [shared_addr_ptr, &addr]() {
343               if (!addr.queued) {
344                 if (!addr.wg) {
345                   return;
346                 }
347                 ensure_enqueue_addr(shared_addr_ptr->pq, addr.wg, &addr);
348               }
349             });
350 
351         addr.live_check = std::make_unique<LiveCheck>(loop_, cl_ssl_ctx_, this,
352                                                       &addr, randgen_);
353       }
354 
355       size_t seq = 0;
356       for (auto &addr : shared_addr->addrs) {
357         addr.dconn_pool = std::make_unique<DownstreamConnectionPool>();
358         addr.seq = seq++;
359       }
360 
361       util::shuffle(std::begin(shared_addr->addrs),
362                     std::end(shared_addr->addrs), randgen_,
363                     [](auto i, auto j) { std::swap((*i).seq, (*j).seq); });
364 
365       if (shared_addr->affinity.type == SessionAffinity::NONE) {
366         std::map<StringRef, WeightGroup *> wgs;
367         size_t num_wgs = 0;
368         for (auto &addr : shared_addr->addrs) {
369           if (wgs.find(addr.group) == std::end(wgs)) {
370             ++num_wgs;
371             wgs.emplace(addr.group, nullptr);
372           }
373         }
374 
375         shared_addr->wgs = std::vector<WeightGroup>(num_wgs);
376 
377         for (auto &addr : shared_addr->addrs) {
378           auto &wg = wgs[addr.group];
379           if (wg == nullptr) {
380             wg = &shared_addr->wgs[--num_wgs];
381             wg->seq = num_wgs;
382           }
383 
384           wg->weight = addr.group_weight;
385           wg->pq.push(DownstreamAddrEntry{&addr, addr.seq, addr.cycle});
386           addr.queued = true;
387           addr.wg = wg;
388         }
389 
390         assert(num_wgs == 0);
391 
392         for (auto &kv : wgs) {
393           shared_addr->pq.push(
394               WeightGroupEntry{kv.second, kv.second->seq, kv.second->cycle});
395           kv.second->queued = true;
396         }
397       }
398 
399       dst->shared_addr = std::move(shared_addr);
400 
401       addr_groups_indexer.emplace(std::move(dkey), i);
402     } else {
403       auto &g = *(std::begin(downstream_addr_groups_) + (*it).second);
404       if (LOG_ENABLED(INFO)) {
405         LOG(INFO) << dst->pattern << " shares the same backend group with "
406                   << g->pattern;
407       }
408       dst->shared_addr = g->shared_addr;
409     }
410   }
411 }
412 
~Worker()413 Worker::~Worker() {
414   ev_async_stop(loop_, &w_);
415   ev_timer_stop(loop_, &mcpool_clear_timer_);
416   ev_timer_stop(loop_, &proc_wev_timer_);
417 }
418 
schedule_clear_mcpool()419 void Worker::schedule_clear_mcpool() {
420   // libev manual says: "If the watcher is already active nothing will
421   // happen."  Since we don't change any timeout here, we don't have
422   // to worry about querying ev_is_active.
423   ev_timer_start(loop_, &mcpool_clear_timer_);
424 }
425 
wait()426 void Worker::wait() {
427 #ifndef NOTHREADS
428   fut_.get();
429 #endif // !NOTHREADS
430 }
431 
run_async()432 void Worker::run_async() {
433 #ifndef NOTHREADS
434   fut_ = std::async(std::launch::async, [this] {
435     (void)reopen_log_files(get_config()->logging);
436     ev_run(loop_);
437     delete_log_config();
438   });
439 #endif // !NOTHREADS
440 }
441 
send(WorkerEvent event)442 void Worker::send(WorkerEvent event) {
443   {
444     std::lock_guard<std::mutex> g(m_);
445 
446     q_.emplace_back(std::move(event));
447   }
448 
449   ev_async_send(loop_, &w_);
450 }
451 
process_events()452 void Worker::process_events() {
453   WorkerEvent wev;
454   {
455     std::lock_guard<std::mutex> g(m_);
456 
457     // Process event one at a time.  This is important for
458     // WorkerEventType::NEW_CONNECTION event since accepting large
459     // number of new connections at once may delay time to 1st byte
460     // for existing connections.
461 
462     if (q_.empty()) {
463       ev_timer_stop(loop_, &proc_wev_timer_);
464       return;
465     }
466 
467     wev = std::move(q_.front());
468     q_.pop_front();
469   }
470 
471   ev_timer_start(loop_, &proc_wev_timer_);
472 
473   auto config = get_config();
474 
475   auto worker_connections = config->conn.upstream.worker_connections;
476 
477   switch (wev.type) {
478   case WorkerEventType::NEW_CONNECTION: {
479     if (LOG_ENABLED(INFO)) {
480       WLOG(INFO, this) << "WorkerEvent: client_fd=" << wev.client_fd
481                        << ", addrlen=" << wev.client_addrlen;
482     }
483 
484     if (worker_stat_.num_connections >= worker_connections) {
485 
486       if (LOG_ENABLED(INFO)) {
487         WLOG(INFO, this) << "Too many connections >= " << worker_connections;
488       }
489 
490       close(wev.client_fd);
491 
492       break;
493     }
494 
495     auto client_handler =
496         tls::accept_connection(this, wev.client_fd, &wev.client_addr.sa,
497                                wev.client_addrlen, wev.faddr);
498     if (!client_handler) {
499       if (LOG_ENABLED(INFO)) {
500         WLOG(ERROR, this) << "ClientHandler creation failed";
501       }
502       close(wev.client_fd);
503       break;
504     }
505 
506     if (LOG_ENABLED(INFO)) {
507       WLOG(INFO, this) << "CLIENT_HANDLER:" << client_handler << " created";
508     }
509 
510     break;
511   }
512   case WorkerEventType::REOPEN_LOG:
513     WLOG(NOTICE, this) << "Reopening log files: worker process (thread " << this
514                        << ")";
515 
516     reopen_log_files(config->logging);
517 
518     break;
519   case WorkerEventType::GRACEFUL_SHUTDOWN:
520     WLOG(NOTICE, this) << "Graceful shutdown commencing";
521 
522     graceful_shutdown_ = true;
523 
524     if (worker_stat_.num_connections == 0 &&
525         worker_stat_.num_close_waits == 0) {
526       ev_break(loop_);
527 
528       return;
529     }
530 
531     break;
532   case WorkerEventType::REPLACE_DOWNSTREAM:
533     WLOG(NOTICE, this) << "Replace downstream";
534 
535     replace_downstream_config(wev.downstreamconf);
536 
537     break;
538 #ifdef ENABLE_HTTP3
539   case WorkerEventType::QUIC_PKT_FORWARD: {
540     const UpstreamAddr *faddr;
541 
542     if (wev.quic_pkt->upstream_addr_index == static_cast<size_t>(-1)) {
543       faddr = find_quic_upstream_addr(wev.quic_pkt->local_addr);
544       if (faddr == nullptr) {
545         LOG(ERROR) << "No suitable upstream address found";
546 
547         break;
548       }
549     } else if (quic_upstream_addrs_.size() <=
550                wev.quic_pkt->upstream_addr_index) {
551       LOG(ERROR) << "upstream_addr_index is too large";
552 
553       break;
554     } else {
555       faddr = &quic_upstream_addrs_[wev.quic_pkt->upstream_addr_index];
556     }
557 
558     quic_conn_handler_.handle_packet(faddr, wev.quic_pkt->remote_addr,
559                                      wev.quic_pkt->local_addr, wev.quic_pkt->pi,
560                                      wev.quic_pkt->data);
561 
562     break;
563   }
564 #endif // ENABLE_HTTP3
565   default:
566     if (LOG_ENABLED(INFO)) {
567       WLOG(INFO, this) << "unknown event type " << static_cast<int>(wev.type);
568     }
569   }
570 }
571 
get_cert_lookup_tree() const572 tls::CertLookupTree *Worker::get_cert_lookup_tree() const { return cert_tree_; }
573 
574 #ifdef ENABLE_HTTP3
get_quic_cert_lookup_tree() const575 tls::CertLookupTree *Worker::get_quic_cert_lookup_tree() const {
576   return quic_cert_tree_;
577 }
578 #endif // ENABLE_HTTP3
579 
get_ticket_keys()580 std::shared_ptr<TicketKeys> Worker::get_ticket_keys() {
581 #ifdef HAVE_ATOMIC_STD_SHARED_PTR
582   return ticket_keys_.load(std::memory_order_acquire);
583 #else  // !HAVE_ATOMIC_STD_SHARED_PTR
584   std::lock_guard<std::mutex> g(ticket_keys_m_);
585   return ticket_keys_;
586 #endif // !HAVE_ATOMIC_STD_SHARED_PTR
587 }
588 
set_ticket_keys(std::shared_ptr<TicketKeys> ticket_keys)589 void Worker::set_ticket_keys(std::shared_ptr<TicketKeys> ticket_keys) {
590 #ifdef HAVE_ATOMIC_STD_SHARED_PTR
591   // This is single writer
592   ticket_keys_.store(std::move(ticket_keys), std::memory_order_release);
593 #else  // !HAVE_ATOMIC_STD_SHARED_PTR
594   std::lock_guard<std::mutex> g(ticket_keys_m_);
595   ticket_keys_ = std::move(ticket_keys);
596 #endif // !HAVE_ATOMIC_STD_SHARED_PTR
597 }
598 
get_worker_stat()599 WorkerStat *Worker::get_worker_stat() { return &worker_stat_; }
600 
get_loop() const601 struct ev_loop *Worker::get_loop() const { return loop_; }
602 
get_sv_ssl_ctx() const603 SSL_CTX *Worker::get_sv_ssl_ctx() const { return sv_ssl_ctx_; }
604 
get_cl_ssl_ctx() const605 SSL_CTX *Worker::get_cl_ssl_ctx() const { return cl_ssl_ctx_; }
606 
607 #ifdef ENABLE_HTTP3
get_quic_sv_ssl_ctx() const608 SSL_CTX *Worker::get_quic_sv_ssl_ctx() const { return quic_sv_ssl_ctx_; }
609 #endif // ENABLE_HTTP3
610 
set_graceful_shutdown(bool f)611 void Worker::set_graceful_shutdown(bool f) { graceful_shutdown_ = f; }
612 
get_graceful_shutdown() const613 bool Worker::get_graceful_shutdown() const { return graceful_shutdown_; }
614 
get_mcpool()615 MemchunkPool *Worker::get_mcpool() { return &mcpool_; }
616 
get_session_cache_memcached_dispatcher()617 MemcachedDispatcher *Worker::get_session_cache_memcached_dispatcher() {
618   return session_cache_memcached_dispatcher_.get();
619 }
620 
get_randgen()621 std::mt19937 &Worker::get_randgen() { return randgen_; }
622 
623 #ifdef HAVE_MRUBY
create_mruby_context()624 int Worker::create_mruby_context() {
625   mruby_ctx_ = mruby::create_mruby_context(StringRef{get_config()->mruby_file});
626   if (!mruby_ctx_) {
627     return -1;
628   }
629 
630   return 0;
631 }
632 
get_mruby_context() const633 mruby::MRubyContext *Worker::get_mruby_context() const {
634   return mruby_ctx_.get();
635 }
636 #endif // HAVE_MRUBY
637 
638 std::vector<std::shared_ptr<DownstreamAddrGroup>> &
get_downstream_addr_groups()639 Worker::get_downstream_addr_groups() {
640   return downstream_addr_groups_;
641 }
642 
get_connect_blocker() const643 ConnectBlocker *Worker::get_connect_blocker() const {
644   return connect_blocker_.get();
645 }
646 
get_downstream_config() const647 const DownstreamConfig *Worker::get_downstream_config() const {
648   return downstreamconf_.get();
649 }
650 
get_connection_handler() const651 ConnectionHandler *Worker::get_connection_handler() const {
652   return conn_handler_;
653 }
654 
655 #ifdef ENABLE_HTTP3
get_quic_connection_handler()656 QUICConnectionHandler *Worker::get_quic_connection_handler() {
657   return &quic_conn_handler_;
658 }
659 #endif // ENABLE_HTTP3
660 
get_dns_tracker()661 DNSTracker *Worker::get_dns_tracker() { return &dns_tracker_; }
662 
663 #ifdef ENABLE_HTTP3
664 #  ifdef HAVE_LIBBPF
should_attach_bpf() const665 bool Worker::should_attach_bpf() const {
666   auto config = get_config();
667   auto &quicconf = config->quic;
668   auto &apiconf = config->api;
669 
670   if (quicconf.bpf.disabled) {
671     return false;
672   }
673 
674   if (!config->single_thread && apiconf.enabled) {
675     return index_ == 1;
676   }
677 
678   return index_ == 0;
679 }
680 
should_update_bpf_map() const681 bool Worker::should_update_bpf_map() const {
682   auto config = get_config();
683   auto &quicconf = config->quic;
684 
685   return !quicconf.bpf.disabled;
686 }
687 
compute_sk_index() const688 uint32_t Worker::compute_sk_index() const {
689   auto config = get_config();
690   auto &apiconf = config->api;
691 
692   if (!config->single_thread && apiconf.enabled) {
693     return index_ - 1;
694   }
695 
696   return index_;
697 }
698 #  endif // HAVE_LIBBPF
699 
setup_quic_server_socket()700 int Worker::setup_quic_server_socket() {
701   size_t n = 0;
702 
703   for (auto &addr : quic_upstream_addrs_) {
704     assert(!addr.host_unix);
705     if (create_quic_server_socket(addr) != 0) {
706       return -1;
707     }
708 
709     // Make sure that each endpoint has a unique address.
710     for (size_t i = 0; i < n; ++i) {
711       const auto &a = quic_upstream_addrs_[i];
712 
713       if (addr.hostport == a.hostport) {
714         LOG(FATAL)
715             << "QUIC frontend endpoint must be unique: a duplicate found for "
716             << addr.hostport;
717 
718         return -1;
719       }
720     }
721 
722     ++n;
723 
724     quic_listeners_.emplace_back(std::make_unique<QUICListener>(&addr, this));
725   }
726 
727   return 0;
728 }
729 
730 #  ifdef HAVE_LIBBPF
731 namespace {
732 // https://github.com/kokke/tiny-AES-c
733 //
734 // License is Public Domain.
735 // Commit hash: 12e7744b4919e9d55de75b7ab566326a1c8e7a67
736 
737 // The number of columns comprising a state in AES. This is a constant
738 // in AES. Value=4
739 #    define Nb 4
740 
741 #    define Nk 4  // The number of 32 bit words in a key.
742 #    define Nr 10 // The number of rounds in AES Cipher.
743 
744 // The lookup-tables are marked const so they can be placed in
745 // read-only storage instead of RAM The numbers below can be computed
746 // dynamically trading ROM for RAM - This can be useful in (embedded)
747 // bootloader applications, where ROM is often limited.
748 const uint8_t sbox[256] = {
749     // 0 1 2 3 4 5 6 7 8 9 A B C D E F
750     0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
751     0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
752     0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
753     0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
754     0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
755     0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
756     0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
757     0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
758     0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
759     0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
760     0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
761     0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
762     0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
763     0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
764     0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
765     0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
766     0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
767     0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
768     0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
769     0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
770     0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
771     0xb0, 0x54, 0xbb, 0x16};
772 
773 #    define getSBoxValue(num) (sbox[(num)])
774 
775 // The round constant word array, Rcon[i], contains the values given
776 // by x to the power (i-1) being powers of x (x is denoted as {02}) in
777 // the field GF(2^8)
778 const uint8_t Rcon[11] = {0x8d, 0x01, 0x02, 0x04, 0x08, 0x10,
779                           0x20, 0x40, 0x80, 0x1b, 0x36};
780 
781 // This function produces Nb(Nr+1) round keys. The round keys are used
782 // in each round to decrypt the states.
KeyExpansion(uint8_t * RoundKey,const uint8_t * Key)783 void KeyExpansion(uint8_t *RoundKey, const uint8_t *Key) {
784   unsigned i, j, k;
785   uint8_t tempa[4]; // Used for the column/row operations
786 
787   // The first round key is the key itself.
788   for (i = 0; i < Nk; ++i) {
789     RoundKey[(i * 4) + 0] = Key[(i * 4) + 0];
790     RoundKey[(i * 4) + 1] = Key[(i * 4) + 1];
791     RoundKey[(i * 4) + 2] = Key[(i * 4) + 2];
792     RoundKey[(i * 4) + 3] = Key[(i * 4) + 3];
793   }
794 
795   // All other round keys are found from the previous round keys.
796   for (i = Nk; i < Nb * (Nr + 1); ++i) {
797     {
798       k = (i - 1) * 4;
799       tempa[0] = RoundKey[k + 0];
800       tempa[1] = RoundKey[k + 1];
801       tempa[2] = RoundKey[k + 2];
802       tempa[3] = RoundKey[k + 3];
803     }
804 
805     if (i % Nk == 0) {
806       // This function shifts the 4 bytes in a word to the left once.
807       // [a0,a1,a2,a3] becomes [a1,a2,a3,a0]
808 
809       // Function RotWord()
810       {
811         const uint8_t u8tmp = tempa[0];
812         tempa[0] = tempa[1];
813         tempa[1] = tempa[2];
814         tempa[2] = tempa[3];
815         tempa[3] = u8tmp;
816       }
817 
818       // SubWord() is a function that takes a four-byte input word and
819       // applies the S-box to each of the four bytes to produce an
820       // output word.
821 
822       // Function Subword()
823       {
824         tempa[0] = getSBoxValue(tempa[0]);
825         tempa[1] = getSBoxValue(tempa[1]);
826         tempa[2] = getSBoxValue(tempa[2]);
827         tempa[3] = getSBoxValue(tempa[3]);
828       }
829 
830       tempa[0] = tempa[0] ^ Rcon[i / Nk];
831     }
832     j = i * 4;
833     k = (i - Nk) * 4;
834     RoundKey[j + 0] = RoundKey[k + 0] ^ tempa[0];
835     RoundKey[j + 1] = RoundKey[k + 1] ^ tempa[1];
836     RoundKey[j + 2] = RoundKey[k + 2] ^ tempa[2];
837     RoundKey[j + 3] = RoundKey[k + 3] ^ tempa[3];
838   }
839 }
840 } // namespace
841 #  endif // HAVE_LIBBPF
842 
create_quic_server_socket(UpstreamAddr & faddr)843 int Worker::create_quic_server_socket(UpstreamAddr &faddr) {
844   std::array<char, STRERROR_BUFSIZE> errbuf;
845   int fd = -1;
846   int rv;
847 
848   auto service = util::utos(faddr.port);
849   addrinfo hints{};
850   hints.ai_family = faddr.family;
851   hints.ai_socktype = SOCK_DGRAM;
852   hints.ai_flags = AI_PASSIVE;
853 #  ifdef AI_ADDRCONFIG
854   hints.ai_flags |= AI_ADDRCONFIG;
855 #  endif // AI_ADDRCONFIG
856 
857   auto node = faddr.host == "*"_sr ? nullptr : faddr.host.data();
858 
859   addrinfo *res, *rp;
860   rv = getaddrinfo(node, service.c_str(), &hints, &res);
861 #  ifdef AI_ADDRCONFIG
862   if (rv != 0) {
863     // Retry without AI_ADDRCONFIG
864     hints.ai_flags &= ~AI_ADDRCONFIG;
865     rv = getaddrinfo(node, service.c_str(), &hints, &res);
866   }
867 #  endif // AI_ADDRCONFIG
868   if (rv != 0) {
869     LOG(FATAL) << "Unable to get IPv" << (faddr.family == AF_INET ? "4" : "6")
870                << " address for " << faddr.host << ", port " << faddr.port
871                << ": " << gai_strerror(rv);
872     return -1;
873   }
874 
875   auto res_d = defer(freeaddrinfo, res);
876 
877   std::array<char, NI_MAXHOST> host;
878 
879   for (rp = res; rp; rp = rp->ai_next) {
880     rv = getnameinfo(rp->ai_addr, rp->ai_addrlen, host.data(), host.size(),
881                      nullptr, 0, NI_NUMERICHOST);
882     if (rv != 0) {
883       LOG(WARN) << "getnameinfo() failed: " << gai_strerror(rv);
884       continue;
885     }
886 
887 #  ifdef SOCK_NONBLOCK
888     fd = socket(rp->ai_family, rp->ai_socktype | SOCK_NONBLOCK | SOCK_CLOEXEC,
889                 rp->ai_protocol);
890     if (fd == -1) {
891       auto error = errno;
892       LOG(WARN) << "socket() syscall failed: "
893                 << xsi_strerror(error, errbuf.data(), errbuf.size());
894       continue;
895     }
896 #  else  // !SOCK_NONBLOCK
897     fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
898     if (fd == -1) {
899       auto error = errno;
900       LOG(WARN) << "socket() syscall failed: "
901                 << xsi_strerror(error, errbuf.data(), errbuf.size());
902       continue;
903     }
904     util::make_socket_nonblocking(fd);
905     util::make_socket_closeonexec(fd);
906 #  endif // !SOCK_NONBLOCK
907 
908     int val = 1;
909     if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val,
910                    static_cast<socklen_t>(sizeof(val))) == -1) {
911       auto error = errno;
912       LOG(WARN) << "Failed to set SO_REUSEADDR option to listener socket: "
913                 << xsi_strerror(error, errbuf.data(), errbuf.size());
914       close(fd);
915       continue;
916     }
917 
918     if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &val,
919                    static_cast<socklen_t>(sizeof(val))) == -1) {
920       auto error = errno;
921       LOG(WARN) << "Failed to set SO_REUSEPORT option to listener socket: "
922                 << xsi_strerror(error, errbuf.data(), errbuf.size());
923       close(fd);
924       continue;
925     }
926 
927     if (faddr.family == AF_INET6) {
928 #  ifdef IPV6_V6ONLY
929       if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &val,
930                      static_cast<socklen_t>(sizeof(val))) == -1) {
931         auto error = errno;
932         LOG(WARN) << "Failed to set IPV6_V6ONLY option to listener socket: "
933                   << xsi_strerror(error, errbuf.data(), errbuf.size());
934         close(fd);
935         continue;
936       }
937 #  endif // IPV6_V6ONLY
938 
939       if (setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &val,
940                      static_cast<socklen_t>(sizeof(val))) == -1) {
941         auto error = errno;
942         LOG(WARN)
943             << "Failed to set IPV6_RECVPKTINFO option to listener socket: "
944             << xsi_strerror(error, errbuf.data(), errbuf.size());
945         close(fd);
946         continue;
947       }
948 
949       if (setsockopt(fd, IPPROTO_IPV6, IPV6_RECVTCLASS, &val,
950                      static_cast<socklen_t>(sizeof(val))) == -1) {
951         auto error = errno;
952         LOG(WARN) << "Failed to set IPV6_RECVTCLASS option to listener socket: "
953                   << xsi_strerror(error, errbuf.data(), errbuf.size());
954         close(fd);
955         continue;
956       }
957 
958 #  if defined(IPV6_MTU_DISCOVER) && defined(IPV6_PMTUDISC_DO)
959       int mtu_disc = IPV6_PMTUDISC_DO;
960       if (setsockopt(fd, IPPROTO_IPV6, IPV6_MTU_DISCOVER, &mtu_disc,
961                      static_cast<socklen_t>(sizeof(mtu_disc))) == -1) {
962         auto error = errno;
963         LOG(WARN)
964             << "Failed to set IPV6_MTU_DISCOVER option to listener socket: "
965             << xsi_strerror(error, errbuf.data(), errbuf.size());
966         close(fd);
967         continue;
968       }
969 #  endif // defined(IPV6_MTU_DISCOVER) && defined(IP_PMTUDISC_DO)
970     } else {
971       if (setsockopt(fd, IPPROTO_IP, IP_PKTINFO, &val,
972                      static_cast<socklen_t>(sizeof(val))) == -1) {
973         auto error = errno;
974         LOG(WARN) << "Failed to set IP_PKTINFO option to listener socket: "
975                   << xsi_strerror(error, errbuf.data(), errbuf.size());
976         close(fd);
977         continue;
978       }
979 
980       if (setsockopt(fd, IPPROTO_IP, IP_RECVTOS, &val,
981                      static_cast<socklen_t>(sizeof(val))) == -1) {
982         auto error = errno;
983         LOG(WARN) << "Failed to set IP_RECVTOS option to listener socket: "
984                   << xsi_strerror(error, errbuf.data(), errbuf.size());
985         close(fd);
986         continue;
987       }
988 
989 #  if defined(IP_MTU_DISCOVER) && defined(IP_PMTUDISC_DO)
990       int mtu_disc = IP_PMTUDISC_DO;
991       if (setsockopt(fd, IPPROTO_IP, IP_MTU_DISCOVER, &mtu_disc,
992                      static_cast<socklen_t>(sizeof(mtu_disc))) == -1) {
993         auto error = errno;
994         LOG(WARN) << "Failed to set IP_MTU_DISCOVER option to listener socket: "
995                   << xsi_strerror(error, errbuf.data(), errbuf.size());
996         close(fd);
997         continue;
998       }
999 #  endif // defined(IP_MTU_DISCOVER) && defined(IP_PMTUDISC_DO)
1000     }
1001 
1002 #  ifdef UDP_GRO
1003     if (setsockopt(fd, IPPROTO_UDP, UDP_GRO, &val, sizeof(val)) == -1) {
1004       auto error = errno;
1005       LOG(WARN) << "Failed to set UDP_GRO option to listener socket: "
1006                 << xsi_strerror(error, errbuf.data(), errbuf.size());
1007       close(fd);
1008       continue;
1009     }
1010 #  endif // UDP_GRO
1011 
1012     if (bind(fd, rp->ai_addr, rp->ai_addrlen) == -1) {
1013       auto error = errno;
1014       LOG(WARN) << "bind() syscall failed: "
1015                 << xsi_strerror(error, errbuf.data(), errbuf.size());
1016       close(fd);
1017       continue;
1018     }
1019 
1020 #  ifdef HAVE_LIBBPF
1021     auto config = get_config();
1022 
1023     auto &quic_bpf_refs = conn_handler_->get_quic_bpf_refs();
1024 
1025     if (should_attach_bpf()) {
1026       auto &bpfconf = config->quic.bpf;
1027 
1028       auto obj = bpf_object__open_file(bpfconf.prog_file.data(), nullptr);
1029       if (!obj) {
1030         auto error = errno;
1031         LOG(FATAL) << "Failed to open bpf object file: "
1032                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1033         close(fd);
1034         return -1;
1035       }
1036 
1037       rv = bpf_object__load(obj);
1038       if (rv != 0) {
1039         auto error = errno;
1040         LOG(FATAL) << "Failed to load bpf object file: "
1041                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1042         close(fd);
1043         return -1;
1044       }
1045 
1046       auto prog = bpf_object__find_program_by_name(obj, "select_reuseport");
1047       if (!prog) {
1048         auto error = errno;
1049         LOG(FATAL) << "Failed to find sk_reuseport program: "
1050                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1051         close(fd);
1052         return -1;
1053       }
1054 
1055       auto &ref = quic_bpf_refs[faddr.index];
1056 
1057       ref.obj = obj;
1058 
1059       ref.reuseport_array =
1060           bpf_object__find_map_by_name(obj, "reuseport_array");
1061       if (!ref.reuseport_array) {
1062         auto error = errno;
1063         LOG(FATAL) << "Failed to get reuseport_array: "
1064                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1065         close(fd);
1066         return -1;
1067       }
1068 
1069       ref.worker_id_map = bpf_object__find_map_by_name(obj, "worker_id_map");
1070       if (!ref.worker_id_map) {
1071         auto error = errno;
1072         LOG(FATAL) << "Failed to get worker_id_map: "
1073                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1074         close(fd);
1075         return -1;
1076       }
1077 
1078       auto sk_info = bpf_object__find_map_by_name(obj, "sk_info");
1079       if (!sk_info) {
1080         auto error = errno;
1081         LOG(FATAL) << "Failed to get sk_info: "
1082                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1083         close(fd);
1084         return -1;
1085       }
1086 
1087       constexpr uint32_t zero = 0;
1088       uint64_t num_socks = config->num_worker;
1089 
1090       rv = bpf_map__update_elem(sk_info, &zero, sizeof(zero), &num_socks,
1091                                 sizeof(num_socks), BPF_ANY);
1092       if (rv != 0) {
1093         auto error = errno;
1094         LOG(FATAL) << "Failed to update sk_info: "
1095                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1096         close(fd);
1097         return -1;
1098       }
1099 
1100       auto &qkms = conn_handler_->get_quic_keying_materials();
1101       auto &qkm = qkms->keying_materials.front();
1102 
1103       auto aes_key = bpf_object__find_map_by_name(obj, "aes_key");
1104       if (!aes_key) {
1105         auto error = errno;
1106         LOG(FATAL) << "Failed to get aes_key: "
1107                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1108         close(fd);
1109         return -1;
1110       }
1111 
1112       constexpr size_t expanded_aes_keylen = 176;
1113       std::array<uint8_t, expanded_aes_keylen> aes_exp_key;
1114 
1115       KeyExpansion(aes_exp_key.data(), qkm.cid_encryption_key.data());
1116 
1117       rv =
1118           bpf_map__update_elem(aes_key, &zero, sizeof(zero), aes_exp_key.data(),
1119                                aes_exp_key.size(), BPF_ANY);
1120       if (rv != 0) {
1121         auto error = errno;
1122         LOG(FATAL) << "Failed to update aes_key: "
1123                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1124         close(fd);
1125         return -1;
1126       }
1127 
1128       auto prog_fd = bpf_program__fd(prog);
1129 
1130       if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &prog_fd,
1131                      static_cast<socklen_t>(sizeof(prog_fd))) == -1) {
1132         LOG(FATAL) << "Failed to attach bpf program: "
1133                    << xsi_strerror(errno, errbuf.data(), errbuf.size());
1134         close(fd);
1135         return -1;
1136       }
1137     }
1138 
1139     if (should_update_bpf_map()) {
1140       const auto &ref = quic_bpf_refs[faddr.index];
1141       auto sk_index = compute_sk_index();
1142 
1143       rv = bpf_map__update_elem(ref.reuseport_array, &sk_index,
1144                                 sizeof(sk_index), &fd, sizeof(fd), BPF_NOEXIST);
1145       if (rv != 0) {
1146         auto error = errno;
1147         LOG(FATAL) << "Failed to update reuseport_array: "
1148                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1149         close(fd);
1150         return -1;
1151       }
1152 
1153       rv = bpf_map__update_elem(ref.worker_id_map, &worker_id_,
1154                                 sizeof(worker_id_), &sk_index, sizeof(sk_index),
1155                                 BPF_NOEXIST);
1156       if (rv != 0) {
1157         auto error = errno;
1158         LOG(FATAL) << "Failed to update worker_id_map: "
1159                    << xsi_strerror(error, errbuf.data(), errbuf.size());
1160         close(fd);
1161         return -1;
1162       }
1163     }
1164 #  endif // HAVE_LIBBPF
1165 
1166     break;
1167   }
1168 
1169   if (!rp) {
1170     LOG(FATAL) << "Listening " << (faddr.family == AF_INET ? "IPv4" : "IPv6")
1171                << " socket failed";
1172 
1173     return -1;
1174   }
1175 
1176   faddr.fd = fd;
1177   faddr.hostport = util::make_http_hostport(mod_config()->balloc,
1178                                             StringRef{host.data()}, faddr.port);
1179 
1180   LOG(NOTICE) << "Listening on " << faddr.hostport << ", quic";
1181 
1182   return 0;
1183 }
1184 
get_worker_id() const1185 const WorkerID &Worker::get_worker_id() const { return worker_id_; }
1186 
find_quic_upstream_addr(const Address & local_addr)1187 const UpstreamAddr *Worker::find_quic_upstream_addr(const Address &local_addr) {
1188   std::array<char, NI_MAXHOST> host;
1189 
1190   auto rv = getnameinfo(&local_addr.su.sa, local_addr.len, host.data(),
1191                         host.size(), nullptr, 0, NI_NUMERICHOST);
1192   if (rv != 0) {
1193     LOG(ERROR) << "getnameinfo: " << gai_strerror(rv);
1194 
1195     return nullptr;
1196   }
1197 
1198   uint16_t port;
1199 
1200   switch (local_addr.su.sa.sa_family) {
1201   case AF_INET:
1202     port = htons(local_addr.su.in.sin_port);
1203 
1204     break;
1205   case AF_INET6:
1206     port = htons(local_addr.su.in6.sin6_port);
1207 
1208     break;
1209   default:
1210     assert(0);
1211     abort();
1212   }
1213 
1214   std::array<char, util::max_hostport> hostport_buf;
1215 
1216   auto hostport = util::make_http_hostport(std::begin(hostport_buf),
1217                                            StringRef{host.data()}, port);
1218   const UpstreamAddr *fallback_faddr = nullptr;
1219 
1220   for (auto &faddr : quic_upstream_addrs_) {
1221     if (faddr.hostport == hostport) {
1222       return &faddr;
1223     }
1224 
1225     if (faddr.port != port || faddr.family != local_addr.su.sa.sa_family) {
1226       continue;
1227     }
1228 
1229     if (faddr.port == 443 || faddr.port == 80) {
1230       switch (faddr.family) {
1231       case AF_INET:
1232         if (faddr.hostport == "0.0.0.0"_sr) {
1233           fallback_faddr = &faddr;
1234         }
1235 
1236         break;
1237       case AF_INET6:
1238         if (faddr.hostport == "[::]"_sr) {
1239           fallback_faddr = &faddr;
1240         }
1241 
1242         break;
1243       default:
1244         assert(0);
1245       }
1246     } else {
1247       switch (faddr.family) {
1248       case AF_INET:
1249         if (util::starts_with(faddr.hostport, "0.0.0.0:"_sr)) {
1250           fallback_faddr = &faddr;
1251         }
1252 
1253         break;
1254       case AF_INET6:
1255         if (util::starts_with(faddr.hostport, "[::]:"_sr)) {
1256           fallback_faddr = &faddr;
1257         }
1258 
1259         break;
1260       default:
1261         assert(0);
1262       }
1263     }
1264   }
1265 
1266   return fallback_faddr;
1267 }
1268 #endif // ENABLE_HTTP3
1269 
1270 namespace {
match_downstream_addr_group_host(const RouterConfig & routerconf,const StringRef & host,const StringRef & path,const std::vector<std::shared_ptr<DownstreamAddrGroup>> & groups,size_t catch_all,BlockAllocator & balloc)1271 size_t match_downstream_addr_group_host(
1272     const RouterConfig &routerconf, const StringRef &host,
1273     const StringRef &path,
1274     const std::vector<std::shared_ptr<DownstreamAddrGroup>> &groups,
1275     size_t catch_all, BlockAllocator &balloc) {
1276 
1277   const auto &router = routerconf.router;
1278   const auto &rev_wildcard_router = routerconf.rev_wildcard_router;
1279   const auto &wildcard_patterns = routerconf.wildcard_patterns;
1280 
1281   if (LOG_ENABLED(INFO)) {
1282     LOG(INFO) << "Perform mapping selection, using host=" << host
1283               << ", path=" << path;
1284   }
1285 
1286   auto group = router.match(host, path);
1287   if (group != -1) {
1288     if (LOG_ENABLED(INFO)) {
1289       LOG(INFO) << "Found pattern with query " << host << path
1290                 << ", matched pattern=" << groups[group]->pattern;
1291     }
1292     return group;
1293   }
1294 
1295   if (!wildcard_patterns.empty() && !host.empty()) {
1296     auto rev_host_src = make_byte_ref(balloc, host.size() - 1);
1297     auto ep = std::copy(std::begin(host) + 1, std::end(host),
1298                         std::begin(rev_host_src));
1299     std::reverse(std::begin(rev_host_src), ep);
1300     auto rev_host = StringRef{std::span{std::begin(rev_host_src), ep}};
1301 
1302     ssize_t best_group = -1;
1303     const RNode *last_node = nullptr;
1304 
1305     for (;;) {
1306       size_t nread = 0;
1307       auto wcidx =
1308           rev_wildcard_router.match_prefix(&nread, &last_node, rev_host);
1309       if (wcidx == -1) {
1310         break;
1311       }
1312 
1313       rev_host = StringRef{std::begin(rev_host) + nread, std::end(rev_host)};
1314 
1315       auto &wc = wildcard_patterns[wcidx];
1316       auto group = wc.router.match(StringRef{}, path);
1317       if (group != -1) {
1318         // We sorted wildcard_patterns in a way that first match is the
1319         // longest host pattern.
1320         if (LOG_ENABLED(INFO)) {
1321           LOG(INFO) << "Found wildcard pattern with query " << host << path
1322                     << ", matched pattern=" << groups[group]->pattern;
1323         }
1324 
1325         best_group = group;
1326       }
1327     }
1328 
1329     if (best_group != -1) {
1330       return best_group;
1331     }
1332   }
1333 
1334   group = router.match(""_sr, path);
1335   if (group != -1) {
1336     if (LOG_ENABLED(INFO)) {
1337       LOG(INFO) << "Found pattern with query " << path
1338                 << ", matched pattern=" << groups[group]->pattern;
1339     }
1340     return group;
1341   }
1342 
1343   if (LOG_ENABLED(INFO)) {
1344     LOG(INFO) << "None match.  Use catch-all pattern";
1345   }
1346   return catch_all;
1347 }
1348 } // namespace
1349 
match_downstream_addr_group(const RouterConfig & routerconf,const StringRef & hostport,const StringRef & raw_path,const std::vector<std::shared_ptr<DownstreamAddrGroup>> & groups,size_t catch_all,BlockAllocator & balloc)1350 size_t match_downstream_addr_group(
1351     const RouterConfig &routerconf, const StringRef &hostport,
1352     const StringRef &raw_path,
1353     const std::vector<std::shared_ptr<DownstreamAddrGroup>> &groups,
1354     size_t catch_all, BlockAllocator &balloc) {
1355   if (std::find(std::begin(hostport), std::end(hostport), '/') !=
1356       std::end(hostport)) {
1357     // We use '/' specially, and if '/' is included in host, it breaks
1358     // our code.  Select catch-all case.
1359     return catch_all;
1360   }
1361 
1362   auto fragment = std::find(std::begin(raw_path), std::end(raw_path), '#');
1363   auto query = std::find(std::begin(raw_path), fragment, '?');
1364   auto path = StringRef{std::begin(raw_path), query};
1365 
1366   if (path.empty() || path[0] != '/') {
1367     path = "/"_sr;
1368   }
1369 
1370   if (hostport.empty()) {
1371     return match_downstream_addr_group_host(routerconf, hostport, path, groups,
1372                                             catch_all, balloc);
1373   }
1374 
1375   StringRef host;
1376   if (hostport[0] == '[') {
1377     // assume this is IPv6 numeric address
1378     auto p = std::find(std::begin(hostport), std::end(hostport), ']');
1379     if (p == std::end(hostport)) {
1380       return catch_all;
1381     }
1382     if (p + 1 < std::end(hostport) && *(p + 1) != ':') {
1383       return catch_all;
1384     }
1385     host = StringRef{std::begin(hostport), p + 1};
1386   } else {
1387     auto p = std::find(std::begin(hostport), std::end(hostport), ':');
1388     if (p == std::begin(hostport)) {
1389       return catch_all;
1390     }
1391     host = StringRef{std::begin(hostport), p};
1392   }
1393 
1394   if (std::find_if(std::begin(host), std::end(host), [](char c) {
1395         return 'A' <= c || c <= 'Z';
1396       }) != std::end(host)) {
1397     auto low_host = make_byte_ref(balloc, host.size() + 1);
1398     auto ep = std::copy(std::begin(host), std::end(host), std::begin(low_host));
1399     *ep = '\0';
1400     util::inp_strlower(std::begin(low_host), ep);
1401     host = StringRef{std::span{std::begin(low_host), ep}};
1402   }
1403   return match_downstream_addr_group_host(routerconf, host, path, groups,
1404                                           catch_all, balloc);
1405 }
1406 
downstream_failure(DownstreamAddr * addr,const Address * raddr)1407 void downstream_failure(DownstreamAddr *addr, const Address *raddr) {
1408   const auto &connect_blocker = addr->connect_blocker;
1409 
1410   if (connect_blocker->in_offline()) {
1411     return;
1412   }
1413 
1414   connect_blocker->on_failure();
1415 
1416   if (addr->fall == 0) {
1417     return;
1418   }
1419 
1420   auto fail_count = connect_blocker->get_fail_count();
1421 
1422   if (fail_count >= addr->fall) {
1423     if (raddr) {
1424       LOG(WARN) << "Could not connect to " << util::to_numeric_addr(raddr)
1425                 << " " << fail_count
1426                 << " times in a row; considered as offline";
1427     } else {
1428       LOG(WARN) << "Could not connect to " << addr->host << ":" << addr->port
1429                 << " " << fail_count
1430                 << " times in a row; considered as offline";
1431     }
1432 
1433     connect_blocker->offline();
1434 
1435     if (addr->rise) {
1436       addr->live_check->schedule();
1437     }
1438   }
1439 }
1440 
1441 } // namespace shrpx
1442