• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "socket.h"
18 
19 #include "message.h"
20 #include "utils.h"
21 
22 #include <errno.h>
23 #include <linux/if_packet.h>
24 #include <netinet/ip.h>
25 #include <netinet/udp.h>
26 #include <string.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <sys/uio.h>
30 #include <unistd.h>
31 
32 // Combine the checksum of |buffer| with |size| bytes with |checksum|. This is
33 // used for checksum calculations for IP and UDP.
addChecksum(const uint8_t * buffer,size_t size,uint32_t checksum)34 static uint32_t addChecksum(const uint8_t* buffer, size_t size, uint32_t checksum) {
35     const uint16_t* data = reinterpret_cast<const uint16_t*>(buffer);
36     while (size > 1) {
37         checksum += *data++;
38         size -= 2;
39     }
40     if (size > 0) {
41         // Odd size, add the last byte
42         checksum += *reinterpret_cast<const uint8_t*>(data);
43     }
44     // msw is the most significant word, the upper 16 bits of the checksum
45     for (uint32_t msw = checksum >> 16; msw != 0; msw = checksum >> 16) {
46         checksum = (checksum & 0xFFFF) + msw;
47     }
48     return checksum;
49 }
50 
51 // Convenienct template function for checksum calculation
addChecksum(const T & data,uint32_t checksum)52 template <typename T> static uint32_t addChecksum(const T& data, uint32_t checksum) {
53     return addChecksum(reinterpret_cast<const uint8_t*>(&data), sizeof(T), checksum);
54 }
55 
56 // Finalize the IP or UDP |checksum| by inverting and truncating it.
finishChecksum(uint32_t checksum)57 static uint32_t finishChecksum(uint32_t checksum) {
58     return ~checksum & 0xFFFF;
59 }
60 
Socket()61 Socket::Socket() : mSocketFd(-1) {}
62 
~Socket()63 Socket::~Socket() {
64     if (mSocketFd != -1) {
65         ::close(mSocketFd);
66         mSocketFd = -1;
67     }
68 }
69 
open(int domain,int type,int protocol)70 Result Socket::open(int domain, int type, int protocol) {
71     if (mSocketFd != -1) {
72         return Result::error("Socket already open");
73     }
74     mSocketFd = ::socket(domain, type, protocol);
75     if (mSocketFd == -1) {
76         return Result::error("Failed to open socket: %s", strerror(errno));
77     }
78     return Result::success();
79 }
80 
bind(const void * sockaddr,size_t sockaddrLength)81 Result Socket::bind(const void* sockaddr, size_t sockaddrLength) {
82     if (mSocketFd == -1) {
83         return Result::error("Socket not open");
84     }
85 
86     int status =
87         ::bind(mSocketFd, reinterpret_cast<const struct sockaddr*>(sockaddr), sockaddrLength);
88     if (status != 0) {
89         return Result::error("Unable to bind raw socket: %s", strerror(errno));
90     }
91 
92     return Result::success();
93 }
94 
bindIp(in_addr_t address,uint16_t port)95 Result Socket::bindIp(in_addr_t address, uint16_t port) {
96     struct sockaddr_in sockaddr;
97     memset(&sockaddr, 0, sizeof(sockaddr));
98     sockaddr.sin_family = AF_INET;
99     sockaddr.sin_port = htons(port);
100     sockaddr.sin_addr.s_addr = address;
101 
102     return bind(&sockaddr, sizeof(sockaddr));
103 }
104 
bindRaw(unsigned int interfaceIndex)105 Result Socket::bindRaw(unsigned int interfaceIndex) {
106     struct sockaddr_ll sockaddr;
107     memset(&sockaddr, 0, sizeof(sockaddr));
108     sockaddr.sll_family = AF_PACKET;
109     sockaddr.sll_protocol = htons(ETH_P_IP);
110     sockaddr.sll_ifindex = interfaceIndex;
111 
112     return bind(&sockaddr, sizeof(sockaddr));
113 }
114 
sendOnInterface(unsigned int interfaceIndex,in_addr_t destinationAddress,uint16_t destinationPort,const Message & message)115 Result Socket::sendOnInterface(unsigned int interfaceIndex, in_addr_t destinationAddress,
116                                uint16_t destinationPort, const Message& message) {
117     if (mSocketFd == -1) {
118         return Result::error("Socket not open");
119     }
120 
121     char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))] = {0};
122     struct sockaddr_in addr;
123     memset(&addr, 0, sizeof(addr));
124     addr.sin_family = AF_INET;
125     addr.sin_port = htons(destinationPort);
126     addr.sin_addr.s_addr = destinationAddress;
127 
128     struct msghdr header;
129     memset(&header, 0, sizeof(header));
130     struct iovec iov;
131     // The struct member is non-const since it's used for receiving but it's
132     // safe to cast away const for sending.
133     iov.iov_base = const_cast<uint8_t*>(message.data());
134     iov.iov_len = message.size();
135     header.msg_name = &addr;
136     header.msg_namelen = sizeof(addr);
137     header.msg_iov = &iov;
138     header.msg_iovlen = 1;
139     header.msg_control = &controlData;
140     header.msg_controllen = sizeof(controlData);
141 
142     struct cmsghdr* controlHeader = CMSG_FIRSTHDR(&header);
143     controlHeader->cmsg_level = IPPROTO_IP;
144     controlHeader->cmsg_type = IP_PKTINFO;
145     controlHeader->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
146     auto packetInfo = reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(controlHeader));
147     memset(packetInfo, 0, sizeof(*packetInfo));
148     packetInfo->ipi_ifindex = interfaceIndex;
149 
150     ssize_t status = ::sendmsg(mSocketFd, &header, 0);
151     if (status <= 0) {
152         return Result::error("Failed to send packet: %s", strerror(errno));
153     }
154     return Result::success();
155 }
156 
sendRawUdp(in_addr_t source,uint16_t sourcePort,in_addr_t destination,uint16_t destinationPort,unsigned int interfaceIndex,const Message & message)157 Result Socket::sendRawUdp(in_addr_t source, uint16_t sourcePort, in_addr_t destination,
158                           uint16_t destinationPort, unsigned int interfaceIndex,
159                           const Message& message) {
160     struct iphdr ip;
161     struct udphdr udp;
162 
163     ip.version = IPVERSION;
164     ip.ihl = sizeof(ip) >> 2;
165     ip.tos = 0;
166     ip.tot_len = htons(sizeof(ip) + sizeof(udp) + message.size());
167     ip.id = 0;
168     ip.frag_off = 0;
169     ip.ttl = IPDEFTTL;
170     ip.protocol = IPPROTO_UDP;
171     ip.check = 0;
172     ip.saddr = source;
173     ip.daddr = destination;
174     ip.check = finishChecksum(addChecksum(ip, 0));
175 
176     udp.source = htons(sourcePort);
177     udp.dest = htons(destinationPort);
178     udp.len = htons(sizeof(udp) + message.size());
179     udp.check = 0;
180 
181     uint32_t udpChecksum = 0;
182     udpChecksum = addChecksum(ip.saddr, udpChecksum);
183     udpChecksum = addChecksum(ip.daddr, udpChecksum);
184     udpChecksum = addChecksum(htons(IPPROTO_UDP), udpChecksum);
185     udpChecksum = addChecksum(udp.len, udpChecksum);
186     udpChecksum = addChecksum(udp, udpChecksum);
187     udpChecksum = addChecksum(message.data(), message.size(), udpChecksum);
188     udp.check = finishChecksum(udpChecksum);
189 
190     struct iovec iov[3];
191 
192     iov[0].iov_base = static_cast<void*>(&ip);
193     iov[0].iov_len = sizeof(ip);
194     iov[1].iov_base = static_cast<void*>(&udp);
195     iov[1].iov_len = sizeof(udp);
196     // sendmsg requires these to be non-const but for sending won't modify them
197     iov[2].iov_base = static_cast<void*>(const_cast<uint8_t*>(message.data()));
198     iov[2].iov_len = message.size();
199 
200     struct sockaddr_ll dest;
201     memset(&dest, 0, sizeof(dest));
202     dest.sll_family = AF_PACKET;
203     dest.sll_protocol = htons(ETH_P_IP);
204     dest.sll_ifindex = interfaceIndex;
205     dest.sll_halen = ETH_ALEN;
206     memset(dest.sll_addr, 0xFF, ETH_ALEN);
207 
208     struct msghdr header;
209     memset(&header, 0, sizeof(header));
210     header.msg_name = &dest;
211     header.msg_namelen = sizeof(dest);
212     header.msg_iov = iov;
213     header.msg_iovlen = sizeof(iov) / sizeof(iov[0]);
214 
215     ssize_t res = ::sendmsg(mSocketFd, &header, 0);
216     if (res == -1) {
217         return Result::error("Failed to send message: %s", strerror(errno));
218     }
219     return Result::success();
220 }
221 
receiveFromInterface(Message * message,unsigned int * interfaceIndex)222 Result Socket::receiveFromInterface(Message* message, unsigned int* interfaceIndex) {
223     char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))];
224     struct msghdr header;
225     memset(&header, 0, sizeof(header));
226     struct iovec iov;
227     iov.iov_base = message->data();
228     iov.iov_len = message->capacity();
229     header.msg_iov = &iov;
230     header.msg_iovlen = 1;
231     header.msg_control = &controlData;
232     header.msg_controllen = sizeof(controlData);
233 
234     ssize_t bytesRead = ::recvmsg(mSocketFd, &header, 0);
235     if (bytesRead < 0) {
236         return Result::error("Error receiving on socket: %s", strerror(errno));
237     }
238     message->setSize(static_cast<size_t>(bytesRead));
239     if (header.msg_controllen >= sizeof(struct cmsghdr)) {
240         for (struct cmsghdr* ctrl = CMSG_FIRSTHDR(&header); ctrl;
241              ctrl = CMSG_NXTHDR(&header, ctrl)) {
242             if (ctrl->cmsg_level == SOL_IP && ctrl->cmsg_type == IP_PKTINFO) {
243                 auto packetInfo = reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(ctrl));
244                 *interfaceIndex = packetInfo->ipi_ifindex;
245             }
246         }
247     }
248     return Result::success();
249 }
250 
receiveRawUdp(uint16_t expectedPort,Message * message,bool * isValid)251 Result Socket::receiveRawUdp(uint16_t expectedPort, Message* message, bool* isValid) {
252     struct iphdr ip;
253     struct udphdr udp;
254 
255     struct iovec iov[3];
256     iov[0].iov_base = &ip;
257     iov[0].iov_len = sizeof(ip);
258     iov[1].iov_base = &udp;
259     iov[1].iov_len = sizeof(udp);
260     iov[2].iov_base = message->data();
261     iov[2].iov_len = message->capacity();
262 
263     ssize_t bytesRead = ::readv(mSocketFd, iov, 3);
264     if (bytesRead < 0) {
265         return Result::error("Unable to read from socket: %s", strerror(errno));
266     }
267     if (static_cast<size_t>(bytesRead) < sizeof(ip) + sizeof(udp)) {
268         // Not enough bytes to even cover IP and UDP headers
269         *isValid = false;
270         return Result::success();
271     }
272     *isValid = ip.version == IPVERSION && ip.ihl == (sizeof(ip) >> 2) &&
273                ip.protocol == IPPROTO_UDP && udp.dest == htons(expectedPort);
274 
275     message->setSize(bytesRead - sizeof(ip) - sizeof(udp));
276     return Result::success();
277 }
278 
enableOption(int level,int optionName)279 Result Socket::enableOption(int level, int optionName) {
280     if (mSocketFd == -1) {
281         return Result::error("Socket not open");
282     }
283 
284     int enabled = 1;
285     int status = ::setsockopt(mSocketFd, level, optionName, &enabled, sizeof(enabled));
286     if (status == -1) {
287         return Result::error("Failed to set socket option: %s", strerror(errno));
288     }
289     return Result::success();
290 }
291