• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Implementation of sandbox2::Comms class.
16 //
17 // Warning: This class is not multi-thread safe (for callers). It uses a single
18 // communications channel (an AF_UNIX socket), so it requires exactly one sender
19 // and one receiver. If you plan to use it from many threads, provide external
20 // exclusive locking.
21 
22 #include "sandboxed_api/sandbox2/comms.h"
23 
24 #include <sys/socket.h>
25 #include <sys/uio.h>
26 #include <sys/un.h>
27 #include <syscall.h>
28 #include <unistd.h>
29 
30 #include <algorithm>
31 #include <atomic>
32 #include <cerrno>
33 #include <cstdint>
34 #include <cstdlib>
35 #include <cstring>
36 #include <memory>
37 #include <optional>
38 #include <string>
39 #include <vector>
40 
41 #include "absl/base/dynamic_annotations.h"
42 #include "absl/status/status.h"
43 #include "absl/status/statusor.h"
44 #include "absl/strings/numbers.h"
45 #include "absl/strings/str_format.h"
46 #include "absl/strings/string_view.h"
47 #include "google/protobuf/message_lite.h"
48 #include "sandboxed_api/sandbox2/util.h"
49 #include "sandboxed_api/util/fileops.h"
50 #include "sandboxed_api/util/raw_logging.h"
51 #include "sandboxed_api/util/status.h"
52 #include "sandboxed_api/util/status.pb.h"
53 #include "sandboxed_api/util/status_macros.h"
54 
55 namespace sandbox2 {
56 
57 class PotentiallyBlockingRegion {
58  public:
~PotentiallyBlockingRegion()59   ~PotentiallyBlockingRegion() {
60     // Do nothing. Not defaulted to avoid "unused variable" warnings.
61   }
62 };
63 namespace {
64 
65 using sapi::file_util::fileops::FDCloser;
66 
IsFatalError(int saved_errno)67 bool IsFatalError(int saved_errno) {
68   return saved_errno != EAGAIN && saved_errno != EWOULDBLOCK &&
69          saved_errno != EFAULT && saved_errno != EINTR &&
70          saved_errno != EINVAL && saved_errno != ENOMEM;
71 }
72 
GetDefaultCommsFd()73 int GetDefaultCommsFd() {
74   if (const char* var = getenv(Comms::kSandbox2CommsFDEnvVar); var) {
75     int fd;
76     SAPI_RAW_CHECK(absl::SimpleAtoi(var, &fd), "cannot parse comms fd var");
77     unsetenv(Comms::kSandbox2CommsFDEnvVar);
78     return fd;
79   }
80   return Comms::kSandbox2ClientCommsFD;
81 }
82 
CreateSockaddrUn(const std::string & socket_name,bool abstract_uds,sockaddr_un * sun)83 socklen_t CreateSockaddrUn(const std::string& socket_name, bool abstract_uds,
84                            sockaddr_un* sun) {
85   sun->sun_family = AF_UNIX;
86   bzero(sun->sun_path, sizeof(sun->sun_path));
87   socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name.c_str());
88   if (abstract_uds) {
89     // Create an 'abstract socket address' by specifying a leading null byte.
90     // The remainder of the path is used as a unique name, but no file is
91     // created on the filesystem. No need to NUL-terminate the string. See `man
92     // 7 unix` for further explanation.
93     strncpy(&sun->sun_path[1], socket_name.c_str(), sizeof(sun->sun_path) - 1);
94     // Len is complicated - it's essentially size of the path, plus initial
95     // NUL-byte, minus size of the sun.sun_family.
96     slen++;
97   } else {
98     // Create the socket address as it was passed from the constructor.
99     strncpy(&sun->sun_path[0], socket_name.c_str(), sizeof(sun->sun_path));
100   }
101 
102   // This takes care of the socket address overflow.
103   if (slen > sizeof(sockaddr_un)) {
104     SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated");
105     slen = sizeof(sockaddr_un);
106   }
107   return slen;
108 }
109 }  // namespace
110 
Comms(int fd,absl::string_view name)111 Comms::Comms(int fd, absl::string_view name) : raw_comms_(RawCommsFdImpl(fd)) {
112   // Generate a unique and meaningful socket name for this FD.
113   // Note: getpid()/gettid() are non-blocking syscalls.
114   if (name.empty()) {
115     name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd,
116                             getpid(), syscall(__NR_gettid));
117   } else {
118     name_ = std::string(name);
119   }
120 
121   // File descriptor is already connected.
122   state_ = State::kConnected;
123 }
124 
Comms(Comms::DefaultConnectionTag)125 Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {}
126 
~Comms()127 Comms::~Comms() { Terminate(); }
128 
GetConnectionFD() const129 int Comms::GetConnectionFD() const {
130   return GetRawComms() == nullptr ? -1 : GetRawComms()->GetConnectionFD();
131 }
132 
Create(absl::string_view socket_name,bool abstract_uds)133 absl::StatusOr<ListeningComms> ListeningComms::Create(
134     absl::string_view socket_name, bool abstract_uds) {
135   ListeningComms comms(std::string(socket_name), abstract_uds);
136   SAPI_RETURN_IF_ERROR(comms.Listen());
137   return comms;
138 }
139 
Listen()140 absl::Status ListeningComms::Listen() {
141   bind_fd_ = FDCloser(socket(AF_UNIX, SOCK_STREAM, 0));  // Non-blocking
142   if (bind_fd_.get() == -1) {
143     return absl::ErrnoToStatus(errno, "socket(AF_UNIX) failed");
144   }
145 
146   sockaddr_un sus;
147   socklen_t slen = CreateSockaddrUn(socket_name_, abstract_uds_, &sus);
148   // bind() is non-blocking.
149   if (bind(bind_fd_.get(), reinterpret_cast<sockaddr*>(&sus), slen) == -1) {
150     return absl::ErrnoToStatus(errno, "bind failed");
151   }
152 
153   // listen() non-blocking.
154   if (listen(bind_fd_.get(), 0) == -1) {
155     return absl::ErrnoToStatus(errno, "listen failed");
156   }
157 
158   SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str());
159   return absl::OkStatus();
160 }
161 
Accept()162 absl::StatusOr<Comms> ListeningComms::Accept() {
163   sockaddr_un suc;
164   socklen_t len = sizeof(suc);
165   int connection_fd;
166   {
167     PotentiallyBlockingRegion region;
168     connection_fd = TEMP_FAILURE_RETRY(
169         accept(bind_fd_.get(), reinterpret_cast<sockaddr*>(&suc), &len));
170   }
171   if (connection_fd == -1) {
172     return absl::ErrnoToStatus(errno, "accept failed");
173   }
174   SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(),
175                 connection_fd);
176   return Comms(connection_fd, socket_name_);
177 }
178 
Connect(const std::string & socket_name,bool abstract_uds)179 absl::StatusOr<Comms> Comms::Connect(const std::string& socket_name,
180                                      bool abstract_uds) {
181   FDCloser connection_fd(socket(AF_UNIX, SOCK_STREAM, 0));  // Non-blocking
182   if (connection_fd.get() == -1) {
183     return absl::ErrnoToStatus(errno, "socket(AF_UNIX)");
184   }
185 
186   sockaddr_un suc;
187   socklen_t slen = CreateSockaddrUn(socket_name, abstract_uds, &suc);
188   int ret;
189   {
190     PotentiallyBlockingRegion region;
191     ret = TEMP_FAILURE_RETRY(
192         connect(connection_fd.get(), reinterpret_cast<sockaddr*>(&suc), slen));
193   }
194   if (ret == -1) {
195     return absl::ErrnoToStatus(errno, "connect(connection_fd)");
196   }
197 
198   SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name.c_str(),
199                 connection_fd.get());
200   return Comms(connection_fd.Release(), socket_name);
201 }
202 
Terminate()203 void Comms::Terminate() {
204   state_ = State::kTerminated;
205 
206   raw_comms_ = std::unique_ptr<RawComms>();
207   listening_comms_.reset();
208 }
209 
SendTLV(uint32_t tag,size_t length,const void * value)210 bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) {
211   if (length > GetMaxMsgSize()) {
212     SAPI_RAW_LOG(ERROR, "Maximum TLV message size exceeded: (%zu > %zu)",
213                  length, GetMaxMsgSize());
214     return false;
215   }
216   if (length > kWarnMsgSize) {
217     // TODO(cblichmann): Use LOG_FIRST_N once Abseil logging is released.
218     static std::atomic<int> times_warned = 0;
219     if (times_warned.fetch_add(1, std::memory_order_relaxed) < 10) {
220       SAPI_RAW_LOG(
221           WARNING,
222           "TLV message of size %zu detected. Please consider switching "
223           "to Buffer API instead.",
224           length);
225     }
226   }
227 
228   SAPI_RAW_VLOG(3, "Sending a TLV message, tag: 0x%08x, length: %zu", tag,
229                 length);
230 
231   // To maintain consistency with `RecvTL()`, we wrap `tag` and `length` in a TL
232   // struct.
233   const InternalTLV tl = {
234       .tag = tag,
235       .len = length,
236   };
237 
238   const size_t inline_size =
239       std::min(length, kSendTLVTempBufferSize - sizeof(tl));
240   uint8_t tlv[kSendTLVTempBufferSize];
241   memcpy(tlv, &tl, sizeof(tl));
242   memcpy(&tlv[sizeof(tl)], value, inline_size);
243   if (!Send(&tlv, sizeof(tl) + inline_size)) {
244     return false;
245   }
246   if (inline_size < length) {
247     return Send(reinterpret_cast<const uint8_t*>(value) + inline_size,
248                 length - inline_size);
249   }
250   return true;
251 }
252 
RecvString(std::string * v)253 bool Comms::RecvString(std::string* v) {
254   uint32_t tag;
255   if (!RecvTLV(&tag, v)) {
256     return false;
257   }
258 
259   if (tag != kTagString) {
260     v->clear();
261     SAPI_RAW_LOG(ERROR, "Expected (kTagString == 0x%x), got: 0x%x", kTagString,
262                  tag);
263     return false;
264   }
265   return true;
266 }
267 
SendString(const std::string & v)268 bool Comms::SendString(const std::string& v) {
269   return SendTLV(kTagString, v.length(), v.c_str());
270 }
271 
RecvBytes(std::vector<uint8_t> * buffer)272 bool Comms::RecvBytes(std::vector<uint8_t>* buffer) {
273   uint32_t tag;
274   if (!RecvTLV(&tag, buffer)) {
275     return false;
276   }
277   if (tag != kTagBytes) {
278     buffer->clear();
279     SAPI_RAW_LOG(ERROR, "Expected (kTagBytes == 0x%x), got: 0x%u", kTagBytes,
280                  tag);
281     return false;
282   }
283   return true;
284 }
285 
SendBytes(const uint8_t * v,size_t len)286 bool Comms::SendBytes(const uint8_t* v, size_t len) {
287   return SendTLV(kTagBytes, len, v);
288 }
289 
SendBytes(const std::vector<uint8_t> & buffer)290 bool Comms::SendBytes(const std::vector<uint8_t>& buffer) {
291   return SendBytes(buffer.data(), buffer.size());
292 }
293 
RecvCreds(pid_t * pid,uid_t * uid,gid_t * gid)294 bool Comms::RecvCreds(pid_t* pid, uid_t* uid, gid_t* gid) {
295   ucred uc;
296   socklen_t sls = sizeof(uc);
297   int rc;
298   {
299     // Not completely sure if getsockopt() can block on SO_PEERCRED, but let's
300     // play it safe.
301     PotentiallyBlockingRegion region;
302     rc = getsockopt(GetConnectionFD(), SOL_SOCKET, SO_PEERCRED, &uc, &sls);
303   }
304   if (rc == -1) {
305     SAPI_RAW_PLOG(ERROR, "getsockopt(SO_PEERCRED)");
306     return false;
307   }
308   *pid = uc.pid;
309   *uid = uc.uid;
310   *gid = uc.gid;
311 
312   SAPI_RAW_VLOG(2, "Received credentials from PID/UID/GID: %d/%u/%u", *pid,
313                 *uid, *gid);
314   return true;
315 }
316 
RecvFD(int * fd)317 bool Comms::RecvFD(int* fd) {
318   char fd_msg[8192];
319   cmsghdr* cmsg = reinterpret_cast<cmsghdr*>(fd_msg);
320 
321   InternalTLV tlv;
322   iovec iov = {.iov_base = &tlv, .iov_len = sizeof(tlv)};
323 
324   msghdr msg = {
325       .msg_name = nullptr,
326       .msg_namelen = 0,
327       .msg_iov = &iov,
328       .msg_iovlen = 1,
329       .msg_control = cmsg,
330       .msg_controllen = sizeof(fd_msg),
331       .msg_flags = 0,
332   };
333 
334   if (GetRawComms() == nullptr) {
335     SAPI_RAW_LOG(ERROR, "RecvFD: connection terminated");
336     return false;
337   }
338 
339   ssize_t len = GetRawComms()->RawRecvMsg(&msg);
340   if (len < 0) {
341     if (IsFatalError(errno)) {
342       Terminate();
343     }
344     SAPI_RAW_PLOG(ERROR, "recvmsg(SCM_RIGHTS)");
345     return false;
346   }
347   if (len == 0) {
348     Terminate();
349     SAPI_RAW_VLOG(1, "RecvFD: end-point terminated the connection.");
350     return false;
351   }
352   if (len != sizeof(tlv)) {
353     SAPI_RAW_LOG(ERROR, "Expected size: %zu, got %zd", sizeof(tlv), len);
354     return false;
355   }
356   // At this point, we know that op() has been called successfully, therefore
357   // msg struct has been fully populated. Apparently MSAN is not aware of
358   // syscall(__NR_recvmsg) semantics so we need to suppress the error (here and
359   // everywhere below).
360   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&tlv, sizeof(tlv));
361 
362   if (tlv.tag != kTagFd) {
363     SAPI_RAW_LOG(ERROR, "Expected (kTagFD: 0x%x), got: 0x%x", kTagFd, tlv.tag);
364     return false;
365   }
366 
367   cmsg = CMSG_FIRSTHDR(&msg);
368   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(cmsg, sizeof(cmsghdr));
369   while (cmsg) {
370     if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
371       if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) {
372         SAPI_RAW_VLOG(1,
373                       "recvmsg(SCM_RIGHTS): cmsg->cmsg_len != "
374                       "CMSG_LEN(sizeof(int)), skipping");
375         continue;
376       }
377       int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
378       *fd = fds[0];
379       ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(fd, sizeof(int));
380       return true;
381     }
382     cmsg = CMSG_NXTHDR(&msg, cmsg);
383   }
384   SAPI_RAW_LOG(ERROR,
385                "Haven't received the SCM_RIGHTS message, process is probably "
386                "out of free file descriptors");
387   return false;
388 }
389 
SendFD(int fd)390 bool Comms::SendFD(int fd) {
391   char fd_msg[CMSG_SPACE(sizeof(int))] = {0};
392   cmsghdr* cmsg = reinterpret_cast<cmsghdr*>(fd_msg);
393   cmsg->cmsg_level = SOL_SOCKET;
394   cmsg->cmsg_type = SCM_RIGHTS;
395   cmsg->cmsg_len = CMSG_LEN(sizeof(int));
396 
397   int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
398   fds[0] = fd;
399 
400   InternalTLV tlv = {kTagFd, 0};
401 
402   iovec iov;
403   iov.iov_base = &tlv;
404   iov.iov_len = sizeof(tlv);
405 
406   msghdr msg;
407   msg.msg_name = nullptr;
408   msg.msg_namelen = 0;
409   msg.msg_iov = &iov;
410   msg.msg_iovlen = 1;
411   msg.msg_control = cmsg;
412   msg.msg_controllen = sizeof(fd_msg);
413   msg.msg_flags = 0;
414 
415   if (GetRawComms() == nullptr) {
416     SAPI_RAW_LOG(ERROR, "SendFD: connection terminated");
417     return false;
418   }
419 
420   ssize_t len = GetRawComms()->RawSendMsg(&msg);
421   if (len == -1 && errno == EPIPE) {
422     Terminate();
423     SAPI_RAW_LOG(ERROR, "sendmsg(SCM_RIGHTS): Peer disconnected");
424     return false;
425   }
426   if (len < 0) {
427     if (IsFatalError(errno)) {
428       Terminate();
429     }
430     SAPI_RAW_PLOG(ERROR, "sendmsg(SCM_RIGHTS)");
431     return false;
432   }
433   if (len != sizeof(tlv)) {
434     SAPI_RAW_LOG(ERROR, "Expected to send %zu bytes, sent %zd", sizeof(tlv),
435                  len);
436     return false;
437   }
438   return true;
439 }
440 
RecvProtoBuf(google::protobuf::MessageLite * message)441 bool Comms::RecvProtoBuf(google::protobuf::MessageLite* message) {
442   uint32_t tag;
443   std::vector<uint8_t> bytes;
444   if (!RecvTLV(&tag, &bytes)) {
445     if (IsConnected()) {
446       SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", name_);
447     } else {
448       Terminate();
449       SAPI_RAW_VLOG(2, "Connection terminated (%s)", name_.c_str());
450     }
451     return false;
452   }
453 
454   if (tag != kTagProto2) {
455     SAPI_RAW_LOG(ERROR, "Expected tag: 0x%x, got: 0x%u", kTagProto2, tag);
456     return false;
457   }
458   return message->ParseFromArray(bytes.data(), bytes.size());
459 }
460 
SendProtoBuf(const google::protobuf::MessageLite & message)461 bool Comms::SendProtoBuf(const google::protobuf::MessageLite& message) {
462   std::string str;
463   if (!message.SerializeToString(&str)) {
464     SAPI_RAW_LOG(ERROR, "Couldn't serialize the ProtoBuf");
465     return false;
466   }
467 
468   return SendTLV(kTagProto2, str.length(),
469                  reinterpret_cast<const uint8_t*>(str.data()));
470 }
471 
472 // *****************************************************************************
473 // All methods below are private, for internal use only.
474 // *****************************************************************************
475 
GetConnectionFD() const476 int Comms::RawCommsFdImpl::GetConnectionFD() const {
477   return connection_fd_.get();
478 }
479 
MoveToAnotherFd()480 void Comms::RawCommsFdImpl::MoveToAnotherFd() {
481   SAPI_RAW_CHECK(connection_fd_.get() != -1,
482                  "Cannot move comms fd as it's not connected");
483   FDCloser new_fd(dup(connection_fd_.get()));
484   SAPI_RAW_CHECK(new_fd.get() != -1, "Failed to move comms to another fd");
485   connection_fd_.Swap(new_fd);
486 }
487 
RawSend(const void * data,size_t len)488 ssize_t Comms::RawCommsFdImpl::RawSend(const void* data, size_t len) {
489   PotentiallyBlockingRegion region;
490   return TEMP_FAILURE_RETRY(write(connection_fd_.get(), data, len));
491 }
492 
RawRecv(void * data,size_t len)493 ssize_t Comms::RawCommsFdImpl::RawRecv(void* data, size_t len) {
494   PotentiallyBlockingRegion region;
495   return TEMP_FAILURE_RETRY(read(connection_fd_.get(), data, len));
496 }
497 
RawSendMsg(const void * msg)498 ssize_t Comms::RawCommsFdImpl::RawSendMsg(const void* msg) {
499   PotentiallyBlockingRegion region;
500   // Use syscall, otherwise we would need to allow socketcall() on PPC.
501   return TEMP_FAILURE_RETRY(util::Syscall(__NR_sendmsg, connection_fd_.get(),
502                                           reinterpret_cast<uintptr_t>(msg), 0));
503 }
504 
RawRecvMsg(void * msg)505 ssize_t Comms::RawCommsFdImpl::RawRecvMsg(void* msg) {
506   PotentiallyBlockingRegion region;
507   // Use syscall, otherwise we would need to allow socketcall() on PPC.
508   return TEMP_FAILURE_RETRY(util::Syscall(__NR_recvmsg, connection_fd_.get(),
509                                           reinterpret_cast<uintptr_t>(msg), 0));
510 }
511 
Send(const void * data,size_t len)512 bool Comms::Send(const void* data, size_t len) {
513   if (GetRawComms() == nullptr) {
514     SAPI_RAW_LOG(ERROR, "Send: connection terminated");
515     return false;
516   }
517 
518   size_t total_sent = 0;
519   const char* bytes = reinterpret_cast<const char*>(data);
520   while (total_sent < len) {
521     ssize_t s = GetRawComms()->RawSend(&bytes[total_sent], len - total_sent);
522     if (s == -1 && errno == EPIPE) {
523       Terminate();
524       // We do not expect the other end to disappear.
525       SAPI_RAW_LOG(ERROR, "Send: end-point terminated the connection");
526       return false;
527     }
528     if (s == -1) {
529       SAPI_RAW_PLOG(ERROR, "write");
530       if (IsFatalError(errno)) {
531         Terminate();
532       }
533       return false;
534     }
535     if (s == 0) {
536       SAPI_RAW_LOG(ERROR,
537                    "Couldn't write more bytes, wrote: %zu, requested: %zu",
538                    total_sent, len);
539       return false;
540     }
541     total_sent += s;
542   }
543   return true;
544 }
545 
Recv(void * data,size_t len)546 bool Comms::Recv(void* data, size_t len) {
547   if (GetRawComms() == nullptr) {
548     SAPI_RAW_LOG(ERROR, "Recv: connection terminated");
549     return false;
550   }
551 
552   size_t total_recv = 0;
553   char* bytes = reinterpret_cast<char*>(data);
554   while (total_recv < len) {
555     ssize_t s = GetRawComms()->RawRecv(&bytes[total_recv], len - total_recv);
556     if (s == -1) {
557       SAPI_RAW_PLOG(ERROR, "read");
558       if (IsFatalError(errno)) {
559         Terminate();
560       }
561       return false;
562     }
563     if (s == 0) {
564       Terminate();
565       // The other end might have finished its work.
566       SAPI_RAW_VLOG(2, "Recv: end-point terminated the connection.");
567       return false;
568     }
569     total_recv += s;
570   }
571   return true;
572 }
573 
574 // Internal helper method (low level).
RecvTL(uint32_t * tag,size_t * length)575 bool Comms::RecvTL(uint32_t* tag, size_t* length) {
576   InternalTLV tl;
577   if (!Recv(reinterpret_cast<uint8_t*>(&tl), sizeof(tl))) {
578     SAPI_RAW_VLOG(2, "RecvTL: Can't read tag and length");
579     return false;
580   }
581   *tag = tl.tag;
582   *length = tl.len;
583   if (*length > GetMaxMsgSize()) {
584     SAPI_RAW_LOG(ERROR, "Maximum TLV message size exceeded: (%zu > %zd)",
585                  *length, GetMaxMsgSize());
586     return false;
587   }
588   if (*length > kWarnMsgSize) {
589     static std::atomic<int> times_warned = 0;
590     if (times_warned.fetch_add(1, std::memory_order_relaxed) < 10) {
591       SAPI_RAW_LOG(
592           WARNING,
593           "TLV message of size: %zu detected. Please consider switching to "
594           "Buffer API instead.",
595           *length);
596     }
597   }
598   return true;
599 }
600 
RecvTLV(uint32_t * tag,std::vector<uint8_t> * value)601 bool Comms::RecvTLV(uint32_t* tag, std::vector<uint8_t>* value) {
602   return RecvTLVGeneric(tag, value);
603 }
604 
RecvTLV(uint32_t * tag,std::string * value)605 bool Comms::RecvTLV(uint32_t* tag, std::string* value) {
606   return RecvTLVGeneric(tag, value);
607 }
608 
609 template <typename T>
RecvTLVGeneric(uint32_t * tag,T * value)610 bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) {
611   size_t length;
612   if (!RecvTL(tag, &length)) {
613     return false;
614   }
615 
616   value->resize(length);
617   return length == 0 || Recv(reinterpret_cast<uint8_t*>(value->data()), length);
618 }
619 
RecvTLV(uint32_t * tag,size_t * length,void * buffer,size_t buffer_size,std::optional<uint32_t> expected_tag)620 bool Comms::RecvTLV(uint32_t* tag, size_t* length, void* buffer,
621                     size_t buffer_size, std::optional<uint32_t> expected_tag) {
622   if (!RecvTL(tag, length)) {
623     return false;
624   }
625 
626   if (expected_tag.has_value() && *tag != *expected_tag) {
627     SAPI_RAW_LOG(ERROR, "Expected tag: 0x%08x, got: 0x%x", *expected_tag, *tag);
628     return false;
629   }
630 
631   if (*length == 0) {
632     return true;
633   }
634 
635   if (*length > buffer_size) {
636     SAPI_RAW_LOG(ERROR, "Buffer size too small (0x%zx > 0x%zx)", *length,
637                  buffer_size);
638     return false;
639   }
640 
641   return Recv(reinterpret_cast<uint8_t*>(buffer), *length);
642 }
643 
RecvInt(void * buffer,size_t len,uint32_t tag)644 bool Comms::RecvInt(void* buffer, size_t len, uint32_t tag) {
645   uint32_t received_tag;
646   size_t received_length;
647   if (!RecvTLV(&received_tag, &received_length, buffer, len, tag)) {
648     return false;
649   }
650 
651   if (received_length != len) {
652     SAPI_RAW_LOG(ERROR, "Expected length: %zu, got: %zu", len, received_length);
653     return false;
654   }
655   return true;
656 }
657 
RecvStatus(absl::Status * status)658 bool Comms::RecvStatus(absl::Status* status) {
659   sapi::StatusProto proto;
660   if (!RecvProtoBuf(&proto)) {
661     return false;
662   }
663   *status = sapi::MakeStatusFromProto(proto);
664   return true;
665 }
666 
SendStatus(const absl::Status & status)667 bool Comms::SendStatus(const absl::Status& status) {
668   sapi::StatusProto proto;
669   sapi::SaveStatusToProto(status, &proto);
670   return SendProtoBuf(proto);
671 }
672 
MoveToAnotherFd()673 void Comms::MoveToAnotherFd() {
674   SAPI_RAW_CHECK(GetRawComms() != nullptr,
675                  "Cannot move comms fd as it's not connected");
676   GetRawComms()->MoveToAnotherFd();
677 }
678 
679 }  // namespace sandbox2
680