• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The gRPC Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <grpc/support/port_platform.h>
16 
17 #include "src/core/lib/iomgr/port.h"  // IWYU pragma: keep
18 
19 #if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
20 
21 #include <ares.h>
22 #include <grpc/support/log_windows.h>
23 #include <winsock2.h>
24 
25 #include "absl/functional/any_invocable.h"
26 #include "absl/log/check.h"
27 #include "absl/status/status.h"
28 #include "absl/strings/str_format.h"
29 #include "src/core/lib/address_utils/sockaddr_utils.h"
30 #include "src/core/lib/event_engine/ares_resolver.h"
31 #include "src/core/lib/event_engine/grpc_polled_fd.h"
32 #include "src/core/lib/event_engine/windows/grpc_polled_fd_windows.h"
33 #include "src/core/lib/event_engine/windows/win_socket.h"
34 #include "src/core/lib/iomgr/error.h"
35 #include "src/core/lib/slice/slice.h"
36 #include "src/core/util/debug_location.h"
37 #include "src/core/util/sync.h"
38 
39 // TODO(apolcyn): remove this hack after fixing upstream.
40 // Our grpc/c-ares code on Windows uses the ares_set_socket_functions API,
41 // which uses "struct iovec" type, which on Windows is defined inside of
42 // a c-ares header that is not public.
43 // See https://github.com/c-ares/c-ares/issues/206.
44 struct iovec {
45   void* iov_base;
46   size_t iov_len;
47 };
48 
49 namespace grpc_event_engine {
50 namespace experimental {
51 namespace {
52 
53 constexpr int kRecvFromSourceAddrSize = 200;
54 constexpr int kReadBufferSize = 4192;
55 
FlattenIovec(const struct iovec * iov,int iov_count)56 grpc_slice FlattenIovec(const struct iovec* iov, int iov_count) {
57   int total = 0;
58   for (int i = 0; i < iov_count; i++) {
59     total += iov[i].iov_len;
60   }
61   grpc_slice out = GRPC_SLICE_MALLOC(total);
62   size_t cur = 0;
63   for (int i = 0; i < iov_count; i++) {
64     for (size_t k = 0; k < iov[i].iov_len; k++) {
65       GRPC_SLICE_START_PTR(out)
66       [cur++] = (static_cast<char*>(iov[i].iov_base))[k];
67     }
68   }
69   return out;
70 }
71 
72 }  // namespace
73 
74 // c-ares reads and takes action on the error codes of the
75 // "virtual socket operations" in this file, via the WSAGetLastError
76 // APIs. If code in this file wants to set a specific WSA error that
77 // c-ares should read, it must do so by calling SetWSAError() on the
78 // WSAErrorContext instance passed to it. A WSAErrorContext must only be
79 // instantiated at the top of the virtual socket function callstack.
80 class WSAErrorContext {
81  public:
WSAErrorContext()82   explicit WSAErrorContext() {};
83 
~WSAErrorContext()84   ~WSAErrorContext() {
85     if (error_ != 0) {
86       WSASetLastError(error_);
87     }
88   }
89 
90   // Disallow copy and assignment operators
91   WSAErrorContext(const WSAErrorContext&) = delete;
92   WSAErrorContext& operator=(const WSAErrorContext&) = delete;
93 
SetWSAError(int error)94   void SetWSAError(int error) { error_ = error; }
95 
96  private:
97   int error_ = 0;
98 };
99 
100 // c-ares creates its own sockets and is meant to read them when readable and
101 // write them when writeable. To fit this socket usage model into the grpc
102 // windows poller (which gives notifications when attempted reads and writes
103 // are actually fulfilled rather than possible), this GrpcPolledFdWindows
104 // class takes advantage of the ares_set_socket_functions API and acts as a
105 // virtual socket. It holds its own read and write buffers which are written
106 // to and read from c-ares and are used with the grpc windows poller, and it,
107 // e.g., manufactures virtual socket error codes when it e.g. needs to tell
108 // the c-ares library to wait for an async read.
109 class GrpcPolledFdWindows : public GrpcPolledFd {
110  public:
GrpcPolledFdWindows(std::unique_ptr<WinSocket> winsocket,grpc_core::Mutex * mu,int address_family,int socket_type,EventEngine * event_engine)111   GrpcPolledFdWindows(std::unique_ptr<WinSocket> winsocket,
112                       grpc_core::Mutex* mu, int address_family, int socket_type,
113                       EventEngine* event_engine)
114       : name_(absl::StrFormat("c-ares socket: %" PRIdPTR,
115                               winsocket->raw_socket())),
116         address_family_(address_family),
117         socket_type_(socket_type),
118         mu_(mu),
119         winsocket_(std::move(winsocket)),
120         read_buf_(grpc_empty_slice()),
121         write_buf_(grpc_empty_slice()),
122         outer_read_closure_([this]() { OnIocpReadable(); }),
__anon930d1a260302() 123         outer_write_closure_([this]() { OnIocpWriteable(); }),
__anon930d1a260402() 124         on_tcp_connect_locked_([this]() { OnTcpConnect(); }),
125         event_engine_(event_engine) {}
126 
~GrpcPolledFdWindows()127   ~GrpcPolledFdWindows() override {
128     GRPC_TRACE_LOG(cares_resolver, INFO)
129         << "(EventEngine c-ares resolver) fd:|" << GetName()
130         << "| ~GrpcPolledFdWindows shutdown_called_: " << shutdown_called_;
131     grpc_core::CSliceUnref(read_buf_);
132     grpc_core::CSliceUnref(write_buf_);
133     CHECK(read_closure_ == nullptr);
134     CHECK(write_closure_ == nullptr);
135     if (!shutdown_called_) {
136       winsocket_->Shutdown(DEBUG_LOCATION, "~GrpcPolledFdWindows");
137     }
138   }
139 
RegisterForOnReadableLocked(absl::AnyInvocable<void (absl::Status)> read_closure)140   void RegisterForOnReadableLocked(
141       absl::AnyInvocable<void(absl::Status)> read_closure) override {
142     CHECK(read_closure_ == nullptr);
143     read_closure_ = std::move(read_closure);
144     grpc_core::CSliceUnref(read_buf_);
145     CHECK(!read_buf_has_data_);
146     read_buf_ = GRPC_SLICE_MALLOC(kReadBufferSize);
147     if (connect_done_) {
148       ContinueRegisterForOnReadableLocked();
149     } else {
150       CHECK(pending_continue_register_for_on_readable_locked_ == false);
151       pending_continue_register_for_on_readable_locked_ = true;
152     }
153   }
154 
RegisterForOnWriteableLocked(absl::AnyInvocable<void (absl::Status)> write_closure)155   void RegisterForOnWriteableLocked(
156       absl::AnyInvocable<void(absl::Status)> write_closure) override {
157     if (socket_type_ == SOCK_DGRAM) {
158       GRPC_TRACE_LOG(cares_resolver, INFO)
159           << "(EventEngine c-ares resolver) fd:|" << GetName()
160           << "| RegisterForOnWriteableLocked called";
161     } else {
162       CHECK(socket_type_ == SOCK_STREAM);
163       GRPC_TRACE_LOG(cares_resolver, INFO)
164           << "(EventEngine c-ares resolver) fd:|" << GetName()
165           << "| RegisterForOnWriteableLocked called tcp_write_state_: "
166           << static_cast<int>(tcp_write_state_)
167           << " connect_done_: " << connect_done_;
168     }
169     CHECK(write_closure_ == nullptr);
170     write_closure_ = std::move(write_closure);
171     if (!connect_done_) {
172       CHECK(!pending_continue_register_for_on_writeable_locked_);
173       pending_continue_register_for_on_writeable_locked_ = true;
174     } else {
175       ContinueRegisterForOnWriteableLocked();
176     }
177   }
178 
IsFdStillReadableLocked()179   bool IsFdStillReadableLocked() override { return read_buf_has_data_; }
180 
ShutdownLocked(absl::Status error)181   bool ShutdownLocked(absl::Status error) override {
182     CHECK(!shutdown_called_);
183     if (!absl::IsCancelled(error)) {
184       return false;
185     }
186     GRPC_TRACE_LOG(cares_resolver, INFO) << "(EventEngine c-ares resolver) fd:|"
187                                          << GetName() << "| ShutdownLocked";
188     shutdown_called_ = true;
189     // The socket is disconnected and closed here since this is an external
190     // cancel request, e.g. a timeout. c-ares shouldn't do anything on the
191     // socket after this point except calling close which should then destroy
192     // the GrpcPolledFdWindows object.
193     winsocket_->Shutdown(DEBUG_LOCATION, "GrpcPolledFdWindows::ShutdownLocked");
194     return true;
195   }
196 
GetWrappedAresSocketLocked()197   ares_socket_t GetWrappedAresSocketLocked() override {
198     return winsocket_->raw_socket();
199   }
200 
GetName() const201   const char* GetName() const override { return name_.c_str(); }
202 
RecvFrom(WSAErrorContext * wsa_error_ctx,void * data,ares_socket_t data_len,int,struct sockaddr * from,ares_socklen_t * from_len)203   ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data,
204                         ares_socket_t data_len, int /* flags */,
205                         struct sockaddr* from, ares_socklen_t* from_len) {
206     GRPC_TRACE_LOG(cares_resolver, INFO)
207         << "(EventEngine c-ares resolver) fd:" << GetName()
208         << " RecvFrom called read_buf_has_data:" << read_buf_has_data_
209         << " Current read buf length:" << GRPC_SLICE_LENGTH(read_buf_);
210     if (!read_buf_has_data_) {
211       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
212       return -1;
213     }
214     ares_ssize_t bytes_read = 0;
215     for (size_t i = 0; i < GRPC_SLICE_LENGTH(read_buf_) && i < data_len; i++) {
216       (static_cast<char*>(data))[i] = GRPC_SLICE_START_PTR(read_buf_)[i];
217       bytes_read++;
218     }
219     read_buf_ = grpc_slice_sub_no_ref(read_buf_, bytes_read,
220                                       GRPC_SLICE_LENGTH(read_buf_));
221     if (GRPC_SLICE_LENGTH(read_buf_) == 0) {
222       read_buf_has_data_ = false;
223     }
224     // c-ares overloads this recv_from virtual socket function to receive
225     // data on both UDP and TCP sockets, and from is nullptr for TCP.
226     if (from != nullptr) {
227       CHECK(*from_len >= recv_from_source_addr_len_);
228       memcpy(from, &recv_from_source_addr_, recv_from_source_addr_len_);
229       *from_len = recv_from_source_addr_len_;
230     }
231     return bytes_read;
232   }
233 
SendV(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)234   ares_ssize_t SendV(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
235                      int iov_count) {
236     GRPC_TRACE_LOG(cares_resolver, INFO)
237         << "(EventEngine c-ares resolver) fd:|" << GetName()
238         << "| SendV called connect_done_:" << connect_done_
239         << " wsa_connect_error_:" << wsa_connect_error_;
240     if (!connect_done_) {
241       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
242       return -1;
243     }
244     if (wsa_connect_error_ != 0) {
245       wsa_error_ctx->SetWSAError(wsa_connect_error_);
246       return -1;
247     }
248     switch (socket_type_) {
249       case SOCK_DGRAM:
250         return SendVUDP(wsa_error_ctx, iov, iov_count);
251       case SOCK_STREAM:
252         return SendVTCP(wsa_error_ctx, iov, iov_count);
253       default:
254         abort();
255     }
256   }
257 
Connect(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)258   int Connect(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
259               ares_socklen_t target_len) {
260     switch (socket_type_) {
261       case SOCK_DGRAM:
262         return ConnectUDP(wsa_error_ctx, target, target_len);
263       case SOCK_STREAM:
264         return ConnectTCP(wsa_error_ctx, target, target_len);
265       default:
266         grpc_core::Crash(
267             absl::StrFormat("Unknown socket_type_: %d", socket_type_));
268     }
269   }
270 
271  private:
272   enum WriteState {
273     WRITE_IDLE,
274     WRITE_REQUESTED,
275     WRITE_PENDING,
276     WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY,
277   };
278 
ScheduleAndNullReadClosure(absl::Status error)279   void ScheduleAndNullReadClosure(absl::Status error) {
280     event_engine_->Run([read_closure = std::move(read_closure_),
281                         error]() mutable { read_closure(error); });
282     read_closure_ = nullptr;
283   }
284 
ScheduleAndNullWriteClosure(absl::Status error)285   void ScheduleAndNullWriteClosure(absl::Status error) {
286     event_engine_->Run([write_closure = std::move(write_closure_),
287                         error]() mutable { write_closure(error); });
288     write_closure_ = nullptr;
289   }
290 
ContinueRegisterForOnReadableLocked()291   void ContinueRegisterForOnReadableLocked() {
292     GRPC_TRACE_LOG(cares_resolver, INFO)
293         << "(EventEngine c-ares resolver) fd:|" << GetName()
294         << "| ContinueRegisterForOnReadableLocked wsa_connect_error_:"
295         << wsa_connect_error_;
296     CHECK(connect_done_);
297     if (wsa_connect_error_ != 0) {
298       ScheduleAndNullReadClosure(GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
299       return;
300     }
301     WSABUF buffer;
302     buffer.buf = reinterpret_cast<char*>(GRPC_SLICE_START_PTR(read_buf_));
303     buffer.len = GRPC_SLICE_LENGTH(read_buf_);
304     recv_from_source_addr_len_ = sizeof(recv_from_source_addr_);
305     DWORD flags = 0;
306     winsocket_->NotifyOnRead(&outer_read_closure_);
307     if (WSARecvFrom(winsocket_->raw_socket(), &buffer, 1, nullptr, &flags,
308                     reinterpret_cast<sockaddr*>(recv_from_source_addr_),
309                     &recv_from_source_addr_len_,
310                     winsocket_->read_info()->overlapped(), nullptr) != 0) {
311       int wsa_last_error = WSAGetLastError();
312       char* msg = gpr_format_message(wsa_last_error);
313       GRPC_TRACE_LOG(cares_resolver, INFO)
314           << "(EventEngine c-ares resolver) fd:" << GetName()
315           << " ContinueRegisterForOnReadableLocked WSARecvFrom error "
316              "code:"
317           << wsa_last_error << " msg:" << msg;
318       gpr_free(msg);
319       if (wsa_last_error != WSA_IO_PENDING) {
320         winsocket_->UnregisterReadCallback();
321         ScheduleAndNullReadClosure(
322             GRPC_WSA_ERROR(wsa_last_error, "WSARecvFrom"));
323         return;
324       }
325     }
326   }
327 
ContinueRegisterForOnWriteableLocked()328   void ContinueRegisterForOnWriteableLocked() {
329     GRPC_TRACE_LOG(cares_resolver, INFO)
330         << "(EventEngine c-ares resolver) fd:|" << GetName()
331         << "| ContinueRegisterForOnWriteableLocked wsa_connect_error_:"
332         << wsa_connect_error_;
333     CHECK(connect_done_);
334     if (wsa_connect_error_ != 0) {
335       ScheduleAndNullWriteClosure(
336           GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
337       return;
338     }
339     if (socket_type_ == SOCK_DGRAM) {
340       ScheduleAndNullWriteClosure(absl::OkStatus());
341       return;
342     }
343     CHECK(socket_type_ == SOCK_STREAM);
344     int wsa_error_code = 0;
345     switch (tcp_write_state_) {
346       case WRITE_IDLE:
347         ScheduleAndNullWriteClosure(absl::OkStatus());
348         break;
349       case WRITE_REQUESTED:
350         tcp_write_state_ = WRITE_PENDING;
351         winsocket_->NotifyOnWrite(&outer_write_closure_);
352         if (SendWriteBuf(nullptr, winsocket_->write_info()->overlapped(),
353                          &wsa_error_code) != 0) {
354           winsocket_->UnregisterWriteCallback();
355           ScheduleAndNullWriteClosure(
356               GRPC_WSA_ERROR(wsa_error_code, "WSASend (overlapped)"));
357           return;
358         }
359         break;
360       case WRITE_PENDING:
361       case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
362         grpc_core::Crash(
363             absl::StrFormat("Invalid tcp_write_state_: %d", tcp_write_state_));
364     }
365   }
366 
SendWriteBuf(LPDWORD bytes_sent_ptr,LPWSAOVERLAPPED overlapped,int * wsa_error_code)367   int SendWriteBuf(LPDWORD bytes_sent_ptr, LPWSAOVERLAPPED overlapped,
368                    int* wsa_error_code) {
369     WSABUF buf;
370     buf.len = GRPC_SLICE_LENGTH(write_buf_);
371     buf.buf = reinterpret_cast<char*>(GRPC_SLICE_START_PTR(write_buf_));
372     DWORD flags = 0;
373     int out = WSASend(winsocket_->raw_socket(), &buf, 1, bytes_sent_ptr, flags,
374                       overlapped, nullptr);
375     *wsa_error_code = WSAGetLastError();
376     GRPC_TRACE_LOG(cares_resolver, INFO)
377         << "(EventEngine c-ares resolver) fd:" << GetName()
378         << " SendWriteBuf WSASend buf.len:" << buf.len << " *bytes_sent_ptr:"
379         << (bytes_sent_ptr != nullptr ? *bytes_sent_ptr : 0)
380         << " overlapped:" << overlapped << " return:" << out
381         << " *wsa_error_code:" << *wsa_error_code;
382     return out;
383   }
384 
SendVUDP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)385   ares_ssize_t SendVUDP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
386                         int iov_count) {
387     // c-ares doesn't handle retryable errors on writes of UDP sockets.
388     // Therefore, the sendv handler for UDP sockets must only attempt
389     // to write everything inline.
390     GRPC_TRACE_LOG(cares_resolver, INFO) << "(EventEngine c-ares resolver) fd:|"
391                                          << GetName() << "| SendVUDP called";
392     CHECK_EQ(GRPC_SLICE_LENGTH(write_buf_), 0);
393     grpc_core::CSliceUnref(write_buf_);
394     write_buf_ = FlattenIovec(iov, iov_count);
395     DWORD bytes_sent = 0;
396     int wsa_error_code = 0;
397     if (SendWriteBuf(&bytes_sent, nullptr, &wsa_error_code) != 0) {
398       grpc_core::CSliceUnref(write_buf_);
399       write_buf_ = grpc_empty_slice();
400       wsa_error_ctx->SetWSAError(wsa_error_code);
401       char* msg = gpr_format_message(wsa_error_code);
402       GRPC_TRACE_LOG(cares_resolver, INFO)
403           << "(EventEngine c-ares resolver) fd:|" << GetName()
404           << "| SendVUDP SendWriteBuf error code:" << wsa_error_code << " msg:|"
405           << msg << "|";
406       gpr_free(msg);
407       return -1;
408     }
409     write_buf_ = grpc_slice_sub_no_ref(write_buf_, bytes_sent,
410                                        GRPC_SLICE_LENGTH(write_buf_));
411     return bytes_sent;
412   }
413 
SendVTCP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)414   ares_ssize_t SendVTCP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
415                         int iov_count) {
416     // The "sendv" handler on TCP sockets buffers up write
417     // requests and returns an artificial WSAEWOULDBLOCK. Writing that buffer
418     // out in the background, and making further send progress in general, will
419     // happen as long as c-ares continues to show interest in writeability on
420     // this fd.
421     GRPC_TRACE_LOG(cares_resolver, INFO)
422         << "(EventEngine c-ares resolver) fd:|" << GetName()
423         << "| SendVTCP called tcp_write_state_:"
424         << static_cast<int>(tcp_write_state_);
425     switch (tcp_write_state_) {
426       case WRITE_IDLE:
427         tcp_write_state_ = WRITE_REQUESTED;
428         grpc_core::CSliceUnref(write_buf_);
429         write_buf_ = FlattenIovec(iov, iov_count);
430         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
431         return -1;
432       case WRITE_REQUESTED:
433       case WRITE_PENDING:
434         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
435         return -1;
436       case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
437         // c-ares is retrying a send on data that we previously returned
438         // WSAEWOULDBLOCK for, but then subsequently wrote out in the
439         // background. Right now, we assume that c-ares is retrying the same
440         // send again. If c-ares still needs to send even more data, we'll get
441         // to it eventually.
442         grpc_slice currently_attempted = FlattenIovec(iov, iov_count);
443         CHECK(GRPC_SLICE_LENGTH(currently_attempted) >=
444               GRPC_SLICE_LENGTH(write_buf_));
445         ares_ssize_t total_sent = 0;
446         for (size_t i = 0; i < GRPC_SLICE_LENGTH(write_buf_); i++) {
447           CHECK(GRPC_SLICE_START_PTR(currently_attempted)[i] ==
448                 GRPC_SLICE_START_PTR(write_buf_)[i]);
449           total_sent++;
450         }
451         grpc_core::CSliceUnref(currently_attempted);
452         tcp_write_state_ = WRITE_IDLE;
453         return total_sent;
454     }
455     grpc_core::Crash(
456         absl::StrFormat("Unknown tcp_write_state_: %d", tcp_write_state_));
457   }
458 
OnTcpConnect()459   void OnTcpConnect() {
460     grpc_core::MutexLock lock(mu_);
461     GRPC_TRACE_LOG(cares_resolver, INFO)
462         << "(EventEngine c-ares resolver) fd:" << GetName()
463         << " InnerOnTcpConnectLocked pending_register_for_readable:"
464         << pending_continue_register_for_on_readable_locked_
465         << " pending_register_for_writeable:"
466         << pending_continue_register_for_on_writeable_locked_;
467     CHECK(!connect_done_);
468     connect_done_ = true;
469     CHECK_EQ(wsa_connect_error_, 0);
470     if (shutdown_called_) {
471       wsa_connect_error_ = WSA_OPERATION_ABORTED;
472     } else {
473       DWORD transferred_bytes = 0;
474       DWORD flags;
475       BOOL wsa_success = WSAGetOverlappedResult(
476           winsocket_->raw_socket(), winsocket_->write_info()->overlapped(),
477           &transferred_bytes, FALSE, &flags);
478       CHECK_EQ(transferred_bytes, 0);
479       if (!wsa_success) {
480         wsa_connect_error_ = WSAGetLastError();
481         char* msg = gpr_format_message(wsa_connect_error_);
482         GRPC_TRACE_LOG(cares_resolver, INFO)
483             << "(EventEngine c-ares resolver) fd:" << GetName()
484             << " InnerOnTcpConnectLocked WSA overlapped result code:"
485             << wsa_connect_error_ << " msg:|" << msg << "|";
486         gpr_free(msg);
487       }
488     }
489     if (pending_continue_register_for_on_readable_locked_) {
490       ContinueRegisterForOnReadableLocked();
491     }
492     if (pending_continue_register_for_on_writeable_locked_) {
493       ContinueRegisterForOnWriteableLocked();
494     }
495   }
496 
ConnectUDP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)497   int ConnectUDP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
498                  ares_socklen_t target_len) {
499     GRPC_TRACE_LOG(cares_resolver, INFO)
500         << "(EventEngine c-ares resolver) fd:" << GetName() << " ConnectUDP";
501     CHECK(!connect_done_);
502     CHECK_EQ(wsa_connect_error_, 0);
503     SOCKET s = winsocket_->raw_socket();
504     int out =
505         WSAConnect(s, target, target_len, nullptr, nullptr, nullptr, nullptr);
506     wsa_connect_error_ = WSAGetLastError();
507     wsa_error_ctx->SetWSAError(wsa_connect_error_);
508     connect_done_ = true;
509     char* msg = gpr_format_message(wsa_connect_error_);
510     GRPC_TRACE_LOG(cares_resolver, INFO)
511         << "(EventEngine c-ares resolver) fd:" << GetName()
512         << " WSAConnect error code:|" << wsa_connect_error_ << "| msg:|" << msg
513         << "|";
514     gpr_free(msg);
515     // c-ares expects a posix-style connect API
516     return out == 0 ? 0 : -1;
517   }
518 
ConnectTCP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)519   int ConnectTCP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
520                  ares_socklen_t target_len) {
521     GRPC_TRACE_LOG(cares_resolver, INFO)
522         << "(EventEngine c-ares resolver) fd:" << GetName() << " ConnectTCP";
523     LPFN_CONNECTEX ConnectEx;
524     GUID guid = WSAID_CONNECTEX;
525     DWORD ioctl_num_bytes;
526     SOCKET s = winsocket_->raw_socket();
527     if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
528                  &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, nullptr,
529                  nullptr) != 0) {
530       int wsa_last_error = WSAGetLastError();
531       wsa_error_ctx->SetWSAError(wsa_last_error);
532       char* msg = gpr_format_message(wsa_last_error);
533       GRPC_TRACE_LOG(cares_resolver, INFO)
534           << "(EventEngine c-ares resolver) fd:" << GetName()
535           << " WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER) error code:"
536           << wsa_last_error << " msg:|" << msg << "|";
537       gpr_free(msg);
538       connect_done_ = true;
539       wsa_connect_error_ = wsa_last_error;
540       return -1;
541     }
542     grpc_resolved_address wildcard4_addr;
543     grpc_resolved_address wildcard6_addr;
544     grpc_sockaddr_make_wildcards(0, &wildcard4_addr, &wildcard6_addr);
545     grpc_resolved_address* local_address = nullptr;
546     if (address_family_ == AF_INET) {
547       local_address = &wildcard4_addr;
548     } else {
549       local_address = &wildcard6_addr;
550     }
551     if (bind(s, reinterpret_cast<struct sockaddr*>(local_address->addr),
552              static_cast<int>(local_address->len)) != 0) {
553       int wsa_last_error = WSAGetLastError();
554       wsa_error_ctx->SetWSAError(wsa_last_error);
555       char* msg = gpr_format_message(wsa_last_error);
556       GRPC_TRACE_LOG(cares_resolver, INFO)
557           << "(EventEngine c-ares resolver) fd:" << GetName()
558           << " bind error code:" << wsa_last_error << " msg:|" << msg << "|";
559       gpr_free(msg);
560       connect_done_ = true;
561       wsa_connect_error_ = wsa_last_error;
562       return -1;
563     }
564     int out = 0;
565     // Register an async OnTcpConnect callback here since it is required by the
566     // WinSocket API.
567     winsocket_->NotifyOnWrite(&on_tcp_connect_locked_);
568     if (ConnectEx(s, target, target_len, nullptr, 0, nullptr,
569                   winsocket_->write_info()->overlapped()) == 0) {
570       out = -1;
571       int wsa_last_error = WSAGetLastError();
572       wsa_error_ctx->SetWSAError(wsa_last_error);
573       char* msg = gpr_format_message(wsa_last_error);
574       GRPC_TRACE_LOG(cares_resolver, INFO)
575           << "(EventEngine c-ares resolver) fd:" << GetName()
576           << " ConnectEx error code:" << wsa_last_error << " msg:|" << msg
577           << "|";
578       gpr_free(msg);
579       if (wsa_last_error == WSA_IO_PENDING) {
580         // c-ares only understands WSAEINPROGRESS and EWOULDBLOCK error codes on
581         // connect, but an async connect on IOCP socket will give
582         // WSA_IO_PENDING, so we need to convert.
583         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
584       } else {
585         winsocket_->UnregisterWriteCallback();
586         // By returning a non-retryable error to c-ares at this point,
587         // we're aborting the possibility of any future operations on this fd.
588         connect_done_ = true;
589         wsa_connect_error_ = wsa_last_error;
590         return -1;
591       }
592     }
593     return out;
594   }
595 
596   // TODO(apolcyn): improve this error handling to be less conversative.
597   // An e.g. ECONNRESET error here should result in errors when
598   // c-ares reads from this socket later, but it shouldn't necessarily cancel
599   // the entire resolution attempt. Doing so will allow the "inject broken
600   // nameserver list" test to pass on Windows.
OnIocpReadable()601   void OnIocpReadable() {
602     grpc_core::MutexLock lock(mu_);
603     absl::Status error;
604     if (winsocket_->read_info()->result().wsa_error != 0) {
605       // WSAEMSGSIZE would be due to receiving more data
606       // than our read buffer's fixed capacity. Assume that
607       // the connection is TCP and read the leftovers
608       // in subsequent c-ares reads.
609       if (winsocket_->read_info()->result().wsa_error != WSAEMSGSIZE) {
610         error = GRPC_WSA_ERROR(winsocket_->read_info()->result().wsa_error,
611                                "OnIocpReadableInner");
612         GRPC_TRACE_LOG(cares_resolver, INFO)
613             << "(EventEngine c-ares resolver) fd:|" << GetName()
614             << "| OnIocpReadableInner winsocket_->read_info.wsa_error "
615                "code:|"
616             << winsocket_->read_info()->result().wsa_error << "| msg:|"
617             << grpc_core::StatusToString(error) << "|";
618       }
619     }
620     if (error.ok()) {
621       read_buf_ = grpc_slice_sub_no_ref(
622           read_buf_, 0, winsocket_->read_info()->result().bytes_transferred);
623       read_buf_has_data_ = true;
624     } else {
625       grpc_core::CSliceUnref(read_buf_);
626       read_buf_ = grpc_empty_slice();
627     }
628     GRPC_TRACE_LOG(cares_resolver, INFO)
629         << "(EventEngine c-ares resolver) fd:|" << GetName()
630         << "| OnIocpReadable finishing. read buf length now:|"
631         << GRPC_SLICE_LENGTH(read_buf_) << "|";
632     ScheduleAndNullReadClosure(error);
633   }
634 
OnIocpWriteable()635   void OnIocpWriteable() {
636     grpc_core::MutexLock lock(mu_);
637     GRPC_TRACE_LOG(cares_resolver, INFO)
638         << "(EventEngine c-ares resolver) OnIocpWriteableInner. fd:|"
639         << GetName() << "|";
640     CHECK(socket_type_ == SOCK_STREAM);
641     absl::Status error;
642     if (winsocket_->write_info()->result().wsa_error != 0) {
643       error = GRPC_WSA_ERROR(winsocket_->write_info()->result().wsa_error,
644                              "OnIocpWriteableInner");
645       GRPC_TRACE_LOG(cares_resolver, INFO)
646           << "(EventEngine c-ares resolver) fd:|" << GetName()
647           << "| OnIocpWriteableInner. winsocket_->write_info.wsa_error "
648              "code:|"
649           << winsocket_->write_info()->result().wsa_error << "| msg:|"
650           << grpc_core::StatusToString(error) << "|";
651     }
652     CHECK(tcp_write_state_ == WRITE_PENDING);
653     if (error.ok()) {
654       tcp_write_state_ = WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY;
655       write_buf_ = grpc_slice_sub_no_ref(
656           write_buf_, 0, winsocket_->write_info()->result().bytes_transferred);
657       GRPC_TRACE_LOG(cares_resolver, INFO)
658           << "(EventEngine c-ares resolver) fd:|" << GetName()
659           << "| OnIocpWriteableInner. bytes transferred:"
660           << winsocket_->write_info()->result().bytes_transferred;
661 
662     } else {
663       grpc_core::CSliceUnref(write_buf_);
664       write_buf_ = grpc_empty_slice();
665     }
666     ScheduleAndNullWriteClosure(error);
667   }
668 
669   const std::string name_;
670   const int address_family_;
671   const int socket_type_;
672   grpc_core::Mutex* mu_;
673   std::unique_ptr<WinSocket> winsocket_;
674   char recv_from_source_addr_[kRecvFromSourceAddrSize];
675   ares_socklen_t recv_from_source_addr_len_;
676   grpc_slice read_buf_;
677   bool read_buf_has_data_ = false;
678   grpc_slice write_buf_;
679   absl::AnyInvocable<void(absl::Status)> read_closure_;
680   absl::AnyInvocable<void(absl::Status)> write_closure_;
681   AnyInvocableClosure outer_read_closure_;
682   AnyInvocableClosure outer_write_closure_;
683   bool shutdown_called_ = false;
684   // State related to TCP sockets
685   AnyInvocableClosure on_tcp_connect_locked_;
686   bool connect_done_ = false;
687   int wsa_connect_error_ = 0;
688   WriteState tcp_write_state_ = WRITE_IDLE;
689   // We don't run register_for_{readable,writeable} logic until
690   // a socket is connected. In the interim, we queue readable/writeable
691   // registrations with the following state.
692   bool pending_continue_register_for_on_readable_locked_ = false;
693   bool pending_continue_register_for_on_writeable_locked_ = false;
694   // This pointer is initialized from the stored pointer inside the shared
695   // pointer owned by the AresResolver and should be valid at the time of use.
696   EventEngine* event_engine_;
697 };
698 
699 // These virtual socket functions are called from within the c-ares
700 // library. These methods generally dispatch those socket calls to the
701 // appropriate methods. The virtual "socket" and "close" methods are
702 // special and instead create/add and remove/destroy GrpcPolledFdWindows
703 // objects.
704 class CustomSockFuncs {
705  public:
Socket(int af,int type,int protocol,void * user_data)706   static ares_socket_t Socket(int af, int type, int protocol, void* user_data) {
707     if (type != SOCK_DGRAM && type != SOCK_STREAM) {
708       GRPC_TRACE_LOG(cares_resolver, INFO)
709           << "(EventEngine c-ares resolver) Socket called with invalid socket "
710              "type:"
711           << type;
712       return INVALID_SOCKET;
713     }
714     GrpcPolledFdFactoryWindows* self =
715         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
716     SOCKET s = WSASocket(af, type, protocol, nullptr, 0,
717                          IOCP::GetDefaultSocketFlags());
718     if (s == INVALID_SOCKET) {
719       GRPC_TRACE_LOG(cares_resolver, INFO)
720           << "(EventEngine c-ares resolver) WSASocket failed with params af:"
721           << af << " type:" << type << " protocol:" << protocol;
722       return INVALID_SOCKET;
723     }
724     if (type == SOCK_STREAM) {
725       absl::Status error = PrepareSocket(s);
726       if (!error.ok()) {
727         GRPC_TRACE_LOG(cares_resolver, INFO)
728             << "(EventEngine c-ares resolver) WSAIoctl failed with error: "
729             << grpc_core::StatusToString(error);
730         return INVALID_SOCKET;
731       }
732     }
733     auto polled_fd = std::make_unique<GrpcPolledFdWindows>(
734         self->iocp_->Watch(s), self->mu_, af, type, self->event_engine_);
735     GRPC_TRACE_LOG(cares_resolver, INFO)
736         << "(EventEngine c-ares resolver) fd:" << polled_fd->GetName()
737         << " created with params af:" << af << " type:" << type
738         << " protocol:" << protocol;
739     CHECK(self->sockets_.insert({s, std::move(polled_fd)}).second);
740     return s;
741   }
742 
Connect(ares_socket_t as,const struct sockaddr * target,ares_socklen_t target_len,void * user_data)743   static int Connect(ares_socket_t as, const struct sockaddr* target,
744                      ares_socklen_t target_len, void* user_data) {
745     WSAErrorContext wsa_error_ctx;
746     GrpcPolledFdFactoryWindows* self =
747         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
748     auto it = self->sockets_.find(as);
749     CHECK(it != self->sockets_.end());
750     return it->second->Connect(&wsa_error_ctx, target, target_len);
751   }
752 
SendV(ares_socket_t as,const struct iovec * iov,int iovec_count,void * user_data)753   static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov,
754                             int iovec_count, void* user_data) {
755     WSAErrorContext wsa_error_ctx;
756     GrpcPolledFdFactoryWindows* self =
757         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
758     auto it = self->sockets_.find(as);
759     CHECK(it != self->sockets_.end());
760     return it->second->SendV(&wsa_error_ctx, iov, iovec_count);
761   }
762 
RecvFrom(ares_socket_t as,void * data,size_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len,void * user_data)763   static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len,
764                                int flags, struct sockaddr* from,
765                                ares_socklen_t* from_len, void* user_data) {
766     WSAErrorContext wsa_error_ctx;
767     GrpcPolledFdFactoryWindows* self =
768         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
769     auto it = self->sockets_.find(as);
770     CHECK(it != self->sockets_.end());
771     return it->second->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
772                                 from_len);
773   }
774 
CloseSocket(SOCKET s,void *)775   static int CloseSocket(SOCKET s, void*) {
776     GRPC_TRACE_LOG(cares_resolver, INFO)
777         << "(EventEngine c-ares resolver) c-ares socket: " << s
778         << " CloseSocket";
779     return 0;
780   }
781 };
782 
783 // Adapter to hold the ownership of GrpcPolledFdWindows internally.
784 class GrpcPolledFdWrapper : public GrpcPolledFd {
785  public:
GrpcPolledFdWrapper(GrpcPolledFdWindows * polled_fd)786   explicit GrpcPolledFdWrapper(GrpcPolledFdWindows* polled_fd)
787       : polled_fd_(polled_fd) {}
788 
RegisterForOnReadableLocked(absl::AnyInvocable<void (absl::Status)> read_closure)789   void RegisterForOnReadableLocked(
790       absl::AnyInvocable<void(absl::Status)> read_closure) override {
791     polled_fd_->RegisterForOnReadableLocked(std::move(read_closure));
792   }
793 
RegisterForOnWriteableLocked(absl::AnyInvocable<void (absl::Status)> write_closure)794   void RegisterForOnWriteableLocked(
795       absl::AnyInvocable<void(absl::Status)> write_closure) override {
796     polled_fd_->RegisterForOnWriteableLocked(std::move(write_closure));
797   }
798 
IsFdStillReadableLocked()799   bool IsFdStillReadableLocked() override {
800     return polled_fd_->IsFdStillReadableLocked();
801   }
802 
ShutdownLocked(absl::Status error)803   bool ShutdownLocked(absl::Status error) override {
804     return polled_fd_->ShutdownLocked(error);
805   }
806 
GetWrappedAresSocketLocked()807   ares_socket_t GetWrappedAresSocketLocked() override {
808     return polled_fd_->GetWrappedAresSocketLocked();
809   }
810 
GetName() const811   const char* GetName() const override { return polled_fd_->GetName(); }
812 
813  private:
814   GrpcPolledFdWindows* polled_fd_;
815 };
816 
GrpcPolledFdFactoryWindows(IOCP * iocp)817 GrpcPolledFdFactoryWindows::GrpcPolledFdFactoryWindows(IOCP* iocp)
818     : iocp_(iocp) {}
819 
~GrpcPolledFdFactoryWindows()820 GrpcPolledFdFactoryWindows::~GrpcPolledFdFactoryWindows() {}
821 
Initialize(grpc_core::Mutex * mutex,EventEngine * event_engine)822 void GrpcPolledFdFactoryWindows::Initialize(grpc_core::Mutex* mutex,
823                                             EventEngine* event_engine) {
824   mu_ = mutex;
825   event_engine_ = event_engine;
826 }
827 
NewGrpcPolledFdLocked(ares_socket_t as)828 std::unique_ptr<GrpcPolledFd> GrpcPolledFdFactoryWindows::NewGrpcPolledFdLocked(
829     ares_socket_t as) {
830   auto it = sockets_.find(as);
831   CHECK(it != sockets_.end());
832   return std::make_unique<GrpcPolledFdWrapper>(it->second.get());
833 }
834 
ConfigureAresChannelLocked(ares_channel channel)835 void GrpcPolledFdFactoryWindows::ConfigureAresChannelLocked(
836     ares_channel channel) {
837   static const struct ares_socket_functions kCustomSockFuncs = {
838       /*asocket=*/&CustomSockFuncs::Socket,
839       /*aclose=*/&CustomSockFuncs::CloseSocket,
840       /*aconnect=*/&CustomSockFuncs::Connect,
841       /*arecvfrom=*/&CustomSockFuncs::RecvFrom,
842       /*asendv=*/&CustomSockFuncs::SendV,
843   };
844   ares_set_socket_functions(channel, &kCustomSockFuncs, this);
845 }
846 
847 }  // namespace experimental
848 }  // namespace grpc_event_engine
849 
850 #endif  // GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
851