1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/libs/utils/unix_sockets.h"
17
18 #include <fcntl.h>
19 #include <sys/uio.h>
20 #include <unistd.h>
21
22 #include <cstring>
23 #include <memory>
24 #include <ostream>
25 #include <utility>
26 #include <vector>
27
28 #include <android-base/logging.h>
29
30 #include "common/libs/fs/shared_fd.h"
31 #include "common/libs/utils/result.h"
32
33 // This would use android::base::ReceiveFileDescriptors, but it silently drops
34 // SCM_CREDENTIALS control messages.
35
36 namespace cuttlefish {
37
FromRaw(const cmsghdr * cmsg)38 ControlMessage ControlMessage::FromRaw(const cmsghdr* cmsg) {
39 ControlMessage message;
40 message.data_ =
41 std::vector<char>((char*)cmsg, ((char*)cmsg) + cmsg->cmsg_len);
42 if (message.IsFileDescriptors()) {
43 size_t fdcount =
44 static_cast<size_t>(cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
45 for (int i = 0; i < fdcount; i++) {
46 // Use memcpy as CMSG_DATA may be unaligned
47 int fd = -1;
48 memcpy(&fd, CMSG_DATA(cmsg) + (i * sizeof(int)), sizeof(fd));
49 message.fds_.push_back(fd);
50 }
51 }
52 return message;
53 }
54
FromFileDescriptors(const std::vector<SharedFD> & fds)55 Result<ControlMessage> ControlMessage::FromFileDescriptors(
56 const std::vector<SharedFD>& fds) {
57 ControlMessage message;
58 message.data_.resize(CMSG_SPACE(fds.size() * sizeof(int)), 0);
59 message.Raw()->cmsg_len = CMSG_LEN(fds.size() * sizeof(int));
60 message.Raw()->cmsg_level = SOL_SOCKET;
61 message.Raw()->cmsg_type = SCM_RIGHTS;
62 for (int i = 0; i < fds.size(); i++) {
63 int fd_copy = fds[i]->Fcntl(F_DUPFD_CLOEXEC, 3);
64 CF_EXPECT(fd_copy >= 0, "Failed to duplicate fd: " << fds[i]->StrError());
65 message.fds_.push_back(fd_copy);
66 // Following the CMSG_DATA spec, use memcpy to avoid alignment issues.
67 memcpy(CMSG_DATA(message.Raw()) + (i * sizeof(int)), &fd_copy, sizeof(int));
68 }
69 return message;
70 }
71
FromCredentials(const ucred & credentials)72 ControlMessage ControlMessage::FromCredentials(const ucred& credentials) {
73 ControlMessage message;
74 message.data_.resize(CMSG_SPACE(sizeof(ucred)), 0);
75 message.Raw()->cmsg_len = CMSG_LEN(sizeof(ucred));
76 message.Raw()->cmsg_level = SOL_SOCKET;
77 message.Raw()->cmsg_type = SCM_CREDENTIALS;
78 // Following the CMSG_DATA spec, use memcpy to avoid alignment issues.
79 memcpy(CMSG_DATA(message.Raw()), &credentials, sizeof(credentials));
80 return message;
81 }
82
ControlMessage(ControlMessage && existing)83 ControlMessage::ControlMessage(ControlMessage&& existing) {
84 // Enforce that the old ControlMessage is left empty, so it doesn't try to
85 // close any file descriptors. https://stackoverflow.com/a/17735913
86 data_ = std::move(existing.data_);
87 existing.data_.clear();
88 fds_ = std::move(existing.fds_);
89 existing.fds_.clear();
90 }
91
operator =(ControlMessage && existing)92 ControlMessage& ControlMessage::operator=(ControlMessage&& existing) {
93 // Enforce that the old ControlMessage is left empty, so it doesn't try to
94 // close any file descriptors. https://stackoverflow.com/a/17735913
95 data_ = std::move(existing.data_);
96 existing.data_.clear();
97 fds_ = std::move(existing.fds_);
98 existing.fds_.clear();
99 return *this;
100 }
101
~ControlMessage()102 ControlMessage::~ControlMessage() {
103 for (const auto& fd : fds_) {
104 if (close(fd) != 0) {
105 PLOG(ERROR) << "Failed to close fd " << fd
106 << ", may have leaked or closed prematurely";
107 }
108 }
109 }
110
Raw()111 cmsghdr* ControlMessage::Raw() {
112 return reinterpret_cast<cmsghdr*>(data_.data());
113 }
114
Raw() const115 const cmsghdr* ControlMessage::Raw() const {
116 return reinterpret_cast<const cmsghdr*>(data_.data());
117 }
118
IsCredentials() const119 bool ControlMessage::IsCredentials() const {
120 bool right_level = Raw()->cmsg_level == SOL_SOCKET;
121 bool right_type = Raw()->cmsg_type == SCM_CREDENTIALS;
122 bool enough_data = Raw()->cmsg_len >= sizeof(cmsghdr) + sizeof(ucred);
123 return right_level && right_type && enough_data;
124 }
125
AsCredentials() const126 Result<ucred> ControlMessage::AsCredentials() const {
127 CF_EXPECT(IsCredentials(), "Control message does not hold a credential");
128 ucred credentials;
129 memcpy(&credentials, CMSG_DATA(Raw()), sizeof(ucred));
130 return credentials;
131 }
132
IsFileDescriptors() const133 bool ControlMessage::IsFileDescriptors() const {
134 bool right_level = Raw()->cmsg_level == SOL_SOCKET;
135 bool right_type = Raw()->cmsg_type == SCM_RIGHTS;
136 return right_level && right_type;
137 }
138
AsSharedFDs() const139 Result<std::vector<SharedFD>> ControlMessage::AsSharedFDs() const {
140 CF_EXPECT(IsFileDescriptors(), "Message does not contain file descriptors");
141 size_t fdcount =
142 static_cast<size_t>(Raw()->cmsg_len - CMSG_LEN(0)) / sizeof(int);
143 std::vector<SharedFD> shared_fds;
144 for (int i = 0; i < fdcount; i++) {
145 // Use memcpy as CMSG_DATA may be unaligned
146 int fd = -1;
147 memcpy(&fd, CMSG_DATA(Raw()) + (i * sizeof(int)), sizeof(fd));
148 SharedFD shared_fd = SharedFD::Dup(fd);
149 CF_EXPECT(shared_fd->IsOpen(), "Could not dup FD " << fd);
150 shared_fds.push_back(shared_fd);
151 }
152 return shared_fds;
153 }
154
HasFileDescriptors()155 bool UnixSocketMessage::HasFileDescriptors() {
156 for (const auto& control_message : control) {
157 if (control_message.IsFileDescriptors()) {
158 return true;
159 }
160 }
161 return false;
162 }
FileDescriptors()163 Result<std::vector<SharedFD>> UnixSocketMessage::FileDescriptors() {
164 std::vector<SharedFD> fds;
165 for (const auto& control_message : control) {
166 if (control_message.IsFileDescriptors()) {
167 auto additional_fds = CF_EXPECT(control_message.AsSharedFDs());
168 fds.insert(fds.end(), additional_fds.begin(), additional_fds.end());
169 }
170 }
171 return fds;
172 }
HasCredentials()173 bool UnixSocketMessage::HasCredentials() {
174 for (const auto& control_message : control) {
175 if (control_message.IsCredentials()) {
176 return true;
177 }
178 }
179 return false;
180 }
Credentials()181 Result<ucred> UnixSocketMessage::Credentials() {
182 std::vector<ucred> credentials;
183 for (const auto& control_message : control) {
184 if (control_message.IsCredentials()) {
185 auto creds = CF_EXPECT(control_message.AsCredentials(),
186 "Message claims to have credentials but does not");
187 credentials.push_back(creds);
188 }
189 }
190 if (credentials.size() == 0) {
191 return CF_ERR("No credentials present");
192 } else if (credentials.size() == 1) {
193 return credentials[0];
194 } else {
195 return CF_ERR("Excepted 1 credential, received " << credentials.size());
196 }
197 }
198
UnixMessageSocket(SharedFD socket)199 UnixMessageSocket::UnixMessageSocket(SharedFD socket) : socket_(socket) {
200 socklen_t ln = sizeof(max_message_size_);
201 CHECK(socket->GetSockOpt(SOL_SOCKET, SO_SNDBUF, &max_message_size_, &ln) == 0)
202 << "error: can't retrieve socket max message size: "
203 << socket->StrError();
204 }
205
EnableCredentials(bool enable)206 Result<void> UnixMessageSocket::EnableCredentials(bool enable) {
207 int flag = enable ? 1 : 0;
208 if (socket_->SetSockOpt(SOL_SOCKET, SO_PASSCRED, &flag, sizeof(flag)) != 0) {
209 return CF_ERR("Could not set credential status to " << enable << ": "
210 << socket_->StrError());
211 }
212 return {};
213 }
214
WriteMessage(const UnixSocketMessage & message)215 Result<void> UnixMessageSocket::WriteMessage(const UnixSocketMessage& message) {
216 auto control_size = 0;
217 for (const auto& control : message.control) {
218 control_size += control.data_.size();
219 }
220 std::vector<char> message_control(control_size, 0);
221 msghdr message_header{};
222 message_header.msg_control = message_control.data();
223 message_header.msg_controllen = message_control.size();
224 auto cmsg = CMSG_FIRSTHDR(&message_header);
225 for (const ControlMessage& control : message.control) {
226 CF_EXPECT(cmsg != nullptr,
227 "Control messages did not fit in control buffer");
228 /* size() should match CMSG_SPACE */
229 memcpy(cmsg, control.data_.data(), control.data_.size());
230 cmsg = CMSG_NXTHDR(&message_header, cmsg);
231 }
232
233 iovec message_iovec;
234 message_iovec.iov_base = (void*)message.data.data();
235 message_iovec.iov_len = message.data.size();
236 message_header.msg_name = nullptr;
237 message_header.msg_namelen = 0;
238 message_header.msg_iov = &message_iovec;
239 message_header.msg_iovlen = 1;
240 message_header.msg_flags = 0;
241
242 auto bytes_sent = socket_->SendMsg(&message_header, MSG_NOSIGNAL);
243 CF_EXPECT(bytes_sent >= 0, "Failed to send message: " << socket_->StrError());
244 CF_EXPECT(bytes_sent == message.data.size(),
245 "Failed to send entire message. Sent "
246 << bytes_sent << ", excepted to send " << message.data.size());
247 return {};
248 }
249
ReadMessage()250 Result<UnixSocketMessage> UnixMessageSocket::ReadMessage() {
251 msghdr message_header{};
252 std::vector<char> message_control(max_message_size_, 0);
253 message_header.msg_control = message_control.data();
254 message_header.msg_controllen = message_control.size();
255 std::vector<char> message_data(max_message_size_, 0);
256 iovec message_iovec;
257 message_iovec.iov_base = message_data.data();
258 message_iovec.iov_len = message_data.size();
259 message_header.msg_iov = &message_iovec;
260 message_header.msg_iovlen = 1;
261 message_header.msg_name = nullptr;
262 message_header.msg_namelen = 0;
263 message_header.msg_flags = 0;
264
265 auto bytes_read = socket_->RecvMsg(&message_header, MSG_CMSG_CLOEXEC);
266 CF_EXPECT(bytes_read >= 0, "Read error: " << socket_->StrError());
267 CF_EXPECT(!(message_header.msg_flags & MSG_TRUNC),
268 "Message was truncated on read");
269 CF_EXPECT(!(message_header.msg_flags & MSG_CTRUNC),
270 "Message control data was truncated on read");
271 CF_EXPECT(!(message_header.msg_flags & MSG_ERRQUEUE), "Error queue error");
272 UnixSocketMessage managed_message;
273 for (auto cmsg = CMSG_FIRSTHDR(&message_header); cmsg != nullptr;
274 cmsg = CMSG_NXTHDR(&message_header, cmsg)) {
275 managed_message.control.emplace_back(ControlMessage::FromRaw(cmsg));
276 }
277 message_data.resize(bytes_read);
278 managed_message.data = std::move(message_data);
279
280 return managed_message;
281 }
282
283 } // namespace cuttlefish
284