1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/net/netlink_message.h"
18
19 #include <limits.h>
20
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <string>
25
26 #include <base/format_macros.h>
27 #include <base/logging.h>
28 #include <base/stl_util.h>
29 #include <base/strings/stringprintf.h>
30
31 #include "shill/net/netlink_packet.h"
32
33 using base::StringAppendF;
34 using base::StringPrintf;
35 using std::map;
36 using std::min;
37 using std::string;
38
39 namespace shill {
40
41 const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0;
42 const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX;
43
44 // NetlinkMessage
45
EncodeHeader(uint32_t sequence_number)46 ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) {
47 ByteString result;
48 if (message_type_ == kIllegalMessageType) {
49 LOG(ERROR) << "Message type not set";
50 return result;
51 }
52 sequence_number_ = sequence_number;
53 if (sequence_number_ == kBroadcastSequenceNumber) {
54 LOG(ERROR) << "Couldn't get a legal sequence number";
55 return result;
56 }
57
58 // Build netlink header.
59 nlmsghdr header;
60 size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header));
61 header.nlmsg_len = nlmsghdr_with_pad;
62 header.nlmsg_type = message_type_;
63 header.nlmsg_flags = NLM_F_REQUEST | flags_;
64 header.nlmsg_seq = sequence_number_;
65 header.nlmsg_pid = getpid();
66
67 // Netlink header + pad.
68 result.Append(ByteString(reinterpret_cast<unsigned char*>(&header),
69 sizeof(header)));
70 result.Resize(nlmsghdr_with_pad); // Zero-fill pad space (if any).
71 return result;
72 }
73
InitAndStripHeader(NetlinkPacket * packet)74 bool NetlinkMessage::InitAndStripHeader(NetlinkPacket* packet) {
75 const nlmsghdr& header = packet->GetNlMsgHeader();
76 message_type_ = header.nlmsg_type;
77 flags_ = header.nlmsg_flags;
78 sequence_number_ = header.nlmsg_seq;
79
80 return true;
81 }
82
InitFromPacket(NetlinkPacket * packet,NetlinkMessage::MessageContext context)83 bool NetlinkMessage::InitFromPacket(NetlinkPacket* packet,
84 NetlinkMessage::MessageContext context) {
85 if (!packet) {
86 LOG(ERROR) << "Null |packet| parameter";
87 return false;
88 }
89 if (!InitAndStripHeader(packet)) {
90 return false;
91 }
92 return true;
93 }
94
95 // static
PrintBytes(int log_level,const unsigned char * buf,size_t num_bytes)96 void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf,
97 size_t num_bytes) {
98 VLOG(log_level) << "Netlink Message -- Examining Bytes";
99 if (!buf) {
100 VLOG(log_level) << "<NULL Buffer>";
101 return;
102 }
103
104 if (num_bytes >= sizeof(nlmsghdr)) {
105 PrintHeader(log_level, reinterpret_cast<const nlmsghdr*>(buf));
106 buf += sizeof(nlmsghdr);
107 num_bytes -= sizeof(nlmsghdr);
108 } else {
109 VLOG(log_level) << "Not enough bytes (" << num_bytes
110 << ") for a complete nlmsghdr (requires "
111 << sizeof(nlmsghdr) << ").";
112 }
113
114 PrintPayload(log_level, buf, num_bytes);
115 }
116
117 // static
PrintPacket(int log_level,const NetlinkPacket & packet)118 void NetlinkMessage::PrintPacket(int log_level, const NetlinkPacket& packet) {
119 VLOG(log_level) << "Netlink Message -- Examining Packet";
120 if (!packet.IsValid()) {
121 VLOG(log_level) << "<Invalid Buffer>";
122 return;
123 }
124
125 PrintHeader(log_level, &packet.GetNlMsgHeader());
126 const ByteString& payload = packet.GetPayload();
127 PrintPayload(log_level, payload.GetConstData(), payload.GetLength());
128 }
129
130 // static
PrintHeader(int log_level,const nlmsghdr * header)131 void NetlinkMessage::PrintHeader(int log_level, const nlmsghdr* header) {
132 const unsigned char* buf = reinterpret_cast<const unsigned char*>(header);
133 VLOG(log_level) << StringPrintf(
134 "len: %02x %02x %02x %02x = %u bytes",
135 buf[0], buf[1], buf[2], buf[3], header->nlmsg_len);
136
137 VLOG(log_level) << StringPrintf(
138 "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s",
139 buf[4], buf[5], buf[6], buf[7], header->nlmsg_type,
140 ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""),
141 ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""),
142 ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""),
143 ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""),
144 ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : ""));
145
146 VLOG(log_level) << StringPrintf(
147 "sequence: %02x %02x %02x %02x = %u",
148 buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq);
149 VLOG(log_level) << StringPrintf(
150 "pid: %02x %02x %02x %02x = %u",
151 buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid);
152 }
153
154 // static
PrintPayload(int log_level,const unsigned char * buf,size_t num_bytes)155 void NetlinkMessage::PrintPayload(int log_level, const unsigned char* buf,
156 size_t num_bytes) {
157 while (num_bytes) {
158 string output;
159 size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32));
160 for (size_t i = 0; i < bytes_this_row; ++i) {
161 StringAppendF(&output, " %02x", *buf++);
162 }
163 VLOG(log_level) << output;
164 num_bytes -= bytes_this_row;
165 }
166 }
167
168 // ErrorAckMessage.
169
170 const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR;
171
InitFromPacket(NetlinkPacket * packet,NetlinkMessage::MessageContext context)172 bool ErrorAckMessage::InitFromPacket(NetlinkPacket* packet,
173 NetlinkMessage::MessageContext context) {
174 if (!packet) {
175 LOG(ERROR) << "Null |const_msg| parameter";
176 return false;
177 }
178 if (!InitAndStripHeader(packet)) {
179 return false;
180 }
181
182 // Get the error code from the payload.
183 return packet->ConsumeData(sizeof(error_), &error_);
184 }
185
Encode(uint32_t sequence_number)186 ByteString ErrorAckMessage::Encode(uint32_t sequence_number) {
187 LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel";
188 return ByteString();
189 }
190
ToString() const191 string ErrorAckMessage::ToString() const {
192 string output;
193 if (error()) {
194 StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s",
195 -error_, strerror(-error_));
196 } else {
197 StringAppendF(&output, "ACK");
198 }
199 return output;
200 }
201
Print(int header_log_level,int) const202 void ErrorAckMessage::Print(int header_log_level,
203 int /*detail_log_level*/) const {
204 VLOG(header_log_level) << ToString();
205 }
206
207 // NoopMessage.
208
209 const uint16_t NoopMessage::kMessageType = NLMSG_NOOP;
210
Encode(uint32_t sequence_number)211 ByteString NoopMessage::Encode(uint32_t sequence_number) {
212 LOG(ERROR) << "We're not supposed to send NOOP to the kernel";
213 return ByteString();
214 }
215
Print(int header_log_level,int) const216 void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const {
217 VLOG(header_log_level) << ToString();
218 }
219
220 // DoneMessage.
221
222 const uint16_t DoneMessage::kMessageType = NLMSG_DONE;
223
Encode(uint32_t sequence_number)224 ByteString DoneMessage::Encode(uint32_t sequence_number) {
225 return EncodeHeader(sequence_number);
226 }
227
Print(int header_log_level,int) const228 void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const {
229 VLOG(header_log_level) << ToString();
230 }
231
232 // OverrunMessage.
233
234 const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN;
235
Encode(uint32_t sequence_number)236 ByteString OverrunMessage::Encode(uint32_t sequence_number) {
237 LOG(ERROR) << "We're not supposed to send Overruns to the kernel";
238 return ByteString();
239 }
240
Print(int header_log_level,int) const241 void OverrunMessage::Print(int header_log_level,
242 int /*detail_log_level*/) const {
243 VLOG(header_log_level) << ToString();
244 }
245
246 // UnknownMessage.
247
Encode(uint32_t sequence_number)248 ByteString UnknownMessage::Encode(uint32_t sequence_number) {
249 LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel";
250 return ByteString();
251 }
252
Print(int header_log_level,int) const253 void UnknownMessage::Print(int header_log_level,
254 int /*detail_log_level*/) const {
255 int total_bytes = message_body_.GetLength();
256 const uint8_t* const_data = message_body_.GetConstData();
257
258 string output = StringPrintf("%d bytes:", total_bytes);
259 for (int i = 0; i < total_bytes; ++i) {
260 StringAppendF(&output, " 0x%02x", const_data[i]);
261 }
262 VLOG(header_log_level) << output;
263 }
264
265 //
266 // Factory class.
267 //
268
AddFactoryMethod(uint16_t message_type,FactoryMethod factory)269 bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type,
270 FactoryMethod factory) {
271 if (ContainsKey(factories_, message_type)) {
272 LOG(WARNING) << "Message type " << message_type << " already exists.";
273 return false;
274 }
275 if (message_type == NetlinkMessage::kIllegalMessageType) {
276 LOG(ERROR) << "Not installing factory for illegal message type.";
277 return false;
278 }
279 factories_[message_type] = factory;
280 return true;
281 }
282
CreateMessage(NetlinkPacket * packet,NetlinkMessage::MessageContext context) const283 NetlinkMessage* NetlinkMessageFactory::CreateMessage(
284 NetlinkPacket* packet, NetlinkMessage::MessageContext context) const {
285 std::unique_ptr<NetlinkMessage> message;
286
287 auto message_type = packet->GetMessageType();
288 if (message_type == NoopMessage::kMessageType) {
289 message.reset(new NoopMessage());
290 } else if (message_type == DoneMessage::kMessageType) {
291 message.reset(new DoneMessage());
292 } else if (message_type == OverrunMessage::kMessageType) {
293 message.reset(new OverrunMessage());
294 } else if (message_type == ErrorAckMessage::kMessageType) {
295 message.reset(new ErrorAckMessage());
296 } else if (ContainsKey(factories_, message_type)) {
297 map<uint16_t, FactoryMethod>::const_iterator factory;
298 factory = factories_.find(message_type);
299 message.reset(factory->second.Run(*packet));
300 }
301
302 // If no factory exists for this message _or_ if a factory exists but it
303 // failed, there'll be no message. Handle either of those cases, by
304 // creating an |UnknownMessage|.
305 if (!message) {
306 message.reset(new UnknownMessage(message_type, packet->GetPayload()));
307 }
308
309 if (!message->InitFromPacket(packet, context)) {
310 LOG(ERROR) << "Message did not initialize properly";
311 return nullptr;
312 }
313
314 return message.release();
315 }
316
317 } // namespace shill.
318