• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/win/fake_network_cost_manager.h"
6 
7 #include <netlistmgr.h>
8 #include <wrl/implements.h>
9 
10 #include <map>
11 
12 #include "base/task/sequenced_task_runner.h"
13 #include "net/base/network_cost_change_notifier_win.h"
14 
15 using Microsoft::WRL::ClassicCom;
16 using Microsoft::WRL::ComPtr;
17 using Microsoft::WRL::RuntimeClass;
18 using Microsoft::WRL::RuntimeClassFlags;
19 
20 namespace net {
21 
22 namespace {
23 
NlmConnectionCostFlagsFromConnectionCost(NetworkChangeNotifier::ConnectionCost source_cost)24 DWORD NlmConnectionCostFlagsFromConnectionCost(
25     NetworkChangeNotifier::ConnectionCost source_cost) {
26   switch (source_cost) {
27     case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNMETERED:
28       return (NLM_CONNECTION_COST_UNRESTRICTED | NLM_CONNECTION_COST_CONGESTED);
29     case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_METERED:
30       return (NLM_CONNECTION_COST_VARIABLE | NLM_CONNECTION_COST_ROAMING |
31               NLM_CONNECTION_COST_APPROACHINGDATALIMIT);
32     case NetworkChangeNotifier::ConnectionCost::CONNECTION_COST_UNKNOWN:
33     default:
34       return NLM_CONNECTION_COST_UNKNOWN;
35   }
36 }
37 
DispatchCostChangedEvent(ComPtr<INetworkCostManagerEvents> event_target,DWORD cost)38 void DispatchCostChangedEvent(ComPtr<INetworkCostManagerEvents> event_target,
39                               DWORD cost) {
40   std::ignore =
41       event_target->CostChanged(cost, /*destination_address=*/nullptr);
42 }
43 
44 }  // namespace
45 
46 // A fake implementation of `INetworkCostManager` that can simulate costs,
47 // changed costs and errors.
48 class FakeNetworkCostManager final
49     : public RuntimeClass<RuntimeClassFlags<ClassicCom>,
50                           INetworkCostManager,
51                           IConnectionPointContainer,
52                           IConnectionPoint> {
53  public:
FakeNetworkCostManager(NetworkChangeNotifier::ConnectionCost connection_cost,NetworkCostManagerStatus error_status)54   FakeNetworkCostManager(NetworkChangeNotifier::ConnectionCost connection_cost,
55                          NetworkCostManagerStatus error_status)
56       : error_status_(error_status), connection_cost_(connection_cost) {}
57 
58   // For each event sink in `event_sinks_`, call
59   // `INetworkCostManagerEvents::CostChanged()` with `changed_cost` on the event
60   // sink's task runner.
PostCostChangedEvents(NetworkChangeNotifier::ConnectionCost changed_cost)61   void PostCostChangedEvents(
62       NetworkChangeNotifier::ConnectionCost changed_cost) {
63     DWORD cost_for_changed_event;
64     std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration>
65         event_sinks_for_changed_event;
66     {
67       base::AutoLock auto_lock(member_lock_);
68       connection_cost_ = changed_cost;
69       cost_for_changed_event =
70           NlmConnectionCostFlagsFromConnectionCost(changed_cost);
71 
72       // Get the snapshot of event sinks to notify.  The snapshot collection
73       // creates a new `ComPtr` for each event sink, which increments each the
74       // event sink's reference count, ensuring that each event sink
75       // remains alive to receive the cost changed event notification.
76       event_sinks_for_changed_event = event_sinks_;
77     }
78 
79     for (const auto& pair : event_sinks_for_changed_event) {
80       const auto& registration = pair.second;
81       registration.event_sink_task_runner_->PostTask(
82           FROM_HERE,
83           base::BindOnce(&DispatchCostChangedEvent, registration.event_sink_,
84                          cost_for_changed_event));
85     }
86   }
87 
88   // Implement the `INetworkCostManager` interface.
89   HRESULT
GetCost(DWORD * cost,NLM_SOCKADDR * destination_ip_address)90   __stdcall GetCost(DWORD* cost,
91                     NLM_SOCKADDR* destination_ip_address) override {
92     if (error_status_ == NetworkCostManagerStatus::kErrorGetCostFailed) {
93       return E_FAIL;
94     }
95 
96     if (destination_ip_address != nullptr) {
97       NOTIMPLEMENTED();
98       return E_NOTIMPL;
99     }
100 
101     {
102       base::AutoLock auto_lock(member_lock_);
103       *cost = NlmConnectionCostFlagsFromConnectionCost(connection_cost_);
104     }
105     return S_OK;
106   }
107 
GetDataPlanStatus(NLM_DATAPLAN_STATUS * data_plan_status,NLM_SOCKADDR * destination_ip_address)108   HRESULT __stdcall GetDataPlanStatus(
109       NLM_DATAPLAN_STATUS* data_plan_status,
110       NLM_SOCKADDR* destination_ip_address) override {
111     NOTIMPLEMENTED();
112     return E_NOTIMPL;
113   }
114 
SetDestinationAddresses(UINT32 length,NLM_SOCKADDR * destination_ip_address_list,VARIANT_BOOL append)115   HRESULT __stdcall SetDestinationAddresses(
116       UINT32 length,
117       NLM_SOCKADDR* destination_ip_address_list,
118       VARIANT_BOOL append) override {
119     NOTIMPLEMENTED();
120     return E_NOTIMPL;
121   }
122 
123   // Implement the `IConnectionPointContainer` interface.
FindConnectionPoint(REFIID connection_point_id,IConnectionPoint ** result)124   HRESULT __stdcall FindConnectionPoint(REFIID connection_point_id,
125                                         IConnectionPoint** result) override {
126     if (error_status_ ==
127         NetworkCostManagerStatus::kErrorFindConnectionPointFailed) {
128       return E_ABORT;
129     }
130 
131     if (connection_point_id != IID_INetworkCostManagerEvents) {
132       return E_NOINTERFACE;
133     }
134 
135     *result = static_cast<IConnectionPoint*>(this);
136     AddRef();
137     return S_OK;
138   }
139 
EnumConnectionPoints(IEnumConnectionPoints ** results)140   HRESULT __stdcall EnumConnectionPoints(
141       IEnumConnectionPoints** results) override {
142     NOTIMPLEMENTED();
143     return E_NOTIMPL;
144   }
145 
146   // Implement the `IConnectionPoint` interface.
Advise(IUnknown * event_sink,DWORD * event_sink_cookie)147   HRESULT __stdcall Advise(IUnknown* event_sink,
148                            DWORD* event_sink_cookie) override {
149     if (error_status_ == NetworkCostManagerStatus::kErrorAdviseFailed) {
150       return E_NOT_VALID_STATE;
151     }
152 
153     ComPtr<INetworkCostManagerEvents> cost_manager_event_sink;
154     HRESULT hr =
155         event_sink->QueryInterface(IID_PPV_ARGS(&cost_manager_event_sink));
156     if (hr != S_OK) {
157       return hr;
158     }
159 
160     base::AutoLock auto_lock(member_lock_);
161 
162     event_sinks_[next_event_sink_cookie_] = {
163         cost_manager_event_sink,
164         base::SequencedTaskRunner::GetCurrentDefault()};
165 
166     *event_sink_cookie = next_event_sink_cookie_;
167     ++next_event_sink_cookie_;
168 
169     return S_OK;
170   }
171 
Unadvise(DWORD event_sink_cookie)172   HRESULT __stdcall Unadvise(DWORD event_sink_cookie) override {
173     base::AutoLock auto_lock(member_lock_);
174 
175     auto it = event_sinks_.find(event_sink_cookie);
176     if (it == event_sinks_.end()) {
177       return ERROR_NOT_FOUND;
178     }
179 
180     event_sinks_.erase(it);
181     return S_OK;
182   }
183 
GetConnectionInterface(IID * result)184   HRESULT __stdcall GetConnectionInterface(IID* result) override {
185     NOTIMPLEMENTED();
186     return E_NOTIMPL;
187   }
188 
GetConnectionPointContainer(IConnectionPointContainer ** result)189   HRESULT __stdcall GetConnectionPointContainer(
190       IConnectionPointContainer** result) override {
191     NOTIMPLEMENTED();
192     return E_NOTIMPL;
193   }
194 
EnumConnections(IEnumConnections ** result)195   HRESULT __stdcall EnumConnections(IEnumConnections** result) override {
196     NOTIMPLEMENTED();
197     return E_NOTIMPL;
198   }
199 
200   // Implement the `IUnknown` interface.
QueryInterface(REFIID interface_id,void ** result)201   HRESULT __stdcall QueryInterface(REFIID interface_id,
202                                    void** result) override {
203     if (error_status_ == NetworkCostManagerStatus::kErrorQueryInterfaceFailed) {
204       return E_NOINTERFACE;
205     }
206     return RuntimeClass<RuntimeClassFlags<ClassicCom>, INetworkCostManager,
207                         IConnectionPointContainer,
208                         IConnectionPoint>::QueryInterface(interface_id, result);
209   }
210 
211   FakeNetworkCostManager(const FakeNetworkCostManager&) = delete;
212   FakeNetworkCostManager& operator=(const FakeNetworkCostManager&) = delete;
213 
214  private:
215   // The error state for this `FakeNetworkCostManager` to simulate.  Cannot be
216   // changed.
217   const NetworkCostManagerStatus error_status_;
218 
219   // Synchronizes access to all members below.
220   base::Lock member_lock_;
221 
222   NetworkChangeNotifier::ConnectionCost connection_cost_
223       GUARDED_BY(member_lock_);
224 
225   DWORD next_event_sink_cookie_ GUARDED_BY(member_lock_) = 0;
226 
227   struct EventSinkRegistration {
228     ComPtr<INetworkCostManagerEvents> event_sink_;
229     scoped_refptr<base::SequencedTaskRunner> event_sink_task_runner_;
230   };
231   std::map</*event_sink_cookie=*/DWORD, EventSinkRegistration> event_sinks_
232       GUARDED_BY(member_lock_);
233 };
234 
FakeNetworkCostManagerEnvironment()235 FakeNetworkCostManagerEnvironment::FakeNetworkCostManagerEnvironment() {
236   // Set up `NetworkCostChangeNotifierWin` to use the fake OS APIs.
237   NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
238       base::BindRepeating(
239           &FakeNetworkCostManagerEnvironment::FakeCoCreateInstance,
240           base::Unretained(this)));
241 }
242 
~FakeNetworkCostManagerEnvironment()243 FakeNetworkCostManagerEnvironment::~FakeNetworkCostManagerEnvironment() {
244   // Restore `NetworkCostChangeNotifierWin` to use the real OS APIs.
245   NetworkCostChangeNotifierWin::OverrideCoCreateInstanceForTesting(
246       base::BindRepeating(&CoCreateInstance));
247 }
248 
FakeCoCreateInstance(REFCLSID class_id,LPUNKNOWN outer_aggregate,DWORD context_flags,REFIID interface_id,LPVOID * result)249 HRESULT FakeNetworkCostManagerEnvironment::FakeCoCreateInstance(
250     REFCLSID class_id,
251     LPUNKNOWN outer_aggregate,
252     DWORD context_flags,
253     REFIID interface_id,
254     LPVOID* result) {
255   NetworkChangeNotifier::ConnectionCost connection_cost_for_new_instance;
256   NetworkCostManagerStatus error_status_for_new_instance;
257   {
258     base::AutoLock auto_lock(member_lock_);
259     connection_cost_for_new_instance = connection_cost_;
260     error_status_for_new_instance = error_status_;
261   }
262 
263   if (error_status_for_new_instance ==
264       NetworkCostManagerStatus::kErrorCoCreateInstanceFailed) {
265     return E_ACCESSDENIED;
266   }
267 
268   if (class_id != CLSID_NetworkListManager) {
269     return E_NOINTERFACE;
270   }
271 
272   if (interface_id != IID_INetworkCostManager) {
273     return E_NOINTERFACE;
274   }
275 
276   ComPtr<FakeNetworkCostManager> instance =
277       Microsoft::WRL::Make<FakeNetworkCostManager>(
278           connection_cost_for_new_instance, error_status_for_new_instance);
279   {
280     base::AutoLock auto_lock(member_lock_);
281     fake_network_cost_managers_.push_back(instance);
282   }
283   *result = instance.Detach();
284   return S_OK;
285 }
286 
SetCost(NetworkChangeNotifier::ConnectionCost value)287 void FakeNetworkCostManagerEnvironment::SetCost(
288     NetworkChangeNotifier::ConnectionCost value) {
289   // Update the cost for each `INetworkCostManager` instance in
290   // `fake_network_cost_managers_`.
291   std::vector<Microsoft::WRL::ComPtr<FakeNetworkCostManager>>
292       fake_network_cost_managers_for_change_event;
293   {
294     base::AutoLock auto_lock(member_lock_);
295     connection_cost_ = value;
296     fake_network_cost_managers_for_change_event = fake_network_cost_managers_;
297   }
298 
299   for (const auto& network_cost_manager :
300        fake_network_cost_managers_for_change_event) {
301     network_cost_manager->PostCostChangedEvents(/*connection_cost=*/value);
302   }
303 }
304 
SimulateError(NetworkCostManagerStatus error_status)305 void FakeNetworkCostManagerEnvironment::SimulateError(
306     NetworkCostManagerStatus error_status) {
307   base::AutoLock auto_lock(member_lock_);
308   error_status_ = error_status;
309 }
310 
311 }  // namespace net
312