1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/dns/public/win_dns_system_settings.h"
11
12 #include "testing/gtest/include/gtest/gtest.h"
13
14 namespace net {
15
16 namespace {
17
18 struct AdapterInfo {
19 IFTYPE if_type;
20 IF_OPER_STATUS oper_status;
21 const WCHAR* dns_suffix;
22 std::string dns_server_addresses[4]; // Empty string indicates end.
23 uint16_t ports[4];
24 };
25
CreateAdapterAddresses(const AdapterInfo * infos)26 std::unique_ptr<IP_ADAPTER_ADDRESSES, base::FreeDeleter> CreateAdapterAddresses(
27 const AdapterInfo* infos) {
28 size_t num_adapters = 0;
29 size_t num_addresses = 0;
30 for (size_t i = 0; infos[i].if_type; ++i) {
31 ++num_adapters;
32 for (size_t j = 0; !infos[i].dns_server_addresses[j].empty(); ++j) {
33 ++num_addresses;
34 }
35 }
36
37 size_t heap_size = num_adapters * sizeof(IP_ADAPTER_ADDRESSES) +
38 num_addresses * (sizeof(IP_ADAPTER_DNS_SERVER_ADDRESS) +
39 sizeof(struct sockaddr_storage));
40 std::unique_ptr<IP_ADAPTER_ADDRESSES, base::FreeDeleter> heap(
41 static_cast<IP_ADAPTER_ADDRESSES*>(malloc(heap_size)));
42 CHECK(heap.get());
43 memset(heap.get(), 0, heap_size);
44
45 IP_ADAPTER_ADDRESSES* adapters = heap.get();
46 IP_ADAPTER_DNS_SERVER_ADDRESS* addresses =
47 reinterpret_cast<IP_ADAPTER_DNS_SERVER_ADDRESS*>(adapters + num_adapters);
48 struct sockaddr_storage* storage =
49 reinterpret_cast<struct sockaddr_storage*>(addresses + num_addresses);
50
51 for (size_t i = 0; i < num_adapters; ++i) {
52 const AdapterInfo& info = infos[i];
53 IP_ADAPTER_ADDRESSES* adapter = adapters + i;
54 if (i + 1 < num_adapters)
55 adapter->Next = adapter + 1;
56 adapter->IfType = info.if_type;
57 adapter->OperStatus = info.oper_status;
58 adapter->DnsSuffix = const_cast<PWCHAR>(info.dns_suffix);
59 IP_ADAPTER_DNS_SERVER_ADDRESS* address = nullptr;
60 for (size_t j = 0; !info.dns_server_addresses[j].empty(); ++j) {
61 --num_addresses;
62 if (j == 0) {
63 address = adapter->FirstDnsServerAddress = addresses + num_addresses;
64 } else {
65 // Note that |address| is moving backwards.
66 address = address->Next = address - 1;
67 }
68 IPAddress ip;
69 CHECK(ip.AssignFromIPLiteral(info.dns_server_addresses[j]));
70 IPEndPoint ipe = IPEndPoint(ip, info.ports[j]);
71 address->Address.lpSockaddr =
72 reinterpret_cast<LPSOCKADDR>(storage + num_addresses);
73 socklen_t length = sizeof(struct sockaddr_storage);
74 CHECK(ipe.ToSockAddr(address->Address.lpSockaddr, &length));
75 address->Address.iSockaddrLength = static_cast<int>(length);
76 }
77 }
78
79 return heap;
80 }
81
TEST(WinDnsSystemSettings,GetAllNameServersEmpty)82 TEST(WinDnsSystemSettings, GetAllNameServersEmpty) {
83 AdapterInfo infos[3] = {
84 {
85 .if_type = IF_TYPE_USB,
86 .oper_status = IfOperStatusUp,
87 .dns_suffix = L"example.com",
88 .dns_server_addresses = {},
89 },
90 {
91 .if_type = IF_TYPE_USB,
92 .oper_status = IfOperStatusUp,
93 .dns_suffix = L"foo.bar",
94 .dns_server_addresses = {},
95 },
96 {0}};
97
98 WinDnsSystemSettings settings;
99 settings.addresses = CreateAdapterAddresses(infos);
100 std::optional<std::vector<IPEndPoint>> nameservers =
101 settings.GetAllNameservers();
102 EXPECT_TRUE(nameservers.has_value());
103 EXPECT_TRUE(nameservers.value().empty());
104 }
105
TEST(WinDnsSystemSettings,GetAllNameServersStatelessDiscoveryAdresses)106 TEST(WinDnsSystemSettings, GetAllNameServersStatelessDiscoveryAdresses) {
107 AdapterInfo infos[3] = {
108 {
109 .if_type = IF_TYPE_USB,
110 .oper_status = IfOperStatusUp,
111 .dns_suffix = L"example.com",
112 .dns_server_addresses = {"fec0:0:0:ffff::1", "fec0:0:0:ffff::2"},
113 },
114 {
115 .if_type = IF_TYPE_USB,
116 .oper_status = IfOperStatusUp,
117 .dns_suffix = L"foo.bar",
118 .dns_server_addresses = {"fec0:0:0:ffff::3"},
119 },
120 {0}};
121
122 WinDnsSystemSettings settings;
123 settings.addresses = CreateAdapterAddresses(infos);
124 std::optional<std::vector<IPEndPoint>> nameservers =
125 settings.GetAllNameservers();
126 EXPECT_TRUE(nameservers.has_value());
127 EXPECT_TRUE(nameservers.value().empty());
128 }
129
TEST(WinDnsSystemSettings,GetAllNameServersValid)130 TEST(WinDnsSystemSettings, GetAllNameServersValid) {
131 AdapterInfo infos[3] = {
132 {.if_type = IF_TYPE_USB,
133 .oper_status = IfOperStatusUp,
134 .dns_suffix = L"example.com",
135 .dns_server_addresses = {"8.8.8.8", "10.0.0.10"},
136 .ports = {11, 22}},
137 {.if_type = IF_TYPE_USB,
138 .oper_status = IfOperStatusUp,
139 .dns_suffix = L"foo.bar",
140 .dns_server_addresses = {"2001:ffff::1111",
141 "aaaa:bbbb:cccc:dddd:eeee:ffff:0:1"},
142 .ports = {33, 44}},
143 {0}};
144
145 WinDnsSystemSettings settings;
146 settings.addresses = CreateAdapterAddresses(infos);
147 std::optional<std::vector<IPEndPoint>> nameservers =
148 settings.GetAllNameservers();
149 EXPECT_TRUE(nameservers.has_value());
150 EXPECT_EQ(4u, nameservers.value().size());
151 EXPECT_EQ(nameservers.value()[0].ToString(), "8.8.8.8:11");
152 EXPECT_EQ(nameservers.value()[1].ToString(), "10.0.0.10:22");
153 EXPECT_EQ(nameservers.value()[2].ToString(), "[2001:ffff::1111]:33");
154 EXPECT_EQ(nameservers.value()[3].ToString(),
155 "[aaaa:bbbb:cccc:dddd:eeee:ffff:0:1]:44");
156 }
157 } // namespace
158
159 } // namespace net
160