1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "fd_holder_internal.h"
17 #include <errno.h>
18 #include <stdio.h>
19 #include "beget_ext.h"
20 #include "securec.h"
21
22 #ifndef PAGE_SIZE
23 #define PAGE_SIZE (4096U)
24 #endif
25
BuildControlMessage(struct msghdr * msghdr,int * fds,int fdCount,bool sendUcred)26 int BuildControlMessage(struct msghdr *msghdr, int *fds, int fdCount, bool sendUcred)
27 {
28 if (msghdr == NULL || (fdCount > 0 && fds == NULL)) {
29 BEGET_LOGE("Build control message with invalid parameter");
30 return -1;
31 }
32
33 if (fdCount > 0) {
34 msghdr->msg_controllen = CMSG_SPACE(sizeof(int) * fdCount);
35 } else {
36 msghdr->msg_controllen = 0;
37 }
38
39 if (sendUcred) {
40 msghdr->msg_controllen += CMSG_SPACE(sizeof(struct ucred));
41 }
42
43 msghdr->msg_control = calloc(1, ((msghdr->msg_controllen == 0) ? 1 : msghdr->msg_controllen));
44 BEGET_ERROR_CHECK(msghdr->msg_control != NULL, return -1, "Failed to build control message");
45
46 struct cmsghdr *cmsg = NULL;
47 cmsg = CMSG_FIRSTHDR(msghdr);
48
49 if (fdCount > 0) {
50 cmsg->cmsg_level = SOL_SOCKET;
51 cmsg->cmsg_type = SCM_RIGHTS;
52 cmsg->cmsg_len = CMSG_LEN(sizeof(int) * fdCount);
53 int ret = memcpy_s(CMSG_DATA(cmsg), cmsg->cmsg_len, fds, sizeof(int) * fdCount);
54 BEGET_ERROR_CHECK(ret == 0, free(msghdr->msg_control);
55 return -1, "Control message is not valid");
56 // build ucred info
57 cmsg = CMSG_NXTHDR(msghdr, cmsg);
58 }
59
60 if (sendUcred) {
61 BEGET_ERROR_CHECK(cmsg != NULL, free(msghdr->msg_control);
62 return -1, "Control message is not valid");
63
64 struct ucred *ucred;
65 cmsg->cmsg_level = SOL_SOCKET;
66 cmsg->cmsg_type = SCM_CREDENTIALS;
67 cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
68 ucred = (struct ucred*) CMSG_DATA(cmsg);
69 ucred->pid = getpid();
70 ucred->uid = getuid();
71 ucred->gid = getgid();
72 }
73 return 0;
74 }
75
GetFdsFromMsg(size_t * outFdCount,pid_t * requestPid,struct msghdr msghdr)76 STATIC int *GetFdsFromMsg(size_t *outFdCount, pid_t *requestPid, struct msghdr msghdr)
77 {
78 if ((msghdr.msg_flags) & MSG_TRUNC) {
79 BEGET_LOGE("Message was truncated when receiving fds");
80 return NULL;
81 }
82
83 struct cmsghdr *cmsg = NULL;
84 int *fds = NULL;
85 size_t fdCount = 0;
86 for (cmsg = CMSG_FIRSTHDR(&msghdr); cmsg != NULL; cmsg = CMSG_NXTHDR(&msghdr, cmsg)) {
87 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
88 fds = (int*)CMSG_DATA(cmsg);
89 fdCount = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
90 BEGET_ERROR_CHECK(fdCount <= MAX_HOLD_FDS, return NULL, "Too many fds returned.");
91 }
92 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
93 cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
94 // Ignore credentials
95 if (requestPid != NULL) {
96 struct ucred *ucred = (struct ucred*)CMSG_DATA(cmsg);
97 *requestPid = ucred->pid;
98 }
99 continue;
100 }
101 }
102 int *outFds = NULL;
103 if (fds != NULL && fdCount > 0) {
104 outFds = calloc(fdCount + 1, sizeof(int));
105 BEGET_ERROR_CHECK(outFds != NULL, return NULL, "Failed to allocate memory for fds");
106 BEGET_ERROR_CHECK(memcpy_s(outFds, sizeof(int) * (fdCount + 1), fds, sizeof(int) * fdCount) == 0,
107 free(outFds); return NULL, "Failed to copy fds");
108 }
109 *outFdCount = fdCount;
110 return outFds;
111 }
112
113 // This function will allocate memory to store FDs
114 // Remember to delete when not used anymore.
ReceiveFds(int sock,struct iovec iovec,size_t * outFdCount,bool nonblock,pid_t * requestPid)115 int *ReceiveFds(int sock, struct iovec iovec, size_t *outFdCount, bool nonblock, pid_t *requestPid)
116 {
117 CMSG_BUFFER_TYPE(CMSG_SPACE(sizeof(struct ucred)) +
118 CMSG_SPACE(sizeof(int) * MAX_HOLD_FDS)) control;
119
120 BEGET_ERROR_CHECK(sizeof(control) <= PAGE_SIZE, return NULL, "Too many fds, out of memory");
121
122 struct msghdr msghdr = {
123 .msg_iov = &iovec,
124 .msg_iovlen = 1,
125 .msg_control = &control,
126 .msg_controllen = sizeof(control),
127 .msg_flags = 0,
128 };
129
130 int flags = MSG_CMSG_CLOEXEC | MSG_TRUNC;
131 if (nonblock) {
132 flags |= MSG_DONTWAIT;
133 }
134 ssize_t rc = TEMP_FAILURE_RETRY(recvmsg(sock, &msghdr, flags));
135 BEGET_ERROR_CHECK(rc >= 0, return NULL, "Failed to get fds from remote, err = %d", errno);
136 return GetFdsFromMsg(outFdCount, requestPid, msghdr);
137 }