• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test/core/event_engine/test_suite/posix/oracle_event_engine_posix.h"
16 
17 #include <grpc/event_engine/event_engine.h>
18 #include <grpc/support/alloc.h>
19 #include <poll.h>
20 #include <sys/socket.h>
21 #include <unistd.h>
22 
23 #include <algorithm>
24 #include <cerrno>
25 #include <cstring>
26 #include <memory>
27 
28 #include "absl/log/check.h"
29 #include "absl/log/log.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/time/clock.h"
34 #include "absl/time/time.h"
35 #include "src/core/lib/address_utils/sockaddr_utils.h"
36 #include "src/core/lib/iomgr/resolved_address.h"
37 #include "src/core/util/crash.h"
38 #include "src/core/util/strerror.h"
39 
40 namespace grpc_event_engine {
41 namespace experimental {
42 
43 namespace {
44 
45 const char* kStopMessage = "STOP";
46 
CreateGRPCResolvedAddress(const EventEngine::ResolvedAddress & ra)47 grpc_resolved_address CreateGRPCResolvedAddress(
48     const EventEngine::ResolvedAddress& ra) {
49   grpc_resolved_address grpc_addr;
50   memcpy(grpc_addr.addr, ra.address(), ra.size());
51   grpc_addr.len = ra.size();
52   return grpc_addr;
53 }
54 
55 // Blocks until poll(2) indicates that one of the fds has pending I/O
56 // the deadline is reached whichever comes first. Returns an OK
57 // status a valid I/O event is available for at least one of the fds, a Status
58 // with canonical code DEADLINE_EXCEEDED if the deadline expired and a non-OK
59 // Status if any other error occurred.
PollFds(struct pollfd * pfds,int nfds,absl::Duration timeout)60 absl::Status PollFds(struct pollfd* pfds, int nfds, absl::Duration timeout) {
61   int rv;
62   while (true) {
63     if (timeout != absl::InfiniteDuration()) {
64       rv = poll(pfds, nfds,
65                 static_cast<int>(absl::ToInt64Milliseconds(timeout)));
66     } else {
67       rv = poll(pfds, nfds, /* timeout = */ -1);
68     }
69     const int saved_errno = errno;
70     errno = saved_errno;
71     if (rv >= 0 || errno != EINTR) {
72       break;
73     }
74   }
75   if (rv < 0) {
76     return absl::UnknownError(grpc_core::StrError(errno));
77   }
78   if (rv == 0) {
79     return absl::CancelledError("Deadline exceeded");
80   }
81   return absl::OkStatus();
82 }
83 
BlockUntilReadable(int fd)84 absl::Status BlockUntilReadable(int fd) {
85   struct pollfd pfd;
86   pfd.fd = fd;
87   pfd.events = POLLIN;
88   pfd.revents = 0;
89   return PollFds(&pfd, 1, absl::InfiniteDuration());
90 }
91 
BlockUntilWritableWithTimeout(int fd,absl::Duration timeout)92 absl::Status BlockUntilWritableWithTimeout(int fd, absl::Duration timeout) {
93   struct pollfd pfd;
94   pfd.fd = fd;
95   pfd.events = POLLOUT;
96   pfd.revents = 0;
97   return PollFds(&pfd, 1, timeout);
98 }
99 
BlockUntilWritable(int fd)100 absl::Status BlockUntilWritable(int fd) {
101   return BlockUntilWritableWithTimeout(fd, absl::InfiniteDuration());
102 }
103 
104 // Tries to read upto num_expected_bytes from the socket. It returns early if
105 // specified data is not yet available.
TryReadBytes(int sockfd,int & saved_errno,int num_expected_bytes)106 std::string TryReadBytes(int sockfd, int& saved_errno, int num_expected_bytes) {
107   int ret = 0;
108   static constexpr int kDefaultNumExpectedBytes = 1024;
109   if (num_expected_bytes <= 0) {
110     num_expected_bytes = kDefaultNumExpectedBytes;
111   }
112   std::string read_data = std::string(num_expected_bytes, '\0');
113   char* buffer = const_cast<char*>(read_data.c_str());
114   int pending_bytes = num_expected_bytes;
115   do {
116     errno = 0;
117     ret = read(sockfd, buffer + num_expected_bytes - pending_bytes,
118                pending_bytes);
119     if (ret > 0) {
120       pending_bytes -= ret;
121     }
122   } while (pending_bytes > 0 && ((ret > 0) || (ret < 0 && errno == EINTR)));
123   saved_errno = errno;
124   return read_data.substr(0, num_expected_bytes - pending_bytes);
125 }
126 
127 // Blocks calling thread until the specified number of bytes have been
128 // read from the provided socket or it encounters an unrecoverable error. It
129 // puts the read bytes into a string and returns the string. If it encounters an
130 // error, it returns an empty string and updates saved_errno with the
131 // appropriate errno.
ReadBytes(int sockfd,int & saved_errno,int num_expected_bytes)132 std::string ReadBytes(int sockfd, int& saved_errno, int num_expected_bytes) {
133   std::string read_data;
134   do {
135     saved_errno = 0;
136     read_data += TryReadBytes(sockfd, saved_errno,
137                               num_expected_bytes - read_data.length());
138     if (saved_errno == EAGAIN &&
139         read_data.length() < static_cast<size_t>(num_expected_bytes)) {
140       CHECK_OK(BlockUntilReadable(sockfd));
141     } else if (saved_errno != 0 && num_expected_bytes > 0) {
142       read_data.clear();
143       break;
144     }
145   } while (read_data.length() < static_cast<size_t>(num_expected_bytes));
146   return read_data;
147 }
148 
149 // Tries to write the specified bytes over the socket. It returns the number of
150 // bytes actually written.
TryWriteBytes(int sockfd,int & saved_errno,std::string write_bytes)151 int TryWriteBytes(int sockfd, int& saved_errno, std::string write_bytes) {
152   int ret = 0;
153   int pending_bytes = write_bytes.length();
154   do {
155     errno = 0;
156     ret = write(sockfd,
157                 write_bytes.c_str() + write_bytes.length() - pending_bytes,
158                 pending_bytes);
159     if (ret > 0) {
160       pending_bytes -= ret;
161     }
162   } while (pending_bytes > 0 && ((ret > 0) || (ret < 0 && errno == EINTR)));
163   saved_errno = errno;
164   return write_bytes.length() - pending_bytes;
165 }
166 
167 // Blocks calling thread until the specified number of bytes have been
168 // written over the provided socket or it encounters an unrecoverable error. The
169 // bytes to write are specified as a string. If it encounters an error, it
170 // returns an empty string and updates saved_errno with the appropriate errno
171 // and returns a value less than zero.
WriteBytes(int sockfd,int & saved_errno,std::string write_bytes)172 int WriteBytes(int sockfd, int& saved_errno, std::string write_bytes) {
173   int ret = 0;
174   int original_write_length = write_bytes.length();
175   do {
176     saved_errno = 0;
177     ret = TryWriteBytes(sockfd, saved_errno, write_bytes);
178     if (saved_errno == EAGAIN && ret < static_cast<int>(write_bytes.length())) {
179       CHECK_GE(ret, 0);
180       CHECK_OK(BlockUntilWritable(sockfd));
181     } else if (saved_errno != 0) {
182       CHECK_LT(ret, 0);
183       return ret;
184     }
185     write_bytes = write_bytes.substr(ret, std::string::npos);
186   } while (!write_bytes.empty());
187   return original_write_length;
188 }
189 }  // namespace
190 
PosixOracleEndpoint(int socket_fd)191 PosixOracleEndpoint::PosixOracleEndpoint(int socket_fd)
192     : socket_fd_(socket_fd) {
193   read_ops_ = grpc_core::Thread(
194       "read_ops_thread",
195       [](void* arg) {
196         static_cast<PosixOracleEndpoint*>(arg)->ProcessReadOperations();
197       },
198       this);
199   write_ops_ = grpc_core::Thread(
200       "write_ops_thread",
201       [](void* arg) {
202         static_cast<PosixOracleEndpoint*>(arg)->ProcessWriteOperations();
203       },
204       this);
205   read_ops_.Start();
206   write_ops_.Start();
207 }
208 
Shutdown()209 void PosixOracleEndpoint::Shutdown() {
210   grpc_core::MutexLock lock(&mu_);
211   if (std::exchange(is_shutdown_, true)) {
212     return;
213   }
214   read_ops_channel_ = ReadOperation();
215   read_op_signal_->Notify();
216   write_ops_channel_ = WriteOperation();
217   write_op_signal_->Notify();
218   read_ops_.Join();
219   write_ops_.Join();
220 }
221 
Create(int socket_fd)222 std::unique_ptr<PosixOracleEndpoint> PosixOracleEndpoint::Create(
223     int socket_fd) {
224   return std::make_unique<PosixOracleEndpoint>(socket_fd);
225 }
226 
~PosixOracleEndpoint()227 PosixOracleEndpoint::~PosixOracleEndpoint() {
228   Shutdown();
229   close(socket_fd_);
230 }
231 
Read(absl::AnyInvocable<void (absl::Status)> on_read,SliceBuffer * buffer,const ReadArgs * args)232 bool PosixOracleEndpoint::Read(absl::AnyInvocable<void(absl::Status)> on_read,
233                                SliceBuffer* buffer, const ReadArgs* args) {
234   grpc_core::MutexLock lock(&mu_);
235   CHECK_NE(buffer, nullptr);
236   int read_hint_bytes =
237       args != nullptr ? std::max(1, static_cast<int>(args->read_hint_bytes))
238                       : 0;
239   read_ops_channel_ =
240       ReadOperation(read_hint_bytes, buffer, std::move(on_read));
241   read_op_signal_->Notify();
242   return false;
243 }
244 
Write(absl::AnyInvocable<void (absl::Status)> on_writable,SliceBuffer * data,const WriteArgs *)245 bool PosixOracleEndpoint::Write(
246     absl::AnyInvocable<void(absl::Status)> on_writable, SliceBuffer* data,
247     const WriteArgs* /*args*/) {
248   grpc_core::MutexLock lock(&mu_);
249   CHECK_NE(data, nullptr);
250   write_ops_channel_ = WriteOperation(data, std::move(on_writable));
251   write_op_signal_->Notify();
252   return false;
253 }
254 
ProcessReadOperations()255 void PosixOracleEndpoint::ProcessReadOperations() {
256   LOG(INFO) << "Starting thread to process read ops ...";
257   while (true) {
258     read_op_signal_->WaitForNotification();
259     read_op_signal_ = std::make_unique<grpc_core::Notification>();
260     auto read_op = std::exchange(read_ops_channel_, ReadOperation());
261     if (!read_op.IsValid()) {
262       read_op(std::string(), absl::CancelledError("Closed"));
263       break;
264     }
265     int saved_errno;
266     std::string read_data =
267         ReadBytes(socket_fd_, saved_errno, read_op.GetNumBytesToRead());
268     read_op(read_data, read_data.empty()
269                            ? absl::CancelledError(
270                                  absl::StrCat("Read failed with error = ",
271                                               grpc_core::StrError(saved_errno)))
272                            : absl::OkStatus());
273   }
274   LOG(INFO) << "Shutting down read ops thread ...";
275 }
276 
ProcessWriteOperations()277 void PosixOracleEndpoint::ProcessWriteOperations() {
278   LOG(INFO) << "Starting thread to process write ops ...";
279   while (true) {
280     write_op_signal_->WaitForNotification();
281     write_op_signal_ = std::make_unique<grpc_core::Notification>();
282     auto write_op = std::exchange(write_ops_channel_, WriteOperation());
283     if (!write_op.IsValid()) {
284       write_op(absl::CancelledError("Closed"));
285       break;
286     }
287     int saved_errno;
288     int ret = WriteBytes(socket_fd_, saved_errno, write_op.GetBytesToWrite());
289     write_op(ret < 0 ? absl::CancelledError(
290                            absl::StrCat("Write failed with error = ",
291                                         grpc_core::StrError(saved_errno)))
292                      : absl::OkStatus());
293   }
294   LOG(INFO) << "Shutting down write ops thread ...";
295 }
296 
PosixOracleListener(EventEngine::Listener::AcceptCallback on_accept,absl::AnyInvocable<void (absl::Status)> on_shutdown,std::unique_ptr<MemoryAllocatorFactory> memory_allocator_factory)297 PosixOracleListener::PosixOracleListener(
298     EventEngine::Listener::AcceptCallback on_accept,
299     absl::AnyInvocable<void(absl::Status)> on_shutdown,
300     std::unique_ptr<MemoryAllocatorFactory> memory_allocator_factory)
301     : on_accept_(std::move(on_accept)),
302       on_shutdown_(std::move(on_shutdown)),
303       memory_allocator_factory_(std::move(memory_allocator_factory)) {
304   if (pipe(pipefd_) == -1) {
305     grpc_core::Crash(absl::StrFormat("Error creating pipe: %s",
306                                      grpc_core::StrError(errno).c_str()));
307   }
308 }
309 
Start()310 absl::Status PosixOracleListener::Start() {
311   grpc_core::MutexLock lock(&mu_);
312   CHECK(!listener_fds_.empty());
313   if (std::exchange(is_started_, true)) {
314     return absl::InternalError("Cannot start listener more than once ...");
315   }
316   serve_ = grpc_core::Thread(
317       "accept_thread",
318       [](void* arg) {
319         static_cast<PosixOracleListener*>(arg)->HandleIncomingConnections();
320       },
321       this);
322   serve_.Start();
323   return absl::OkStatus();
324 }
325 
~PosixOracleListener()326 PosixOracleListener::~PosixOracleListener() {
327   grpc_core::MutexLock lock(&mu_);
328   if (!is_started_) {
329     serve_.Join();
330     return;
331   }
332   for (int i = 0; i < static_cast<int>(listener_fds_.size()); i++) {
333     shutdown(listener_fds_[i], SHUT_RDWR);
334   }
335   // Send a STOP message over the pipe.
336   CHECK(write(pipefd_[1], kStopMessage, strlen(kStopMessage)) != -1);
337   serve_.Join();
338   on_shutdown_(absl::OkStatus());
339 }
340 
HandleIncomingConnections()341 void PosixOracleListener::HandleIncomingConnections() {
342   LOG(INFO) << "Starting accept thread ...";
343   CHECK(!listener_fds_.empty());
344   int nfds = listener_fds_.size();
345   // Add one extra file descriptor to poll the pipe fd.
346   ++nfds;
347   struct pollfd* pfds =
348       static_cast<struct pollfd*>(gpr_malloc(sizeof(struct pollfd) * nfds));
349   memset(pfds, 0, sizeof(struct pollfd) * nfds);
350   while (true) {
351     for (int i = 0; i < nfds; i++) {
352       pfds[i].fd = i == nfds - 1 ? pipefd_[0] : listener_fds_[i];
353       pfds[i].events = POLLIN;
354       pfds[i].revents = 0;
355     }
356     if (!PollFds(pfds, nfds, absl::InfiniteDuration()).ok()) {
357       break;
358     }
359     int saved_errno = 0;
360     if ((pfds[nfds - 1].revents & POLLIN) &&
361         ReadBytes(pipefd_[0], saved_errno, strlen(kStopMessage)) ==
362             std::string(kStopMessage)) {
363       break;
364     }
365     for (int i = 0; i < nfds - 1; i++) {
366       if (!(pfds[i].revents & POLLIN)) {
367         continue;
368       }
369       // pfds[i].fd has a readable event.
370       int client_sock_fd = accept(pfds[i].fd, nullptr, nullptr);
371       if (client_sock_fd < 0) {
372         LOG(ERROR) << "Error accepting new connection: "
373                    << grpc_core::StrError(errno)
374                    << ". Ignoring connection attempt ...";
375         continue;
376       }
377       on_accept_(PosixOracleEndpoint::Create(client_sock_fd),
378                  memory_allocator_factory_->CreateMemoryAllocator("test"));
379     }
380   }
381   LOG(INFO) << "Shutting down accept thread ...";
382   gpr_free(pfds);
383 }
384 
Bind(const EventEngine::ResolvedAddress & addr)385 absl::StatusOr<int> PosixOracleListener::Bind(
386     const EventEngine::ResolvedAddress& addr) {
387   grpc_core::MutexLock lock(&mu_);
388   if (is_started_) {
389     return absl::FailedPreconditionError(
390         "Listener is already started, ports can no longer be bound");
391   }
392   int new_socket;
393   int opt = -1;
394   grpc_resolved_address address = CreateGRPCResolvedAddress(addr);
395   const char* scheme = grpc_sockaddr_get_uri_scheme(&address);
396   if (scheme == nullptr || strcmp(scheme, "ipv6") != 0) {
397     return absl::UnimplementedError(
398         "Unsupported bind address type. Only IPV6 addresses are supported "
399         "currently by the PosixOracleListener ...");
400   }
401 
402   // Creating a new socket file descriptor.
403   if ((new_socket = socket(AF_INET6, SOCK_STREAM, 0)) <= 0) {
404     return absl::UnknownError(
405         absl::StrCat("Error creating socket: ", grpc_core::StrError(errno)));
406   }
407   // MacOS biulds fail if SO_REUSEADDR and SO_REUSEPORT are set in the same
408   // setsockopt syscall. So they are set separately one after the other.
409   if (setsockopt(new_socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) {
410     return absl::UnknownError(absl::StrCat("Error setsockopt(SO_REUSEADDR): ",
411                                            grpc_core::StrError(errno)));
412   }
413   if (setsockopt(new_socket, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt))) {
414     return absl::UnknownError(absl::StrCat("Error setsockopt(SO_REUSEPORT): ",
415                                            grpc_core::StrError(errno)));
416   }
417 
418   // Forcefully bind the new socket.
419   if (bind(new_socket, reinterpret_cast<const struct sockaddr*>(addr.address()),
420            address.len) < 0) {
421     return absl::UnknownError(
422         absl::StrCat("Error bind: ", grpc_core::StrError(errno)));
423   }
424   // Set the new socket to listen for one active connection at a time.
425   if (listen(new_socket, 1) < 0) {
426     return absl::UnknownError(
427         absl::StrCat("Error listen: ", grpc_core::StrError(errno)));
428   }
429   listener_fds_.push_back(new_socket);
430   return 0;
431 }
432 
433 // PosixOracleEventEngine implements blocking connect. It blocks the calling
434 // thread until either connect succeeds or fails with timeout.
Connect(OnConnectCallback on_connect,const ResolvedAddress & addr,const EndpointConfig &,MemoryAllocator,EventEngine::Duration timeout)435 EventEngine::ConnectionHandle PosixOracleEventEngine::Connect(
436     OnConnectCallback on_connect, const ResolvedAddress& addr,
437     const EndpointConfig& /*args*/, MemoryAllocator /*memory_allocator*/,
438     EventEngine::Duration timeout) {
439   int client_sock_fd;
440   absl::Time deadline = absl::Now() + absl::FromChrono(timeout);
441   grpc_resolved_address address = CreateGRPCResolvedAddress(addr);
442   const char* scheme = grpc_sockaddr_get_uri_scheme(&address);
443   if (scheme == nullptr || strcmp(scheme, "ipv6") != 0) {
444     on_connect(
445         absl::CancelledError("Unsupported bind address type. Only ipv6 "
446                              "addresses are currently supported."));
447     return {};
448   }
449   if ((client_sock_fd = socket(AF_INET6, SOCK_STREAM, 0)) < 0) {
450     on_connect(absl::CancelledError(
451         absl::StrCat("Connect failed: socket creation error: ",
452                      grpc_core::StrError(errno).c_str())));
453     return {};
454   }
455   int err;
456   int num_retries = 0;
457   static constexpr int kMaxRetries = 5;
458   do {
459     err = connect(client_sock_fd, const_cast<struct sockaddr*>(addr.address()),
460                   address.len);
461     if (err < 0 && (errno == EINPROGRESS || errno == EWOULDBLOCK)) {
462       auto status = BlockUntilWritableWithTimeout(
463           client_sock_fd,
464           std::max(deadline - absl::Now(), absl::ZeroDuration()));
465       if (!status.ok()) {
466         on_connect(status);
467         return {};
468       }
469     } else if (err < 0) {
470       if (errno != ECONNREFUSED || ++num_retries > kMaxRetries) {
471         on_connect(absl::CancelledError("Connect failed."));
472         return {};
473       }
474       // If ECONNREFUSED && num_retries < kMaxRetries, wait a while and try
475       // again.
476       absl::SleepFor(absl::Milliseconds(100));
477     }
478   } while (err < 0 && absl::Now() < deadline);
479   if (err < 0 && absl::Now() >= deadline) {
480     on_connect(absl::CancelledError("Deadline exceeded"));
481   } else {
482     on_connect(PosixOracleEndpoint::Create(client_sock_fd));
483   }
484   return {};
485 }
486 
487 }  // namespace experimental
488 }  // namespace grpc_event_engine
489