// // Copyright (C) 2013 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "shill/net/netlink_socket.h" #include #include #include #include #include #include "shill/net/byte_string.h" #include "shill/net/mock_sockets.h" #include "shill/net/netlink_message.h" using std::min; using std::string; using testing::_; using testing::Invoke; using testing::Return; using testing::Test; namespace shill { class NetlinkSocketTest; const int kFakeFd = 99; class NetlinkSocketTest : public Test { public: NetlinkSocketTest() {} virtual ~NetlinkSocketTest() {} virtual void SetUp() { mock_sockets_ = new MockSockets(); netlink_socket_.sockets_.reset(mock_sockets_); } virtual void InitializeSocket(int fd) { EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC)) .WillOnce(Return(fd)); EXPECT_CALL(*mock_sockets_, SetReceiveBuffer( fd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0)); EXPECT_CALL(*mock_sockets_, Bind(fd, _, sizeof(struct sockaddr_nl))) .WillOnce(Return(0)); EXPECT_TRUE(netlink_socket_.Init()); } protected: MockSockets* mock_sockets_; // Owned by netlink_socket_. NetlinkSocket netlink_socket_; }; class FakeSocketRead { public: explicit FakeSocketRead(const ByteString& next_read_string) { next_read_string_ = next_read_string; } // Copies |len| bytes of |next_read_string_| into |buf| and clears // |next_read_string_|. ssize_t FakeSuccessfulRead(int sockfd, void* buf, size_t len, int flags, struct sockaddr* src_addr, socklen_t* addrlen) { if (!buf) { return -1; } int read_bytes = min(len, next_read_string_.GetLength()); memcpy(buf, next_read_string_.GetConstData(), read_bytes); next_read_string_.Clear(); return read_bytes; } private: ByteString next_read_string_; }; TEST_F(NetlinkSocketTest, InitWorkingTest) { SetUp(); InitializeSocket(kFakeFd); EXPECT_CALL(*mock_sockets_, Close(kFakeFd)); } TEST_F(NetlinkSocketTest, InitBrokenSocketTest) { SetUp(); const int kBadFd = -1; EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC)) .WillOnce(Return(kBadFd)); EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(_, _)).Times(0); EXPECT_CALL(*mock_sockets_, Bind(_, _, _)).Times(0); EXPECT_FALSE(netlink_socket_.Init()); } TEST_F(NetlinkSocketTest, InitBrokenBufferTest) { SetUp(); EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC)) .WillOnce(Return(kFakeFd)); EXPECT_CALL(*mock_sockets_, SetReceiveBuffer( kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(-1)); EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl))) .WillOnce(Return(0)); EXPECT_TRUE(netlink_socket_.Init()); // Destructor. EXPECT_CALL(*mock_sockets_, Close(kFakeFd)); } TEST_F(NetlinkSocketTest, InitBrokenBindTest) { SetUp(); EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC)) .WillOnce(Return(kFakeFd)); EXPECT_CALL(*mock_sockets_, SetReceiveBuffer( kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0)); EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl))) .WillOnce(Return(-1)); EXPECT_CALL(*mock_sockets_, Close(kFakeFd)).WillOnce(Return(0)); EXPECT_FALSE(netlink_socket_.Init()); } TEST_F(NetlinkSocketTest, SendMessageTest) { SetUp(); InitializeSocket(kFakeFd); string message_string("This text is really arbitrary"); ByteString message(message_string.c_str(), message_string.size()); // Good Send. EXPECT_CALL(*mock_sockets_, Send(kFakeFd, message.GetConstData(), message.GetLength(), 0)) .WillOnce(Return(message.GetLength())); EXPECT_TRUE(netlink_socket_.SendMessage(message)); // Short Send. EXPECT_CALL(*mock_sockets_, Send(kFakeFd, message.GetConstData(), message.GetLength(), 0)) .WillOnce(Return(message.GetLength() - 3)); EXPECT_FALSE(netlink_socket_.SendMessage(message)); // Bad Send. EXPECT_CALL(*mock_sockets_, Send(kFakeFd, message.GetConstData(), message.GetLength(), 0)) .WillOnce(Return(-1)); EXPECT_FALSE(netlink_socket_.SendMessage(message)); // Destructor. EXPECT_CALL(*mock_sockets_, Close(kFakeFd)); } TEST_F(NetlinkSocketTest, SequenceNumberTest) { SetUp(); // Just a sequence number. const uint32_t arbitrary_number = 42; netlink_socket_.sequence_number_ = arbitrary_number; EXPECT_EQ(arbitrary_number+1, netlink_socket_.GetSequenceNumber()); // Make sure we don't go to |NetlinkMessage::kBroadcastSequenceNumber|. netlink_socket_.sequence_number_ = NetlinkMessage::kBroadcastSequenceNumber; EXPECT_NE(NetlinkMessage::kBroadcastSequenceNumber, netlink_socket_.GetSequenceNumber()); } TEST_F(NetlinkSocketTest, GoodRecvMessageTest) { SetUp(); InitializeSocket(kFakeFd); ByteString message; static const string next_read_string( "Random text may include things like 'freaking fracking foo'."); static const size_t read_size = next_read_string.size(); ByteString expected_results(next_read_string.c_str(), read_size); FakeSocketRead fake_socket_read(expected_results); // Expect one call to get the size... EXPECT_CALL(*mock_sockets_, RecvFrom(kFakeFd, _, _, MSG_TRUNC | MSG_PEEK, _, _)) .WillOnce(Return(read_size)); // ...and expect a second call to get the data. EXPECT_CALL(*mock_sockets_, RecvFrom(kFakeFd, _, read_size, 0, _, _)) .WillOnce(Invoke(&fake_socket_read, &FakeSocketRead::FakeSuccessfulRead)); EXPECT_TRUE(netlink_socket_.RecvMessage(&message)); EXPECT_TRUE(message.Equals(expected_results)); // Destructor. EXPECT_CALL(*mock_sockets_, Close(kFakeFd)); } TEST_F(NetlinkSocketTest, BadRecvMessageTest) { SetUp(); InitializeSocket(kFakeFd); ByteString message; EXPECT_CALL(*mock_sockets_, RecvFrom(kFakeFd, _, _, _, _, _)) .WillOnce(Return(-1)); EXPECT_FALSE(netlink_socket_.RecvMessage(&message)); EXPECT_CALL(*mock_sockets_, Close(kFakeFd)); } } // namespace shill.