// Copyright 2024 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/base/network_cost_change_notifier_win.h" #include #include #include "base/check.h" #include "base/no_destructor.h" #include "base/task/bind_post_task.h" #include "base/task/sequenced_task_runner.h" #include "base/task/thread_pool.h" #include "base/threading/scoped_thread_priority.h" #include "base/win/com_init_util.h" using Microsoft::WRL::ComPtr; namespace net { namespace { NetworkChangeNotifier::ConnectionCost ConnectionCostFromNlmConnectionCost( DWORD connection_cost_flags) { if (connection_cost_flags == NLM_CONNECTION_COST_UNKNOWN) { return NetworkChangeNotifier::CONNECTION_COST_UNKNOWN; } else if ((connection_cost_flags & NLM_CONNECTION_COST_UNRESTRICTED) != 0) { return NetworkChangeNotifier::CONNECTION_COST_UNMETERED; } else { return NetworkChangeNotifier::CONNECTION_COST_METERED; } } NetworkCostChangeNotifierWin::CoCreateInstanceCallback& GetCoCreateInstanceCallback() { static base::NoDestructor< NetworkCostChangeNotifierWin::CoCreateInstanceCallback> co_create_instance_callback{base::BindRepeating(&CoCreateInstance)}; return *co_create_instance_callback; } } // namespace // This class is used as an event sink to register for notifications from the // `INetworkCostManagerEvents` interface. In particular, we are focused on // getting notified when the connection cost changes. class NetworkCostManagerEventSinkWin final : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, INetworkCostManagerEvents> { public: static HRESULT CreateInstance( INetworkCostManager* network_cost_manager, base::RepeatingClosure cost_changed_callback, ComPtr* result) { ComPtr instance = Microsoft::WRL::Make( cost_changed_callback); HRESULT hr = instance->RegisterForNotifications(network_cost_manager); if (hr != S_OK) { return hr; } *result = instance; return S_OK; } void UnRegisterForNotifications() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (event_sink_connection_point_) { event_sink_connection_point_->Unadvise(event_sink_connection_cookie_); event_sink_connection_point_.Reset(); } } // Implement the INetworkCostManagerEvents interface. HRESULT __stdcall CostChanged(DWORD /*cost*/, NLM_SOCKADDR* /*socket_address*/) final { // It is possible to get multiple notifications in a short period of time. // Rather than worrying about whether this notification represents the // latest, just notify the owner who can get the current value from the // INetworkCostManager so we know that we're actually getting the correct // value. cost_changed_callback_.Run(); return S_OK; } HRESULT __stdcall DataPlanStatusChanged( NLM_SOCKADDR* /*socket_address*/) final { return S_OK; } NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback) : cost_changed_callback_(cost_changed_callback) {} NetworkCostManagerEventSinkWin(const NetworkCostManagerEventSinkWin&) = delete; NetworkCostManagerEventSinkWin& operator=( const NetworkCostManagerEventSinkWin&) = delete; private: ~NetworkCostManagerEventSinkWin() final = default; HRESULT RegisterForNotifications(INetworkCostManager* cost_manager) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); base::win::AssertComInitialized(); base::win::AssertComApartmentType(base::win::ComApartmentType::STA); ComPtr this_event_sink_unknown; HRESULT hr = QueryInterface(IID_PPV_ARGS(&this_event_sink_unknown)); // `NetworkCostManagerEventSinkWin::QueryInterface` for `IUnknown` must // succeed since it is implemented by this class. CHECK_EQ(hr, S_OK); ComPtr connection_point_container; hr = cost_manager->QueryInterface(IID_PPV_ARGS(&connection_point_container)); if (hr != S_OK) { return hr; } Microsoft::WRL::ComPtr event_sink_connection_point; hr = connection_point_container->FindConnectionPoint( IID_INetworkCostManagerEvents, &event_sink_connection_point); if (hr != S_OK) { return hr; } hr = event_sink_connection_point->Advise(this_event_sink_unknown.Get(), &event_sink_connection_cookie_); if (hr != S_OK) { return hr; } CHECK_EQ(event_sink_connection_point_, nullptr); event_sink_connection_point_ = event_sink_connection_point; return S_OK; } base::RepeatingClosure cost_changed_callback_; // The following members must be accessed on the sequence from // `sequence_checker_` SEQUENCE_CHECKER(sequence_checker_); DWORD event_sink_connection_cookie_ = 0; Microsoft::WRL::ComPtr event_sink_connection_point_; }; // static base::SequenceBound NetworkCostChangeNotifierWin::CreateInstance( CostChangedCallback cost_changed_callback) { scoped_refptr com_best_effort_task_runner = base::ThreadPool::CreateCOMSTATaskRunner( {base::MayBlock(), base::TaskPriority::BEST_EFFORT, base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}); return base::SequenceBound( com_best_effort_task_runner, // Ensure `cost_changed_callback` runs on the sequence of the creator and // owner of `NetworkCostChangeNotifierWin`. base::BindPostTask(base::SequencedTaskRunner::GetCurrentDefault(), cost_changed_callback)); } NetworkCostChangeNotifierWin::NetworkCostChangeNotifierWin( CostChangedCallback cost_changed_callback) : cost_changed_callback_(cost_changed_callback) { StartWatching(); } NetworkCostChangeNotifierWin::~NetworkCostChangeNotifierWin() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); StopWatching(); } void NetworkCostChangeNotifierWin::StartWatching() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (base::win::GetVersion() < kSupportedOsVersion) { return; } base::win::AssertComInitialized(); base::win::AssertComApartmentType(base::win::ComApartmentType::STA); SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY(); // Create `INetworkListManager` using `CoCreateInstance()`. Tests may provide // a fake implementation of `INetworkListManager` through an // `OverrideCoCreateInstanceForTesting()`. ComPtr cost_manager; HRESULT hr = GetCoCreateInstanceCallback().Run( CLSID_NetworkListManager, /*unknown_outer=*/nullptr, CLSCTX_ALL, IID_INetworkCostManager, &cost_manager); if (hr != S_OK) { return; } // Subscribe to cost changed events. hr = NetworkCostManagerEventSinkWin::CreateInstance( cost_manager.Get(), // Cost changed callbacks must run on this sequence to get the new cost // from `INetworkCostManager`. base::BindPostTask( base::SequencedTaskRunner::GetCurrentDefault(), base::BindRepeating(&NetworkCostChangeNotifierWin::HandleCostChanged, weak_ptr_factory_.GetWeakPtr())), &cost_manager_event_sink_); if (hr != S_OK) { return; } // Set the initial cost and inform observers of the initial value. cost_manager_ = cost_manager; HandleCostChanged(); } void NetworkCostChangeNotifierWin::StopWatching() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (cost_manager_event_sink_) { cost_manager_event_sink_->UnRegisterForNotifications(); cost_manager_event_sink_.Reset(); } cost_manager_.Reset(); } void NetworkCostChangeNotifierWin::HandleCostChanged() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DWORD connection_cost_flags; HRESULT hr = cost_manager_->GetCost(&connection_cost_flags, /*destination_ip_address=*/nullptr); if (hr != S_OK) { connection_cost_flags = NLM_CONNECTION_COST_UNKNOWN; } NetworkChangeNotifier::ConnectionCost changed_cost = ConnectionCostFromNlmConnectionCost(connection_cost_flags); cost_changed_callback_.Run(changed_cost); } // static void NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting( CoCreateInstanceCallback callback_for_testing) { GetCoCreateInstanceCallback() = callback_for_testing; } } // namespace net