1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/test/win/fake_network_cost_manager.h"
6
7 #include <netlistmgr.h>
8 #include <wrl/implements.h>
9
10 #include <map>
11
12 #include "base/task/sequenced_task_runner.h"
13 #include "net/base/network_cost_change_notifier_win.h"
14
15 using Microsoft::WRL::ClassicCom;
16 using Microsoft::WRL::ComPtr;
17 using Microsoft::WRL::RuntimeClass;
18 using Microsoft::WRL::RuntimeClassFlags;
19
20 namespace net {
21
22 namespace {
23
NlmConnectionCostFlagsFromConnectionCost(NetworkChangeNotifier::ConnectionCost source_cost)24 DWORD NlmConnectionCostFlagsFromConnectionCost(
25 NetworkChangeNotifier::ConnectionCost source_cost) {
26 switch (source_cost) {
27 case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNMETERED:
28 return (NLM_CONNECTION_COST_UNRESTRICTED | NLM_CONNECTION_COST_CONGESTED);
29 case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_METERED:
30 return (NLM_CONNECTION_COST_VARIABLE | NLM_CONNECTION_COST_ROAMING |
31 NLM_CONNECTION_COST_APPROACHINGDATALIMIT);
32 case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNKNOWN:
33 default:
34 return NLM_CONNECTION_COST_UNKNOWN;
35 }
36 }
37
DispatchCostChangedEvent(ComPtr<INetworkCostManagerEvents> event_target,DWORD cost)38 void DispatchCostChangedEvent(ComPtr<INetworkCostManagerEvents> event_target,
39 DWORD cost) {
40 std::ignore =
41 event_target->CostChanged(cost, /*destination_address=*/nullptr);
42 }
43
44 } // namespace
45
46 // A fake implementation of `INetworkCostManager` that can simulate costs,
47 // changed costs and errors.
48 class FakeNetworkCostManager final
49 : public RuntimeClass<RuntimeClassFlags<ClassicCom>,
50 INetworkCostManager,
51 IConnectionPointContainer,
52 IConnectionPoint> {
53 public:
FakeNetworkCostManager(NetworkChangeNotifier::ConnectionCost connection_cost,NetworkCostManagerStatus error_status)54 FakeNetworkCostManager(NetworkChangeNotifier::ConnectionCost connection_cost,
55 NetworkCostManagerStatus error_status)
56 : error_status_(error_status), connection_cost_(connection_cost) {}
57
58 // For each event sink in `event_sinks_`, call
59 // `INetworkCostManagerEvents::CostChanged()` with `changed_cost` on the event
60 // sink's task runner.
PostCostChangedEvents(NetworkChangeNotifier::ConnectionCost changed_cost)61 void PostCostChangedEvents(
62 NetworkChangeNotifier::ConnectionCost changed_cost) {
63 DWORD cost_for_changed_event;
64 std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration>
65 event_sinks_for_changed_event;
66 {
67 base::AutoLock auto_lock(member_lock_);
68 connection_cost_ = changed_cost;
69 cost_for_changed_event =
70 NlmConnectionCostFlagsFromConnectionCost(changed_cost);
71
72 // Get the snapshot of event sinks to notify. The snapshot collection
73 // creates a new `ComPtr` for each event sink, which increments each the
74 // event sink's reference count, ensuring that each event sink
75 // remains alive to receive the cost changed event notification.
76 event_sinks_for_changed_event = event_sinks_;
77 }
78
79 for (const auto& pair : event_sinks_for_changed_event) {
80 const auto& registration = pair.second;
81 registration.event_sink_task_runner_->PostTask(
82 FROM_HERE,
83 base::BindOnce(&DispatchCostChangedEvent, registration.event_sink_,
84 cost_for_changed_event));
85 }
86 }
87
88 // Implement the `INetworkCostManager` interface.
89 HRESULT
GetCost(DWORD * cost,NLM_SOCKADDR * destination_ip_address)90 __stdcall GetCost(DWORD* cost,
91 NLM_SOCKADDR* destination_ip_address) override {
92 if (error_status_ == NetworkCostManagerStatus::kErrorGetCostFailed) {
93 return E_FAIL;
94 }
95
96 if (destination_ip_address != nullptr) {
97 NOTIMPLEMENTED();
98 return E_NOTIMPL;
99 }
100
101 {
102 base::AutoLock auto_lock(member_lock_);
103 *cost = NlmConnectionCostFlagsFromConnectionCost(connection_cost_);
104 }
105 return S_OK;
106 }
107
GetDataPlanStatus(NLM_DATAPLAN_STATUS * data_plan_status,NLM_SOCKADDR * destination_ip_address)108 HRESULT __stdcall GetDataPlanStatus(
109 NLM_DATAPLAN_STATUS* data_plan_status,
110 NLM_SOCKADDR* destination_ip_address) override {
111 NOTIMPLEMENTED();
112 return E_NOTIMPL;
113 }
114
SetDestinationAddresses(UINT32 length,NLM_SOCKADDR * destination_ip_address_list,VARIANT_BOOL append)115 HRESULT __stdcall SetDestinationAddresses(
116 UINT32 length,
117 NLM_SOCKADDR* destination_ip_address_list,
118 VARIANT_BOOL append) override {
119 NOTIMPLEMENTED();
120 return E_NOTIMPL;
121 }
122
123 // Implement the `IConnectionPointContainer` interface.
FindConnectionPoint(REFIID connection_point_id,IConnectionPoint ** result)124 HRESULT __stdcall FindConnectionPoint(REFIID connection_point_id,
125 IConnectionPoint** result) override {
126 if (error_status_ ==
127 NetworkCostManagerStatus::kErrorFindConnectionPointFailed) {
128 return E_ABORT;
129 }
130
131 if (connection_point_id != IID_INetworkCostManagerEvents) {
132 return E_NOINTERFACE;
133 }
134
135 *result = static_cast<IConnectionPoint*>(this);
136 AddRef();
137 return S_OK;
138 }
139
EnumConnectionPoints(IEnumConnectionPoints ** results)140 HRESULT __stdcall EnumConnectionPoints(
141 IEnumConnectionPoints** results) override {
142 NOTIMPLEMENTED();
143 return E_NOTIMPL;
144 }
145
146 // Implement the `IConnectionPoint` interface.
Advise(IUnknown * event_sink,DWORD * event_sink_cookie)147 HRESULT __stdcall Advise(IUnknown* event_sink,
148 DWORD* event_sink_cookie) override {
149 if (error_status_ == NetworkCostManagerStatus::kErrorAdviseFailed) {
150 return E_NOT_VALID_STATE;
151 }
152
153 ComPtr<INetworkCostManagerEvents> cost_manager_event_sink;
154 HRESULT hr =
155 event_sink->QueryInterface(IID_PPV_ARGS(&cost_manager_event_sink));
156 if (hr != S_OK) {
157 return hr;
158 }
159
160 base::AutoLock auto_lock(member_lock_);
161
162 event_sinks_[next_event_sink_cookie_] = {
163 cost_manager_event_sink,
164 base::SequencedTaskRunner::GetCurrentDefault()};
165
166 *event_sink_cookie = next_event_sink_cookie_;
167 ++next_event_sink_cookie_;
168
169 return S_OK;
170 }
171
Unadvise(DWORD event_sink_cookie)172 HRESULT __stdcall Unadvise(DWORD event_sink_cookie) override {
173 base::AutoLock auto_lock(member_lock_);
174
175 auto it = event_sinks_.find(event_sink_cookie);
176 if (it == event_sinks_.end()) {
177 return ERROR_NOT_FOUND;
178 }
179
180 event_sinks_.erase(it);
181 return S_OK;
182 }
183
GetConnectionInterface(IID * result)184 HRESULT __stdcall GetConnectionInterface(IID* result) override {
185 NOTIMPLEMENTED();
186 return E_NOTIMPL;
187 }
188
GetConnectionPointContainer(IConnectionPointContainer ** result)189 HRESULT __stdcall GetConnectionPointContainer(
190 IConnectionPointContainer** result) override {
191 NOTIMPLEMENTED();
192 return E_NOTIMPL;
193 }
194
EnumConnections(IEnumConnections ** result)195 HRESULT __stdcall EnumConnections(IEnumConnections** result) override {
196 NOTIMPLEMENTED();
197 return E_NOTIMPL;
198 }
199
200 // Implement the `IUnknown` interface.
QueryInterface(REFIID interface_id,void ** result)201 HRESULT __stdcall QueryInterface(REFIID interface_id,
202 void** result) override {
203 if (error_status_ == NetworkCostManagerStatus::kErrorQueryInterfaceFailed) {
204 return E_NOINTERFACE;
205 }
206 return RuntimeClass<RuntimeClassFlags<ClassicCom>, INetworkCostManager,
207 IConnectionPointContainer,
208 IConnectionPoint>::QueryInterface(interface_id, result);
209 }
210
211 FakeNetworkCostManager(const FakeNetworkCostManager&) = delete;
212 FakeNetworkCostManager& operator=(const FakeNetworkCostManager&) = delete;
213
214 private:
215 // The error state for this `FakeNetworkCostManager` to simulate. Cannot be
216 // changed.
217 const NetworkCostManagerStatus error_status_;
218
219 // Synchronizes access to all members below.
220 base::Lock member_lock_;
221
222 NetworkChangeNotifier::ConnectionCost connection_cost_
223 GUARDED_BY(member_lock_);
224
225 DWORD next_event_sink_cookie_ GUARDED_BY(member_lock_) = 0;
226
227 struct EventSinkRegistration {
228 ComPtr<INetworkCostManagerEvents> event_sink_;
229 scoped_refptr<base::SequencedTaskRunner> event_sink_task_runner_;
230 };
231 std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration> event_sinks_
232 GUARDED_BY(member_lock_);
233 };
234
FakeNetworkCostManagerEnvironment()235 FakeNetworkCostManagerEnvironment::FakeNetworkCostManagerEnvironment() {
236 // Set up `NetworkCostChangeNotifierWin` to use the fake OS APIs.
237 NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
238 base::BindRepeating(
239 &FakeNetworkCostManagerEnvironment::FakeCoCreateInstance,
240 base::Unretained(this)));
241 }
242
~FakeNetworkCostManagerEnvironment()243 FakeNetworkCostManagerEnvironment::~FakeNetworkCostManagerEnvironment() {
244 // Restore `NetworkCostChangeNotifierWin` to use the real OS APIs.
245 NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
246 base::BindRepeating(&CoCreateInstance));
247 }
248
FakeCoCreateInstance(REFCLSID class_id,LPUNKNOWN outer_aggregate,DWORD context_flags,REFIID interface_id,LPVOID * result)249 HRESULT FakeNetworkCostManagerEnvironment::FakeCoCreateInstance(
250 REFCLSID class_id,
251 LPUNKNOWN outer_aggregate,
252 DWORD context_flags,
253 REFIID interface_id,
254 LPVOID* result) {
255 NetworkChangeNotifier::ConnectionCost connection_cost_for_new_instance;
256 NetworkCostManagerStatus error_status_for_new_instance;
257 {
258 base::AutoLock auto_lock(member_lock_);
259 connection_cost_for_new_instance = connection_cost_;
260 error_status_for_new_instance = error_status_;
261 }
262
263 if (error_status_for_new_instance ==
264 NetworkCostManagerStatus::kErrorCoCreateInstanceFailed) {
265 return E_ACCESSDENIED;
266 }
267
268 if (class_id != CLSID_NetworkListManager) {
269 return E_NOINTERFACE;
270 }
271
272 if (interface_id != IID_INetworkCostManager) {
273 return E_NOINTERFACE;
274 }
275
276 ComPtr<FakeNetworkCostManager> instance =
277 Microsoft::WRL::Make<FakeNetworkCostManager>(
278 connection_cost_for_new_instance, error_status_for_new_instance);
279 {
280 base::AutoLock auto_lock(member_lock_);
281 fake_network_cost_managers_.push_back(instance);
282 }
283 *result = instance.Detach();
284 return S_OK;
285 }
286
SetCost(NetworkChangeNotifier::ConnectionCost value)287 void FakeNetworkCostManagerEnvironment::SetCost(
288 NetworkChangeNotifier::ConnectionCost value) {
289 // Update the cost for each `INetworkCostManager` instance in
290 // `fake_network_cost_managers_`.
291 std::vector<Microsoft::WRL::ComPtr<FakeNetworkCostManager>>
292 fake_network_cost_managers_for_change_event;
293 {
294 base::AutoLock auto_lock(member_lock_);
295 connection_cost_ = value;
296 fake_network_cost_managers_for_change_event = fake_network_cost_managers_;
297 }
298
299 for (const auto& network_cost_manager :
300 fake_network_cost_managers_for_change_event) {
301 network_cost_manager->PostCostChangedEvents(/*connection_cost=*/value);
302 }
303 }
304
SimulateError(NetworkCostManagerStatus error_status)305 void FakeNetworkCostManagerEnvironment::SimulateError(
306 NetworkCostManagerStatus error_status) {
307 base::AutoLock auto_lock(member_lock_);
308 error_status_ = error_status;
309 }
310
311 } // namespace net
312