• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (C) 2015 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/netlink_packet.h"
18 
19 #include <algorithm>
20 
21 #include <base/logging.h>
22 
23 #include "shill/net/byte_string.h"
24 
25 namespace shill {
26 
NetlinkPacket(const unsigned char * buf,size_t len)27 NetlinkPacket::NetlinkPacket(const unsigned char* buf, size_t len)
28     : consumed_bytes_(0) {
29   if (!buf || len < sizeof(header_)) {
30     LOG(ERROR) << "Cannot retrieve header.";
31     return;
32   }
33 
34   memcpy(&header_, buf, sizeof(header_));
35   if (len < header_.nlmsg_len || header_.nlmsg_len < sizeof(header_)) {
36     LOG(ERROR) << "Discarding incomplete / invalid message.";
37     return;
38   }
39 
40   payload_.reset(
41       new ByteString(buf + sizeof(header_), len - sizeof(header_)));
42 }
43 
~NetlinkPacket()44 NetlinkPacket::~NetlinkPacket() {
45 }
46 
IsValid() const47 bool NetlinkPacket::IsValid() const {
48   return payload_ != nullptr;
49 }
50 
GetLength() const51 size_t NetlinkPacket::GetLength() const {
52   return GetNlMsgHeader().nlmsg_len;
53 }
54 
GetMessageType() const55 uint16_t NetlinkPacket::GetMessageType() const {
56   return GetNlMsgHeader().nlmsg_type;
57 }
58 
GetMessageSequence() const59 uint32_t NetlinkPacket::GetMessageSequence() const {
60   return GetNlMsgHeader().nlmsg_seq;
61 }
62 
GetRemainingLength() const63 size_t NetlinkPacket::GetRemainingLength() const {
64   return GetPayload().GetLength() - consumed_bytes_;
65 }
66 
GetPayload() const67 const ByteString& NetlinkPacket::GetPayload() const {
68   CHECK(IsValid());
69   return *payload_.get();
70 }
71 
ConsumeAttributes(const AttributeList::NewFromIdMethod & factory,const AttributeListRefPtr & attributes)72 bool NetlinkPacket::ConsumeAttributes(
73     const AttributeList::NewFromIdMethod& factory,
74     const AttributeListRefPtr& attributes) {
75   bool result = attributes->Decode(GetPayload(), consumed_bytes_, factory);
76   consumed_bytes_ = GetPayload().GetLength();
77   return result;
78 }
79 
ConsumeData(size_t len,void * data)80 bool NetlinkPacket::ConsumeData(size_t len, void* data) {
81   if (GetRemainingLength() < len) {
82     LOG(ERROR) << "Not enough bytes remaining.";
83     return false;
84   }
85 
86   memcpy(data, payload_->GetData() + consumed_bytes_, len);
87   consumed_bytes_ = std::min(payload_->GetLength(),
88                              consumed_bytes_ + NLMSG_ALIGN(len));
89   return true;
90 }
91 
92 
GetNlMsgHeader() const93 const nlmsghdr& NetlinkPacket::GetNlMsgHeader() const {
94   CHECK(IsValid());
95   return header_;
96 }
97 
GetGenlMsgHdr(genlmsghdr * header) const98 bool NetlinkPacket::GetGenlMsgHdr(genlmsghdr* header) const {
99   if (GetPayload().GetLength() < sizeof(*header)) {
100     return false;
101   }
102   memcpy(header, payload_->GetConstData(), sizeof(*header));
103   return true;
104 }
105 
MutableNetlinkPacket(const unsigned char * buf,size_t len)106 MutableNetlinkPacket::MutableNetlinkPacket(const unsigned char* buf, size_t len)
107     : NetlinkPacket(buf, len) {
108 }
109 
~MutableNetlinkPacket()110 MutableNetlinkPacket::~MutableNetlinkPacket() {
111 }
112 
ResetConsumedBytes()113 void MutableNetlinkPacket::ResetConsumedBytes() {
114   set_consumed_bytes(0);
115 }
116 
GetMutableHeader()117 nlmsghdr* MutableNetlinkPacket::GetMutableHeader() {
118   CHECK(IsValid());
119   return mutable_header();
120 }
121 
GetMutablePayload()122 ByteString* MutableNetlinkPacket::GetMutablePayload() {
123   CHECK(IsValid());
124   return mutable_payload();
125 }
126 
SetMessageType(uint16_t type)127 void MutableNetlinkPacket::SetMessageType(uint16_t type) {
128   mutable_header()->nlmsg_type = type;
129 }
130 
SetMessageSequence(uint32_t sequence)131 void MutableNetlinkPacket::SetMessageSequence(uint32_t sequence) {
132   mutable_header()->nlmsg_seq = sequence;
133 }
134 
135 }  // namespace shill.
136