• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <libnl++/Socket.h>
18 
19 #include <libnl++/printer.h>
20 
21 #include <android-base/logging.h>
22 
23 namespace android::nl {
24 
25 /**
26  * Print all outbound/inbound Netlink messages.
27  */
28 static constexpr bool kSuperVerbose = false;
29 
Socket(int protocol,unsigned pid,uint32_t groups)30 Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) {
31     mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
32     if (!mFd.ok()) {
33         PLOG(ERROR) << "Can't open Netlink socket";
34         mFailed = true;
35         return;
36     }
37 
38     sockaddr_nl sa = {};
39     sa.nl_family = AF_NETLINK;
40     sa.nl_pid = pid;
41     sa.nl_groups = groups;
42 
43     if (bind(mFd.get(), reinterpret_cast<sockaddr*>(&sa), sizeof(sa)) < 0) {
44         PLOG(ERROR) << "Can't bind Netlink socket";
45         mFd.reset();
46         mFailed = true;
47     }
48 }
49 
send(const Buffer<nlmsghdr> & msg,const sockaddr_nl & sa)50 bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
51     if constexpr (kSuperVerbose) {
52         LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending to " << sa.nl_pid << ": "
53                      << toString(msg, mProtocol);
54     }
55     if (mFailed) return false;
56 
57     mSeq = msg->nlmsg_seq;
58     const auto rawMsg = msg.getRaw();
59     const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
60                                   reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
61     if (bytesSent < 0) {
62         PLOG(ERROR) << "Can't send Netlink message";
63         return false;
64     } else if (size_t(bytesSent) != rawMsg.len()) {
65         LOG(ERROR) << "Can't send Netlink message: truncated message";
66         return false;
67     }
68     return true;
69 }
70 
send(const Buffer<nlmsghdr> & msg,uint32_t destination)71 bool Socket::send(const Buffer<nlmsghdr>& msg, uint32_t destination) {
72     sockaddr_nl sa = {.nl_family = AF_NETLINK, .nl_pad = 0, .nl_pid = destination, .nl_groups = 0};
73     return send(msg, sa);
74 }
75 
increaseReceiveBuffer(size_t maxSize)76 bool Socket::increaseReceiveBuffer(size_t maxSize) {
77     if (maxSize == 0) {
78         LOG(ERROR) << "Maximum receive size should not be zero";
79         return false;
80     }
81 
82     if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
83     return true;
84 }
85 
receive(size_t maxSize)86 std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
87     return receiveFrom(maxSize).first;
88 }
89 
receiveFrom(size_t maxSize)90 std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
91     if (mFailed) return {std::nullopt, {}};
92 
93     if (!increaseReceiveBuffer(maxSize)) return {std::nullopt, {}};
94 
95     sockaddr_nl sa = {};
96     socklen_t saLen = sizeof(sa);
97     const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC,
98                                         reinterpret_cast<sockaddr*>(&sa), &saLen);
99 
100     if (bytesReceived <= 0) {
101         PLOG(ERROR) << "Failed to receive Netlink message";
102         return {std::nullopt, {}};
103     } else if (size_t(bytesReceived) > maxSize) {
104         PLOG(ERROR) << "Received data larger than maximum receive size: "  //
105                     << bytesReceived << " > " << maxSize;
106         return {std::nullopt, {}};
107     }
108 
109     Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(mReceiveBuffer.data()), bytesReceived);
110     if constexpr (kSuperVerbose) {
111         LOG(VERBOSE) << "received from " << sa.nl_pid << ": " << toString(msg, mProtocol);
112     }
113     return {msg, sa};
114 }
115 
receiveAck(uint32_t seq)116 bool Socket::receiveAck(uint32_t seq) {
117     const auto nlerr = receive<nlmsgerr>({NLMSG_ERROR});
118     if (!nlerr.has_value()) return false;
119 
120     if (nlerr->data.msg.nlmsg_seq != seq) {
121         LOG(ERROR) << "Received ACK for a different message (" << nlerr->data.msg.nlmsg_seq
122                    << ", expected " << seq << "). Multi-message tracking is not implemented.";
123         return false;
124     }
125 
126     if (nlerr->data.error == 0) return true;
127 
128     LOG(WARNING) << "Received Netlink error message: " << strerror(-nlerr->data.error);
129     return false;
130 }
131 
receive(const std::set<nlmsgtype_t> & msgtypes,size_t maxSize)132 std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msgtypes,
133                                                 size_t maxSize) {
134     if (mFailed || !increaseReceiveBuffer(maxSize)) return std::nullopt;
135 
136     for (const auto rawMsg : *this) {
137         if (msgtypes.count(rawMsg->nlmsg_type) == 0) {
138             LOG(WARNING) << "Received (and ignored) unexpected Netlink message of type "
139                          << rawMsg->nlmsg_type;
140             continue;
141         }
142 
143         return rawMsg;
144     }
145 
146     return std::nullopt;
147 }
148 
getPid()149 std::optional<unsigned> Socket::getPid() {
150     if (mFailed) return std::nullopt;
151 
152     sockaddr_nl sa = {};
153     socklen_t sasize = sizeof(sa);
154     if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
155         PLOG(ERROR) << "Failed to get PID of Netlink socket";
156         return std::nullopt;
157     }
158     return sa.nl_pid;
159 }
160 
preparePoll(short events)161 pollfd Socket::preparePoll(short events) {
162     return {mFd.get(), events, 0};
163 }
164 
addMembership(unsigned group)165 bool Socket::addMembership(unsigned group) {
166     const auto res =
167             setsockopt(mFd.get(), SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &group, sizeof(group));
168     if (res < 0) {
169         PLOG(ERROR) << "Failed joining multicast group " << group;
170         return false;
171     }
172     return true;
173 }
174 
dropMembership(unsigned group)175 bool Socket::dropMembership(unsigned group) {
176     const auto res =
177             setsockopt(mFd.get(), SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, &group, sizeof(group));
178     if (res < 0) {
179         PLOG(ERROR) << "Failed leaving multicast group " << group;
180         return false;
181     }
182     return true;
183 }
184 
receive_iterator(Socket & socket,bool end)185 Socket::receive_iterator::receive_iterator(Socket& socket, bool end)
186     : mSocket(socket), mIsEnd(end) {
187     if (!end) receive();
188 }
189 
operator ++()190 Socket::receive_iterator Socket::receive_iterator::operator++() {
191     CHECK(!mIsEnd) << "Trying to increment end iterator";
192     ++mCurrent;
193     if (mCurrent.isEnd()) receive();
194     return *this;
195 }
196 
operator ==(const receive_iterator & other) const197 bool Socket::receive_iterator::operator==(const receive_iterator& other) const {
198     if (mIsEnd != other.mIsEnd) return false;
199     if (mIsEnd && other.mIsEnd) return true;
200     return mCurrent == other.mCurrent;
201 }
202 
operator *() const203 const Buffer<nlmsghdr>& Socket::receive_iterator::operator*() const {
204     CHECK(!mIsEnd) << "Trying to dereference end iterator";
205     return *mCurrent;
206 }
207 
receive()208 void Socket::receive_iterator::receive() {
209     CHECK(!mIsEnd) << "Trying to receive on end iterator";
210     CHECK(mCurrent.isEnd()) << "Trying to receive without draining previous read";
211 
212     const auto buf = mSocket.receive();
213     if (buf.has_value()) {
214         mCurrent = buf->begin();
215     } else {
216         mIsEnd = true;
217     }
218 }
219 
begin()220 Socket::receive_iterator Socket::begin() {
221     return {*this, false};
222 }
223 
end()224 Socket::receive_iterator Socket::end() {
225     return {*this, true};
226 }
227 
228 }  // namespace android::nl
229