1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/libs/net/netlink_client.h"
17
18 #include <errno.h>
19 #include <linux/rtnetlink.h>
20 #include <linux/sockios.h>
21 #include <net/if.h>
22 #include <sys/socket.h>
23
24 #include "common/libs/fs/shared_fd.h"
25 #include "common/libs/glog/logging.h"
26
27 namespace cvd {
28 namespace {
29 // NetlinkClient implementation.
30 // Talks to libnetlink to apply network changes.
31 class NetlinkClientImpl : public NetlinkClient {
32 public:
33 NetlinkClientImpl() = default;
34 virtual ~NetlinkClientImpl() = default;
35
36 virtual bool Send(const NetlinkRequest& message);
37
38 // Initialize NetlinkClient instance.
39 // Open netlink channel and initialize interface list.
40 // Parameter |type| specifies which netlink target to address, eg.
41 // NETLINK_ROUTE.
42 // Returns true, if initialization was successful.
43 bool OpenNetlink(int type);
44
45 private:
46 bool CheckResponse(uint32_t seq_no);
47
48 SharedFD netlink_fd_;
49 sockaddr_nl address_;
50 };
51
CheckResponse(uint32_t seq_no)52 bool NetlinkClientImpl::CheckResponse(uint32_t seq_no) {
53 uint32_t len;
54 char buf[4096];
55 struct iovec iov = { buf, sizeof(buf) };
56 struct sockaddr_nl sa;
57 struct msghdr msg = { &sa, sizeof(sa), &iov, 1, NULL, 0, 0 };
58 struct nlmsghdr *nh;
59
60 int result = netlink_fd_->RecvMsg(&msg, 0);
61 if (result < 0) {
62 LOG(ERROR) << "Netlink error: " << strerror(errno);
63 return false;
64 }
65
66 len = static_cast<uint32_t>(result);
67 LOG(INFO) << "Received netlink response (" << len << " bytes)";
68
69 for (nh = reinterpret_cast<nlmsghdr*>(buf);
70 NLMSG_OK(nh, len);
71 nh = NLMSG_NEXT(nh, len)) {
72 if (nh->nlmsg_seq != seq_no) {
73 // This really shouldn't happen. If we see this, it means somebody is
74 // issuing netlink requests using the same socket as us, and ignoring
75 // responses.
76 LOG(WARNING) << "Sequence number mismatch: "
77 << nh->nlmsg_seq << " != " << seq_no;
78 continue;
79 }
80
81 // This is the end of multi-part message.
82 // It indicates there's nothing more netlink wants to tell us.
83 // It also means we failed to find the response to our request.
84 if (nh->nlmsg_type == NLMSG_DONE)
85 break;
86
87 // This is the 'nlmsgerr' package carrying response to our request.
88 // It carries an 'error' value (errno) along with the netlink header info
89 // that caused this error.
90 if (nh->nlmsg_type == NLMSG_ERROR) {
91 nlmsgerr* err = reinterpret_cast<nlmsgerr*>(nh + 1);
92 if (err->error < 0) {
93 LOG(ERROR) << "Failed to complete netlink request: "
94 << "Netlink error: " << err->error
95 << ", errno: " << strerror(errno);
96 return false;
97 }
98 return true;
99 }
100 }
101
102 LOG(ERROR) << "No response from netlink.";
103 return false;
104 }
105
Send(const NetlinkRequest & message)106 bool NetlinkClientImpl::Send(const NetlinkRequest& message) {
107 struct sockaddr_nl netlink_addr;
108 struct iovec netlink_iov = {
109 message.RequestData(),
110 message.RequestLength()
111 };
112 struct msghdr msg;
113 memset(&msg, 0, sizeof(msg));
114 memset(&netlink_addr, 0, sizeof(netlink_addr));
115
116 msg.msg_name = &address_;
117 msg.msg_namelen = sizeof(address_);
118 msg.msg_iov = &netlink_iov;
119 msg.msg_iovlen = sizeof(netlink_iov) / sizeof(iovec);
120
121 if (netlink_fd_->SendMsg(&msg, 0) < 0) {
122 LOG(ERROR) << "Failed to send netlink message: "
123 << strerror(errno);
124
125 return false;
126 }
127
128 return CheckResponse(message.SeqNo());
129 }
130
OpenNetlink(int type)131 bool NetlinkClientImpl::OpenNetlink(int type) {
132 netlink_fd_ = SharedFD::Socket(AF_NETLINK, SOCK_RAW, type);
133 if (!netlink_fd_->IsOpen()) return false;
134
135 address_.nl_family = AF_NETLINK;
136 address_.nl_groups = 0;
137
138 netlink_fd_->Bind(reinterpret_cast<sockaddr*>(&address_), sizeof(address_));
139
140 return true;
141 }
142
143 class NetlinkClientFactoryImpl : public NetlinkClientFactory {
144 public:
145 NetlinkClientFactoryImpl() = default;
146 ~NetlinkClientFactoryImpl() override = default;
147
New(int type)148 std::unique_ptr<NetlinkClient> New(int type) override {
149 auto client_raw = new NetlinkClientImpl();
150 // Use RVO when possible.
151 std::unique_ptr<NetlinkClient> client(client_raw);
152
153 if (!client_raw->OpenNetlink(type)) {
154 // Note: deletes client_raw.
155 client.reset();
156 }
157 return client;
158 }
159 };
160
161 } // namespace
162
Default()163 NetlinkClientFactory* NetlinkClientFactory::Default() {
164 static NetlinkClientFactory &factory = *new NetlinkClientFactoryImpl();
165 return &factory;
166 }
167
168 } // namespace cvd
169