• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "inspector_socket.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <map>
21 
22 #include "inspector/inspector_utils.h"
23 #include "llhttp.h"
24 #include "openssl/sha.h" // Sha-1 hash
25 
26 #define ACCEPT_KEY_LENGTH Base64EncodeSize(20)
27 
28 namespace jsvm {
29 namespace inspector {
30 
31 class TcpHolder {
32 public:
33     static void DisconnectAndDispose(TcpHolder* holder);
34     using Pointer = DeleteFnPtr<TcpHolder, DisconnectAndDispose>;
35 
36     static Pointer Accept(uv_stream_t* server, InspectorSocket::DelegatePointer delegate);
37     void SetHandler(ProtocolHandler* handler);
38     int WriteRaw(const std::vector<char>& buffer, uv_write_cb writeCb);
GetTcp()39     uv_tcp_t* GetTcp()
40     {
41         return &tcp;
42     }
43     InspectorSocket::Delegate* GetDelegate();
44 
45 private:
From(void * handle)46     static TcpHolder* From(void* handle)
47     {
48         return jsvm::inspector::ContainerOf(&TcpHolder::tcp, reinterpret_cast<uv_tcp_t*>(handle));
49     }
50     static void OnClosed(uv_handle_t* handle);
51     static void OnDataReceivedCb(uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf);
52     explicit TcpHolder(InspectorSocket::DelegatePointer delegate);
53     ~TcpHolder() = default;
54     void ReclaimUvBuf(const uv_buf_t* buf, ssize_t read);
55 
56     uv_tcp_t tcp;
57     const InspectorSocket::DelegatePointer delegate;
58     ProtocolHandler* handler;
59     std::vector<char> buffer;
60 };
61 
62 class ProtocolHandler {
63 public:
64     ProtocolHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp);
65 
66     virtual void AcceptUpgrade(const std::string& acceptKey) = 0;
67     virtual void OnData(std::vector<char>* data) = 0;
68     virtual void OnEof() = 0;
69     virtual void Write(const std::vector<char> data) = 0;
70     virtual void CancelHandshake() = 0;
71 
72     std::string GetHost() const;
73 
GetInspectorSocket()74     InspectorSocket* GetInspectorSocket()
75     {
76         return inspector;
77     }
78     virtual void Shutdown() = 0;
79 
80 protected:
81     virtual ~ProtocolHandler() = default;
82     int WriteRaw(const std::vector<char>& buffer, uv_write_cb writeCb);
83     InspectorSocket::Delegate* GetDelegate();
84 
85     InspectorSocket* const inspector;
86     TcpHolder::Pointer tcp;
87 };
88 
89 namespace {
90 class WriteRequest {
91 public:
WriteRequest(ProtocolHandler * handler,const std::vector<char> & buffer)92     WriteRequest(ProtocolHandler* handler, const std::vector<char>& buffer)
93         : handler(handler), storage(buffer), req(uv_write_t()), buf(uv_buf_init(storage.data(), storage.size()))
94     {}
95 
FromWriteReq(uv_write_t * req)96     static WriteRequest* FromWriteReq(uv_write_t* req)
97     {
98         return jsvm::inspector::ContainerOf(&WriteRequest::req, req);
99     }
100 
Cleanup(uv_write_t * req,int status)101     static void Cleanup(uv_write_t* req, int status)
102     {
103         delete WriteRequest::FromWriteReq(req);
104     }
105 
106     ProtocolHandler* const handler;
107     std::vector<char> storage;
108     uv_write_t req;
109     uv_buf_t buf;
110 };
111 
AllocateBuffer(uv_handle_t * stream,size_t len,uv_buf_t * buf)112 void AllocateBuffer(uv_handle_t* stream, size_t len, uv_buf_t* buf)
113 {
114     CHECK(len > 0);
115     *buf = uv_buf_init(new char[len], len);
116 }
117 
RemoveFromBeginning(std::vector<char> * buffer,size_t count)118 static void RemoveFromBeginning(std::vector<char>* buffer, size_t count)
119 {
120     buffer->erase(buffer->begin(), buffer->begin() + count);
121 }
122 
123 static const char CLOSE_FRAME[] = { '\x88', '\x00' };
124 
125 enum WsDecodeResult { FRAME_OK, FRAME_INCOMPLETE, FRAME_CLOSE, FRAME_ERROR, FRAME_PING };
126 
GenerateAcceptString(const std::string & clientKey,char (* buffer)[ACCEPT_KEY_LENGTH])127 static void GenerateAcceptString(const std::string& clientKey, char (*buffer)[ACCEPT_KEY_LENGTH])
128 {
129     // Magic string from websockets spec.
130     static constexpr char wsMagic[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
131     std::string input(clientKey + wsMagic);
132     char hash[SHA_DIGEST_LENGTH];
133     USE(SHA1(reinterpret_cast<const unsigned char*>(input.data()), input.size(),
134              reinterpret_cast<unsigned char*>(hash)));
135     jsvm::inspector::Base64Encode(hash, sizeof(hash), *buffer, sizeof(*buffer));
136 }
137 
TrimPort(const std::string & host)138 static std::string TrimPort(const std::string& host)
139 {
140     size_t lastColonPos = host.rfind(':');
141     if (lastColonPos == std::string::npos) {
142         return host;
143     }
144     size_t bracket = host.rfind(']');
145     if (bracket == std::string::npos || lastColonPos > bracket) {
146         return host.substr(0, lastColonPos);
147     }
148     return host;
149 }
150 
IsIPAddress(const std::string & host)151 static bool IsIPAddress(const std::string& host)
152 {
153     // To avoid DNS rebinding attacks, we are aware of the following requirements:
154     // * the host name must be an IP address (CVE-2018-7160, CVE-2022-32212),
155     // * the IP address must be routable (hackerone.com/reports/1632921), and
156     // * the IP address must be formatted unambiguously (CVE-2022-43548).
157 
158     // The logic below assumes that the string is null-terminated, so ensure that
159     // we did not somehow end up with null characters within the string.
160     if (host.find('\0') != std::string::npos) {
161         return false;
162     }
163 
164     // All IPv6 addresses must be enclosed in square brackets, and anything
165     // enclosed in square brackets must be an IPv6 address.
166     if (host.length() >= ByteSize::SIZE_4_BYTES && host.front() == '[' && host.back() == ']') {
167         // INET6_ADDRSTRLEN is the maximum length of the dual format (including the
168         // terminating null character), which is the longest possible representation
169         // of an IPv6 address: xxxx:xxxx:xxxx:xxxx:xxxx:xxxx:ddd.ddd.ddd.ddd
170         if (host.length() - ByteSize::SIZE_2_BYTES >= INET6_ADDRSTRLEN) {
171             return false;
172         }
173 
174         // Annoyingly, libuv's implementation of inet_pton() deviates from other
175         // implementations of the function in that it allows '%' in IPv6 addresses.
176         if (host.find('%') != std::string::npos) {
177             return false;
178         }
179 
180         // Parse the IPv6 address to ensure it is syntactically valid.
181         char ipv6Str[INET6_ADDRSTRLEN];
182         std::copy(host.begin() + 1, host.end() - 1, ipv6Str);
183         ipv6Str[host.length()] = '\0';
184         unsigned char ipv6[sizeof(struct in6_addr)];
185         if (uv_inet_pton(AF_INET6, ipv6Str, ipv6) != 0) {
186             return false;
187         }
188 
189         // The only non-routable IPv6 address is ::/128. It should not be necessary
190         // to explicitly reject it because it will still be enclosed in square
191         // brackets and not even macOS should make DNS requests in that case, but
192         // history has taught us that we cannot be careful enough.
193         // Note that RFC 4291 defines both "IPv4-Compatible IPv6 Addresses" and
194         // "IPv4-Mapped IPv6 Addresses", which means that there are IPv6 addresses
195         // (other than ::/128) that represent non-routable IPv4 addresses. However,
196         // this translation assumes that the host is interpreted as an IPv6 address
197         // in the first place, at which point DNS rebinding should not be an issue.
198         if (std::all_of(ipv6, ipv6 + sizeof(ipv6), [](auto b) { return b == 0; })) {
199             return false;
200         }
201 
202         // It is a syntactically valid and routable IPv6 address enclosed in square
203         // brackets. No client should be able to misinterpret this.
204         return true;
205     }
206 
207     // Anything not enclosed in square brackets must be an IPv4 address. It is
208     // important here that inet_pton() accepts only the so-called dotted-decimal
209     // notation, which is a strict subset of the so-called numbers-and-dots
210     // notation that is allowed by inet_aton() and inet_addr(). This subset does
211     // not allow hexadecimal or octal number formats.
212     unsigned char ipv4[sizeof(struct in_addr)];
213     if (uv_inet_pton(AF_INET, host.c_str(), ipv4) != 0) {
214         return false;
215     }
216 
217     if (ipv4[0] == 0) {
218         return false;
219     }
220 
221     // It is a routable IPv4 address in dotted-decimal notation.
222     return true;
223 }
224 
225 // Constants for hybi-10 frame format.
226 
227 typedef int OpCode;
228 
229 const OpCode K_OP_CODE_CONTINUATION = 0x0;
230 const OpCode K_OP_CODE_TEXT = 0x1;
231 const OpCode K_OP_CODE_BINARY = 0x2;
232 const OpCode K_OP_CODE_CLOSE = 0x8;
233 const OpCode K_OP_CODE_PING = 0x9;
234 const OpCode K_OP_CODE_PONG = 0xA;
235 
236 const unsigned char K_FINAL_BIT = 0x80;
237 const unsigned char K_RESERVED_1_BIT = 0x40;
238 const unsigned char K_RESERVED_2_BIT = 0x20;
239 const unsigned char K_RESERVED_3_BIT = 0x10;
240 const unsigned char K_OP_CODE_MASK = 0xF;
241 const unsigned char K_MASK_BIT = 0x80;
242 const unsigned char K_PAYLOAD_LENGTH_MASK = 0x7F;
243 const unsigned char K_PONG_FRAME_HEADER = 0x8A;
244 
245 const size_t K_MAX_SINGLE_BYTE_PAYLOAD_LENGTH = 125;
246 const size_t K_TWO_BYTE_PAYLOAD_LENGTH_FIELD = 126;
247 const size_t K_EIGHT_BYTE_PAYLOAD_LENGTH_FIELD = 127;
248 const size_t K_MASKING_KEY_WIDTH_IN_BYTES = 4;
249 
encode_frame_hybi17(const std::vector<char> & message)250 static std::vector<char> encode_frame_hybi17(const std::vector<char>& message)
251 {
252     std::vector<char> frame;
253     OpCode opCode = K_OP_CODE_TEXT;
254     frame.push_back(K_FINAL_BIT | opCode);
255     const size_t dataLength = message.size();
256     if (dataLength <= K_MAX_SINGLE_BYTE_PAYLOAD_LENGTH) {
257         frame.push_back(static_cast<char>(dataLength));
258     } else if (dataLength <= 0xFFFF) {
259         frame.push_back(K_TWO_BYTE_PAYLOAD_LENGTH_FIELD);
260         frame.push_back((dataLength & 0xFF00) >> ByteOffset::BIT_8);
261         frame.push_back(dataLength & 0xFF);
262     } else {
263         frame.push_back(K_EIGHT_BYTE_PAYLOAD_LENGTH_FIELD);
264         constexpr size_t byteCount = 8;
265         char extendedPayloadLength[byteCount];
266         size_t remaining = dataLength;
267         // Fill the length into extendedPayloadLength in the network byte order.
268         for (int i = 0; i < byteCount; ++i) {
269             extendedPayloadLength[byteCount - 1 - i] = remaining & 0xFF;
270             remaining >>= byteCount;
271         }
272         frame.insert(frame.end(), extendedPayloadLength, extendedPayloadLength + byteCount);
273         CHECK_EQ(0, remaining);
274     }
275     frame.insert(frame.end(), message.begin(), message.end());
276     return frame;
277 }
278 
DecodeFrameHybi17(const std::vector<char> & buffer,bool clientFrame,int * bytesConsumed,std::vector<char> * output,bool * compressed)279 static WsDecodeResult DecodeFrameHybi17(const std::vector<char>& buffer,
280                                         bool clientFrame,
281                                         int* bytesConsumed,
282                                         std::vector<char>* output,
283                                         bool* compressed)
284 {
285     *bytesConsumed = 0;
286     if (buffer.size() < ByteSize::SIZE_2_BYTES) {
287         return FRAME_INCOMPLETE;
288     }
289 
290     auto it = buffer.begin();
291 
292     unsigned char firstByte = *it++;
293     unsigned char secondByte = *it++;
294 
295     bool final = (firstByte & K_FINAL_BIT) != 0;
296     bool reserved1 = (firstByte & K_RESERVED_1_BIT) != 0;
297     bool reserved2 = (firstByte & K_RESERVED_2_BIT) != 0;
298     bool reserved3 = (firstByte & K_RESERVED_3_BIT) != 0;
299     int opCode = firstByte & K_OP_CODE_MASK;
300     bool masked = (secondByte & K_MASK_BIT) != 0;
301     *compressed = reserved1;
302     if (!final || reserved2 || reserved3) {
303         return FRAME_ERROR; // Only compression extension is supported.
304     }
305 
306     bool closed = false;
307     uint64_t payloadLength64 = secondByte & K_PAYLOAD_LENGTH_MASK;
308     switch (opCode) {
309         case K_OP_CODE_CLOSE:
310             closed = true;
311             break;
312         case K_OP_CODE_TEXT:
313             break;
314         case K_OP_CODE_PING: {
315             output->push_back(K_PONG_FRAME_HEADER);
316             output->push_back(static_cast<char>(payloadLength64));
317             output->insert(output->end(), it, it + payloadLength64);
318             return FRAME_PING;
319         }
320         case K_OP_CODE_BINARY:       // We don't support binary frames yet.
321         case K_OP_CODE_CONTINUATION: // We don't support binary frames yet.
322         case K_OP_CODE_PONG:         // We don't support binary frames yet.
323         default:
324             return FRAME_ERROR;
325     }
326 
327     // In Hybi-17 spec client MUST mask its frame.
328     if (clientFrame && !masked) {
329         return FRAME_ERROR;
330     }
331 
332     if (payloadLength64 > K_MAX_SINGLE_BYTE_PAYLOAD_LENGTH) {
333         int extendedPayloadLengthSize;
334         if (payloadLength64 == K_TWO_BYTE_PAYLOAD_LENGTH_FIELD) {
335             extendedPayloadLengthSize = ByteSize::SIZE_2_BYTES;
336         } else if (payloadLength64 == K_EIGHT_BYTE_PAYLOAD_LENGTH_FIELD) {
337             extendedPayloadLengthSize = ByteSize::SIZE_8_BYTES;
338         } else {
339             return FRAME_ERROR;
340         }
341         if ((buffer.end() - it) < extendedPayloadLengthSize) {
342             return FRAME_INCOMPLETE;
343         }
344         payloadLength64 = 0;
345         for (int i = 0; i < extendedPayloadLengthSize; ++i) {
346             payloadLength64 <<= ByteOffset::BIT_8;
347             payloadLength64 |= static_cast<unsigned char>(*it++);
348         }
349     }
350 
351     static const uint64_t maxPayloadLength = 0x7FFFFFFFFFFFFFFF;
352     static const size_t maxLength = SIZE_MAX;
353     if (payloadLength64 > maxPayloadLength || payloadLength64 > maxLength - K_MASKING_KEY_WIDTH_IN_BYTES) {
354         // WebSocket frame length too large.
355         return FRAME_ERROR;
356     }
357     size_t payloadLength = static_cast<size_t>(payloadLength64);
358 
359     if (buffer.size() - K_MASKING_KEY_WIDTH_IN_BYTES < payloadLength) {
360         return FRAME_INCOMPLETE;
361     }
362 
363     std::vector<char>::const_iterator maskingKey = it;
364     std::vector<char>::const_iterator payload = it + K_MASKING_KEY_WIDTH_IN_BYTES;
365     for (size_t i = 0; i < payloadLength; ++i) { // Unmask the payload.
366         output->insert(output->end(), payload[i] ^ maskingKey[i % K_MASKING_KEY_WIDTH_IN_BYTES]);
367     }
368 
369     size_t pos = it + K_MASKING_KEY_WIDTH_IN_BYTES + payloadLength - buffer.begin();
370     *bytesConsumed = pos;
371     return closed ? FRAME_CLOSE : FRAME_OK;
372 }
373 
374 // WS protocol
375 class WsHandler : public ProtocolHandler {
376 public:
WsHandler(InspectorSocket * inspector,TcpHolder::Pointer tcp)377     WsHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp)
378         : ProtocolHandler(inspector, std::move(tcp)), onCloseSent(&WsHandler::WaitForCloseReply),
379           onCloseReceived(&WsHandler::CloseFrameReceived), dispose(false)
380     {}
381 
AcceptUpgrade(const std::string & acceptKey)382     void AcceptUpgrade(const std::string& acceptKey) override {}
CancelHandshake()383     void CancelHandshake() override {}
384 
OnEof()385     void OnEof() override
386     {
387         tcp.reset();
388         if (dispose) {
389             delete this;
390         }
391     }
392 
Write(const std::vector<char> data)393     void Write(const std::vector<char> data) override
394     {
395         std::vector<char> output = encode_frame_hybi17(data);
396         WriteRaw(output, WriteRequest::Cleanup);
397     }
398 
OnData(std::vector<char> * data)399     void OnData(std::vector<char>* data) override
400     {
401         int processed = 0;
402         do {
403             processed = ParseWsFrames(*data);
404             if (processed > 0) {
405                 RemoveFromBeginning(data, processed);
406             }
407         } while (processed > 0 && !data->empty());
408     }
409 
410 protected:
Shutdown()411     void Shutdown() override
412     {
413         if (tcp) {
414             dispose = true;
415             SendClose();
416         } else {
417             // if tcp is null, delete this
418             delete this;
419         }
420     }
421 
422 private:
423     using Callback = void (WsHandler::*)();
424 
OnCloseFrameWritten(uv_write_t * req,int status)425     static void OnCloseFrameWritten(uv_write_t* req, int status)
426     {
427         WriteRequest* wr = WriteRequest::FromWriteReq(req);
428         WsHandler* handler = static_cast<WsHandler*>(wr->handler);
429         delete wr;
430         Callback cb = handler->onCloseSent;
431         (handler->*cb)();
432     }
433 
WaitForCloseReply()434     void WaitForCloseReply()
435     {
436         onCloseReceived = &WsHandler::OnEof;
437     }
438 
SendClose()439     void SendClose()
440     {
441         WriteRaw(std::vector<char>(CLOSE_FRAME, CLOSE_FRAME + sizeof(CLOSE_FRAME)), OnCloseFrameWritten);
442     }
443 
CloseFrameReceived()444     void CloseFrameReceived()
445     {
446         onCloseSent = &WsHandler::OnEof;
447         SendClose();
448     }
449 
ParseWsFrames(const std::vector<char> & buffer)450     int ParseWsFrames(const std::vector<char>& buffer)
451     {
452         int bytesConsumed = 0;
453         std::vector<char> output;
454         bool compressed = false;
455 
456         WsDecodeResult r = DecodeFrameHybi17(buffer, true /* clientFrame */, &bytesConsumed, &output, &compressed);
457         // Compressed frame means client is ignoring the headers and misbehaves
458         if (compressed || r == FRAME_ERROR) {
459             OnEof();
460             bytesConsumed = 0;
461         } else if (r == FRAME_CLOSE) {
462             (this->*onCloseReceived)();
463             bytesConsumed = 0;
464         } else if (r == FRAME_OK) {
465             GetDelegate()->OnWsFrame(output);
466         } else if (r == FRAME_PING) {
467             WriteRaw(output, WriteRequest::Cleanup);
468         }
469         return bytesConsumed;
470     }
471 
472     Callback onCloseSent;
473     Callback onCloseReceived;
474     bool dispose;
475 };
476 
477 // HTTP protocol
478 class HttpEvent {
479 public:
HttpEvent(const std::string & path,bool upgrade,bool isGET,const std::string & wsKey,const std::string & host)480     HttpEvent(const std::string& path, bool upgrade, bool isGET, const std::string& wsKey, const std::string& host)
481         : path(path), upgrade(upgrade), isGET(isGET), wsKey(wsKey), host(host)
482     {}
483 
484     std::string path;
485     bool upgrade;
486     bool isGET;
487     std::string wsKey;
488     std::string host;
489 };
490 
491 class HttpHandler : public ProtocolHandler {
492 public:
HttpHandler(InspectorSocket * inspector,TcpHolder::Pointer tcp)493     explicit HttpHandler(InspectorSocket* inspector, TcpHolder::Pointer tcp)
494         : ProtocolHandler(inspector, std::move(tcp)), parsingValue(false)
495     {
496         llhttp_init(&parser, HTTP_REQUEST, &parserSettings);
497         llhttp_settings_init(&parserSettings);
498         parserSettings.on_header_field = OnHeaderField;
499         parserSettings.on_header_value = OnHeaderValue;
500         parserSettings.on_message_complete = OnMessageComplete;
501         parserSettings.on_url = OnPath;
502     }
503 
AcceptUpgrade(const std::string & acceptKey)504     void AcceptUpgrade(const std::string& acceptKey) override
505     {
506         char acceptString[ACCEPT_KEY_LENGTH];
507         GenerateAcceptString(acceptKey, &acceptString);
508         const char acceptWsPrefix[] = "HTTP/1.1 101 Switching Protocols\r\n"
509                                       "Upgrade: websocket\r\n"
510                                       "Connection: Upgrade\r\n"
511                                       "Sec-WebSocket-Accept: ";
512         const char acceptWsSuffix[] = "\r\n\r\n";
513         std::vector<char> reply(acceptWsPrefix, acceptWsPrefix + sizeof(acceptWsPrefix) - 1);
514         reply.insert(reply.end(), acceptString, acceptString + sizeof(acceptString));
515         reply.insert(reply.end(), acceptWsSuffix, acceptWsSuffix + sizeof(acceptWsSuffix) - 1);
516         if (WriteRaw(reply, WriteRequest::Cleanup) >= 0) {
517             inspector->SwitchProtocol(new WsHandler(inspector, std::move(tcp)));
518         } else {
519             tcp.reset();
520         }
521     }
522 
CancelHandshake()523     void CancelHandshake() override
524     {
525         const char handshakeFailedResponse[] = "HTTP/1.0 400 Bad Request\r\n"
526                                                "Content-Type: text/html; charset=UTF-8\r\n\r\n"
527                                                "WebSockets request was expected\r\n";
528         WriteRaw(std::vector<char>(handshakeFailedResponse,
529                                    handshakeFailedResponse + sizeof(handshakeFailedResponse) - 1),
530                  ThenCloseAndReportFailure);
531     }
532 
OnEof()533     void OnEof() override
534     {
535         tcp.reset();
536     }
537 
Write(const std::vector<char> data)538     void Write(const std::vector<char> data) override
539     {
540         WriteRaw(data, WriteRequest::Cleanup);
541     }
542 
OnData(std::vector<char> * data)543     void OnData(std::vector<char>* data) override
544     {
545         llhttp_errno_t err = llhttp_execute(&parser, data->data(), data->size());
546         if (err == HPE_PAUSED_UPGRADE) {
547             err = HPE_OK;
548             llhttp_resume_after_upgrade(&parser);
549         }
550         data->clear();
551         if (err != HPE_OK) {
552             CancelHandshake();
553         }
554         // Event handling may delete *this
555         std::vector<HttpEvent> httpEvents;
556         std::swap(httpEvents, events);
557         for (const HttpEvent& event : httpEvents) {
558             if (!IsAllowedHost(event.host) || !event.isGET) {
559                 CancelHandshake();
560                 return;
561             } else if (!event.upgrade) {
562                 GetDelegate()->OnHttpGet(event.host, event.path);
563             } else if (event.wsKey.empty()) {
564                 CancelHandshake();
565                 return;
566             } else {
567                 GetDelegate()->OnSocketUpgrade(event.host, event.path, event.wsKey);
568             }
569         }
570     }
571 
572 protected:
Shutdown()573     void Shutdown() override
574     {
575         delete this;
576     }
577 
578 private:
ThenCloseAndReportFailure(uv_write_t * req,int status)579     static void ThenCloseAndReportFailure(uv_write_t* req, int status)
580     {
581         ProtocolHandler* handler = WriteRequest::FromWriteReq(req)->handler;
582         WriteRequest::Cleanup(req, status);
583         handler->GetInspectorSocket()->SwitchProtocol(nullptr);
584     }
585 
OnHeaderValue(llhttp_t * parser,const char * at,size_t length)586     static int OnHeaderValue(llhttp_t* parser, const char* at, size_t length)
587     {
588         HttpHandler* handler = From(parser);
589         handler->parsingValue = true;
590         handler->headers[handler->currentHeader].append(at, length);
591         return 0;
592     }
593 
OnHeaderField(llhttp_t * parser,const char * at,size_t length)594     static int OnHeaderField(llhttp_t* parser, const char* at, size_t length)
595     {
596         HttpHandler* handler = From(parser);
597         if (handler->parsingValue) {
598             handler->parsingValue = false;
599             handler->currentHeader.clear();
600         }
601         handler->currentHeader.append(at, length);
602         return 0;
603     }
604 
OnPath(llhttp_t * parser,const char * at,size_t length)605     static int OnPath(llhttp_t* parser, const char* at, size_t length)
606     {
607         HttpHandler* handler = From(parser);
608         handler->path.append(at, length);
609         return 0;
610     }
611 
From(llhttp_t * parser)612     static HttpHandler* From(llhttp_t* parser)
613     {
614         return jsvm::inspector::ContainerOf(&HttpHandler::parser, parser);
615     }
616 
OnMessageComplete(llhttp_t * parser)617     static int OnMessageComplete(llhttp_t* parser)
618     {
619         // Event needs to be fired after the parser is done.
620         HttpHandler* handler = From(parser);
621         handler->events.emplace_back(handler->path, parser->upgrade, parser->method == HTTP_GET,
622                                      handler->HeaderValue("Sec-WebSocket-Key"), handler->HeaderValue("Host"));
623         handler->path = "";
624         handler->parsingValue = false;
625         handler->headers.clear();
626         handler->currentHeader = "";
627         return 0;
628     }
629 
HeaderValue(const std::string & header) const630     std::string HeaderValue(const std::string& header) const
631     {
632         bool headerFound = false;
633         std::string value;
634         for (const auto& header_value : headers) {
635             if (jsvm::inspector::StringEqualNoCaseN(header_value.first.data(), header.data(), header.length())) {
636                 if (headerFound) {
637                     return "";
638                 }
639                 value = header_value.second;
640                 headerFound = true;
641             }
642         }
643         return value;
644     }
645 
IsAllowedHost(const std::string & hostWithPort) const646     bool IsAllowedHost(const std::string& hostWithPort) const
647     {
648         std::string host = TrimPort(hostWithPort);
649         return host.empty() || IsIPAddress(host) || jsvm::inspector::StringEqualNoCase(host.data(), "localhost");
650     }
651 
652     bool parsingValue;
653     llhttp_t parser;
654     llhttp_settings_t parserSettings;
655     std::vector<HttpEvent> events;
656     std::string currentHeader;
657     std::map<std::string, std::string> headers;
658     std::string path;
659 };
660 
661 } // namespace
662 
663 // Any protocol
ProtocolHandler(InspectorSocket * inspector,TcpHolder::Pointer tcpParam)664 ProtocolHandler::ProtocolHandler(InspectorSocket* inspector, TcpHolder::Pointer tcpParam)
665     : inspector(inspector), tcp(std::move(tcpParam))
666 {
667     CHECK_NOT_NULL(tcp);
668     tcp->SetHandler(this);
669 }
670 
WriteRaw(const std::vector<char> & buffer,uv_write_cb writeCb)671 int ProtocolHandler::WriteRaw(const std::vector<char>& buffer, uv_write_cb writeCb)
672 {
673     return tcp->WriteRaw(buffer, writeCb);
674 }
675 
GetDelegate()676 InspectorSocket::Delegate* ProtocolHandler::GetDelegate()
677 {
678     return tcp->GetDelegate();
679 }
680 
GetHost() const681 std::string ProtocolHandler::GetHost() const
682 {
683     char ip[INET6_ADDRSTRLEN];
684     sockaddr_storage addr;
685     int len = sizeof(addr);
686     int err = uv_tcp_getsockname(tcp->GetTcp(), reinterpret_cast<struct sockaddr*>(&addr), &len);
687     if (err) {
688         return "";
689     }
690     if (addr.ss_family == AF_INET6) {
691         // using ipv6
692         const sockaddr_in6* v6 = reinterpret_cast<const sockaddr_in6*>(&addr);
693         err = uv_ip6_name(v6, ip, sizeof(ip));
694     } else {
695         // using ipv4
696         const sockaddr_in* v4 = reinterpret_cast<const sockaddr_in*>(&addr);
697         err = uv_ip4_name(v4, ip, sizeof(ip));
698     }
699     if (err) {
700         return "";
701     }
702     return ip;
703 }
704 
705 // RAII uv_tcp_t wrapper
TcpHolder(InspectorSocket::DelegatePointer delegate)706 TcpHolder::TcpHolder(InspectorSocket::DelegatePointer delegate) : tcp(), delegate(std::move(delegate)), handler(nullptr)
707 {}
708 
709 // static
Accept(uv_stream_t * server,InspectorSocket::DelegatePointer delegate)710 TcpHolder::Pointer TcpHolder::Accept(uv_stream_t* server, InspectorSocket::DelegatePointer delegate)
711 {
712     TcpHolder* result = new TcpHolder(std::move(delegate));
713     uv_stream_t* tcp = reinterpret_cast<uv_stream_t*>(&result->tcp);
714     int err = uv_tcp_init(server->loop, &result->tcp);
715     if (err == 0) {
716         err = uv_accept(server, tcp);
717     }
718     if (err == 0) {
719         err = uv_read_start(tcp, AllocateBuffer, OnDataReceivedCb);
720     }
721     if (err == 0) {
722         return TcpHolder::Pointer(result);
723     } else {
724         delete result;
725         return nullptr;
726     }
727 }
728 
SetHandler(ProtocolHandler * protocalHandler)729 void TcpHolder::SetHandler(ProtocolHandler* protocalHandler)
730 {
731     handler = protocalHandler;
732 }
733 
WriteRaw(const std::vector<char> & buffer,uv_write_cb writeCb)734 int TcpHolder::WriteRaw(const std::vector<char>& buffer, uv_write_cb writeCb)
735 {
736     // Freed in write_request_cleanup
737     WriteRequest* wr = new WriteRequest(handler, buffer);
738     uv_stream_t* stream = reinterpret_cast<uv_stream_t*>(&tcp);
739     int err = uv_write(&wr->req, stream, &wr->buf, 1, writeCb);
740     if (err < 0) {
741         delete wr;
742     }
743     return err < 0;
744 }
745 
GetDelegate()746 InspectorSocket::Delegate* TcpHolder::GetDelegate()
747 {
748     return delegate.get();
749 }
750 
751 // static
OnClosed(uv_handle_t * handle)752 void TcpHolder::OnClosed(uv_handle_t* handle)
753 {
754     delete From(handle);
755 }
756 
OnDataReceivedCb(uv_stream_t * tcp,ssize_t nread,const uv_buf_t * buf)757 void TcpHolder::OnDataReceivedCb(uv_stream_t* tcp, ssize_t nread, const uv_buf_t* buf)
758 {
759     TcpHolder* holder = From(tcp);
760     holder->ReclaimUvBuf(buf, nread);
761     if (nread < 0 || nread == UV_EOF) {
762         holder->handler->OnEof();
763     } else {
764         holder->handler->OnData(&holder->buffer);
765     }
766 }
767 
768 // static
DisconnectAndDispose(TcpHolder * holder)769 void TcpHolder::DisconnectAndDispose(TcpHolder* holder)
770 {
771     uv_handle_t* handle = reinterpret_cast<uv_handle_t*>(&holder->tcp);
772     uv_close(handle, OnClosed);
773 }
774 
ReclaimUvBuf(const uv_buf_t * buf,ssize_t read)775 void TcpHolder::ReclaimUvBuf(const uv_buf_t* buf, ssize_t read)
776 {
777     if (read > 0) {
778         DCHECK(read <= buf.len);
779         // insert buffer
780         buffer.insert(buffer.end(), buf->base, buf->base + read);
781     }
782     delete[] buf->base;
783 }
784 
785 InspectorSocket::~InspectorSocket() = default;
786 
787 // static method
Shutdown(ProtocolHandler * handler)788 void InspectorSocket::Shutdown(ProtocolHandler* handler)
789 {
790     handler->Shutdown();
791 }
792 
793 // static method
Accept(uv_stream_t * server,DelegatePointer delegate)794 InspectorSocket::Pointer InspectorSocket::Accept(uv_stream_t* server, DelegatePointer delegate)
795 {
796     auto tcp = TcpHolder::Accept(server, std::move(delegate));
797     InspectorSocket* inspector = nullptr;
798     if (tcp) {
799         // If accept tcp, create new inspector socket
800         inspector = new InspectorSocket();
801         inspector->SwitchProtocol(new HttpHandler(inspector, std::move(tcp)));
802         return InspectorSocket::Pointer(inspector);
803     }
804     return InspectorSocket::Pointer(nullptr);
805 }
806 
AcceptUpgrade(const std::string & acceptKey)807 void InspectorSocket::AcceptUpgrade(const std::string& acceptKey)
808 {
809     protocolHandler->AcceptUpgrade(acceptKey);
810 }
811 
CancelHandshake()812 void InspectorSocket::CancelHandshake()
813 {
814     protocolHandler->CancelHandshake();
815 }
816 
GetHost()817 std::string InspectorSocket::GetHost()
818 {
819     return protocolHandler->GetHost();
820 }
821 
SwitchProtocol(ProtocolHandler * handler)822 void InspectorSocket::SwitchProtocol(ProtocolHandler* handler)
823 {
824     protocolHandler.reset(std::move(handler));
825 }
826 
Write(const char * data,size_t len)827 void InspectorSocket::Write(const char* data, size_t len)
828 {
829     protocolHandler->Write(std::vector<char>(data, data + len));
830 }
831 
832 } // namespace inspector
833 } // namespace jsvm
834