// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "discovery/public/dns_sd_service_watcher.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" using testing::_; using testing::ContainerEq; using testing::IsSubsetOf; using testing::IsSupersetOf; using testing::StrictMock; namespace openscreen { namespace discovery { namespace { std::vector ConvertRefs( const std::vector>& value) { std::vector strings; // This loop is required to unwrap reference_wrapper objects. for (const std::string& val : value) { strings.push_back(val); } return strings; } static const IPAddress kAddressV4(192, 168, 0, 0); static const IPEndpoint kEndpointV4{kAddressV4, 0}; constexpr char kCastServiceId[] = "_googlecast._tcp"; constexpr char kCastDomainId[] = "local"; constexpr NetworkInterfaceIndex kNetworkInterface = 0; class MockDnsSdService : public DnsSdService { public: MockDnsSdService() : querier_(this) {} DnsSdQuerier* GetQuerier() override { return &querier_; } DnsSdPublisher* GetPublisher() override { return nullptr; } MOCK_METHOD2(StartQuery, void(const std::string& service, DnsSdQuerier::Callback* cb)); MOCK_METHOD2(StopQuery, void(const std::string& service, DnsSdQuerier::Callback* cb)); MOCK_METHOD1(ReinitializeQueries, void(const std::string& service)); private: class MockQuerier : public DnsSdQuerier { public: explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) { OSP_DCHECK(service); } void StartQuery(const std::string& service, DnsSdQuerier::Callback* cb) override { mock_service_->StartQuery(service, cb); } void StopQuery(const std::string& service, DnsSdQuerier::Callback* cb) override { mock_service_->StopQuery(service, cb); } void ReinitializeQueries(const std::string& service) override { mock_service_->ReinitializeQueries(service); } private: MockDnsSdService* const mock_service_; }; MockQuerier querier_; }; } // namespace class TestServiceWatcher : public DnsSdServiceWatcher { public: using DnsSdServiceWatcher::ConstRefT; explicit TestServiceWatcher(MockDnsSdService* service) : DnsSdServiceWatcher( service, kCastServiceId, [this](const DnsSdInstance& instance) { return Convert(instance); }, [this](std::vector ref, ConstRefT service, ServicesUpdatedState state) { Callback(std::move(ref)); }) {} MOCK_METHOD1(Callback, void(std::vector)); using DnsSdServiceWatcher::OnEndpointCreated; using DnsSdServiceWatcher::OnEndpointUpdated; using DnsSdServiceWatcher::OnEndpointDeleted; private: std::string Convert(const DnsSdInstance& instance) { return instance.instance_id(); } }; class DnsSdServiceWatcherTests : public testing::Test { public: DnsSdServiceWatcherTests() : watcher_(&service_) { // Start service discovery, since all other tests need it EXPECT_FALSE(watcher_.is_running()); EXPECT_CALL(service_, StartQuery(kCastServiceId, _)); watcher_.StartDiscovery(); testing::Mock::VerifyAndClearExpectations(&service_); } protected: void CreateNewInstance(const DnsSdInstanceEndpoint& record) { const std::vector services_before = ConvertRefs(watcher_.GetServices()); const size_t count = services_before.size(); std::vector callbacked_services; EXPECT_CALL(watcher_, Callback(_)) .WillOnce([services = &callbacked_services]( std::vector value) { *services = ConvertRefs(value); }); watcher_.OnEndpointCreated(record); testing::Mock::VerifyAndClearExpectations(&watcher_); std::vector fetched_services = ConvertRefs(watcher_.GetServices()); EXPECT_EQ(fetched_services.size(), count + 1); EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); EXPECT_THAT(fetched_services, IsSupersetOf(services_before)); } void CreateExistingInstance(const DnsSdInstanceEndpoint& record) { const std::vector services_before = ConvertRefs(watcher_.GetServices()); const size_t count = services_before.size(); std::vector callbacked_services; EXPECT_CALL(watcher_, Callback(_)) .WillOnce([services = &callbacked_services]( std::vector value) { *services = ConvertRefs(value); }); watcher_.OnEndpointCreated(record); testing::Mock::VerifyAndClearExpectations(&watcher_); const std::vector fetched_services = ConvertRefs(watcher_.GetServices()); EXPECT_EQ(fetched_services.size(), count); EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); EXPECT_THAT(fetched_services, ContainerEq(services_before)); } void UpdateExistingInstance(const DnsSdInstanceEndpoint& record) { const std::vector services_before = ConvertRefs(watcher_.GetServices()); const size_t count = services_before.size(); std::vector callbacked_services; EXPECT_CALL(watcher_, Callback(_)) .WillOnce([services = &callbacked_services]( std::vector value) { *services = ConvertRefs(value); }); watcher_.OnEndpointUpdated(record); testing::Mock::VerifyAndClearExpectations(&watcher_); const std::vector fetched_services = ConvertRefs(watcher_.GetServices()); EXPECT_EQ(fetched_services.size(), count); EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); EXPECT_THAT(fetched_services, ContainerEq(services_before)); } void DeleteExistingInstance(const DnsSdInstanceEndpoint& record) { const std::vector services_before = ConvertRefs(watcher_.GetServices()); const size_t count = services_before.size(); std::vector callbacked_services; EXPECT_CALL(watcher_, Callback(_)) .WillOnce([services = &callbacked_services]( std::vector value) { *services = ConvertRefs(value); }); watcher_.OnEndpointDeleted(record); testing::Mock::VerifyAndClearExpectations(&watcher_); const std::vector fetched_services = ConvertRefs(watcher_.GetServices()); EXPECT_EQ(fetched_services.size(), count - 1); } void UpdateNonExistingInstance(const DnsSdInstanceEndpoint& record) { const std::vector services_before = ConvertRefs(watcher_.GetServices()); const size_t count = services_before.size(); EXPECT_CALL(watcher_, Callback(_)).Times(0); watcher_.OnEndpointUpdated(record); testing::Mock::VerifyAndClearExpectations(&watcher_); const std::vector fetched_services = ConvertRefs(watcher_.GetServices()); EXPECT_EQ(fetched_services.size(), count); EXPECT_THAT(services_before, ContainerEq(fetched_services)); } void DeleteNonExistingInstance(const DnsSdInstanceEndpoint& record) { const std::vector services_before = ConvertRefs(watcher_.GetServices()); const size_t count = services_before.size(); EXPECT_CALL(watcher_, Callback(_)).Times(0); watcher_.OnEndpointDeleted(record); testing::Mock::VerifyAndClearExpectations(&watcher_); const std::vector fetched_services = ConvertRefs(watcher_.GetServices()); EXPECT_EQ(fetched_services.size(), count); EXPECT_THAT(services_before, ContainerEq(fetched_services)); } bool ContainsService(const DnsSdInstanceEndpoint& record) { const std::string& service = record.instance_id(); const std::vector services = watcher_.GetServices(); return std::find_if(services.begin(), services.end(), [&service](const std::string& ref) { return service == ref; }) != services.end(); } StrictMock service_; StrictMock watcher_; std::vector fetched_services; }; TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) { EXPECT_TRUE(watcher_.is_running()); EXPECT_CALL(service_, StopQuery(kCastServiceId, _)); watcher_.StopDiscovery(); EXPECT_FALSE(watcher_.is_running()); } TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) { StrictMock service; StrictMock watcher(&service); EXPECT_FALSE(watcher.DiscoverNow().ok()); EXPECT_FALSE(watcher.ForceRefresh().ok()); } TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) { const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId, DnsSdTxtRecord{}, kNetworkInterface, kEndpointV4); CreateNewInstance(record); // Refresh services. EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId)); EXPECT_TRUE(watcher_.DiscoverNow().ok()); EXPECT_EQ(watcher_.GetServices().size(), size_t{1}); testing::Mock::VerifyAndClearExpectations(&service_); EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId)); EXPECT_TRUE(watcher_.ForceRefresh().ok()); EXPECT_EQ(watcher_.GetServices().size(), size_t{0}); testing::Mock::VerifyAndClearExpectations(&service_); } TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) { const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId, DnsSdTxtRecord{}, kNetworkInterface, kEndpointV4); const DnsSdInstanceEndpoint record2("Instance2", kCastServiceId, kCastDomainId, DnsSdTxtRecord{}, kNetworkInterface, kEndpointV4); EXPECT_FALSE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); CreateNewInstance(record); EXPECT_TRUE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); CreateExistingInstance(record); EXPECT_TRUE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); UpdateNonExistingInstance(record2); EXPECT_TRUE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); DeleteNonExistingInstance(record2); EXPECT_TRUE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); CreateNewInstance(record2); EXPECT_TRUE(ContainsService(record)); EXPECT_TRUE(ContainsService(record2)); UpdateExistingInstance(record2); EXPECT_TRUE(ContainsService(record)); EXPECT_TRUE(ContainsService(record2)); UpdateExistingInstance(record); EXPECT_TRUE(ContainsService(record)); EXPECT_TRUE(ContainsService(record2)); DeleteExistingInstance(record); EXPECT_FALSE(ContainsService(record)); EXPECT_TRUE(ContainsService(record2)); UpdateNonExistingInstance(record); EXPECT_FALSE(ContainsService(record)); EXPECT_TRUE(ContainsService(record2)); DeleteNonExistingInstance(record); EXPECT_FALSE(ContainsService(record)); EXPECT_TRUE(ContainsService(record2)); DeleteExistingInstance(record2); EXPECT_FALSE(ContainsService(record)); EXPECT_FALSE(ContainsService(record2)); } } // namespace discovery } // namespace openscreen