• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (C) 2012 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/netlink_message.h"
18 
19 #include <limits.h>
20 
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <string>
25 
26 #include <base/format_macros.h>
27 #include <base/logging.h>
28 #include <base/stl_util.h>
29 #include <base/strings/stringprintf.h>
30 
31 #include "shill/net/netlink_packet.h"
32 
33 using base::StringAppendF;
34 using base::StringPrintf;
35 using std::map;
36 using std::min;
37 using std::string;
38 
39 namespace shill {
40 
41 const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0;
42 const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX;
43 
44 // NetlinkMessage
45 
EncodeHeader(uint32_t sequence_number)46 ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) {
47   ByteString result;
48   if (message_type_ == kIllegalMessageType) {
49     LOG(ERROR) << "Message type not set";
50     return result;
51   }
52   sequence_number_ = sequence_number;
53   if (sequence_number_ == kBroadcastSequenceNumber) {
54     LOG(ERROR) << "Couldn't get a legal sequence number";
55     return result;
56   }
57 
58   // Build netlink header.
59   nlmsghdr header;
60   size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header));
61   header.nlmsg_len = nlmsghdr_with_pad;
62   header.nlmsg_type = message_type_;
63   header.nlmsg_flags = NLM_F_REQUEST | flags_;
64   header.nlmsg_seq = sequence_number_;
65   header.nlmsg_pid = getpid();
66 
67   // Netlink header + pad.
68   result.Append(ByteString(reinterpret_cast<unsigned char*>(&header),
69                            sizeof(header)));
70   result.Resize(nlmsghdr_with_pad);  // Zero-fill pad space (if any).
71   return result;
72 }
73 
InitAndStripHeader(NetlinkPacket * packet)74 bool NetlinkMessage::InitAndStripHeader(NetlinkPacket* packet) {
75   const nlmsghdr& header = packet->GetNlMsgHeader();
76   message_type_ = header.nlmsg_type;
77   flags_ = header.nlmsg_flags;
78   sequence_number_ = header.nlmsg_seq;
79 
80   return true;
81 }
82 
InitFromPacket(NetlinkPacket * packet,NetlinkMessage::MessageContext context)83 bool NetlinkMessage::InitFromPacket(NetlinkPacket* packet,
84                                     NetlinkMessage::MessageContext context) {
85   if (!packet) {
86     LOG(ERROR) << "Null |packet| parameter";
87     return false;
88   }
89   if (!InitAndStripHeader(packet)) {
90     return false;
91   }
92   return true;
93 }
94 
95 // static
PrintBytes(int log_level,const unsigned char * buf,size_t num_bytes)96 void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf,
97                                 size_t num_bytes) {
98   VLOG(log_level) << "Netlink Message -- Examining Bytes";
99   if (!buf) {
100     VLOG(log_level) << "<NULL Buffer>";
101     return;
102   }
103 
104   if (num_bytes >= sizeof(nlmsghdr)) {
105       PrintHeader(log_level, reinterpret_cast<const nlmsghdr*>(buf));
106       buf += sizeof(nlmsghdr);
107       num_bytes -= sizeof(nlmsghdr);
108   } else {
109     VLOG(log_level) << "Not enough bytes (" << num_bytes
110                     << ") for a complete nlmsghdr (requires "
111                     << sizeof(nlmsghdr) << ").";
112   }
113 
114   PrintPayload(log_level, buf, num_bytes);
115 }
116 
117 // static
PrintPacket(int log_level,const NetlinkPacket & packet)118 void NetlinkMessage::PrintPacket(int log_level, const NetlinkPacket& packet) {
119   VLOG(log_level) << "Netlink Message -- Examining Packet";
120   if (!packet.IsValid()) {
121     VLOG(log_level) << "<Invalid Buffer>";
122     return;
123   }
124 
125   PrintHeader(log_level, &packet.GetNlMsgHeader());
126   const ByteString& payload = packet.GetPayload();
127   PrintPayload(log_level, payload.GetConstData(), payload.GetLength());
128 }
129 
130 // static
PrintHeader(int log_level,const nlmsghdr * header)131 void NetlinkMessage::PrintHeader(int log_level, const nlmsghdr* header) {
132   const unsigned char* buf = reinterpret_cast<const unsigned char*>(header);
133   VLOG(log_level) << StringPrintf(
134       "len:          %02x %02x %02x %02x = %u bytes",
135       buf[0], buf[1], buf[2], buf[3], header->nlmsg_len);
136 
137   VLOG(log_level) << StringPrintf(
138       "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s",
139       buf[4], buf[5], buf[6], buf[7], header->nlmsg_type,
140       ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""),
141       ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""),
142       ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""),
143       ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""),
144       ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : ""));
145 
146   VLOG(log_level) << StringPrintf(
147       "sequence:     %02x %02x %02x %02x = %u",
148       buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq);
149   VLOG(log_level) << StringPrintf(
150       "pid:          %02x %02x %02x %02x = %u",
151       buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid);
152 }
153 
154 // static
PrintPayload(int log_level,const unsigned char * buf,size_t num_bytes)155 void NetlinkMessage::PrintPayload(int log_level, const unsigned char* buf,
156                                   size_t num_bytes) {
157   while (num_bytes) {
158     string output;
159     size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32));
160     for (size_t i = 0; i < bytes_this_row; ++i) {
161       StringAppendF(&output, " %02x", *buf++);
162     }
163     VLOG(log_level) << output;
164     num_bytes -= bytes_this_row;
165   }
166 }
167 
168 // ErrorAckMessage.
169 
170 const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR;
171 
InitFromPacket(NetlinkPacket * packet,NetlinkMessage::MessageContext context)172 bool ErrorAckMessage::InitFromPacket(NetlinkPacket* packet,
173                                      NetlinkMessage::MessageContext context) {
174   if (!packet) {
175     LOG(ERROR) << "Null |const_msg| parameter";
176     return false;
177   }
178   if (!InitAndStripHeader(packet)) {
179     return false;
180   }
181 
182   // Get the error code from the payload.
183   return packet->ConsumeData(sizeof(error_), &error_);
184 }
185 
Encode(uint32_t sequence_number)186 ByteString ErrorAckMessage::Encode(uint32_t sequence_number) {
187   LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel";
188   return ByteString();
189 }
190 
ToString() const191 string ErrorAckMessage::ToString() const {
192   string output;
193   if (error()) {
194     StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s",
195                   -error_, strerror(-error_));
196   } else {
197     StringAppendF(&output, "ACK");
198   }
199   return output;
200 }
201 
Print(int header_log_level,int) const202 void ErrorAckMessage::Print(int header_log_level,
203                             int /*detail_log_level*/) const {
204   VLOG(header_log_level) << ToString();
205 }
206 
207 // NoopMessage.
208 
209 const uint16_t NoopMessage::kMessageType = NLMSG_NOOP;
210 
Encode(uint32_t sequence_number)211 ByteString NoopMessage::Encode(uint32_t sequence_number) {
212   LOG(ERROR) << "We're not supposed to send NOOP to the kernel";
213   return ByteString();
214 }
215 
Print(int header_log_level,int) const216 void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const {
217   VLOG(header_log_level) << ToString();
218 }
219 
220 // DoneMessage.
221 
222 const uint16_t DoneMessage::kMessageType = NLMSG_DONE;
223 
Encode(uint32_t sequence_number)224 ByteString DoneMessage::Encode(uint32_t sequence_number) {
225   return EncodeHeader(sequence_number);
226 }
227 
Print(int header_log_level,int) const228 void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const {
229   VLOG(header_log_level) << ToString();
230 }
231 
232 // OverrunMessage.
233 
234 const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN;
235 
Encode(uint32_t sequence_number)236 ByteString OverrunMessage::Encode(uint32_t sequence_number) {
237   LOG(ERROR) << "We're not supposed to send Overruns to the kernel";
238   return ByteString();
239 }
240 
Print(int header_log_level,int) const241 void OverrunMessage::Print(int header_log_level,
242                            int /*detail_log_level*/) const {
243   VLOG(header_log_level) << ToString();
244 }
245 
246 // UnknownMessage.
247 
Encode(uint32_t sequence_number)248 ByteString UnknownMessage::Encode(uint32_t sequence_number) {
249   LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel";
250   return ByteString();
251 }
252 
Print(int header_log_level,int) const253 void UnknownMessage::Print(int header_log_level,
254                            int /*detail_log_level*/) const {
255   int total_bytes = message_body_.GetLength();
256   const uint8_t* const_data = message_body_.GetConstData();
257 
258   string output = StringPrintf("%d bytes:", total_bytes);
259   for (int i = 0; i < total_bytes; ++i) {
260     StringAppendF(&output, " 0x%02x", const_data[i]);
261   }
262   VLOG(header_log_level) << output;
263 }
264 
265 //
266 // Factory class.
267 //
268 
AddFactoryMethod(uint16_t message_type,FactoryMethod factory)269 bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type,
270                                              FactoryMethod factory) {
271   if (ContainsKey(factories_, message_type)) {
272     LOG(WARNING) << "Message type " << message_type << " already exists.";
273     return false;
274   }
275   if (message_type == NetlinkMessage::kIllegalMessageType) {
276     LOG(ERROR) << "Not installing factory for illegal message type.";
277     return false;
278   }
279   factories_[message_type] = factory;
280   return true;
281 }
282 
CreateMessage(NetlinkPacket * packet,NetlinkMessage::MessageContext context) const283 NetlinkMessage* NetlinkMessageFactory::CreateMessage(
284     NetlinkPacket* packet, NetlinkMessage::MessageContext context) const {
285   std::unique_ptr<NetlinkMessage> message;
286 
287   auto message_type = packet->GetMessageType();
288   if (message_type == NoopMessage::kMessageType) {
289     message.reset(new NoopMessage());
290   } else if (message_type == DoneMessage::kMessageType) {
291     message.reset(new DoneMessage());
292   } else if (message_type == OverrunMessage::kMessageType) {
293     message.reset(new OverrunMessage());
294   } else if (message_type == ErrorAckMessage::kMessageType) {
295     message.reset(new ErrorAckMessage());
296   } else if (ContainsKey(factories_, message_type)) {
297     map<uint16_t, FactoryMethod>::const_iterator factory;
298     factory = factories_.find(message_type);
299     message.reset(factory->second.Run(*packet));
300   }
301 
302   // If no factory exists for this message _or_ if a factory exists but it
303   // failed, there'll be no message.  Handle either of those cases, by
304   // creating an |UnknownMessage|.
305   if (!message) {
306     message.reset(new UnknownMessage(message_type, packet->GetPayload()));
307   }
308 
309   if (!message->InitFromPacket(packet, context)) {
310     LOG(ERROR) << "Message did not initialize properly";
311     return nullptr;
312   }
313 
314   return message.release();
315 }
316 
317 }  // namespace shill.
318