• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "inspector_socket.h"
2 #include "llhttp.h"
3 
4 #include "base64-inl.h"
5 #include "util-inl.h"
6 
7 #include "openssl/sha.h"  // Sha-1 hash
8 
9 #include <algorithm>
10 #include <cstring>
11 #include <map>
12 
13 #define ACCEPT_KEY_LENGTH base64_encoded_size(20)
14 
15 #define DUMP_READS 0
16 #define DUMP_WRITES 0
17 
18 namespace node {
19 namespace inspector {
20 
21 class TcpHolder {
22  public:
23   static void DisconnectAndDispose(TcpHolder* holder);
24   using Pointer = DeleteFnPtr<TcpHolder, DisconnectAndDispose>;
25 
26   static Pointer Accept(uv_stream_t* server,
27                         InspectorSocket::DelegatePointer delegate);
28   void SetHandler(ProtocolHandler* handler);
29   int WriteRaw(const std::vector<char>& buffer, uv_write_cb write_cb);
tcp()30   uv_tcp_t* tcp() {
31     return &tcp_;
32   }
33   InspectorSocket::Delegate* delegate();
34 
35  private:
From(void * handle)36   static TcpHolder* From(void* handle) {
37     return node::ContainerOf(&TcpHolder::tcp_,
38                              reinterpret_cast<uv_tcp_t*>(handle));
39   }
40   static void OnClosed(uv_handle_t* handle);
41   static void OnDataReceivedCb(uv_stream_t* stream, ssize_t nread,
42                                const uv_buf_t* buf);
43   explicit TcpHolder(InspectorSocket::DelegatePointer delegate);
44   ~TcpHolder() = default;
45   void ReclaimUvBuf(const uv_buf_t* buf, ssize_t read);
46 
47   uv_tcp_t tcp_;
48   const InspectorSocket::DelegatePointer delegate_;
49   ProtocolHandler* handler_;
50   std::vector<char> buffer;
51 };
52 
53 
54 class ProtocolHandler {
55  public:
56   ProtocolHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp);
57 
58   virtual void AcceptUpgrade(const std::string& accept_key) = 0;
59   virtual void OnData(std::vector<char>* data) = 0;
60   virtual void OnEof() = 0;
61   virtual void Write(const std::vector<char> data) = 0;
62   virtual void CancelHandshake() = 0;
63 
64   std::string GetHost() const;
65 
inspector()66   InspectorSocket* inspector() {
67     return inspector_;
68   }
69   virtual void Shutdown() = 0;
70 
71  protected:
72   virtual ~ProtocolHandler() = default;
73   int WriteRaw(const std::vector<char>& buffer, uv_write_cb write_cb);
74   InspectorSocket::Delegate* delegate();
75 
76   InspectorSocket* const inspector_;
77   TcpHolder::Pointer tcp_;
78 };
79 
80 namespace {
81 
82 #if DUMP_READS || DUMP_WRITES
dump_hex(const char * buf,size_t len)83 static void dump_hex(const char* buf, size_t len) {
84   const char* ptr = buf;
85   const char* end = ptr + len;
86   const char* cptr;
87   char c;
88   int i;
89 
90   while (ptr < end) {
91     cptr = ptr;
92     for (i = 0; i < 16 && ptr < end; i++) {
93       printf("%2.2X  ", static_cast<unsigned char>(*(ptr++)));
94     }
95     for (i = 72 - (i * 4); i > 0; i--) {
96       printf(" ");
97     }
98     for (i = 0; i < 16 && cptr < end; i++) {
99       c = *(cptr++);
100       printf("%c", (c > 0x19) ? c : '.');
101     }
102     printf("\n");
103   }
104   printf("\n\n");
105 }
106 #endif
107 
108 class WriteRequest {
109  public:
WriteRequest(ProtocolHandler * handler,const std::vector<char> & buffer)110   WriteRequest(ProtocolHandler* handler, const std::vector<char>& buffer)
111       : handler(handler)
112       , storage(buffer)
113       , req(uv_write_t())
114       , buf(uv_buf_init(storage.data(), storage.size())) {}
115 
from_write_req(uv_write_t * req)116   static WriteRequest* from_write_req(uv_write_t* req) {
117     return node::ContainerOf(&WriteRequest::req, req);
118   }
119 
Cleanup(uv_write_t * req,int status)120   static void Cleanup(uv_write_t* req, int status) {
121     delete WriteRequest::from_write_req(req);
122   }
123 
124   ProtocolHandler* const handler;
125   std::vector<char> storage;
126   uv_write_t req;
127   uv_buf_t buf;
128 };
129 
allocate_buffer(uv_handle_t * stream,size_t len,uv_buf_t * buf)130 void allocate_buffer(uv_handle_t* stream, size_t len, uv_buf_t* buf) {
131   *buf = uv_buf_init(new char[len], len);
132 }
133 
remove_from_beginning(std::vector<char> * buffer,size_t count)134 static void remove_from_beginning(std::vector<char>* buffer, size_t count) {
135   buffer->erase(buffer->begin(), buffer->begin() + count);
136 }
137 
138 static const char CLOSE_FRAME[] = {'\x88', '\x00'};
139 
140 enum ws_decode_result {
141   FRAME_OK, FRAME_INCOMPLETE, FRAME_CLOSE, FRAME_ERROR
142 };
143 
generate_accept_string(const std::string & client_key,char (* buffer)[ACCEPT_KEY_LENGTH])144 static void generate_accept_string(const std::string& client_key,
145                                    char (*buffer)[ACCEPT_KEY_LENGTH]) {
146   // Magic string from websockets spec.
147   static const char ws_magic[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
148   std::string input(client_key + ws_magic);
149   char hash[SHA_DIGEST_LENGTH];
150   SHA1(reinterpret_cast<const unsigned char*>(&input[0]), input.size(),
151        reinterpret_cast<unsigned char*>(hash));
152   node::base64_encode(hash, sizeof(hash), *buffer, sizeof(*buffer));
153 }
154 
TrimPort(const std::string & host)155 static std::string TrimPort(const std::string& host) {
156   size_t last_colon_pos = host.rfind(':');
157   if (last_colon_pos == std::string::npos)
158     return host;
159   size_t bracket = host.rfind(']');
160   if (bracket == std::string::npos || last_colon_pos > bracket)
161     return host.substr(0, last_colon_pos);
162   return host;
163 }
164 
IsIPAddress(const std::string & host)165 static bool IsIPAddress(const std::string& host) {
166   // TODO(tniessen): add CVEs to the following bullet points
167   // To avoid DNS rebinding attacks, we are aware of the following requirements:
168   // * the host name must be an IP address,
169   // * the IP address must be routable, and
170   // * the IP address must be formatted unambiguously.
171 
172   // The logic below assumes that the string is null-terminated, so ensure that
173   // we did not somehow end up with null characters within the string.
174   if (host.find('\0') != std::string::npos) return false;
175 
176   // All IPv6 addresses must be enclosed in square brackets, and anything
177   // enclosed in square brackets must be an IPv6 address.
178   if (host.length() >= 4 && host.front() == '[' && host.back() == ']') {
179     // INET6_ADDRSTRLEN is the maximum length of the dual format (including the
180     // terminating null character), which is the longest possible representation
181     // of an IPv6 address: xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:ddd.ddd.ddd.ddd
182     if (host.length() - 2 >= INET6_ADDRSTRLEN) return false;
183 
184     // Annoyingly, libuv's implementation of inet_pton() deviates from other
185     // implementations of the function in that it allows '%' in IPv6 addresses.
186     if (host.find('%') != std::string::npos) return false;
187 
188     // Parse the IPv6 address to ensure it is syntactically valid.
189     char ipv6_str[INET6_ADDRSTRLEN];
190     std::copy(host.begin() + 1, host.end() - 1, ipv6_str);
191     ipv6_str[host.length()] = '\0';
192     unsigned char ipv6[sizeof(struct in6_addr)];
193     if (uv_inet_pton(AF_INET6, ipv6_str, ipv6) != 0) return false;
194 
195     // The only non-routable IPv6 address is ::/128. It should not be necessary
196     // to explicitly reject it because it will still be enclosed in square
197     // brackets and not even macOS should make DNS requests in that case, but
198     // history has taught us that we cannot be careful enough.
199     // Note that RFC 4291 defines both "IPv4-Compatible IPv6 Addresses" and
200     // "IPv4-Mapped IPv6 Addresses", which means that there are IPv6 addresses
201     // (other than ::/128) that represent non-routable IPv4 addresses. However,
202     // this translation assumes that the host is interpreted as an IPv6 address
203     // in the first place, at which point DNS rebinding should not be an issue.
204     if (std::all_of(ipv6, ipv6 + sizeof(ipv6), [](auto b) { return b == 0; })) {
205       return false;
206     }
207 
208     // It is a syntactically valid and routable IPv6 address enclosed in square
209     // brackets. No client should be able to misinterpret this.
210     return true;
211   }
212 
213   // Anything not enclosed in square brackets must be an IPv4 address. It is
214   // important here that inet_pton() accepts only the so-called dotted-decimal
215   // notation, which is a strict subset of the so-called numbers-and-dots
216   // notation that is allowed by inet_aton() and inet_addr(). This subset does
217   // not allow hexadecimal or octal number formats.
218   unsigned char ipv4[sizeof(struct in_addr)];
219   if (uv_inet_pton(AF_INET, host.c_str(), ipv4) != 0) return false;
220 
221   // The only strictly non-routable IPv4 address is 0.0.0.0, and macOS will make
222   // DNS requests for this IP address, so we need to explicitly reject it. In
223   // fact, we can safely reject all of 0.0.0.0/8 (see Section 3.2 of RFC 791 and
224   // Section 3.2.1.3 of RFC 1122).
225   // Note that inet_pton() stores the IPv4 address in network byte order.
226   if (ipv4[0] == 0) return false;
227 
228   // It is a routable IPv4 address in dotted-decimal notation.
229   return true;
230 }
231 
232 // Constants for hybi-10 frame format.
233 
234 typedef int OpCode;
235 
236 const OpCode kOpCodeContinuation = 0x0;
237 const OpCode kOpCodeText = 0x1;
238 const OpCode kOpCodeBinary = 0x2;
239 const OpCode kOpCodeClose = 0x8;
240 const OpCode kOpCodePing = 0x9;
241 const OpCode kOpCodePong = 0xA;
242 
243 const unsigned char kFinalBit = 0x80;
244 const unsigned char kReserved1Bit = 0x40;
245 const unsigned char kReserved2Bit = 0x20;
246 const unsigned char kReserved3Bit = 0x10;
247 const unsigned char kOpCodeMask = 0xF;
248 const unsigned char kMaskBit = 0x80;
249 const unsigned char kPayloadLengthMask = 0x7F;
250 
251 const size_t kMaxSingleBytePayloadLength = 125;
252 const size_t kTwoBytePayloadLengthField = 126;
253 const size_t kEightBytePayloadLengthField = 127;
254 const size_t kMaskingKeyWidthInBytes = 4;
255 
encode_frame_hybi17(const std::vector<char> & message)256 static std::vector<char> encode_frame_hybi17(const std::vector<char>& message) {
257   std::vector<char> frame;
258   OpCode op_code = kOpCodeText;
259   frame.push_back(kFinalBit | op_code);
260   const size_t data_length = message.size();
261   if (data_length <= kMaxSingleBytePayloadLength) {
262     frame.push_back(static_cast<char>(data_length));
263   } else if (data_length <= 0xFFFF) {
264     frame.push_back(kTwoBytePayloadLengthField);
265     frame.push_back((data_length & 0xFF00) >> 8);
266     frame.push_back(data_length & 0xFF);
267   } else {
268     frame.push_back(kEightBytePayloadLengthField);
269     char extended_payload_length[8];
270     size_t remaining = data_length;
271     // Fill the length into extended_payload_length in the network byte order.
272     for (int i = 0; i < 8; ++i) {
273       extended_payload_length[7 - i] = remaining & 0xFF;
274       remaining >>= 8;
275     }
276     frame.insert(frame.end(), extended_payload_length,
277                  extended_payload_length + 8);
278     CHECK_EQ(0, remaining);
279   }
280   frame.insert(frame.end(), message.begin(), message.end());
281   return frame;
282 }
283 
decode_frame_hybi17(const std::vector<char> & buffer,bool client_frame,int * bytes_consumed,std::vector<char> * output,bool * compressed)284 static ws_decode_result decode_frame_hybi17(const std::vector<char>& buffer,
285                                             bool client_frame,
286                                             int* bytes_consumed,
287                                             std::vector<char>* output,
288                                             bool* compressed) {
289   *bytes_consumed = 0;
290   if (buffer.size() < 2)
291     return FRAME_INCOMPLETE;
292 
293   auto it = buffer.begin();
294 
295   unsigned char first_byte = *it++;
296   unsigned char second_byte = *it++;
297 
298   bool final = (first_byte & kFinalBit) != 0;
299   bool reserved1 = (first_byte & kReserved1Bit) != 0;
300   bool reserved2 = (first_byte & kReserved2Bit) != 0;
301   bool reserved3 = (first_byte & kReserved3Bit) != 0;
302   int op_code = first_byte & kOpCodeMask;
303   bool masked = (second_byte & kMaskBit) != 0;
304   *compressed = reserved1;
305   if (!final || reserved2 || reserved3)
306     return FRAME_ERROR;  // Only compression extension is supported.
307 
308   bool closed = false;
309   switch (op_code) {
310     case kOpCodeClose:
311       closed = true;
312       break;
313     case kOpCodeText:
314       break;
315     case kOpCodeBinary:        // We don't support binary frames yet.
316     case kOpCodeContinuation:  // We don't support binary frames yet.
317     case kOpCodePing:          // We don't support binary frames yet.
318     case kOpCodePong:          // We don't support binary frames yet.
319     default:
320       return FRAME_ERROR;
321   }
322 
323   // In Hybi-17 spec client MUST mask its frame.
324   if (client_frame && !masked) {
325     return FRAME_ERROR;
326   }
327 
328   uint64_t payload_length64 = second_byte & kPayloadLengthMask;
329   if (payload_length64 > kMaxSingleBytePayloadLength) {
330     int extended_payload_length_size;
331     if (payload_length64 == kTwoBytePayloadLengthField) {
332       extended_payload_length_size = 2;
333     } else if (payload_length64 == kEightBytePayloadLengthField) {
334       extended_payload_length_size = 8;
335     } else {
336       return FRAME_ERROR;
337     }
338     if ((buffer.end() - it) < extended_payload_length_size)
339       return FRAME_INCOMPLETE;
340     payload_length64 = 0;
341     for (int i = 0; i < extended_payload_length_size; ++i) {
342       payload_length64 <<= 8;
343       payload_length64 |= static_cast<unsigned char>(*it++);
344     }
345   }
346 
347   static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
348   static const size_t max_length = SIZE_MAX;
349   if (payload_length64 > max_payload_length ||
350       payload_length64 > max_length - kMaskingKeyWidthInBytes) {
351     // WebSocket frame length too large.
352     return FRAME_ERROR;
353   }
354   size_t payload_length = static_cast<size_t>(payload_length64);
355 
356   if (buffer.size() - kMaskingKeyWidthInBytes < payload_length)
357     return FRAME_INCOMPLETE;
358 
359   std::vector<char>::const_iterator masking_key = it;
360   std::vector<char>::const_iterator payload = it + kMaskingKeyWidthInBytes;
361   for (size_t i = 0; i < payload_length; ++i)  // Unmask the payload.
362     output->insert(output->end(),
363                    payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]);
364 
365   size_t pos = it + kMaskingKeyWidthInBytes + payload_length - buffer.begin();
366   *bytes_consumed = pos;
367   return closed ? FRAME_CLOSE : FRAME_OK;
368 }
369 
370 // WS protocol
371 class WsHandler : public ProtocolHandler {
372  public:
WsHandler(InspectorSocket * inspector,TcpHolder::Pointer tcp)373   WsHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp)
374             : ProtocolHandler(inspector, std::move(tcp)),
375               OnCloseSent(&WsHandler::WaitForCloseReply),
376               OnCloseReceived(&WsHandler::CloseFrameReceived),
377               dispose_(false) { }
378 
AcceptUpgrade(const std::string & accept_key)379   void AcceptUpgrade(const std::string& accept_key) override { }
CancelHandshake()380   void CancelHandshake() override {}
381 
OnEof()382   void OnEof() override {
383     tcp_.reset();
384     if (dispose_)
385       delete this;
386   }
387 
OnData(std::vector<char> * data)388   void OnData(std::vector<char>* data) override {
389     // 1. Parse.
390     int processed = 0;
391     do {
392       processed = ParseWsFrames(*data);
393       // 2. Fix the data size & length
394       if (processed > 0) {
395         remove_from_beginning(data, processed);
396       }
397     } while (processed > 0 && !data->empty());
398   }
399 
Write(const std::vector<char> data)400   void Write(const std::vector<char> data) override {
401     std::vector<char> output = encode_frame_hybi17(data);
402     WriteRaw(output, WriteRequest::Cleanup);
403   }
404 
405  protected:
Shutdown()406   void Shutdown() override {
407     if (tcp_) {
408       dispose_ = true;
409       SendClose();
410     } else {
411       delete this;
412     }
413   }
414 
415  private:
416   using Callback = void (WsHandler::*)();
417 
OnCloseFrameWritten(uv_write_t * req,int status)418   static void OnCloseFrameWritten(uv_write_t* req, int status) {
419     WriteRequest* wr = WriteRequest::from_write_req(req);
420     WsHandler* handler = static_cast<WsHandler*>(wr->handler);
421     delete wr;
422     Callback cb = handler->OnCloseSent;
423     (handler->*cb)();
424   }
425 
WaitForCloseReply()426   void WaitForCloseReply() {
427     OnCloseReceived = &WsHandler::OnEof;
428   }
429 
SendClose()430   void SendClose() {
431     WriteRaw(std::vector<char>(CLOSE_FRAME, CLOSE_FRAME + sizeof(CLOSE_FRAME)),
432              OnCloseFrameWritten);
433   }
434 
CloseFrameReceived()435   void CloseFrameReceived() {
436     OnCloseSent = &WsHandler::OnEof;
437     SendClose();
438   }
439 
ParseWsFrames(const std::vector<char> & buffer)440   int ParseWsFrames(const std::vector<char>& buffer) {
441     int bytes_consumed = 0;
442     std::vector<char> output;
443     bool compressed = false;
444 
445     ws_decode_result r =  decode_frame_hybi17(buffer,
446                                               true /* client_frame */,
447                                               &bytes_consumed, &output,
448                                               &compressed);
449     // Compressed frame means client is ignoring the headers and misbehaves
450     if (compressed || r == FRAME_ERROR) {
451       OnEof();
452       bytes_consumed = 0;
453     } else if (r == FRAME_CLOSE) {
454       (this->*OnCloseReceived)();
455       bytes_consumed = 0;
456     } else if (r == FRAME_OK) {
457       delegate()->OnWsFrame(output);
458     }
459     return bytes_consumed;
460   }
461 
462 
463   Callback OnCloseSent;
464   Callback OnCloseReceived;
465   bool dispose_;
466 };
467 
468 // HTTP protocol
469 class HttpEvent {
470  public:
HttpEvent(const std::string & path,bool upgrade,bool isGET,const std::string & ws_key,const std::string & host)471   HttpEvent(const std::string& path, bool upgrade, bool isGET,
472             const std::string& ws_key, const std::string& host)
473             : path(path), upgrade(upgrade), isGET(isGET), ws_key(ws_key),
474               host(host) { }
475 
476   std::string path;
477   bool upgrade;
478   bool isGET;
479   std::string ws_key;
480   std::string host;
481 };
482 
483 class HttpHandler : public ProtocolHandler {
484  public:
HttpHandler(InspectorSocket * inspector,TcpHolder::Pointer tcp)485   explicit HttpHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp)
486                        : ProtocolHandler(inspector, std::move(tcp)),
487                          parsing_value_(false) {
488     llhttp_init(&parser_, HTTP_REQUEST, &parser_settings);
489     llhttp_settings_init(&parser_settings);
490     parser_settings.on_header_field = OnHeaderField;
491     parser_settings.on_header_value = OnHeaderValue;
492     parser_settings.on_message_complete = OnMessageComplete;
493     parser_settings.on_url = OnPath;
494   }
495 
AcceptUpgrade(const std::string & accept_key)496   void AcceptUpgrade(const std::string& accept_key) override {
497     char accept_string[ACCEPT_KEY_LENGTH];
498     generate_accept_string(accept_key, &accept_string);
499     const char accept_ws_prefix[] = "HTTP/1.1 101 Switching Protocols\r\n"
500                                     "Upgrade: websocket\r\n"
501                                     "Connection: Upgrade\r\n"
502                                     "Sec-WebSocket-Accept: ";
503     const char accept_ws_suffix[] = "\r\n\r\n";
504     std::vector<char> reply(accept_ws_prefix,
505                             accept_ws_prefix + sizeof(accept_ws_prefix) - 1);
506     reply.insert(reply.end(), accept_string,
507                  accept_string + sizeof(accept_string));
508     reply.insert(reply.end(), accept_ws_suffix,
509                  accept_ws_suffix + sizeof(accept_ws_suffix) - 1);
510     if (WriteRaw(reply, WriteRequest::Cleanup) >= 0) {
511       inspector_->SwitchProtocol(new WsHandler(inspector_, std::move(tcp_)));
512     } else {
513       tcp_.reset();
514     }
515   }
516 
CancelHandshake()517   void CancelHandshake() override {
518     const char HANDSHAKE_FAILED_RESPONSE[] =
519         "HTTP/1.0 400 Bad Request\r\n"
520         "Content-Type: text/html; charset=UTF-8\r\n\r\n"
521         "WebSockets request was expected\r\n";
522     WriteRaw(std::vector<char>(HANDSHAKE_FAILED_RESPONSE,
523              HANDSHAKE_FAILED_RESPONSE + sizeof(HANDSHAKE_FAILED_RESPONSE) - 1),
524              ThenCloseAndReportFailure);
525   }
526 
527 
OnEof()528   void OnEof() override {
529     tcp_.reset();
530   }
531 
OnData(std::vector<char> * data)532   void OnData(std::vector<char>* data) override {
533     llhttp_errno_t err;
534     err = llhttp_execute(&parser_, data->data(), data->size());
535 
536     if (err == HPE_PAUSED_UPGRADE) {
537       err = HPE_OK;
538       llhttp_resume_after_upgrade(&parser_);
539     }
540     data->clear();
541     if (err != HPE_OK) {
542       CancelHandshake();
543     }
544     // Event handling may delete *this
545     std::vector<HttpEvent> events;
546     std::swap(events, events_);
547     for (const HttpEvent& event : events) {
548       if (!IsAllowedHost(event.host) || !event.isGET) {
549         CancelHandshake();
550         return;
551       } else if (!event.upgrade) {
552         delegate()->OnHttpGet(event.host, event.path);
553       } else if (event.ws_key.empty()) {
554         CancelHandshake();
555         return;
556       } else {
557         delegate()->OnSocketUpgrade(event.host, event.path, event.ws_key);
558       }
559     }
560   }
561 
Write(const std::vector<char> data)562   void Write(const std::vector<char> data) override {
563     WriteRaw(data, WriteRequest::Cleanup);
564   }
565 
566  protected:
Shutdown()567   void Shutdown() override {
568     delete this;
569   }
570 
571  private:
ThenCloseAndReportFailure(uv_write_t * req,int status)572   static void ThenCloseAndReportFailure(uv_write_t* req, int status) {
573     ProtocolHandler* handler = WriteRequest::from_write_req(req)->handler;
574     WriteRequest::Cleanup(req, status);
575     handler->inspector()->SwitchProtocol(nullptr);
576   }
577 
OnHeaderValue(llhttp_t * parser,const char * at,size_t length)578   static int OnHeaderValue(llhttp_t* parser, const char* at, size_t length) {
579     HttpHandler* handler = From(parser);
580     handler->parsing_value_ = true;
581     handler->headers_[handler->current_header_].append(at, length);
582     return 0;
583   }
584 
OnHeaderField(llhttp_t * parser,const char * at,size_t length)585   static int OnHeaderField(llhttp_t* parser, const char* at, size_t length) {
586     HttpHandler* handler = From(parser);
587     if (handler->parsing_value_) {
588       handler->parsing_value_ = false;
589       handler->current_header_.clear();
590     }
591     handler->current_header_.append(at, length);
592     return 0;
593   }
594 
OnPath(llhttp_t * parser,const char * at,size_t length)595   static int OnPath(llhttp_t* parser, const char* at, size_t length) {
596     HttpHandler* handler = From(parser);
597     handler->path_.append(at, length);
598     return 0;
599   }
600 
From(llhttp_t * parser)601   static HttpHandler* From(llhttp_t* parser) {
602     return node::ContainerOf(&HttpHandler::parser_, parser);
603   }
604 
OnMessageComplete(llhttp_t * parser)605   static int OnMessageComplete(llhttp_t* parser) {
606     // Event needs to be fired after the parser is done.
607     HttpHandler* handler = From(parser);
608     handler->events_.emplace_back(handler->path_,
609                                   parser->upgrade,
610                                   parser->method == HTTP_GET,
611                                   handler->HeaderValue("Sec-WebSocket-Key"),
612                                   handler->HeaderValue("Host"));
613     handler->path_ = "";
614     handler->parsing_value_ = false;
615     handler->headers_.clear();
616     handler->current_header_ = "";
617     return 0;
618   }
619 
HeaderValue(const std::string & header) const620   std::string HeaderValue(const std::string& header) const {
621     bool header_found = false;
622     std::string value;
623     for (const auto& header_value : headers_) {
624       if (node::StringEqualNoCaseN(header_value.first.data(), header.data(),
625                                    header.length())) {
626         if (header_found)
627           return "";
628         value = header_value.second;
629         header_found = true;
630       }
631     }
632     return value;
633   }
634 
IsAllowedHost(const std::string & host_with_port) const635   bool IsAllowedHost(const std::string& host_with_port) const {
636     std::string host = TrimPort(host_with_port);
637     return host.empty() || IsIPAddress(host)
638            || node::StringEqualNoCase(host.data(), "localhost");
639   }
640 
641   bool parsing_value_;
642   llhttp_t parser_;
643   llhttp_settings_t parser_settings;
644   std::vector<HttpEvent> events_;
645   std::string current_header_;
646   std::map<std::string, std::string> headers_;
647   std::string path_;
648 };
649 
650 }  // namespace
651 
652 // Any protocol
ProtocolHandler(InspectorSocket * inspector,TcpHolder::Pointer tcp)653 ProtocolHandler::ProtocolHandler(InspectorSocket* inspector,
654                                  TcpHolder::Pointer tcp)
655                                  : inspector_(inspector), tcp_(std::move(tcp)) {
656   CHECK_NOT_NULL(tcp_);
657   tcp_->SetHandler(this);
658 }
659 
WriteRaw(const std::vector<char> & buffer,uv_write_cb write_cb)660 int ProtocolHandler::WriteRaw(const std::vector<char>& buffer,
661                               uv_write_cb write_cb) {
662   return tcp_->WriteRaw(buffer, write_cb);
663 }
664 
delegate()665 InspectorSocket::Delegate* ProtocolHandler::delegate() {
666   return tcp_->delegate();
667 }
668 
GetHost() const669 std::string ProtocolHandler::GetHost() const {
670   char ip[INET6_ADDRSTRLEN];
671   sockaddr_storage addr;
672   int len = sizeof(addr);
673   int err = uv_tcp_getsockname(tcp_->tcp(),
674                                reinterpret_cast<struct sockaddr*>(&addr),
675                                &len);
676   if (err != 0)
677     return "";
678   if (addr.ss_family == AF_INET6) {
679     const sockaddr_in6* v6 = reinterpret_cast<const sockaddr_in6*>(&addr);
680     err = uv_ip6_name(v6, ip, sizeof(ip));
681   } else {
682     const sockaddr_in* v4 = reinterpret_cast<const sockaddr_in*>(&addr);
683     err = uv_ip4_name(v4, ip, sizeof(ip));
684   }
685   if (err != 0)
686     return "";
687   return ip;
688 }
689 
690 // RAII uv_tcp_t wrapper
TcpHolder(InspectorSocket::DelegatePointer delegate)691 TcpHolder::TcpHolder(InspectorSocket::DelegatePointer delegate)
692                      : tcp_(),
693                        delegate_(std::move(delegate)),
694                        handler_(nullptr) { }
695 
696 // static
Accept(uv_stream_t * server,InspectorSocket::DelegatePointer delegate)697 TcpHolder::Pointer TcpHolder::Accept(
698     uv_stream_t* server,
699     InspectorSocket::DelegatePointer delegate) {
700   TcpHolder* result = new TcpHolder(std::move(delegate));
701   uv_stream_t* tcp = reinterpret_cast<uv_stream_t*>(&result->tcp_);
702   int err = uv_tcp_init(server->loop, &result->tcp_);
703   if (err == 0) {
704     err = uv_accept(server, tcp);
705   }
706   if (err == 0) {
707     err = uv_read_start(tcp, allocate_buffer, OnDataReceivedCb);
708   }
709   if (err == 0) {
710     return TcpHolder::Pointer(result);
711   } else {
712     delete result;
713     return nullptr;
714   }
715 }
716 
SetHandler(ProtocolHandler * handler)717 void TcpHolder::SetHandler(ProtocolHandler* handler) {
718   handler_ = handler;
719 }
720 
WriteRaw(const std::vector<char> & buffer,uv_write_cb write_cb)721 int TcpHolder::WriteRaw(const std::vector<char>& buffer, uv_write_cb write_cb) {
722 #if DUMP_WRITES
723   printf("%s (%ld bytes):\n", __FUNCTION__, buffer.size());
724   dump_hex(buffer.data(), buffer.size());
725   printf("\n");
726 #endif
727 
728   // Freed in write_request_cleanup
729   WriteRequest* wr = new WriteRequest(handler_, buffer);
730   uv_stream_t* stream = reinterpret_cast<uv_stream_t*>(&tcp_);
731   int err = uv_write(&wr->req, stream, &wr->buf, 1, write_cb);
732   if (err < 0)
733     delete wr;
734   return err < 0;
735 }
736 
delegate()737 InspectorSocket::Delegate* TcpHolder::delegate() {
738   return delegate_.get();
739 }
740 
741 // static
OnClosed(uv_handle_t * handle)742 void TcpHolder::OnClosed(uv_handle_t* handle) {
743   delete From(handle);
744 }
745 
OnDataReceivedCb(uv_stream_t * tcp,ssize_t nread,const uv_buf_t * buf)746 void TcpHolder::OnDataReceivedCb(uv_stream_t* tcp, ssize_t nread,
747                                  const uv_buf_t* buf) {
748 #if DUMP_READS
749   if (nread >= 0) {
750     printf("%s (%ld bytes)\n", __FUNCTION__, nread);
751     dump_hex(buf->base, nread);
752   } else {
753     printf("[%s:%d] %s\n", __FUNCTION__, __LINE__, uv_err_name(nread));
754   }
755 #endif
756   TcpHolder* holder = From(tcp);
757   holder->ReclaimUvBuf(buf, nread);
758   if (nread < 0 || nread == UV_EOF) {
759     holder->handler_->OnEof();
760   } else {
761     holder->handler_->OnData(&holder->buffer);
762   }
763 }
764 
765 // static
DisconnectAndDispose(TcpHolder * holder)766 void TcpHolder::DisconnectAndDispose(TcpHolder* holder) {
767   uv_handle_t* handle = reinterpret_cast<uv_handle_t*>(&holder->tcp_);
768   uv_close(handle, OnClosed);
769 }
770 
ReclaimUvBuf(const uv_buf_t * buf,ssize_t read)771 void TcpHolder::ReclaimUvBuf(const uv_buf_t* buf, ssize_t read) {
772   if (read > 0) {
773     buffer.insert(buffer.end(), buf->base, buf->base + read);
774   }
775   delete[] buf->base;
776 }
777 
778 InspectorSocket::~InspectorSocket() = default;
779 
780 // static
Shutdown(ProtocolHandler * handler)781 void InspectorSocket::Shutdown(ProtocolHandler* handler) {
782   handler->Shutdown();
783 }
784 
785 // static
Accept(uv_stream_t * server,DelegatePointer delegate)786 InspectorSocket::Pointer InspectorSocket::Accept(uv_stream_t* server,
787                                                  DelegatePointer delegate) {
788   auto tcp = TcpHolder::Accept(server, std::move(delegate));
789   if (tcp) {
790     InspectorSocket* inspector = new InspectorSocket();
791     inspector->SwitchProtocol(new HttpHandler(inspector, std::move(tcp)));
792     return InspectorSocket::Pointer(inspector);
793   } else {
794     return InspectorSocket::Pointer(nullptr);
795   }
796 }
797 
AcceptUpgrade(const std::string & ws_key)798 void InspectorSocket::AcceptUpgrade(const std::string& ws_key) {
799   protocol_handler_->AcceptUpgrade(ws_key);
800 }
801 
CancelHandshake()802 void InspectorSocket::CancelHandshake() {
803   protocol_handler_->CancelHandshake();
804 }
805 
GetHost()806 std::string InspectorSocket::GetHost() {
807   return protocol_handler_->GetHost();
808 }
809 
SwitchProtocol(ProtocolHandler * handler)810 void InspectorSocket::SwitchProtocol(ProtocolHandler* handler) {
811   protocol_handler_.reset(std::move(handler));
812 }
813 
Write(const char * data,size_t len)814 void InspectorSocket::Write(const char* data, size_t len) {
815   protocol_handler_->Write(std::vector<char>(data, data + len));
816 }
817 
818 }  // namespace inspector
819 }  // namespace node
820