• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/dns/address_sorter.h"
11 
12 #include <winsock2.h>
13 
14 #include <algorithm>
15 #include <utility>
16 #include <vector>
17 
18 #include "base/functional/bind.h"
19 #include "base/location.h"
20 #include "base/logging.h"
21 #include "base/memory/free_deleter.h"
22 #include "base/task/thread_pool.h"
23 #include "net/base/ip_address.h"
24 #include "net/base/ip_endpoint.h"
25 #include "net/base/winsock_init.h"
26 
27 namespace net {
28 
29 namespace {
30 
31 class AddressSorterWin : public AddressSorter {
32  public:
AddressSorterWin()33   AddressSorterWin() {
34     EnsureWinsockInit();
35   }
36 
37   AddressSorterWin(const AddressSorterWin&) = delete;
38   AddressSorterWin& operator=(const AddressSorterWin&) = delete;
39 
~AddressSorterWin()40   ~AddressSorterWin() override {}
41 
42   // AddressSorter:
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const43   void Sort(const std::vector<IPEndPoint>& endpoints,
44             CallbackType callback) const override {
45     DCHECK(!endpoints.empty());
46     Job::Start(endpoints, std::move(callback));
47   }
48 
49  private:
50   // Executes the SIO_ADDRESS_LIST_SORT ioctl asynchronously, and
51   // performs the necessary conversions to/from `std::vector<IPEndPoint>`.
52   class Job : public base::RefCountedThreadSafe<Job> {
53    public:
Start(const std::vector<IPEndPoint> & endpoints,CallbackType callback)54     static void Start(const std::vector<IPEndPoint>& endpoints,
55                       CallbackType callback) {
56       auto job = base::WrapRefCounted(new Job(endpoints, std::move(callback)));
57       base::ThreadPool::PostTaskAndReply(
58           FROM_HERE,
59           {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
60           base::BindOnce(&Job::Run, job),
61           base::BindOnce(&Job::OnComplete, job));
62     }
63 
64     Job(const Job&) = delete;
65     Job& operator=(const Job&) = delete;
66 
67    private:
68     friend class base::RefCountedThreadSafe<Job>;
69 
Job(const std::vector<IPEndPoint> & endpoints,CallbackType callback)70     Job(const std::vector<IPEndPoint>& endpoints, CallbackType callback)
71         : callback_(std::move(callback)),
72           buffer_size_((sizeof(SOCKET_ADDRESS_LIST) +
73                         base::CheckedNumeric<DWORD>(endpoints.size()) *
74                             (sizeof(SOCKET_ADDRESS) + sizeof(SOCKADDR_STORAGE)))
75                            .ValueOrDie<DWORD>()),
76           input_buffer_(
77               reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))),
78           output_buffer_(
79               reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))) {
80       input_buffer_->iAddressCount = base::checked_cast<INT>(endpoints.size());
81       SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>(
82           input_buffer_->Address + input_buffer_->iAddressCount);
83 
84       for (size_t i = 0; i < endpoints.size(); ++i) {
85         IPEndPoint ipe = endpoints[i];
86         // Addresses must be sockaddr_in6.
87         if (ipe.address().IsIPv4()) {
88           ipe = IPEndPoint(ConvertIPv4ToIPv4MappedIPv6(ipe.address()),
89                            ipe.port());
90         }
91 
92         struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i);
93         socklen_t addr_len = sizeof(SOCKADDR_STORAGE);
94         bool result = ipe.ToSockAddr(addr, &addr_len);
95         DCHECK(result);
96         input_buffer_->Address[i].lpSockaddr = addr;
97         input_buffer_->Address[i].iSockaddrLength = addr_len;
98       }
99     }
100 
~Job()101     ~Job() {}
102 
103     // Executed asynchronously in ThreadPool.
Run()104     void Run() {
105       SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
106       if (sock == INVALID_SOCKET)
107         return;
108       DWORD result_size = 0;
109       int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(),
110                             buffer_size_, output_buffer_.get(), buffer_size_,
111                             &result_size, nullptr, nullptr);
112       if (result == SOCKET_ERROR) {
113         LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError();
114       } else {
115         success_ = true;
116       }
117       closesocket(sock);
118     }
119 
120     // Executed on the calling thread.
OnComplete()121     void OnComplete() {
122       std::vector<IPEndPoint> sorted;
123       if (success_) {
124         sorted.reserve(output_buffer_->iAddressCount);
125         for (int i = 0; i < output_buffer_->iAddressCount; ++i) {
126           IPEndPoint ipe;
127           bool result =
128               ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr,
129                                output_buffer_->Address[i].iSockaddrLength);
130           DCHECK(result) << "Unable to roundtrip between IPEndPoint and "
131                          << "SOCKET_ADDRESS!";
132           // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works.
133           if (ipe.address().IsIPv4MappedIPv6()) {
134             ipe = IPEndPoint(ConvertIPv4MappedIPv6ToIPv4(ipe.address()),
135                              ipe.port());
136           }
137           sorted.push_back(ipe);
138         }
139       }
140       std::move(callback_).Run(success_, std::move(sorted));
141     }
142 
143     CallbackType callback_;
144     const DWORD buffer_size_;
145     std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> input_buffer_;
146     std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> output_buffer_;
147     bool success_ = false;
148   };
149 };
150 
151 }  // namespace
152 
153 // static
CreateAddressSorter()154 std::unique_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
155   return std::make_unique<AddressSorterWin>();
156 }
157 
158 }  // namespace net
159