1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/arp_client.h"
18
19 #include <linux/if_packet.h>
20 #include <net/ethernet.h>
21 #include <net/if_arp.h>
22 #include <netinet/in.h>
23 #include <string.h>
24
25 #include "shill/arp_packet.h"
26 #include "shill/logging.h"
27 #include "shill/net/byte_string.h"
28 #include "shill/net/sockets.h"
29
30 namespace shill {
31
32 // ARP opcode is the last uint16_t in the ARP header.
33 const size_t ArpClient::kArpOpOffset = sizeof(arphdr) - sizeof(uint16_t);
34
35 // The largest packet we expect is one with IPv6 addresses in it.
36 const size_t ArpClient::kMaxArpPacketLength =
37 sizeof(arphdr) + sizeof(in6_addr) * 2 + ETH_ALEN * 2;
38
ArpClient(int interface_index)39 ArpClient::ArpClient(int interface_index)
40 : interface_index_(interface_index),
41 sockets_(new Sockets()),
42 socket_(-1) {}
43
~ArpClient()44 ArpClient::~ArpClient() {}
45
StartReplyListener()46 bool ArpClient::StartReplyListener() {
47 return Start(ARPOP_REPLY);
48 }
49
StartRequestListener()50 bool ArpClient::StartRequestListener() {
51 return Start(ARPOP_REQUEST);
52 }
53
Start(uint16_t arp_opcode)54 bool ArpClient::Start(uint16_t arp_opcode) {
55 if (!CreateSocket(arp_opcode)) {
56 LOG(ERROR) << "Could not open ARP socket.";
57 Stop();
58 return false;
59 }
60 return true;
61 }
62
Stop()63 void ArpClient::Stop() {
64 socket_closer_.reset();
65 }
66
67
CreateSocket(uint16_t arp_opcode)68 bool ArpClient::CreateSocket(uint16_t arp_opcode) {
69 int socket = sockets_->Socket(PF_PACKET, SOCK_DGRAM, htons(ETHERTYPE_ARP));
70 if (socket == -1) {
71 PLOG(ERROR) << "Could not create ARP socket";
72 return false;
73 }
74 socket_ = socket;
75 socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_));
76
77 // Create a packet filter incoming ARP packets.
78 const sock_filter arp_filter[] = {
79 // If a packet contains the ARP opcode we are looking for...
80 BPF_STMT(BPF_LD | BPF_H | BPF_ABS, kArpOpOffset),
81 BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, arp_opcode, 0, 1),
82 // Return the the packet (up to largest expected packet size).
83 BPF_STMT(BPF_RET | BPF_K, kMaxArpPacketLength),
84 // Otherwise, drop it.
85 BPF_STMT(BPF_RET | BPF_K, 0),
86 };
87
88 sock_fprog pf;
89 pf.filter = const_cast<sock_filter*>(arp_filter);
90 pf.len = arraysize(arp_filter);
91 if (sockets_->AttachFilter(socket_, &pf) != 0) {
92 PLOG(ERROR) << "Could not attach packet filter";
93 return false;
94 }
95
96 if (sockets_->SetNonBlocking(socket_) != 0) {
97 PLOG(ERROR) << "Could not set socket to be non-blocking";
98 return false;
99 }
100
101 sockaddr_ll socket_address;
102 memset(&socket_address, 0, sizeof(socket_address));
103 socket_address.sll_family = AF_PACKET;
104 socket_address.sll_protocol = htons(ETHERTYPE_ARP);
105 socket_address.sll_ifindex = interface_index_;
106
107 if (sockets_->Bind(socket_,
108 reinterpret_cast<struct sockaddr*>(&socket_address),
109 sizeof(socket_address)) != 0) {
110 PLOG(ERROR) << "Could not bind socket to interface";
111 return false;
112 }
113
114 return true;
115 }
116
ReceivePacket(ArpPacket * packet,ByteString * sender) const117 bool ArpClient::ReceivePacket(ArpPacket* packet, ByteString* sender) const {
118 ByteString payload(kMaxArpPacketLength);
119 sockaddr_ll socket_address;
120 memset(&socket_address, 0, sizeof(socket_address));
121 socklen_t socklen = sizeof(socket_address);
122 int result = sockets_->RecvFrom(
123 socket_,
124 payload.GetData(),
125 payload.GetLength(),
126 0,
127 reinterpret_cast<struct sockaddr*>(&socket_address),
128 &socklen);
129 if (result < 0) {
130 PLOG(ERROR) << "Socket recvfrom failed";
131 return false;
132 }
133
134 payload.Resize(result);
135 if (!packet->Parse(payload)) {
136 LOG(ERROR) << "Failed to parse ARP packet.";
137 return false;
138 }
139
140 // The socket address returned may only be big enough to contain
141 // the hardware address of the sender.
142 CHECK(socklen >=
143 sizeof(socket_address) - sizeof(socket_address.sll_addr) + ETH_ALEN);
144 CHECK(socket_address.sll_halen == ETH_ALEN);
145 *sender = ByteString(
146 reinterpret_cast<const unsigned char*>(&socket_address.sll_addr),
147 socket_address.sll_halen);
148 return true;
149 }
150
TransmitRequest(const ArpPacket & packet) const151 bool ArpClient::TransmitRequest(const ArpPacket& packet) const {
152 ByteString payload;
153 if (!packet.FormatRequest(&payload)) {
154 return false;
155 }
156
157 sockaddr_ll socket_address;
158 memset(&socket_address, 0, sizeof(socket_address));
159 socket_address.sll_family = AF_PACKET;
160 socket_address.sll_protocol = htons(ETHERTYPE_ARP);
161 socket_address.sll_hatype = ARPHRD_ETHER;
162 socket_address.sll_halen = ETH_ALEN;
163 socket_address.sll_ifindex = interface_index_;
164
165 ByteString remote_address = packet.remote_mac_address();
166 CHECK(sizeof(socket_address.sll_addr) >= remote_address.GetLength());
167 if (remote_address.IsZero()) {
168 // If the destination MAC address is unspecified, send the packet
169 // to the broadcast (all-ones) address.
170 remote_address.BitwiseInvert();
171 }
172 memcpy(&socket_address.sll_addr, remote_address.GetConstData(),
173 remote_address.GetLength());
174
175 int result = sockets_->SendTo(
176 socket_,
177 payload.GetConstData(),
178 payload.GetLength(),
179 0,
180 reinterpret_cast<struct sockaddr*>(&socket_address),
181 sizeof(socket_address));
182 const int expected_result = static_cast<int>(payload.GetLength());
183 if (result != expected_result) {
184 if (result < 0) {
185 PLOG(ERROR) << "Socket sendto failed";
186 } else if (result < static_cast<int>(payload.GetLength())) {
187 LOG(ERROR) << "Socket sendto returned "
188 << result
189 << " which is different from expected result "
190 << expected_result;
191 }
192 return false;
193 }
194
195 return true;
196 }
197
198 } // namespace shill
199