• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/base/network_cost_change_notifier_win.h"
6 
7 #include <wrl.h>
8 #include <wrl/client.h>
9 
10 #include "base/check.h"
11 #include "base/no_destructor.h"
12 #include "base/task/bind_post_task.h"
13 #include "base/task/sequenced_task_runner.h"
14 #include "base/task/thread_pool.h"
15 #include "base/threading/scoped_thread_priority.h"
16 #include "base/win/com_init_util.h"
17 
18 using Microsoft::WRL::ComPtr;
19 
20 namespace net {
21 
22 namespace {
23 
ConnectionCostFromNlmConnectionCost(DWORD connection_cost_flags)24 NetworkChangeNotifier::ConnectionCost ConnectionCostFromNlmConnectionCost(
25     DWORD connection_cost_flags) {
26   if (connection_cost_flags == NLM_CONNECTION_COST_UNKNOWN) {
27     return NetworkChangeNotifier::CONNECTION_COST_UNKNOWN;
28   } else if ((connection_cost_flags & NLM_CONNECTION_COST_UNRESTRICTED) != 0) {
29     return NetworkChangeNotifier::CONNECTION_COST_UNMETERED;
30   } else {
31     return NetworkChangeNotifier::CONNECTION_COST_METERED;
32   }
33 }
34 
35 NetworkCostChangeNotifierWin::CoCreateInstanceCallback&
GetCoCreateInstanceCallback()36 GetCoCreateInstanceCallback() {
37   static base::NoDestructor<
38       NetworkCostChangeNotifierWin::CoCreateInstanceCallback>
39       co_create_instance_callback{base::BindRepeating(&CoCreateInstance)};
40   return *co_create_instance_callback;
41 }
42 
43 }  // namespace
44 
45 // This class is used as an event sink to register for notifications from the
46 // `INetworkCostManagerEvents` interface. In particular, we are focused on
47 // getting notified when the connection cost changes.
48 class NetworkCostManagerEventSinkWin final
49     : public Microsoft::WRL::RuntimeClass<
50           Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
51           INetworkCostManagerEvents> {
52  public:
CreateInstance(INetworkCostManager * network_cost_manager,base::RepeatingClosure cost_changed_callback,ComPtr<NetworkCostManagerEventSinkWin> * result)53   static HRESULT CreateInstance(
54       INetworkCostManager* network_cost_manager,
55       base::RepeatingClosure cost_changed_callback,
56       ComPtr<NetworkCostManagerEventSinkWin>* result) {
57     ComPtr<NetworkCostManagerEventSinkWin> instance =
58         Microsoft::WRL::Make<net::NetworkCostManagerEventSinkWin>(
59             cost_changed_callback);
60     HRESULT hr = instance->RegisterForNotifications(network_cost_manager);
61     if (hr != S_OK) {
62       return hr;
63     }
64 
65     *result = instance;
66     return S_OK;
67   }
68 
UnRegisterForNotifications()69   void UnRegisterForNotifications() {
70     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
71 
72     if (event_sink_connection_point_) {
73       event_sink_connection_point_->Unadvise(event_sink_connection_cookie_);
74       event_sink_connection_point_.Reset();
75     }
76   }
77 
78   // Implement the INetworkCostManagerEvents interface.
CostChanged(DWORD,NLM_SOCKADDR *)79   HRESULT __stdcall CostChanged(DWORD /*cost*/,
80                                 NLM_SOCKADDR* /*socket_address*/) final {
81     // It is possible to get multiple notifications in a short period of time.
82     // Rather than worrying about whether this notification represents the
83     // latest, just notify the owner who can get the current value from the
84     // INetworkCostManager so we know that we're actually getting the correct
85     // value.
86     cost_changed_callback_.Run();
87     return S_OK;
88   }
89 
DataPlanStatusChanged(NLM_SOCKADDR *)90   HRESULT __stdcall DataPlanStatusChanged(
91       NLM_SOCKADDR* /*socket_address*/) final {
92     return S_OK;
93   }
94 
NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback)95   NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback)
96       : cost_changed_callback_(cost_changed_callback) {}
97 
98   NetworkCostManagerEventSinkWin(const NetworkCostManagerEventSinkWin&) =
99       delete;
100   NetworkCostManagerEventSinkWin& operator=(
101       const NetworkCostManagerEventSinkWin&) = delete;
102 
103  private:
104   ~NetworkCostManagerEventSinkWin() final = default;
105 
RegisterForNotifications(INetworkCostManager * cost_manager)106   HRESULT RegisterForNotifications(INetworkCostManager* cost_manager) {
107     DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
108 
109     base::win::AssertComInitialized();
110     base::win::AssertComApartmentType(base::win::ComApartmentType::STA);
111 
112     ComPtr<IUnknown> this_event_sink_unknown;
113     HRESULT hr = QueryInterface(IID_PPV_ARGS(&this_event_sink_unknown));
114 
115     // `NetworkCostManagerEventSinkWin::QueryInterface` for `IUnknown` must
116     // succeed since it is implemented by this class.
117     CHECK_EQ(hr, S_OK);
118 
119     ComPtr<IConnectionPointContainer> connection_point_container;
120     hr =
121         cost_manager->QueryInterface(IID_PPV_ARGS(&connection_point_container));
122     if (hr != S_OK) {
123       return hr;
124     }
125 
126     Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point;
127     hr = connection_point_container->FindConnectionPoint(
128         IID_INetworkCostManagerEvents, &event_sink_connection_point);
129     if (hr != S_OK) {
130       return hr;
131     }
132 
133     hr = event_sink_connection_point->Advise(this_event_sink_unknown.Get(),
134                                              &event_sink_connection_cookie_);
135     if (hr != S_OK) {
136       return hr;
137     }
138 
139     CHECK_EQ(event_sink_connection_point_, nullptr);
140     event_sink_connection_point_ = event_sink_connection_point;
141     return S_OK;
142   }
143 
144   base::RepeatingClosure cost_changed_callback_;
145 
146   // The following members must be accessed on the sequence from
147   // `sequence_checker_`
148   SEQUENCE_CHECKER(sequence_checker_);
149   DWORD event_sink_connection_cookie_ = 0;
150   Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point_;
151 };
152 
153 // static
154 base::SequenceBound<NetworkCostChangeNotifierWin>
CreateInstance(CostChangedCallback cost_changed_callback)155 NetworkCostChangeNotifierWin::CreateInstance(
156     CostChangedCallback cost_changed_callback) {
157   scoped_refptr<base::SequencedTaskRunner> com_best_effort_task_runner =
158       base::ThreadPool::CreateCOMSTATaskRunner(
159           {base::MayBlock(), base::TaskPriority::BEST_EFFORT,
160            base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN});
161 
162   return base::SequenceBound<NetworkCostChangeNotifierWin>(
163       com_best_effort_task_runner,
164       // Ensure `cost_changed_callback` runs on the sequence of the creator and
165       // owner of `NetworkCostChangeNotifierWin`.
166       base::BindPostTask(base::SequencedTaskRunner::GetCurrentDefault(),
167                          cost_changed_callback));
168 }
169 
NetworkCostChangeNotifierWin(CostChangedCallback cost_changed_callback)170 NetworkCostChangeNotifierWin::NetworkCostChangeNotifierWin(
171     CostChangedCallback cost_changed_callback)
172     : cost_changed_callback_(cost_changed_callback) {
173   StartWatching();
174 }
175 
~NetworkCostChangeNotifierWin()176 NetworkCostChangeNotifierWin::~NetworkCostChangeNotifierWin() {
177   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
178   StopWatching();
179 }
180 
StartWatching()181 void NetworkCostChangeNotifierWin::StartWatching() {
182   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
183 
184   if (base::win::GetVersion() < kSupportedOsVersion) {
185     return;
186   }
187 
188   base::win::AssertComInitialized();
189   base::win::AssertComApartmentType(base::win::ComApartmentType::STA);
190 
191   SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY();
192 
193   // Create `INetworkListManager` using `CoCreateInstance()`.  Tests may provide
194   // a fake implementation of `INetworkListManager` through an
195   // `OverrideCoCreateInstanceForTesting()`.
196   ComPtr<INetworkCostManager> cost_manager;
197   HRESULT hr = GetCoCreateInstanceCallback().Run(
198       CLSID_NetworkListManager, /*unknown_outer=*/nullptr, CLSCTX_ALL,
199       IID_INetworkCostManager, &cost_manager);
200   if (hr != S_OK) {
201     return;
202   }
203 
204   // Subscribe to cost changed events.
205   hr = NetworkCostManagerEventSinkWin::CreateInstance(
206       cost_manager.Get(),
207       // Cost changed callbacks must run on this sequence to get the new cost
208       // from `INetworkCostManager`.
209       base::BindPostTask(
210           base::SequencedTaskRunner::GetCurrentDefault(),
211           base::BindRepeating(&NetworkCostChangeNotifierWin::HandleCostChanged,
212                               weak_ptr_factory_.GetWeakPtr())),
213       &cost_manager_event_sink_);
214 
215   if (hr != S_OK) {
216     return;
217   }
218 
219   // Set the initial cost and inform observers of the initial value.
220   cost_manager_ = cost_manager;
221   HandleCostChanged();
222 }
223 
StopWatching()224 void NetworkCostChangeNotifierWin::StopWatching() {
225   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
226 
227   if (cost_manager_event_sink_) {
228     cost_manager_event_sink_->UnRegisterForNotifications();
229     cost_manager_event_sink_.Reset();
230   }
231 
232   cost_manager_.Reset();
233 }
234 
HandleCostChanged()235 void NetworkCostChangeNotifierWin::HandleCostChanged() {
236   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
237 
238   DWORD connection_cost_flags;
239   HRESULT hr = cost_manager_->GetCost(&connection_cost_flags,
240                                       /*destination_ip_address=*/nullptr);
241   if (hr != S_OK) {
242     connection_cost_flags = NLM_CONNECTION_COST_UNKNOWN;
243   }
244 
245   NetworkChangeNotifier::ConnectionCost changed_cost =
246       ConnectionCostFromNlmConnectionCost(connection_cost_flags);
247 
248   cost_changed_callback_.Run(changed_cost);
249 }
250 
251 // static
OverrideCoCreateInstanceForTesting(CoCreateInstanceCallback callback_for_testing)252 void NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
253     CoCreateInstanceCallback callback_for_testing) {
254   GetCoCreateInstanceCallback() = callback_for_testing;
255 }
256 
257 }  // namespace net
258