• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <https/WebSocketHandler.h>
18 
19 #include <https/ClientSocket.h>
20 #include <https/Support.h>
21 
22 #include <iostream>
23 #include <sstream>
24 
handleRequest(uint8_t * data,size_t size,bool isEOS)25 ssize_t WebSocketHandler::handleRequest(
26         uint8_t *data, size_t size, bool isEOS) {
27     (void)isEOS;
28 
29     size_t offset = 0;
30     while (offset + 1 < size) {
31         uint8_t *packet = &data[offset];
32         const size_t avail = size - offset;
33 
34         size_t packetOffset = 0;
35         const uint8_t headerByte = packet[packetOffset];
36 
37         const bool hasMask = (packet[packetOffset + 1] & 0x80) != 0;
38         size_t payloadLen = packet[packetOffset + 1] & 0x7f;
39         packetOffset += 2;
40 
41         if (payloadLen == 126) {
42             if (packetOffset + 1 >= avail) {
43                 break;
44             }
45 
46             payloadLen = U16_AT(&packet[packetOffset]);
47             packetOffset += 2;
48         } else if (payloadLen == 127) {
49             if (packetOffset + 7 >= avail) {
50                 break;
51             }
52 
53             payloadLen = U64_AT(&packet[packetOffset]);
54             packetOffset += 8;
55         }
56 
57         uint32_t mask = 0;
58         if (hasMask) {
59             if (packetOffset + 3 >= avail) {
60                 break;
61             }
62 
63             mask = U32_AT(&packet[packetOffset]);
64             packetOffset += 4;
65         }
66 
67         if (packetOffset + payloadLen > avail) {
68             break;
69         }
70 
71         if (mask) {
72             for (size_t i = 0; i < payloadLen; ++i) {
73                 packet[packetOffset + i] ^= ((mask >> (8 * (3 - (i % 4)))) & 0xff);
74             }
75         }
76 
77         int err = handleMessage(headerByte, &packet[packetOffset], payloadLen);
78 
79         offset += packetOffset + payloadLen;
80 
81         if (err < 0) {
82             return err;
83         }
84     }
85 
86     return offset;
87 }
88 
isConnected()89 bool WebSocketHandler::isConnected() {
90     return mOutputCallback != nullptr || mClientSocket.lock() != nullptr;
91 }
92 
setClientSocket(std::weak_ptr<ClientSocket> clientSocket)93 void WebSocketHandler::setClientSocket(std::weak_ptr<ClientSocket> clientSocket) {
94     mClientSocket = clientSocket;
95 }
96 
setOutputCallback(const sockaddr_in & remoteAddr,OutputCallback fn)97 void WebSocketHandler::setOutputCallback(
98         const sockaddr_in &remoteAddr, OutputCallback fn) {
99     mOutputCallback = fn;
100     mRemoteAddr = remoteAddr;
101 }
102 
handleMessage(uint8_t headerByte,const uint8_t * msg,size_t len)103 int WebSocketHandler::handleMessage(
104         uint8_t headerByte, const uint8_t *msg, size_t len) {
105     std::cerr
106         << "WebSocketHandler::handleMessage(0x"
107         << std::hex
108         << (unsigned)headerByte
109         << std::dec
110         << ")"
111         << std::endl;
112 
113     hexdump(msg, len);
114 
115     const uint8_t opcode = headerByte & 0x0f;
116     if (opcode == 8) {
117         // Connection close.
118         return -1;
119     }
120 
121     return 0;
122 }
123 
sendMessage(const void * data,size_t size,SendMode mode)124 int WebSocketHandler::sendMessage(
125         const void *data, size_t size, SendMode mode) {
126     static constexpr bool kUseMask = false;
127 
128     size_t numHeaderBytes = 2 + (kUseMask ? 4 : 0);
129     if (size > 65535) {
130         numHeaderBytes += 8;
131     } else if (size > 125) {
132         numHeaderBytes += 2;
133     }
134 
135     static constexpr uint8_t kOpCodeBySendMode[] = {
136         0x1,  // text
137         0x2,  // binary
138         0x8,  // closeConnection
139     };
140 
141     auto opcode = kOpCodeBySendMode[static_cast<uint8_t>(mode)];
142 
143     std::unique_ptr<uint8_t[]> buffer(new uint8_t[numHeaderBytes + size]);
144     uint8_t *msg = buffer.get();
145     msg[0] = 0x80 | opcode;  // FIN==1
146     msg[1] = kUseMask ? 0x80 : 0x00;
147 
148     if (size > 65535) {
149         msg[1] |= 127;
150         msg[2] = 0x00;
151         msg[3] = 0x00;
152         msg[4] = 0x00;
153         msg[5] = 0x00;
154         msg[6] = (size >> 24) & 0xff;
155         msg[7] = (size >> 16) & 0xff;
156         msg[8] = (size >> 8) & 0xff;
157         msg[9] = size & 0xff;
158     } else if (size > 125) {
159         msg[1] |= 126;
160         msg[2] = (size >> 8) & 0xff;
161         msg[3] = size & 0xff;
162     } else {
163         msg[1] |= size;
164     }
165 
166     if (kUseMask) {
167         uint32_t mask = rand();
168         msg[numHeaderBytes - 4] = (mask >> 24) & 0xff;
169         msg[numHeaderBytes - 3] = (mask >> 16) & 0xff;
170         msg[numHeaderBytes - 2] = (mask >> 8) & 0xff;
171         msg[numHeaderBytes - 1] = mask & 0xff;
172 
173         for (size_t i = 0; i < size; ++i) {
174             msg[numHeaderBytes + i] =
175                 ((const uint8_t *)data)[i]
176                     ^ ((mask >> (8 * (3 - (i % 4)))) & 0xff);
177         }
178     } else {
179         memcpy(&msg[numHeaderBytes], data, size);
180     }
181 
182     if (mOutputCallback) {
183         mOutputCallback(msg, numHeaderBytes + size);
184     } else {
185         auto clientSocket = mClientSocket.lock();
186         if (clientSocket) {
187             clientSocket->queueOutputData(msg, numHeaderBytes + size);
188         }
189     }
190 
191     return 0;
192 }
193 
remoteHost() const194 std::string WebSocketHandler::remoteHost() const {
195     sockaddr_in remoteAddr;
196 
197     if (mOutputCallback) {
198         remoteAddr = mRemoteAddr;
199     } else {
200         auto clientSocket = mClientSocket.lock();
201         if (clientSocket) {
202             remoteAddr = clientSocket->remoteAddr();
203         } else {
204             return "0.0.0.0";
205         }
206     }
207 
208     const uint32_t ipAddress = ntohl(remoteAddr.sin_addr.s_addr);
209 
210     std::stringstream ss;
211     ss << (ipAddress >> 24)
212        << "."
213        << ((ipAddress >> 16) & 0xff)
214        << "."
215        << ((ipAddress >> 8) & 0xff)
216        << "."
217        << (ipAddress & 0xff);
218 
219     return ss.str();
220 }
221 
222