• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "uds/service_endpoint.h"
2 
3 #include <poll.h>
4 #include <sys/epoll.h>
5 #include <sys/eventfd.h>
6 #include <sys/socket.h>
7 #include <sys/un.h>
8 #include <algorithm>  // std::min
9 
10 #include <android-base/logging.h>
11 #include <android-base/strings.h>
12 #include <cutils/sockets.h>
13 #include <pdx/service.h>
14 #include <selinux/selinux.h>
15 #include <uds/channel_manager.h>
16 #include <uds/client_channel_factory.h>
17 #include <uds/ipc_helper.h>
18 
19 namespace {
20 
21 constexpr int kMaxBackLogForSocketListen = 1;
22 
23 using android::pdx::BorrowedChannelHandle;
24 using android::pdx::BorrowedHandle;
25 using android::pdx::ChannelReference;
26 using android::pdx::ErrorStatus;
27 using android::pdx::FileReference;
28 using android::pdx::LocalChannelHandle;
29 using android::pdx::LocalHandle;
30 using android::pdx::Status;
31 using android::pdx::uds::ChannelInfo;
32 using android::pdx::uds::ChannelManager;
33 
34 struct MessageState {
GetLocalFileHandle__anon2c8f2ab30111::MessageState35   bool GetLocalFileHandle(int index, LocalHandle* handle) {
36     if (index < 0) {
37       handle->Reset(index);
38     } else if (static_cast<size_t>(index) < request.file_descriptors.size()) {
39       *handle = std::move(request.file_descriptors[index]);
40     } else {
41       return false;
42     }
43     return true;
44   }
45 
GetLocalChannelHandle__anon2c8f2ab30111::MessageState46   bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
47     if (index < 0) {
48       *handle = LocalChannelHandle{nullptr, index};
49     } else if (static_cast<size_t>(index) < request.channels.size()) {
50       auto& channel_info = request.channels[index];
51       *handle = ChannelManager::Get().CreateHandle(
52           std::move(channel_info.data_fd),
53           std::move(channel_info.pollin_event_fd),
54           std::move(channel_info.pollhup_event_fd));
55     } else {
56       return false;
57     }
58     return true;
59   }
60 
PushFileHandle__anon2c8f2ab30111::MessageState61   Status<FileReference> PushFileHandle(BorrowedHandle handle) {
62     if (!handle)
63       return handle.Get();
64     response.file_descriptors.push_back(std::move(handle));
65     return response.file_descriptors.size() - 1;
66   }
67 
PushChannelHandle__anon2c8f2ab30111::MessageState68   Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) {
69     if (!handle)
70       return handle.value();
71 
72     if (auto* channel_data =
73             ChannelManager::Get().GetChannelData(handle.value())) {
74       ChannelInfo<BorrowedHandle> channel_info{
75           channel_data->data_fd(), channel_data->pollin_event_fd(),
76           channel_data->pollhup_event_fd()};
77       response.channels.push_back(std::move(channel_info));
78       return response.channels.size() - 1;
79     } else {
80       return ErrorStatus{EINVAL};
81     }
82   }
83 
PushChannelHandle__anon2c8f2ab30111::MessageState84   Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd,
85                                              BorrowedHandle pollin_event_fd,
86                                              BorrowedHandle pollhup_event_fd) {
87     if (!data_fd || !pollin_event_fd || !pollhup_event_fd)
88       return ErrorStatus{EINVAL};
89     ChannelInfo<BorrowedHandle> channel_info{std::move(data_fd),
90                                              std::move(pollin_event_fd),
91                                              std::move(pollhup_event_fd)};
92     response.channels.push_back(std::move(channel_info));
93     return response.channels.size() - 1;
94   }
95 
WriteData__anon2c8f2ab30111::MessageState96   Status<size_t> WriteData(const iovec* vector, size_t vector_length) {
97     size_t size = 0;
98     for (size_t i = 0; i < vector_length; i++) {
99       const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base);
100       response_data.insert(response_data.end(), data, data + vector[i].iov_len);
101       size += vector[i].iov_len;
102     }
103     return size;
104   }
105 
ReadData__anon2c8f2ab30111::MessageState106   Status<size_t> ReadData(const iovec* vector, size_t vector_length) {
107     size_t size_remaining = request_data.size() - request_data_read_pos;
108     size_t size = 0;
109     for (size_t i = 0; i < vector_length && size_remaining > 0; i++) {
110       size_t size_to_copy = std::min(size_remaining, vector[i].iov_len);
111       memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos,
112              size_to_copy);
113       size += size_to_copy;
114       request_data_read_pos += size_to_copy;
115       size_remaining -= size_to_copy;
116     }
117     return size;
118   }
119 
120   android::pdx::uds::RequestHeader<LocalHandle> request;
121   android::pdx::uds::ResponseHeader<BorrowedHandle> response;
122   std::vector<LocalHandle> sockets_to_close;
123   std::vector<uint8_t> request_data;
124   size_t request_data_read_pos{0};
125   std::vector<uint8_t> response_data;
126 };
127 
128 }  // anonymous namespace
129 
130 namespace android {
131 namespace pdx {
132 namespace uds {
133 
Endpoint(const std::string & endpoint_path,bool blocking,bool use_init_socket_fd)134 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
135                    bool use_init_socket_fd)
136     : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)},
137       is_blocking_{blocking} {
138   LocalHandle fd;
139   if (use_init_socket_fd) {
140     // Cut off the /dev/socket/ prefix from the full socket path and use the
141     // resulting "name" to retrieve the file descriptor for the socket created
142     // by the init process.
143     constexpr char prefix[] = "/dev/socket/";
144     CHECK(android::base::StartsWith(endpoint_path_, prefix))
145         << "Endpoint::Endpoint: Socket name '" << endpoint_path_
146         << "' must begin with '" << prefix << "'";
147     std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1);
148     fd.Reset(android_get_control_socket(socket_name.c_str()));
149     CHECK(fd.IsValid())
150         << "Endpoint::Endpoint: Unable to obtain the control socket fd for '"
151         << socket_name << "'";
152     fcntl(fd.Get(), F_SETFD, FD_CLOEXEC);
153   } else {
154     fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0));
155     CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: "
156                         << strerror(errno);
157 
158     sockaddr_un local;
159     local.sun_family = AF_UNIX;
160     strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path));
161     local.sun_path[sizeof(local.sun_path) - 1] = '\0';
162 
163     unlink(local.sun_path);
164     int ret =
165         bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
166     CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
167   }
168   Init(std::move(fd));
169 }
170 
Endpoint(LocalHandle socket_fd)171 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
172 
Init(LocalHandle socket_fd)173 void Endpoint::Init(LocalHandle socket_fd) {
174   if (socket_fd) {
175     CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
176         << "Endpoint::Endpoint: listen error: " << strerror(errno);
177   }
178   cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
179   CHECK(cancel_event_fd_.IsValid())
180       << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
181 
182   epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC));
183   CHECK(epoll_fd_.IsValid())
184       << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
185 
186   if (socket_fd) {
187     epoll_event socket_event;
188     socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
189     socket_event.data.fd = socket_fd.Get();
190     int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
191                         &socket_event);
192     CHECK_EQ(ret, 0)
193         << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
194         << strerror(errno);
195   }
196 
197   epoll_event cancel_event;
198   cancel_event.events = EPOLLIN;
199   cancel_event.data.fd = cancel_event_fd_.Get();
200 
201   int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
202                       &cancel_event);
203   CHECK_EQ(ret, 0)
204       << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
205       << strerror(errno);
206   socket_fd_ = std::move(socket_fd);
207 }
208 
AllocateMessageState()209 void* Endpoint::AllocateMessageState() { return new MessageState; }
210 
FreeMessageState(void * state)211 void Endpoint::FreeMessageState(void* state) {
212   delete static_cast<MessageState*>(state);
213 }
214 
AcceptConnection(Message * message)215 Status<void> Endpoint::AcceptConnection(Message* message) {
216   if (!socket_fd_)
217     return ErrorStatus(EBADF);
218 
219   sockaddr_un remote;
220   socklen_t addrlen = sizeof(remote);
221   LocalHandle connection_fd{accept4(socket_fd_.Get(),
222                                     reinterpret_cast<sockaddr*>(&remote),
223                                     &addrlen, SOCK_CLOEXEC)};
224   if (!connection_fd) {
225     ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s",
226           strerror(errno));
227     return ErrorStatus(errno);
228   }
229 
230   LocalHandle local_socket;
231   LocalHandle remote_socket;
232   auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
233   if (!status)
234     return status;
235 
236   // Borrow the local channel handle before we move it into OnNewChannel().
237   BorrowedHandle channel_handle = local_socket.Borrow();
238   status = OnNewChannel(std::move(local_socket));
239   if (!status)
240     return status;
241 
242   // Send the channel socket fd to the client.
243   ChannelConnectionInfo<LocalHandle> connection_info;
244   connection_info.channel_fd = std::move(remote_socket);
245   status = SendData(connection_fd.Borrow(), connection_info);
246 
247   if (status) {
248     // Get the CHANNEL_OPEN message from client over the channel socket.
249     status = ReceiveMessageForChannel(channel_handle, message);
250   } else {
251     CloseChannel(GetChannelId(channel_handle));
252   }
253 
254   // Don't need the connection socket anymore. Further communication should
255   // happen over the channel socket.
256   shutdown(connection_fd.Get(), SHUT_WR);
257   return status;
258 }
259 
SetService(Service * service)260 Status<void> Endpoint::SetService(Service* service) {
261   service_ = service;
262   return {};
263 }
264 
SetChannel(int channel_id,Channel * channel)265 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) {
266   std::lock_guard<std::mutex> autolock(channel_mutex_);
267   auto channel_data = channels_.find(channel_id);
268   if (channel_data == channels_.end())
269     return ErrorStatus{EINVAL};
270   channel_data->second.channel_state = channel;
271   return {};
272 }
273 
OnNewChannel(LocalHandle channel_fd)274 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) {
275   std::lock_guard<std::mutex> autolock(channel_mutex_);
276   Status<void> status;
277   status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr));
278   return status;
279 }
280 
OnNewChannelLocked(LocalHandle channel_fd,Channel * channel_state)281 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked(
282     LocalHandle channel_fd, Channel* channel_state) {
283   epoll_event event;
284   event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
285   event.data.fd = channel_fd.Get();
286   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) {
287     ALOGE(
288         "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n",
289         strerror(errno));
290     return ErrorStatus(errno);
291   }
292   ChannelData channel_data;
293   channel_data.data_fd = std::move(channel_fd);
294   channel_data.channel_state = channel_state;
295   for (;;) {
296     // Try new channel IDs until we find one which is not already in the map.
297     if (last_channel_id_++ == std::numeric_limits<int32_t>::max())
298       last_channel_id_ = 1;
299     auto iter = channels_.lower_bound(last_channel_id_);
300     if (iter == channels_.end() || iter->first != last_channel_id_) {
301       channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_);
302       iter = channels_.emplace_hint(iter, last_channel_id_,
303                                     std::move(channel_data));
304       return std::make_pair(last_channel_id_, &iter->second);
305     }
306   }
307 }
308 
ReenableEpollEvent(const BorrowedHandle & fd)309 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) {
310   epoll_event event;
311   event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
312   event.data.fd = fd.Get();
313   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) {
314     ALOGE(
315         "Endpoint::ReenableEpollEvent: Failed to re-enable channel to "
316         "endpoint: %s\n",
317         strerror(errno));
318     return ErrorStatus(errno);
319   }
320   return {};
321 }
322 
CloseChannel(int channel_id)323 Status<void> Endpoint::CloseChannel(int channel_id) {
324   std::lock_guard<std::mutex> autolock(channel_mutex_);
325   return CloseChannelLocked(channel_id);
326 }
327 
CloseChannelLocked(int32_t channel_id)328 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) {
329   ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id);
330 
331   auto iter = channels_.find(channel_id);
332   if (iter == channels_.end())
333     return ErrorStatus{EINVAL};
334 
335   int channel_fd = iter->second.data_fd.Get();
336   Status<void> status;
337   epoll_event dummy;  // See BUGS in man 2 epoll_ctl.
338   if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) {
339     status.SetError(errno);
340     ALOGE(
341         "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: "
342         "%s\n",
343         strerror(errno));
344   } else {
345     status.SetValue();
346   }
347 
348   channel_fd_to_id_.erase(channel_fd);
349   channels_.erase(iter);
350   return status;
351 }
352 
ModifyChannelEvents(int channel_id,int clear_mask,int set_mask)353 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask,
354                                            int set_mask) {
355   std::lock_guard<std::mutex> autolock(channel_mutex_);
356 
357   auto search = channels_.find(channel_id);
358   if (search != channels_.end()) {
359     auto& channel_data = search->second;
360     channel_data.event_set.ModifyEvents(clear_mask, set_mask);
361     return {};
362   }
363 
364   return ErrorStatus{EINVAL};
365 }
366 
CreateChannelSocketPair(LocalHandle * local_socket,LocalHandle * remote_socket)367 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket,
368                                                LocalHandle* remote_socket) {
369   Status<void> status;
370   char* endpoint_context = nullptr;
371   // Make sure the channel socket has the correct SELinux label applied.
372   // Here we get the label from the endpoint file descriptor, which should be
373   // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace
374   // "endpoint" with "channel" to produce the channel label such as this:
375   // "u:object_r:pdx_service_channel_socket:s0".
376   if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) {
377     std::string channel_context = endpoint_context;
378     freecon(endpoint_context);
379     const std::string suffix = "_endpoint_socket";
380     auto pos = channel_context.find(suffix);
381     if (pos != std::string::npos) {
382       channel_context.replace(pos, suffix.size(), "_channel_socket");
383     } else {
384       ALOGW(
385           "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' "
386           "does not contain expected substring '%s'",
387           channel_context.c_str(), suffix.c_str());
388     }
389     ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1,
390              "Endpoint::CreateChannelSocketPair: Failed to set channel socket "
391              "security context: %s",
392              strerror(errno));
393   } else {
394     ALOGE(
395         "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint "
396         "socket's security context: %s",
397         strerror(errno));
398   }
399 
400   int channel_pair[2] = {};
401   if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) {
402     ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s",
403           strerror(errno));
404     status.SetError(errno);
405     return status;
406   }
407 
408   setsockcreatecon_raw(nullptr);
409 
410   local_socket->Reset(channel_pair[0]);
411   remote_socket->Reset(channel_pair[1]);
412 
413   int optval = 1;
414   if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval,
415                  sizeof(optval)) == -1) {
416     ALOGE(
417         "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of "
418         "the credentials for channel %d: %s",
419         local_socket->Get(), strerror(errno));
420     status.SetError(errno);
421   }
422   return status;
423 }
424 
PushChannel(Message * message,int,Channel * channel,int * channel_id)425 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message,
426                                                   int /*flags*/,
427                                                   Channel* channel,
428                                                   int* channel_id) {
429   LocalHandle local_socket;
430   LocalHandle remote_socket;
431   auto status = CreateChannelSocketPair(&local_socket, &remote_socket);
432   if (!status)
433     return status.error_status();
434 
435   std::lock_guard<std::mutex> autolock(channel_mutex_);
436   auto channel_data_status =
437       OnNewChannelLocked(std::move(local_socket), channel);
438   if (!channel_data_status)
439     return channel_data_status.error_status();
440 
441   ChannelData* channel_data;
442   std::tie(*channel_id, channel_data) = channel_data_status.take();
443 
444   // Flags are ignored for now.
445   // TODO(xiaohuit): Implement those.
446 
447   auto* state = static_cast<MessageState*>(message->GetState());
448   Status<ChannelReference> ref = state->PushChannelHandle(
449       remote_socket.Borrow(), channel_data->event_set.pollin_event_fd(),
450       channel_data->event_set.pollhup_event_fd());
451   if (!ref)
452     return ref.error_status();
453   state->sockets_to_close.push_back(std::move(remote_socket));
454   return RemoteChannelHandle{ref.get()};
455 }
456 
CheckChannel(const Message *,ChannelReference,Channel **)457 Status<int> Endpoint::CheckChannel(const Message* /*message*/,
458                                    ChannelReference /*ref*/,
459                                    Channel** /*channel*/) {
460   // TODO(xiaohuit): Implement this.
461   return ErrorStatus(EFAULT);
462 }
463 
GetChannelState(int32_t channel_id)464 Channel* Endpoint::GetChannelState(int32_t channel_id) {
465   std::lock_guard<std::mutex> autolock(channel_mutex_);
466   auto channel_data = channels_.find(channel_id);
467   return (channel_data != channels_.end()) ? channel_data->second.channel_state
468                                            : nullptr;
469 }
470 
GetChannelSocketFd(int32_t channel_id)471 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) {
472   std::lock_guard<std::mutex> autolock(channel_mutex_);
473   BorrowedHandle handle;
474   auto channel_data = channels_.find(channel_id);
475   if (channel_data != channels_.end())
476     handle = channel_data->second.data_fd.Borrow();
477   return handle;
478 }
479 
GetChannelEventFd(int32_t channel_id)480 Status<std::pair<BorrowedHandle, BorrowedHandle>> Endpoint::GetChannelEventFd(
481     int32_t channel_id) {
482   std::lock_guard<std::mutex> autolock(channel_mutex_);
483   auto channel_data = channels_.find(channel_id);
484   if (channel_data != channels_.end()) {
485     return {{channel_data->second.event_set.pollin_event_fd(),
486              channel_data->second.event_set.pollhup_event_fd()}};
487   }
488   return ErrorStatus(ENOENT);
489 }
490 
GetChannelId(const BorrowedHandle & channel_fd)491 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) {
492   std::lock_guard<std::mutex> autolock(channel_mutex_);
493   auto iter = channel_fd_to_id_.find(channel_fd.Get());
494   return (iter != channel_fd_to_id_.end()) ? iter->second : -1;
495 }
496 
ReceiveMessageForChannel(const BorrowedHandle & channel_fd,Message * message)497 Status<void> Endpoint::ReceiveMessageForChannel(
498     const BorrowedHandle& channel_fd, Message* message) {
499   RequestHeader<LocalHandle> request;
500   int32_t channel_id = GetChannelId(channel_fd);
501   auto status = ReceiveData(channel_fd.Borrow(), &request);
502   if (!status) {
503     if (status.error() == ESHUTDOWN) {
504       BuildCloseMessage(channel_id, message);
505       return {};
506     } else {
507       CloseChannel(channel_id);
508       return status;
509     }
510   }
511 
512   MessageInfo info;
513   info.pid = request.cred.pid;
514   info.tid = -1;
515   info.cid = channel_id;
516   info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID
517                                 : GetNextAvailableMessageId();
518   info.euid = request.cred.uid;
519   info.egid = request.cred.gid;
520   info.op = request.op;
521   info.flags = 0;
522   info.service = service_;
523   info.channel = GetChannelState(channel_id);
524   info.send_len = request.send_len;
525   info.recv_len = request.max_recv_len;
526   info.fd_count = request.file_descriptors.size();
527   static_assert(sizeof(info.impulse) == request.impulse_payload.size(),
528                 "Impulse payload sizes must be the same in RequestHeader and "
529                 "MessageInfo");
530   memcpy(info.impulse, request.impulse_payload.data(),
531          request.impulse_payload.size());
532   *message = Message{info};
533   auto* state = static_cast<MessageState*>(message->GetState());
534   state->request = std::move(request);
535   if (request.send_len > 0 && !request.is_impulse) {
536     state->request_data.resize(request.send_len);
537     status = ReceiveData(channel_fd, state->request_data.data(),
538                          state->request_data.size());
539   }
540 
541   if (status && request.is_impulse)
542     status = ReenableEpollEvent(channel_fd);
543 
544   if (!status) {
545     if (status.error() == ESHUTDOWN) {
546       BuildCloseMessage(channel_id, message);
547       return {};
548     } else {
549       CloseChannel(channel_id);
550       return status;
551     }
552   }
553 
554   return status;
555 }
556 
BuildCloseMessage(int32_t channel_id,Message * message)557 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) {
558   ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id);
559   MessageInfo info;
560   info.pid = -1;
561   info.tid = -1;
562   info.cid = channel_id;
563   info.mid = GetNextAvailableMessageId();
564   info.euid = -1;
565   info.egid = -1;
566   info.op = opcodes::CHANNEL_CLOSE;
567   info.flags = 0;
568   info.service = service_;
569   info.channel = GetChannelState(channel_id);
570   info.send_len = 0;
571   info.recv_len = 0;
572   info.fd_count = 0;
573   *message = Message{info};
574 }
575 
MessageReceive(Message * message)576 Status<void> Endpoint::MessageReceive(Message* message) {
577   // Receive at most one event from the epoll set. This should prevent multiple
578   // dispatch threads from attempting to handle messages on the same socket at
579   // the same time.
580   epoll_event event;
581   int count = RETRY_EINTR(
582       epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0));
583   if (count < 0) {
584     ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n",
585           strerror(errno));
586     return ErrorStatus{errno};
587   } else if (count == 0) {
588     return ErrorStatus{ETIMEDOUT};
589   }
590 
591   if (event.data.fd == cancel_event_fd_.Get()) {
592     return ErrorStatus{ESHUTDOWN};
593   }
594 
595   if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
596     auto status = AcceptConnection(message);
597     if (!status)
598       return status;
599     return ReenableEpollEvent(socket_fd_.Borrow());
600   }
601 
602   BorrowedHandle channel_fd{event.data.fd};
603   return ReceiveMessageForChannel(channel_fd, message);
604 }
605 
MessageReply(Message * message,int return_code)606 Status<void> Endpoint::MessageReply(Message* message, int return_code) {
607   const int32_t channel_id = message->GetChannelId();
608   auto channel_socket = GetChannelSocketFd(channel_id);
609   if (!channel_socket)
610     return ErrorStatus{EBADF};
611 
612   auto* state = static_cast<MessageState*>(message->GetState());
613   switch (message->GetOp()) {
614     case opcodes::CHANNEL_CLOSE:
615       return CloseChannel(channel_id);
616 
617     case opcodes::CHANNEL_OPEN:
618       if (return_code < 0) {
619         return CloseChannel(channel_id);
620       } else {
621         // Open messages do not have a payload and may not transfer any channels
622         // or file descriptors on behalf of the service.
623         state->response_data.clear();
624         state->response.file_descriptors.clear();
625         state->response.channels.clear();
626 
627         // Return the channel event-related fds in a single ChannelInfo entry
628         // with an empty data_fd member.
629         auto status = GetChannelEventFd(channel_id);
630         if (!status)
631           return status.error_status();
632 
633         auto handles = status.take();
634         state->response.channels.push_back({BorrowedHandle(),
635                                             std::move(handles.first),
636                                             std::move(handles.second)});
637         return_code = 0;
638       }
639       break;
640   }
641 
642   state->response.ret_code = return_code;
643   state->response.recv_len = state->response_data.size();
644   auto status = SendData(channel_socket, state->response);
645   if (status && !state->response_data.empty()) {
646     status = SendData(channel_socket, state->response_data.data(),
647                       state->response_data.size());
648   }
649 
650   if (status)
651     status = ReenableEpollEvent(channel_socket);
652 
653   return status;
654 }
655 
MessageReplyFd(Message * message,unsigned int push_fd)656 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) {
657   auto* state = static_cast<MessageState*>(message->GetState());
658   auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)});
659   if (!ref)
660     return ref.error_status();
661   return MessageReply(message, ref.get());
662 }
663 
MessageReplyChannelHandle(Message * message,const LocalChannelHandle & handle)664 Status<void> Endpoint::MessageReplyChannelHandle(
665     Message* message, const LocalChannelHandle& handle) {
666   auto* state = static_cast<MessageState*>(message->GetState());
667   auto ref = state->PushChannelHandle(handle.Borrow());
668   if (!ref)
669     return ref.error_status();
670   return MessageReply(message, ref.get());
671 }
672 
MessageReplyChannelHandle(Message * message,const BorrowedChannelHandle & handle)673 Status<void> Endpoint::MessageReplyChannelHandle(
674     Message* message, const BorrowedChannelHandle& handle) {
675   auto* state = static_cast<MessageState*>(message->GetState());
676   auto ref = state->PushChannelHandle(handle.Duplicate());
677   if (!ref)
678     return ref.error_status();
679   return MessageReply(message, ref.get());
680 }
681 
MessageReplyChannelHandle(Message * message,const RemoteChannelHandle & handle)682 Status<void> Endpoint::MessageReplyChannelHandle(
683     Message* message, const RemoteChannelHandle& handle) {
684   return MessageReply(message, handle.value());
685 }
686 
ReadMessageData(Message * message,const iovec * vector,size_t vector_length)687 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector,
688                                          size_t vector_length) {
689   auto* state = static_cast<MessageState*>(message->GetState());
690   return state->ReadData(vector, vector_length);
691 }
692 
WriteMessageData(Message * message,const iovec * vector,size_t vector_length)693 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector,
694                                           size_t vector_length) {
695   auto* state = static_cast<MessageState*>(message->GetState());
696   return state->WriteData(vector, vector_length);
697 }
698 
PushFileHandle(Message * message,const LocalHandle & handle)699 Status<FileReference> Endpoint::PushFileHandle(Message* message,
700                                                const LocalHandle& handle) {
701   auto* state = static_cast<MessageState*>(message->GetState());
702   return state->PushFileHandle(handle.Borrow());
703 }
704 
PushFileHandle(Message * message,const BorrowedHandle & handle)705 Status<FileReference> Endpoint::PushFileHandle(Message* message,
706                                                const BorrowedHandle& handle) {
707   auto* state = static_cast<MessageState*>(message->GetState());
708   return state->PushFileHandle(handle.Duplicate());
709 }
710 
PushFileHandle(Message *,const RemoteHandle & handle)711 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/,
712                                                const RemoteHandle& handle) {
713   return handle.Get();
714 }
715 
PushChannelHandle(Message * message,const LocalChannelHandle & handle)716 Status<ChannelReference> Endpoint::PushChannelHandle(
717     Message* message, const LocalChannelHandle& handle) {
718   auto* state = static_cast<MessageState*>(message->GetState());
719   return state->PushChannelHandle(handle.Borrow());
720 }
721 
PushChannelHandle(Message * message,const BorrowedChannelHandle & handle)722 Status<ChannelReference> Endpoint::PushChannelHandle(
723     Message* message, const BorrowedChannelHandle& handle) {
724   auto* state = static_cast<MessageState*>(message->GetState());
725   return state->PushChannelHandle(handle.Duplicate());
726 }
727 
PushChannelHandle(Message *,const RemoteChannelHandle & handle)728 Status<ChannelReference> Endpoint::PushChannelHandle(
729     Message* /*message*/, const RemoteChannelHandle& handle) {
730   return handle.value();
731 }
732 
GetFileHandle(Message * message,FileReference ref) const733 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const {
734   LocalHandle handle;
735   auto* state = static_cast<MessageState*>(message->GetState());
736   state->GetLocalFileHandle(ref, &handle);
737   return handle;
738 }
739 
GetChannelHandle(Message * message,ChannelReference ref) const740 LocalChannelHandle Endpoint::GetChannelHandle(Message* message,
741                                               ChannelReference ref) const {
742   LocalChannelHandle handle;
743   auto* state = static_cast<MessageState*>(message->GetState());
744   state->GetLocalChannelHandle(ref, &handle);
745   return handle;
746 }
747 
Cancel()748 Status<void> Endpoint::Cancel() {
749   if (eventfd_write(cancel_event_fd_.Get(), 1) < 0)
750     return ErrorStatus{errno};
751   return {};
752 }
753 
Create(const std::string & endpoint_path,mode_t,bool blocking)754 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path,
755                                            mode_t /*unused_mode*/,
756                                            bool blocking) {
757   return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking));
758 }
759 
CreateAndBindSocket(const std::string & endpoint_path,bool blocking)760 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
761     const std::string& endpoint_path, bool blocking) {
762   return std::unique_ptr<Endpoint>(
763       new Endpoint(endpoint_path, blocking, false));
764 }
765 
CreateFromSocketFd(LocalHandle socket_fd)766 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
767   return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
768 }
769 
RegisterNewChannelForTests(LocalHandle channel_fd)770 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
771   int optval = 1;
772   if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
773                  sizeof(optval)) == -1) {
774     ALOGE(
775         "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
776         "of the credentials for channel %d: %s",
777         channel_fd.Get(), strerror(errno));
778     return ErrorStatus(errno);
779   }
780   return OnNewChannel(std::move(channel_fd));
781 }
782 
783 }  // namespace uds
784 }  // namespace pdx
785 }  // namespace android
786