1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/base/network_cost_change_notifier_win.h"
6
7 #include <wrl.h>
8 #include <wrl/client.h>
9
10 #include "base/check.h"
11 #include "base/no_destructor.h"
12 #include "base/task/bind_post_task.h"
13 #include "base/task/sequenced_task_runner.h"
14 #include "base/task/thread_pool.h"
15 #include "base/threading/scoped_thread_priority.h"
16 #include "base/win/com_init_util.h"
17
18 using Microsoft::WRL::ComPtr;
19
20 namespace net {
21
22 namespace {
23
ConnectionCostFromNlmConnectionCost(DWORD connection_cost_flags)24 NetworkChangeNotifier::ConnectionCost ConnectionCostFromNlmConnectionCost(
25 DWORD connection_cost_flags) {
26 if (connection_cost_flags == NLM_CONNECTION_COST_UNKNOWN) {
27 return NetworkChangeNotifier::CONNECTION_COST_UNKNOWN;
28 } else if ((connection_cost_flags & NLM_CONNECTION_COST_UNRESTRICTED) != 0) {
29 return NetworkChangeNotifier::CONNECTION_COST_UNMETERED;
30 } else {
31 return NetworkChangeNotifier::CONNECTION_COST_METERED;
32 }
33 }
34
35 NetworkCostChangeNotifierWin::CoCreateInstanceCallback&
GetCoCreateInstanceCallback()36 GetCoCreateInstanceCallback() {
37 static base::NoDestructor<
38 NetworkCostChangeNotifierWin::CoCreateInstanceCallback>
39 co_create_instance_callback{base::BindRepeating(&CoCreateInstance)};
40 return *co_create_instance_callback;
41 }
42
43 } // namespace
44
45 // This class is used as an event sink to register for notifications from the
46 // `INetworkCostManagerEvents` interface. In particular, we are focused on
47 // getting notified when the connection cost changes.
48 class NetworkCostManagerEventSinkWin final
49 : public Microsoft::WRL::RuntimeClass<
50 Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
51 INetworkCostManagerEvents> {
52 public:
CreateInstance(INetworkCostManager * network_cost_manager,base::RepeatingClosure cost_changed_callback,ComPtr<NetworkCostManagerEventSinkWin> * result)53 static HRESULT CreateInstance(
54 INetworkCostManager* network_cost_manager,
55 base::RepeatingClosure cost_changed_callback,
56 ComPtr<NetworkCostManagerEventSinkWin>* result) {
57 ComPtr<NetworkCostManagerEventSinkWin> instance =
58 Microsoft::WRL::Make<net::NetworkCostManagerEventSinkWin>(
59 cost_changed_callback);
60 HRESULT hr = instance->RegisterForNotifications(network_cost_manager);
61 if (hr != S_OK) {
62 return hr;
63 }
64
65 *result = instance;
66 return S_OK;
67 }
68
UnRegisterForNotifications()69 void UnRegisterForNotifications() {
70 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
71
72 if (event_sink_connection_point_) {
73 event_sink_connection_point_->Unadvise(event_sink_connection_cookie_);
74 event_sink_connection_point_.Reset();
75 }
76 }
77
78 // Implement the INetworkCostManagerEvents interface.
CostChanged(DWORD,NLM_SOCKADDR *)79 HRESULT __stdcall CostChanged(DWORD /*cost*/,
80 NLM_SOCKADDR* /*socket_address*/) final {
81 // It is possible to get multiple notifications in a short period of time.
82 // Rather than worrying about whether this notification represents the
83 // latest, just notify the owner who can get the current value from the
84 // INetworkCostManager so we know that we're actually getting the correct
85 // value.
86 cost_changed_callback_.Run();
87 return S_OK;
88 }
89
DataPlanStatusChanged(NLM_SOCKADDR *)90 HRESULT __stdcall DataPlanStatusChanged(
91 NLM_SOCKADDR* /*socket_address*/) final {
92 return S_OK;
93 }
94
NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback)95 NetworkCostManagerEventSinkWin(base::RepeatingClosure cost_changed_callback)
96 : cost_changed_callback_(cost_changed_callback) {}
97
98 NetworkCostManagerEventSinkWin(const NetworkCostManagerEventSinkWin&) =
99 delete;
100 NetworkCostManagerEventSinkWin& operator=(
101 const NetworkCostManagerEventSinkWin&) = delete;
102
103 private:
104 ~NetworkCostManagerEventSinkWin() final = default;
105
RegisterForNotifications(INetworkCostManager * cost_manager)106 HRESULT RegisterForNotifications(INetworkCostManager* cost_manager) {
107 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
108
109 base::win::AssertComInitialized();
110 base::win::AssertComApartmentType(base::win::ComApartmentType::STA);
111
112 ComPtr<IUnknown> this_event_sink_unknown;
113 HRESULT hr = QueryInterface(IID_PPV_ARGS(&this_event_sink_unknown));
114
115 // `NetworkCostManagerEventSinkWin::QueryInterface` for `IUnknown` must
116 // succeed since it is implemented by this class.
117 CHECK_EQ(hr, S_OK);
118
119 ComPtr<IConnectionPointContainer> connection_point_container;
120 hr =
121 cost_manager->QueryInterface(IID_PPV_ARGS(&connection_point_container));
122 if (hr != S_OK) {
123 return hr;
124 }
125
126 Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point;
127 hr = connection_point_container->FindConnectionPoint(
128 IID_INetworkCostManagerEvents, &event_sink_connection_point);
129 if (hr != S_OK) {
130 return hr;
131 }
132
133 hr = event_sink_connection_point->Advise(this_event_sink_unknown.Get(),
134 &event_sink_connection_cookie_);
135 if (hr != S_OK) {
136 return hr;
137 }
138
139 CHECK_EQ(event_sink_connection_point_, nullptr);
140 event_sink_connection_point_ = event_sink_connection_point;
141 return S_OK;
142 }
143
144 base::RepeatingClosure cost_changed_callback_;
145
146 // The following members must be accessed on the sequence from
147 // `sequence_checker_`
148 SEQUENCE_CHECKER(sequence_checker_);
149 DWORD event_sink_connection_cookie_ = 0;
150 Microsoft::WRL::ComPtr<IConnectionPoint> event_sink_connection_point_;
151 };
152
153 // static
154 base::SequenceBound<NetworkCostChangeNotifierWin>
CreateInstance(CostChangedCallback cost_changed_callback)155 NetworkCostChangeNotifierWin::CreateInstance(
156 CostChangedCallback cost_changed_callback) {
157 scoped_refptr<base::SequencedTaskRunner> com_best_effort_task_runner =
158 base::ThreadPool::CreateCOMSTATaskRunner(
159 {base::MayBlock(), base::TaskPriority::BEST_EFFORT,
160 base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN});
161
162 return base::SequenceBound<NetworkCostChangeNotifierWin>(
163 com_best_effort_task_runner,
164 // Ensure `cost_changed_callback` runs on the sequence of the creator and
165 // owner of `NetworkCostChangeNotifierWin`.
166 base::BindPostTask(base::SequencedTaskRunner::GetCurrentDefault(),
167 cost_changed_callback));
168 }
169
NetworkCostChangeNotifierWin(CostChangedCallback cost_changed_callback)170 NetworkCostChangeNotifierWin::NetworkCostChangeNotifierWin(
171 CostChangedCallback cost_changed_callback)
172 : cost_changed_callback_(cost_changed_callback) {
173 StartWatching();
174 }
175
~NetworkCostChangeNotifierWin()176 NetworkCostChangeNotifierWin::~NetworkCostChangeNotifierWin() {
177 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
178 StopWatching();
179 }
180
StartWatching()181 void NetworkCostChangeNotifierWin::StartWatching() {
182 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
183
184 if (base::win::GetVersion() < kSupportedOsVersion) {
185 return;
186 }
187
188 base::win::AssertComInitialized();
189 base::win::AssertComApartmentType(base::win::ComApartmentType::STA);
190
191 SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY();
192
193 // Create `INetworkListManager` using `CoCreateInstance()`. Tests may provide
194 // a fake implementation of `INetworkListManager` through an
195 // `OverrideCoCreateInstanceForTesting()`.
196 ComPtr<INetworkCostManager> cost_manager;
197 HRESULT hr = GetCoCreateInstanceCallback().Run(
198 CLSID_NetworkListManager, /*unknown_outer=*/nullptr, CLSCTX_ALL,
199 IID_INetworkCostManager, &cost_manager);
200 if (hr != S_OK) {
201 return;
202 }
203
204 // Subscribe to cost changed events.
205 hr = NetworkCostManagerEventSinkWin::CreateInstance(
206 cost_manager.Get(),
207 // Cost changed callbacks must run on this sequence to get the new cost
208 // from `INetworkCostManager`.
209 base::BindPostTask(
210 base::SequencedTaskRunner::GetCurrentDefault(),
211 base::BindRepeating(&NetworkCostChangeNotifierWin::HandleCostChanged,
212 weak_ptr_factory_.GetWeakPtr())),
213 &cost_manager_event_sink_);
214
215 if (hr != S_OK) {
216 return;
217 }
218
219 // Set the initial cost and inform observers of the initial value.
220 cost_manager_ = cost_manager;
221 HandleCostChanged();
222 }
223
StopWatching()224 void NetworkCostChangeNotifierWin::StopWatching() {
225 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
226
227 if (cost_manager_event_sink_) {
228 cost_manager_event_sink_->UnRegisterForNotifications();
229 cost_manager_event_sink_.Reset();
230 }
231
232 cost_manager_.Reset();
233 }
234
HandleCostChanged()235 void NetworkCostChangeNotifierWin::HandleCostChanged() {
236 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
237
238 DWORD connection_cost_flags;
239 HRESULT hr = cost_manager_->GetCost(&connection_cost_flags,
240 /*destination_ip_address=*/nullptr);
241 if (hr != S_OK) {
242 connection_cost_flags = NLM_CONNECTION_COST_UNKNOWN;
243 }
244
245 NetworkChangeNotifier::ConnectionCost changed_cost =
246 ConnectionCostFromNlmConnectionCost(connection_cost_flags);
247
248 cost_changed_callback_.Run(changed_cost);
249 }
250
251 // static
OverrideCoCreateInstanceForTesting(CoCreateInstanceCallback callback_for_testing)252 void NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
253 CoCreateInstanceCallback callback_for_testing) {
254 GetCoCreateInstanceCallback() = callback_for_testing;
255 }
256
257 } // namespace net
258