1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 #include "discovery/mdns/mdns_probe.h"
5
6 #include <memory>
7 #include <utility>
8
9 #include "discovery/common/config.h"
10 #include "discovery/mdns/mdns_probe_manager.h"
11 #include "discovery/mdns/mdns_querier.h"
12 #include "discovery/mdns/mdns_random.h"
13 #include "discovery/mdns/mdns_receiver.h"
14 #include "discovery/mdns/mdns_sender.h"
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
17 #include "platform/test/fake_clock.h"
18 #include "platform/test/fake_task_runner.h"
19 #include "platform/test/fake_udp_socket.h"
20
21 using testing::_;
22 using testing::Invoke;
23 using testing::Return;
24 using testing::StrictMock;
25
26 namespace openscreen {
27 namespace discovery {
28
29 class MockMdnsSender : public MdnsSender {
30 public:
MockMdnsSender(UdpSocket * socket)31 explicit MockMdnsSender(UdpSocket* socket) : MdnsSender(socket) {}
32 MOCK_METHOD1(SendMulticast, Error(const MdnsMessage& message));
33 MOCK_METHOD2(SendMessage,
34 Error(const MdnsMessage& message, const IPEndpoint& endpoint));
35 };
36
37 class MockObserver : public MdnsProbeImpl::Observer {
38 public:
39 MOCK_METHOD1(OnProbeSuccess, void(MdnsProbe*));
40 MOCK_METHOD1(OnProbeFailure, void(MdnsProbe*));
41 };
42
43 class MdnsProbeTests : public testing::Test {
44 public:
MdnsProbeTests()45 MdnsProbeTests()
46 : clock_(Clock::now()),
47 task_runner_(&clock_),
48 socket_(&task_runner_),
49 sender_(&socket_),
50 receiver_(config_) {
51 EXPECT_EQ(task_runner_.delayed_task_count(), 0);
52 probe_ = CreateProbe();
53 EXPECT_EQ(task_runner_.delayed_task_count(), 1);
54 }
55
56 protected:
CreateProbe()57 std::unique_ptr<MdnsProbeImpl> CreateProbe() {
58 return std::make_unique<MdnsProbeImpl>(&sender_, &receiver_, &random_,
59 &task_runner_, FakeClock::now,
60 &observer_, name_, address_v4_);
61 }
62
CreateMessage(const DomainName & domain)63 MdnsMessage CreateMessage(const DomainName& domain) {
64 MdnsMessage message(0, MessageType::Response);
65 SrvRecordRdata rdata(0, 0, 80, domain);
66 MdnsRecord record(std::move(domain), DnsType::kSRV, DnsClass::kIN,
67 RecordType::kUnique, std::chrono::seconds(1),
68 std::move(rdata));
69 message.AddAnswer(record);
70 return message;
71 }
72
OnMessageReceived(const MdnsMessage & message)73 void OnMessageReceived(const MdnsMessage& message) {
74 probe_->OnMessageReceived(message);
75 }
76
77 Config config_;
78 FakeClock clock_;
79 FakeTaskRunner task_runner_;
80 FakeUdpSocket socket_;
81 StrictMock<MockMdnsSender> sender_;
82 MdnsReceiver receiver_;
83 MdnsRandom random_;
84 StrictMock<MockObserver> observer_;
85
86 std::unique_ptr<MdnsProbeImpl> probe_;
87
88 const DomainName name_{"test", "_googlecast", "_tcp", "local"};
89 const DomainName name2_{"test2", "_googlecast", "_tcp", "local"};
90
91 const IPAddress address_v4_{192, 168, 0, 0};
92 const IPEndpoint endpoint_v4_{address_v4_, 80};
93 };
94
TEST_F(MdnsProbeTests,TestNoCancelationFlow)95 TEST_F(MdnsProbeTests, TestNoCancelationFlow) {
96 EXPECT_CALL(sender_, SendMulticast(_));
97 clock_.Advance(kDelayBetweenProbeQueries);
98 EXPECT_EQ(task_runner_.delayed_task_count(), 1);
99 testing::Mock::VerifyAndClearExpectations(&sender_);
100
101 EXPECT_CALL(sender_, SendMulticast(_));
102 clock_.Advance(kDelayBetweenProbeQueries);
103 EXPECT_EQ(task_runner_.delayed_task_count(), 1);
104 testing::Mock::VerifyAndClearExpectations(&sender_);
105
106 EXPECT_CALL(sender_, SendMulticast(_));
107 clock_.Advance(kDelayBetweenProbeQueries);
108 EXPECT_EQ(task_runner_.delayed_task_count(), 1);
109 testing::Mock::VerifyAndClearExpectations(&sender_);
110
111 EXPECT_CALL(observer_, OnProbeSuccess(probe_.get())).Times(1);
112 clock_.Advance(kDelayBetweenProbeQueries);
113 EXPECT_EQ(task_runner_.delayed_task_count(), 0);
114 }
115
TEST_F(MdnsProbeTests,CancelationWhenMatchingMessageReceived)116 TEST_F(MdnsProbeTests, CancelationWhenMatchingMessageReceived) {
117 EXPECT_CALL(observer_, OnProbeFailure(probe_.get())).Times(1);
118 OnMessageReceived(CreateMessage(name_));
119 }
120
TEST_F(MdnsProbeTests,TestNoCancelationOnUnrelatedMessages)121 TEST_F(MdnsProbeTests, TestNoCancelationOnUnrelatedMessages) {
122 OnMessageReceived(CreateMessage(name2_));
123
124 EXPECT_CALL(sender_, SendMulticast(_));
125 clock_.Advance(kDelayBetweenProbeQueries);
126 EXPECT_EQ(task_runner_.delayed_task_count(), 1);
127 testing::Mock::VerifyAndClearExpectations(&sender_);
128 }
129
130 } // namespace discovery
131 } // namespace openscreen
132