1 /*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "socket.h"
18
19 #include "message.h"
20 #include "utils.h"
21
22 #include <errno.h>
23 #include <linux/if_packet.h>
24 #include <netinet/ip.h>
25 #include <netinet/udp.h>
26 #include <string.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <sys/uio.h>
30 #include <unistd.h>
31
32 // Combine the checksum of |buffer| with |size| bytes with |checksum|. This is
33 // used for checksum calculations for IP and UDP.
addChecksum(const uint8_t * buffer,size_t size,uint32_t checksum)34 static uint32_t addChecksum(const uint8_t* buffer, size_t size, uint32_t checksum) {
35 const uint16_t* data = reinterpret_cast<const uint16_t*>(buffer);
36 while (size > 1) {
37 checksum += *data++;
38 size -= 2;
39 }
40 if (size > 0) {
41 // Odd size, add the last byte
42 checksum += *reinterpret_cast<const uint8_t*>(data);
43 }
44 // msw is the most significant word, the upper 16 bits of the checksum
45 for (uint32_t msw = checksum >> 16; msw != 0; msw = checksum >> 16) {
46 checksum = (checksum & 0xFFFF) + msw;
47 }
48 return checksum;
49 }
50
51 // Convenienct template function for checksum calculation
addChecksum(const T & data,uint32_t checksum)52 template <typename T> static uint32_t addChecksum(const T& data, uint32_t checksum) {
53 return addChecksum(reinterpret_cast<const uint8_t*>(&data), sizeof(T), checksum);
54 }
55
56 // Finalize the IP or UDP |checksum| by inverting and truncating it.
finishChecksum(uint32_t checksum)57 static uint32_t finishChecksum(uint32_t checksum) {
58 return ~checksum & 0xFFFF;
59 }
60
Socket()61 Socket::Socket() : mSocketFd(-1) {}
62
~Socket()63 Socket::~Socket() {
64 if (mSocketFd != -1) {
65 ::close(mSocketFd);
66 mSocketFd = -1;
67 }
68 }
69
open(int domain,int type,int protocol)70 Result Socket::open(int domain, int type, int protocol) {
71 if (mSocketFd != -1) {
72 return Result::error("Socket already open");
73 }
74 mSocketFd = ::socket(domain, type, protocol);
75 if (mSocketFd == -1) {
76 return Result::error("Failed to open socket: %s", strerror(errno));
77 }
78 return Result::success();
79 }
80
bind(const void * sockaddr,size_t sockaddrLength)81 Result Socket::bind(const void* sockaddr, size_t sockaddrLength) {
82 if (mSocketFd == -1) {
83 return Result::error("Socket not open");
84 }
85
86 int status =
87 ::bind(mSocketFd, reinterpret_cast<const struct sockaddr*>(sockaddr), sockaddrLength);
88 if (status != 0) {
89 return Result::error("Unable to bind raw socket: %s", strerror(errno));
90 }
91
92 return Result::success();
93 }
94
bindIp(in_addr_t address,uint16_t port)95 Result Socket::bindIp(in_addr_t address, uint16_t port) {
96 struct sockaddr_in sockaddr;
97 memset(&sockaddr, 0, sizeof(sockaddr));
98 sockaddr.sin_family = AF_INET;
99 sockaddr.sin_port = htons(port);
100 sockaddr.sin_addr.s_addr = address;
101
102 return bind(&sockaddr, sizeof(sockaddr));
103 }
104
bindRaw(unsigned int interfaceIndex)105 Result Socket::bindRaw(unsigned int interfaceIndex) {
106 struct sockaddr_ll sockaddr;
107 memset(&sockaddr, 0, sizeof(sockaddr));
108 sockaddr.sll_family = AF_PACKET;
109 sockaddr.sll_protocol = htons(ETH_P_IP);
110 sockaddr.sll_ifindex = interfaceIndex;
111
112 return bind(&sockaddr, sizeof(sockaddr));
113 }
114
sendOnInterface(unsigned int interfaceIndex,in_addr_t destinationAddress,uint16_t destinationPort,const Message & message)115 Result Socket::sendOnInterface(unsigned int interfaceIndex, in_addr_t destinationAddress,
116 uint16_t destinationPort, const Message& message) {
117 if (mSocketFd == -1) {
118 return Result::error("Socket not open");
119 }
120
121 char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))] = {0};
122 struct sockaddr_in addr;
123 memset(&addr, 0, sizeof(addr));
124 addr.sin_family = AF_INET;
125 addr.sin_port = htons(destinationPort);
126 addr.sin_addr.s_addr = destinationAddress;
127
128 struct msghdr header;
129 memset(&header, 0, sizeof(header));
130 struct iovec iov;
131 // The struct member is non-const since it's used for receiving but it's
132 // safe to cast away const for sending.
133 iov.iov_base = const_cast<uint8_t*>(message.data());
134 iov.iov_len = message.size();
135 header.msg_name = &addr;
136 header.msg_namelen = sizeof(addr);
137 header.msg_iov = &iov;
138 header.msg_iovlen = 1;
139 header.msg_control = &controlData;
140 header.msg_controllen = sizeof(controlData);
141
142 struct cmsghdr* controlHeader = CMSG_FIRSTHDR(&header);
143 controlHeader->cmsg_level = IPPROTO_IP;
144 controlHeader->cmsg_type = IP_PKTINFO;
145 controlHeader->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
146 auto packetInfo = reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(controlHeader));
147 memset(packetInfo, 0, sizeof(*packetInfo));
148 packetInfo->ipi_ifindex = interfaceIndex;
149
150 ssize_t status = ::sendmsg(mSocketFd, &header, 0);
151 if (status <= 0) {
152 return Result::error("Failed to send packet: %s", strerror(errno));
153 }
154 return Result::success();
155 }
156
sendRawUdp(in_addr_t source,uint16_t sourcePort,in_addr_t destination,uint16_t destinationPort,unsigned int interfaceIndex,const Message & message)157 Result Socket::sendRawUdp(in_addr_t source, uint16_t sourcePort, in_addr_t destination,
158 uint16_t destinationPort, unsigned int interfaceIndex,
159 const Message& message) {
160 struct iphdr ip;
161 struct udphdr udp;
162
163 ip.version = IPVERSION;
164 ip.ihl = sizeof(ip) >> 2;
165 ip.tos = 0;
166 ip.tot_len = htons(sizeof(ip) + sizeof(udp) + message.size());
167 ip.id = 0;
168 ip.frag_off = 0;
169 ip.ttl = IPDEFTTL;
170 ip.protocol = IPPROTO_UDP;
171 ip.check = 0;
172 ip.saddr = source;
173 ip.daddr = destination;
174 ip.check = finishChecksum(addChecksum(ip, 0));
175
176 udp.source = htons(sourcePort);
177 udp.dest = htons(destinationPort);
178 udp.len = htons(sizeof(udp) + message.size());
179 udp.check = 0;
180
181 uint32_t udpChecksum = 0;
182 udpChecksum = addChecksum(ip.saddr, udpChecksum);
183 udpChecksum = addChecksum(ip.daddr, udpChecksum);
184 udpChecksum = addChecksum(htons(IPPROTO_UDP), udpChecksum);
185 udpChecksum = addChecksum(udp.len, udpChecksum);
186 udpChecksum = addChecksum(udp, udpChecksum);
187 udpChecksum = addChecksum(message.data(), message.size(), udpChecksum);
188 udp.check = finishChecksum(udpChecksum);
189
190 struct iovec iov[3];
191
192 iov[0].iov_base = static_cast<void*>(&ip);
193 iov[0].iov_len = sizeof(ip);
194 iov[1].iov_base = static_cast<void*>(&udp);
195 iov[1].iov_len = sizeof(udp);
196 // sendmsg requires these to be non-const but for sending won't modify them
197 iov[2].iov_base = static_cast<void*>(const_cast<uint8_t*>(message.data()));
198 iov[2].iov_len = message.size();
199
200 struct sockaddr_ll dest;
201 memset(&dest, 0, sizeof(dest));
202 dest.sll_family = AF_PACKET;
203 dest.sll_protocol = htons(ETH_P_IP);
204 dest.sll_ifindex = interfaceIndex;
205 dest.sll_halen = ETH_ALEN;
206 memset(dest.sll_addr, 0xFF, ETH_ALEN);
207
208 struct msghdr header;
209 memset(&header, 0, sizeof(header));
210 header.msg_name = &dest;
211 header.msg_namelen = sizeof(dest);
212 header.msg_iov = iov;
213 header.msg_iovlen = sizeof(iov) / sizeof(iov[0]);
214
215 ssize_t res = ::sendmsg(mSocketFd, &header, 0);
216 if (res == -1) {
217 return Result::error("Failed to send message: %s", strerror(errno));
218 }
219 return Result::success();
220 }
221
receiveFromInterface(Message * message,unsigned int * interfaceIndex)222 Result Socket::receiveFromInterface(Message* message, unsigned int* interfaceIndex) {
223 char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))];
224 struct msghdr header;
225 memset(&header, 0, sizeof(header));
226 struct iovec iov;
227 iov.iov_base = message->data();
228 iov.iov_len = message->capacity();
229 header.msg_iov = &iov;
230 header.msg_iovlen = 1;
231 header.msg_control = &controlData;
232 header.msg_controllen = sizeof(controlData);
233
234 ssize_t bytesRead = ::recvmsg(mSocketFd, &header, 0);
235 if (bytesRead < 0) {
236 return Result::error("Error receiving on socket: %s", strerror(errno));
237 }
238 message->setSize(static_cast<size_t>(bytesRead));
239 if (header.msg_controllen >= sizeof(struct cmsghdr)) {
240 for (struct cmsghdr* ctrl = CMSG_FIRSTHDR(&header); ctrl;
241 ctrl = CMSG_NXTHDR(&header, ctrl)) {
242 if (ctrl->cmsg_level == SOL_IP && ctrl->cmsg_type == IP_PKTINFO) {
243 auto packetInfo = reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(ctrl));
244 *interfaceIndex = packetInfo->ipi_ifindex;
245 }
246 }
247 }
248 return Result::success();
249 }
250
receiveRawUdp(uint16_t expectedPort,Message * message,bool * isValid)251 Result Socket::receiveRawUdp(uint16_t expectedPort, Message* message, bool* isValid) {
252 struct iphdr ip;
253 struct udphdr udp;
254
255 struct iovec iov[3];
256 iov[0].iov_base = &ip;
257 iov[0].iov_len = sizeof(ip);
258 iov[1].iov_base = &udp;
259 iov[1].iov_len = sizeof(udp);
260 iov[2].iov_base = message->data();
261 iov[2].iov_len = message->capacity();
262
263 ssize_t bytesRead = ::readv(mSocketFd, iov, 3);
264 if (bytesRead < 0) {
265 return Result::error("Unable to read from socket: %s", strerror(errno));
266 }
267 if (static_cast<size_t>(bytesRead) < sizeof(ip) + sizeof(udp)) {
268 // Not enough bytes to even cover IP and UDP headers
269 *isValid = false;
270 return Result::success();
271 }
272 *isValid = ip.version == IPVERSION && ip.ihl == (sizeof(ip) >> 2) &&
273 ip.protocol == IPPROTO_UDP && udp.dest == htons(expectedPort);
274
275 message->setSize(bytesRead - sizeof(ip) - sizeof(udp));
276 return Result::success();
277 }
278
enableOption(int level,int optionName)279 Result Socket::enableOption(int level, int optionName) {
280 if (mSocketFd == -1) {
281 return Result::error("Socket not open");
282 }
283
284 int enabled = 1;
285 int status = ::setsockopt(mSocketFd, level, optionName, &enabled, sizeof(enabled));
286 if (status == -1) {
287 return Result::error("Failed to set socket option: %s", strerror(errno));
288 }
289 return Result::success();
290 }
291