• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/netlink_socket.h"
18 
19 #include <string>
20 
21 #include <linux/if_packet.h>
22 #include <linux/netlink.h>
23 #include <sys/socket.h>
24 
25 #include <base/logging.h>
26 
27 #include "shill/net/netlink_message.h"
28 #include "shill/net/sockets.h"
29 
30 // This is from a version of linux/socket.h that we don't have.
31 #define SOL_NETLINK 270
32 
33 namespace shill {
34 
35 // Keep this large enough to avoid overflows on IPv6 SNM routing update spikes
36 const int NetlinkSocket::kReceiveBufferSize = 512 * 1024;
37 
NetlinkSocket()38 NetlinkSocket::NetlinkSocket() : sequence_number_(0), file_descriptor_(-1) {}
39 
~NetlinkSocket()40 NetlinkSocket::~NetlinkSocket() {
41   if (sockets_ && (file_descriptor_ >= 0)) {
42     sockets_->Close(file_descriptor_);
43   }
44 }
45 
Init()46 bool NetlinkSocket::Init() {
47   // Allows for a test to set |sockets_| before calling |Init|.
48   if (sockets_) {
49     LOG(INFO) << "|sockets_| already has a value -- this must be a test.";
50   } else {
51     sockets_.reset(new Sockets);
52   }
53 
54   // The following is stolen directly from RTNLHandler.
55   // TODO(wdg): refactor this and RTNLHandler together to use common code.
56   // crbug.com/221940
57 
58   file_descriptor_ = sockets_->Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC);
59   if (file_descriptor_ < 0) {
60     LOG(ERROR) << "Failed to open netlink socket";
61     return false;
62   }
63 
64   if (sockets_->SetReceiveBuffer(file_descriptor_, kReceiveBufferSize)) {
65     LOG(ERROR) << "Failed to increase receive buffer size";
66   }
67 
68   struct sockaddr_nl addr;
69   memset(&addr, 0, sizeof(addr));
70   addr.nl_family = AF_NETLINK;
71 
72   if (sockets_->Bind(file_descriptor_,
73                     reinterpret_cast<struct sockaddr*>(&addr),
74                     sizeof(addr)) < 0) {
75     sockets_->Close(file_descriptor_);
76     file_descriptor_ = -1;
77     LOG(ERROR) << "Netlink socket bind failed";
78     return false;
79   }
80   VLOG(2) << "Netlink socket started";
81 
82   return true;
83 }
84 
RecvMessage(ByteString * message)85 bool NetlinkSocket::RecvMessage(ByteString* message) {
86   if (!message) {
87     LOG(ERROR) << "Null |message|";
88     return false;
89   }
90 
91   // Determine the amount of data currently waiting.
92   const size_t kDummyReadByteCount = 1;
93   ByteString dummy_read(kDummyReadByteCount);
94   ssize_t result;
95   result = sockets_->RecvFrom(
96       file_descriptor_,
97       dummy_read.GetData(),
98       dummy_read.GetLength(),
99       MSG_TRUNC | MSG_PEEK,
100       nullptr,
101       nullptr);
102   if (result < 0) {
103     PLOG(ERROR) << "Socket recvfrom failed.";
104     return false;
105   }
106 
107   // Read the data that was waiting when we did our previous read.
108   message->Resize(result);
109   result = sockets_->RecvFrom(
110       file_descriptor_,
111       message->GetData(),
112       message->GetLength(),
113       0,
114       nullptr,
115       nullptr);
116   if (result < 0) {
117     PLOG(ERROR) << "Second socket recvfrom failed.";
118     return false;
119   }
120   return true;
121 }
122 
SendMessage(const ByteString & out_msg)123 bool NetlinkSocket::SendMessage(const ByteString& out_msg) {
124   ssize_t result = sockets_->Send(file_descriptor(), out_msg.GetConstData(),
125                                   out_msg.GetLength(), 0);
126   if (!result) {
127     PLOG(ERROR) << "Send failed.";
128     return false;
129   }
130   if (result != static_cast<ssize_t>(out_msg.GetLength())) {
131     LOG(ERROR) << "Only sent " << result << " bytes out of "
132                << out_msg.GetLength() << ".";
133     return false;
134   }
135 
136   return true;
137 }
138 
SubscribeToEvents(uint32_t group_id)139 bool NetlinkSocket::SubscribeToEvents(uint32_t group_id) {
140   int err = setsockopt(file_descriptor_, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
141                        &group_id, sizeof(group_id));
142   if (err < 0) {
143     PLOG(ERROR) << "setsockopt didn't work.";
144     return false;
145   }
146   return true;
147 }
148 
GetSequenceNumber()149 uint32_t NetlinkSocket::GetSequenceNumber() {
150   if (++sequence_number_ == NetlinkMessage::kBroadcastSequenceNumber)
151     ++sequence_number_;
152   return sequence_number_;
153 }
154 
155 }  // namespace shill.
156