1 //
2 // Copyright (C) 2015 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "shill/net/netlink_packet.h"
18
19 #include <algorithm>
20
21 #include <base/logging.h>
22
23 #include "shill/net/byte_string.h"
24
25 namespace shill {
26
NetlinkPacket(const unsigned char * buf,size_t len)27 NetlinkPacket::NetlinkPacket(const unsigned char* buf, size_t len)
28 : consumed_bytes_(0) {
29 if (!buf || len < sizeof(header_)) {
30 LOG(ERROR) << "Cannot retrieve header.";
31 return;
32 }
33
34 memcpy(&header_, buf, sizeof(header_));
35 if (len < header_.nlmsg_len || header_.nlmsg_len < sizeof(header_)) {
36 LOG(ERROR) << "Discarding incomplete / invalid message.";
37 return;
38 }
39
40 payload_.reset(
41 new ByteString(buf + sizeof(header_), len - sizeof(header_)));
42 }
43
~NetlinkPacket()44 NetlinkPacket::~NetlinkPacket() {
45 }
46
IsValid() const47 bool NetlinkPacket::IsValid() const {
48 return payload_ != nullptr;
49 }
50
GetLength() const51 size_t NetlinkPacket::GetLength() const {
52 return GetNlMsgHeader().nlmsg_len;
53 }
54
GetMessageType() const55 uint16_t NetlinkPacket::GetMessageType() const {
56 return GetNlMsgHeader().nlmsg_type;
57 }
58
GetMessageSequence() const59 uint32_t NetlinkPacket::GetMessageSequence() const {
60 return GetNlMsgHeader().nlmsg_seq;
61 }
62
GetRemainingLength() const63 size_t NetlinkPacket::GetRemainingLength() const {
64 return GetPayload().GetLength() - consumed_bytes_;
65 }
66
GetPayload() const67 const ByteString& NetlinkPacket::GetPayload() const {
68 CHECK(IsValid());
69 return *payload_.get();
70 }
71
ConsumeAttributes(const AttributeList::NewFromIdMethod & factory,const AttributeListRefPtr & attributes)72 bool NetlinkPacket::ConsumeAttributes(
73 const AttributeList::NewFromIdMethod& factory,
74 const AttributeListRefPtr& attributes) {
75 bool result = attributes->Decode(GetPayload(), consumed_bytes_, factory);
76 consumed_bytes_ = GetPayload().GetLength();
77 return result;
78 }
79
ConsumeData(size_t len,void * data)80 bool NetlinkPacket::ConsumeData(size_t len, void* data) {
81 if (GetRemainingLength() < len) {
82 LOG(ERROR) << "Not enough bytes remaining.";
83 return false;
84 }
85
86 memcpy(data, payload_->GetData() + consumed_bytes_, len);
87 consumed_bytes_ = std::min(payload_->GetLength(),
88 consumed_bytes_ + NLMSG_ALIGN(len));
89 return true;
90 }
91
92
GetNlMsgHeader() const93 const nlmsghdr& NetlinkPacket::GetNlMsgHeader() const {
94 CHECK(IsValid());
95 return header_;
96 }
97
GetGenlMsgHdr(genlmsghdr * header) const98 bool NetlinkPacket::GetGenlMsgHdr(genlmsghdr* header) const {
99 if (GetPayload().GetLength() < sizeof(*header)) {
100 return false;
101 }
102 memcpy(header, payload_->GetConstData(), sizeof(*header));
103 return true;
104 }
105
MutableNetlinkPacket(const unsigned char * buf,size_t len)106 MutableNetlinkPacket::MutableNetlinkPacket(const unsigned char* buf, size_t len)
107 : NetlinkPacket(buf, len) {
108 }
109
~MutableNetlinkPacket()110 MutableNetlinkPacket::~MutableNetlinkPacket() {
111 }
112
ResetConsumedBytes()113 void MutableNetlinkPacket::ResetConsumedBytes() {
114 set_consumed_bytes(0);
115 }
116
GetMutableHeader()117 nlmsghdr* MutableNetlinkPacket::GetMutableHeader() {
118 CHECK(IsValid());
119 return mutable_header();
120 }
121
GetMutablePayload()122 ByteString* MutableNetlinkPacket::GetMutablePayload() {
123 CHECK(IsValid());
124 return mutable_payload();
125 }
126
SetMessageType(uint16_t type)127 void MutableNetlinkPacket::SetMessageType(uint16_t type) {
128 mutable_header()->nlmsg_type = type;
129 }
130
SetMessageSequence(uint32_t sequence)131 void MutableNetlinkPacket::SetMessageSequence(uint32_t sequence) {
132 mutable_header()->nlmsg_seq = sequence;
133 }
134
135 } // namespace shill.
136