1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <libnl++/Socket.h>
18
19 #include <libnl++/printer.h>
20
21 #include <android-base/logging.h>
22
23 namespace android::nl {
24
25 /**
26 * Print all outbound/inbound Netlink messages.
27 */
28 static constexpr bool kSuperVerbose = false;
29
Socket(int protocol,unsigned pid,uint32_t groups)30 Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) {
31 mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
32 if (!mFd.ok()) {
33 PLOG(ERROR) << "Can't open Netlink socket";
34 mFailed = true;
35 return;
36 }
37
38 sockaddr_nl sa = {};
39 sa.nl_family = AF_NETLINK;
40 sa.nl_pid = pid;
41 sa.nl_groups = groups;
42
43 if (bind(mFd.get(), reinterpret_cast<sockaddr*>(&sa), sizeof(sa)) < 0) {
44 PLOG(ERROR) << "Can't bind Netlink socket";
45 mFd.reset();
46 mFailed = true;
47 }
48 }
49
clearPollErr()50 void Socket::clearPollErr() {
51 sockaddr_nl sa = {};
52 socklen_t saLen = sizeof(sa);
53 const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), mReceiveBuffer.size(), 0,
54 reinterpret_cast<sockaddr*>(&sa), &saLen);
55 if (errno != EINVAL) {
56 PLOG(WARNING) << "clearPollError() caught unexpected error: ";
57 }
58 CHECK_LE(bytesReceived, 0) << "clearPollError() didn't find an error!";
59 }
60
send(const Buffer<nlmsghdr> & msg,const sockaddr_nl & sa)61 bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
62 if constexpr (kSuperVerbose) {
63 LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending to " << sa.nl_pid << ": "
64 << toString(msg, mProtocol);
65 }
66 if (mFailed) return false;
67
68 mSeq = msg->nlmsg_seq;
69 const auto rawMsg = msg.getRaw();
70 const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
71 reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
72 if (bytesSent < 0) {
73 PLOG(ERROR) << "Can't send Netlink message";
74 return false;
75 } else if (size_t(bytesSent) != rawMsg.len()) {
76 LOG(ERROR) << "Can't send Netlink message: truncated message";
77 return false;
78 }
79 return true;
80 }
81
send(const Buffer<nlmsghdr> & msg,uint32_t destination)82 bool Socket::send(const Buffer<nlmsghdr>& msg, uint32_t destination) {
83 sockaddr_nl sa = {.nl_family = AF_NETLINK, .nl_pad = 0, .nl_pid = destination, .nl_groups = 0};
84 return send(msg, sa);
85 }
86
increaseReceiveBuffer(size_t maxSize)87 bool Socket::increaseReceiveBuffer(size_t maxSize) {
88 if (maxSize == 0) {
89 LOG(ERROR) << "Maximum receive size should not be zero";
90 return false;
91 }
92
93 if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
94 return true;
95 }
96
receive(size_t maxSize)97 std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
98 return receiveFrom(maxSize).first;
99 }
100
receiveFrom(size_t maxSize)101 std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
102 if (mFailed) return {std::nullopt, {}};
103
104 if (!increaseReceiveBuffer(maxSize)) return {std::nullopt, {}};
105
106 sockaddr_nl sa = {};
107 socklen_t saLen = sizeof(sa);
108 const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC,
109 reinterpret_cast<sockaddr*>(&sa), &saLen);
110
111 if (bytesReceived <= 0) {
112 PLOG(ERROR) << "Failed to receive Netlink message";
113 return {std::nullopt, {}};
114 } else if (size_t(bytesReceived) > maxSize) {
115 PLOG(ERROR) << "Received data larger than maximum receive size: " //
116 << bytesReceived << " > " << maxSize;
117 return {std::nullopt, {}};
118 }
119
120 Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(mReceiveBuffer.data()), bytesReceived);
121 if constexpr (kSuperVerbose) {
122 LOG(VERBOSE) << "received from " << sa.nl_pid << ": " << toString(msg, mProtocol);
123 }
124 long headerByteTotal = 0;
125 for (const auto hdr : msg) {
126 headerByteTotal += hdr->nlmsg_len;
127 }
128 if (bytesReceived != headerByteTotal) {
129 LOG(ERROR) << "received " << bytesReceived << " bytes, header claims " << headerByteTotal;
130 }
131 return {msg, sa};
132 }
133
receiveAck(uint32_t seq)134 bool Socket::receiveAck(uint32_t seq) {
135 const auto nlerr = receive<nlmsgerr>({NLMSG_ERROR});
136 if (!nlerr.has_value()) return false;
137
138 if (nlerr->data.msg.nlmsg_seq != seq) {
139 LOG(ERROR) << "Received ACK for a different message (" << nlerr->data.msg.nlmsg_seq
140 << ", expected " << seq << "). Multi-message tracking is not implemented.";
141 return false;
142 }
143
144 if (nlerr->data.error == 0) return true;
145
146 LOG(WARNING) << "Received Netlink error message: " << strerror(-nlerr->data.error);
147 return false;
148 }
149
receive(const std::set<nlmsgtype_t> & msgtypes,size_t maxSize)150 std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msgtypes,
151 size_t maxSize) {
152 if (mFailed || !increaseReceiveBuffer(maxSize)) return std::nullopt;
153
154 for (const auto rawMsg : *this) {
155 if (msgtypes.count(rawMsg->nlmsg_type) == 0) {
156 LOG(WARNING) << "Received (and ignored) unexpected Netlink message of type "
157 << rawMsg->nlmsg_type;
158 continue;
159 }
160
161 return rawMsg;
162 }
163
164 return std::nullopt;
165 }
166
getPid()167 std::optional<unsigned> Socket::getPid() {
168 if (mFailed) return std::nullopt;
169
170 sockaddr_nl sa = {};
171 socklen_t sasize = sizeof(sa);
172 if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
173 PLOG(ERROR) << "Failed to get PID of Netlink socket";
174 return std::nullopt;
175 }
176 return sa.nl_pid;
177 }
178
preparePoll(short events)179 pollfd Socket::preparePoll(short events) {
180 CHECK(mFd.get() > 0) << "Netlink socket fd is invalid!";
181 return {mFd.get(), events, 0};
182 }
183
addMembership(unsigned group)184 bool Socket::addMembership(unsigned group) {
185 const auto res =
186 setsockopt(mFd.get(), SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &group, sizeof(group));
187 if (res < 0) {
188 PLOG(ERROR) << "Failed joining multicast group " << group;
189 return false;
190 }
191 return true;
192 }
193
dropMembership(unsigned group)194 bool Socket::dropMembership(unsigned group) {
195 const auto res =
196 setsockopt(mFd.get(), SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, &group, sizeof(group));
197 if (res < 0) {
198 PLOG(ERROR) << "Failed leaving multicast group " << group;
199 return false;
200 }
201 return true;
202 }
203
receive_iterator(Socket & socket,bool end)204 Socket::receive_iterator::receive_iterator(Socket& socket, bool end)
205 : mSocket(socket), mIsEnd(end) {
206 if (!end) receive();
207 }
208
operator ++()209 Socket::receive_iterator Socket::receive_iterator::operator++() {
210 CHECK(!mIsEnd) << "Trying to increment end iterator";
211 ++mCurrent;
212 if (mCurrent.isEnd()) receive();
213 return *this;
214 }
215
operator ==(const receive_iterator & other) const216 bool Socket::receive_iterator::operator==(const receive_iterator& other) const {
217 if (mIsEnd != other.mIsEnd) return false;
218 if (mIsEnd && other.mIsEnd) return true;
219 return mCurrent == other.mCurrent;
220 }
221
operator *() const222 const Buffer<nlmsghdr>& Socket::receive_iterator::operator*() const {
223 CHECK(!mIsEnd) << "Trying to dereference end iterator";
224 return *mCurrent;
225 }
226
receive()227 void Socket::receive_iterator::receive() {
228 CHECK(!mIsEnd) << "Trying to receive on end iterator";
229 CHECK(mCurrent.isEnd()) << "Trying to receive without draining previous read";
230
231 const auto buf = mSocket.receive();
232 if (buf.has_value()) {
233 mCurrent = buf->begin();
234 } else {
235 mIsEnd = true;
236 }
237 }
238
begin()239 Socket::receive_iterator Socket::begin() {
240 return {*this, false};
241 }
242
end()243 Socket::receive_iterator Socket::end() {
244 return {*this, true};
245 }
246
247 } // namespace android::nl
248