1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/dns/address_sorter.h"
11
12 #include <winsock2.h>
13
14 #include <algorithm>
15 #include <utility>
16 #include <vector>
17
18 #include "base/functional/bind.h"
19 #include "base/location.h"
20 #include "base/logging.h"
21 #include "base/memory/free_deleter.h"
22 #include "base/task/thread_pool.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/winsock_init.h"
26
27 namespace net {
28
29 namespace {
30
31 class AddressSorterWin : public AddressSorter {
32 public:
AddressSorterWin()33 AddressSorterWin() {
34 EnsureWinsockInit();
35 }
36
37 AddressSorterWin(const AddressSorterWin&) = delete;
38 AddressSorterWin& operator=(const AddressSorterWin&) = delete;
39
~AddressSorterWin()40 ~AddressSorterWin() override {}
41
42 // AddressSorter:
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const43 void Sort(const std::vector<IPEndPoint>& endpoints,
44 CallbackType callback) const override {
45 DCHECK(!endpoints.empty());
46 Job::Start(endpoints, std::move(callback));
47 }
48
49 private:
50 // Executes the SIO_ADDRESS_LIST_SORT ioctl asynchronously, and
51 // performs the necessary conversions to/from `std::vector<IPEndPoint>`.
52 class Job : public base::RefCountedThreadSafe<Job> {
53 public:
Start(const std::vector<IPEndPoint> & endpoints,CallbackType callback)54 static void Start(const std::vector<IPEndPoint>& endpoints,
55 CallbackType callback) {
56 auto job = base::WrapRefCounted(new Job(endpoints, std::move(callback)));
57 base::ThreadPool::PostTaskAndReply(
58 FROM_HERE,
59 {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
60 base::BindOnce(&Job::Run, job),
61 base::BindOnce(&Job::OnComplete, job));
62 }
63
64 Job(const Job&) = delete;
65 Job& operator=(const Job&) = delete;
66
67 private:
68 friend class base::RefCountedThreadSafe<Job>;
69
Job(const std::vector<IPEndPoint> & endpoints,CallbackType callback)70 Job(const std::vector<IPEndPoint>& endpoints, CallbackType callback)
71 : callback_(std::move(callback)),
72 buffer_size_((sizeof(SOCKET_ADDRESS_LIST) +
73 base::CheckedNumeric<DWORD>(endpoints.size()) *
74 (sizeof(SOCKET_ADDRESS) + sizeof(SOCKADDR_STORAGE)))
75 .ValueOrDie<DWORD>()),
76 input_buffer_(
77 reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))),
78 output_buffer_(
79 reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))) {
80 input_buffer_->iAddressCount = base::checked_cast<INT>(endpoints.size());
81 SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>(
82 input_buffer_->Address + input_buffer_->iAddressCount);
83
84 for (size_t i = 0; i < endpoints.size(); ++i) {
85 IPEndPoint ipe = endpoints[i];
86 // Addresses must be sockaddr_in6.
87 if (ipe.address().IsIPv4()) {
88 ipe = IPEndPoint(ConvertIPv4ToIPv4MappedIPv6(ipe.address()),
89 ipe.port());
90 }
91
92 struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i);
93 socklen_t addr_len = sizeof(SOCKADDR_STORAGE);
94 bool result = ipe.ToSockAddr(addr, &addr_len);
95 DCHECK(result);
96 input_buffer_->Address[i].lpSockaddr = addr;
97 input_buffer_->Address[i].iSockaddrLength = addr_len;
98 }
99 }
100
~Job()101 ~Job() {}
102
103 // Executed asynchronously in ThreadPool.
Run()104 void Run() {
105 SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
106 if (sock == INVALID_SOCKET)
107 return;
108 DWORD result_size = 0;
109 int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(),
110 buffer_size_, output_buffer_.get(), buffer_size_,
111 &result_size, nullptr, nullptr);
112 if (result == SOCKET_ERROR) {
113 LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError();
114 } else {
115 success_ = true;
116 }
117 closesocket(sock);
118 }
119
120 // Executed on the calling thread.
OnComplete()121 void OnComplete() {
122 std::vector<IPEndPoint> sorted;
123 if (success_) {
124 sorted.reserve(output_buffer_->iAddressCount);
125 for (int i = 0; i < output_buffer_->iAddressCount; ++i) {
126 IPEndPoint ipe;
127 bool result =
128 ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr,
129 output_buffer_->Address[i].iSockaddrLength);
130 DCHECK(result) << "Unable to roundtrip between IPEndPoint and "
131 << "SOCKET_ADDRESS!";
132 // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works.
133 if (ipe.address().IsIPv4MappedIPv6()) {
134 ipe = IPEndPoint(ConvertIPv4MappedIPv6ToIPv4(ipe.address()),
135 ipe.port());
136 }
137 sorted.push_back(ipe);
138 }
139 }
140 std::move(callback_).Run(success_, std::move(sorted));
141 }
142
143 CallbackType callback_;
144 const DWORD buffer_size_;
145 std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> input_buffer_;
146 std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> output_buffer_;
147 bool success_ = false;
148 };
149 };
150
151 } // namespace
152
153 // static
CreateAddressSorter()154 std::unique_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
155 return std::make_unique<AddressSorterWin>();
156 }
157
158 } // namespace net
159