1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <https/WebSocketHandler.h>
18
19 #include <https/ClientSocket.h>
20 #include <https/Support.h>
21
22 #include <iostream>
23 #include <sstream>
24
handleRequest(uint8_t * data,size_t size,bool isEOS)25 ssize_t WebSocketHandler::handleRequest(
26 uint8_t *data, size_t size, bool isEOS) {
27 (void)isEOS;
28
29 size_t offset = 0;
30 while (offset + 1 < size) {
31 uint8_t *packet = &data[offset];
32 const size_t avail = size - offset;
33
34 size_t packetOffset = 0;
35 const uint8_t headerByte = packet[packetOffset];
36
37 const bool hasMask = (packet[packetOffset + 1] & 0x80) != 0;
38 size_t payloadLen = packet[packetOffset + 1] & 0x7f;
39 packetOffset += 2;
40
41 if (payloadLen == 126) {
42 if (packetOffset + 1 >= avail) {
43 break;
44 }
45
46 payloadLen = U16_AT(&packet[packetOffset]);
47 packetOffset += 2;
48 } else if (payloadLen == 127) {
49 if (packetOffset + 7 >= avail) {
50 break;
51 }
52
53 payloadLen = U64_AT(&packet[packetOffset]);
54 packetOffset += 8;
55 }
56
57 uint32_t mask = 0;
58 if (hasMask) {
59 if (packetOffset + 3 >= avail) {
60 break;
61 }
62
63 mask = U32_AT(&packet[packetOffset]);
64 packetOffset += 4;
65 }
66
67 if (packetOffset + payloadLen > avail) {
68 break;
69 }
70
71 if (mask) {
72 for (size_t i = 0; i < payloadLen; ++i) {
73 packet[packetOffset + i] ^= ((mask >> (8 * (3 - (i % 4)))) & 0xff);
74 }
75 }
76
77 int err = handleMessage(headerByte, &packet[packetOffset], payloadLen);
78
79 offset += packetOffset + payloadLen;
80
81 if (err < 0) {
82 return err;
83 }
84 }
85
86 return offset;
87 }
88
isConnected()89 bool WebSocketHandler::isConnected() {
90 return mOutputCallback != nullptr || mClientSocket.lock() != nullptr;
91 }
92
setClientSocket(std::weak_ptr<ClientSocket> clientSocket)93 void WebSocketHandler::setClientSocket(std::weak_ptr<ClientSocket> clientSocket) {
94 mClientSocket = clientSocket;
95 }
96
setOutputCallback(const sockaddr_in & remoteAddr,OutputCallback fn)97 void WebSocketHandler::setOutputCallback(
98 const sockaddr_in &remoteAddr, OutputCallback fn) {
99 mOutputCallback = fn;
100 mRemoteAddr = remoteAddr;
101 }
102
handleMessage(uint8_t headerByte,const uint8_t * msg,size_t len)103 int WebSocketHandler::handleMessage(
104 uint8_t headerByte, const uint8_t *msg, size_t len) {
105 std::cerr
106 << "WebSocketHandler::handleMessage(0x"
107 << std::hex
108 << (unsigned)headerByte
109 << std::dec
110 << ")"
111 << std::endl;
112
113 hexdump(msg, len);
114
115 const uint8_t opcode = headerByte & 0x0f;
116 if (opcode == 8) {
117 // Connection close.
118 return -1;
119 }
120
121 return 0;
122 }
123
sendMessage(const void * data,size_t size,SendMode mode)124 int WebSocketHandler::sendMessage(
125 const void *data, size_t size, SendMode mode) {
126 static constexpr bool kUseMask = false;
127
128 size_t numHeaderBytes = 2 + (kUseMask ? 4 : 0);
129 if (size > 65535) {
130 numHeaderBytes += 8;
131 } else if (size > 125) {
132 numHeaderBytes += 2;
133 }
134
135 static constexpr uint8_t kOpCodeBySendMode[] = {
136 0x1, // text
137 0x2, // binary
138 0x8, // closeConnection
139 };
140
141 auto opcode = kOpCodeBySendMode[static_cast<uint8_t>(mode)];
142
143 std::unique_ptr<uint8_t[]> buffer(new uint8_t[numHeaderBytes + size]);
144 uint8_t *msg = buffer.get();
145 msg[0] = 0x80 | opcode; // FIN==1
146 msg[1] = kUseMask ? 0x80 : 0x00;
147
148 if (size > 65535) {
149 msg[1] |= 127;
150 msg[2] = 0x00;
151 msg[3] = 0x00;
152 msg[4] = 0x00;
153 msg[5] = 0x00;
154 msg[6] = (size >> 24) & 0xff;
155 msg[7] = (size >> 16) & 0xff;
156 msg[8] = (size >> 8) & 0xff;
157 msg[9] = size & 0xff;
158 } else if (size > 125) {
159 msg[1] |= 126;
160 msg[2] = (size >> 8) & 0xff;
161 msg[3] = size & 0xff;
162 } else {
163 msg[1] |= size;
164 }
165
166 if (kUseMask) {
167 uint32_t mask = rand();
168 msg[numHeaderBytes - 4] = (mask >> 24) & 0xff;
169 msg[numHeaderBytes - 3] = (mask >> 16) & 0xff;
170 msg[numHeaderBytes - 2] = (mask >> 8) & 0xff;
171 msg[numHeaderBytes - 1] = mask & 0xff;
172
173 for (size_t i = 0; i < size; ++i) {
174 msg[numHeaderBytes + i] =
175 ((const uint8_t *)data)[i]
176 ^ ((mask >> (8 * (3 - (i % 4)))) & 0xff);
177 }
178 } else {
179 memcpy(&msg[numHeaderBytes], data, size);
180 }
181
182 if (mOutputCallback) {
183 mOutputCallback(msg, numHeaderBytes + size);
184 } else {
185 auto clientSocket = mClientSocket.lock();
186 if (clientSocket) {
187 clientSocket->queueOutputData(msg, numHeaderBytes + size);
188 }
189 }
190
191 return 0;
192 }
193
remoteHost() const194 std::string WebSocketHandler::remoteHost() const {
195 sockaddr_in remoteAddr;
196
197 if (mOutputCallback) {
198 remoteAddr = mRemoteAddr;
199 } else {
200 auto clientSocket = mClientSocket.lock();
201 if (clientSocket) {
202 remoteAddr = clientSocket->remoteAddr();
203 } else {
204 return "0.0.0.0";
205 }
206 }
207
208 const uint32_t ipAddress = ntohl(remoteAddr.sin_addr.s_addr);
209
210 std::stringstream ss;
211 ss << (ipAddress >> 24)
212 << "."
213 << ((ipAddress >> 16) & 0xff)
214 << "."
215 << ((ipAddress >> 8) & 0xff)
216 << "."
217 << (ipAddress & 0xff);
218
219 return ss.str();
220 }
221
222