1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <fcntl.h>
17 #include <unistd.h>
18 #include <sys/types.h>
19 #include <sys/socket.h>
20
21 #include "beget_ext.h"
22 #include "control_fd.h"
23 #include "init_utils.h"
24 #include "securec.h"
25
26 static CmdService g_cmdService;
27
28 CallbackControlFdProcess g_controlFdFunc = NULL;
29
OnClose(const TaskHandle task)30 static void OnClose(const TaskHandle task)
31 {
32 CmdTask *agent = (CmdTask *)LE_GetUserData(task);
33 BEGET_ERROR_CHECK(agent != NULL, return, "[control_fd] Can not get agent");
34 OH_ListRemove(&agent->item);
35 OH_ListInit(&agent->item);
36 }
37
CheckSocketPermission(const TaskHandle task)38 CONTROL_FD_STATIC int CheckSocketPermission(const TaskHandle task)
39 {
40 struct ucred uc = {-1, -1, -1};
41 socklen_t len = sizeof(uc);
42 if (getsockopt(LE_GetSocketFd(task), SOL_SOCKET, SO_PEERCRED, &uc, &len) < 0) {
43 BEGET_LOGE("Failed to get socket option. err = %d", errno);
44 return -1;
45 }
46 // Only root is permitted to use control fd of init.
47 if (uc.uid != 0) { // non-root user
48 errno = EPERM;
49 return -1;
50 }
51
52 return 0;
53 }
54
CmdOnRecvMessage(const TaskHandle task,const uint8_t * buffer,uint32_t buffLen)55 CONTROL_FD_STATIC void CmdOnRecvMessage(const TaskHandle task, const uint8_t *buffer, uint32_t buffLen)
56 {
57 if (buffer == NULL) {
58 return;
59 }
60 CmdTask *agent = (CmdTask *)LE_GetUserData(task);
61 BEGET_ERROR_CHECK(agent != NULL, return, "[control_fd] Can not get agent");
62
63 // parse msg to exec
64 CmdMessage *msg = (CmdMessage *)buffer;
65 if ((msg->type >= ACTION_MAX) || (msg->cmd[0] == '\0') || (msg->ptyName[0] == '\0')) {
66 BEGET_LOGE("[control_fd] Received msg is invaild");
67 return;
68 }
69
70 if (CheckSocketPermission(task) < 0) {
71 BEGET_LOGE("Check socket permission failed, err = %d", errno);
72 return;
73 }
74 #ifndef STARTUP_INIT_TEST
75 agent->pid = fork();
76 if (agent->pid == 0) {
77 OpenConsole();
78 char *realPath = GetRealPath(msg->ptyName);
79 BEGET_ERROR_CHECK(realPath != NULL, return, "Failed get realpath, err=%d", errno);
80 int n = strncmp(realPath, "/dev/pts/", strlen("/dev/pts/"));
81 BEGET_ERROR_CHECK(n == 0, free(realPath); _exit(1), "pts path %s is invaild", realPath);
82 int fd = open(realPath, O_RDWR);
83 free(realPath);
84 BEGET_ERROR_CHECK(fd >= 0, return, "Failed open %s, err=%d", msg->ptyName, errno);
85 (void)dup2(fd, STDIN_FILENO);
86 (void)dup2(fd, STDOUT_FILENO);
87 (void)dup2(fd, STDERR_FILENO); // Redirect fd to 0, 1, 2
88 (void)close(fd);
89 if (g_controlFdFunc != NULL) {
90 g_controlFdFunc(msg->type, msg->cmd, NULL);
91 }
92 exit(0);
93 } else if (agent->pid < 0) {
94 BEGET_LOGE("[control_fd] Failed to fork child process, err = %d", errno);
95 }
96 #endif
97 return;
98 }
99
SendMessage(LoopHandle loop,TaskHandle task,const char * message)100 CONTROL_FD_STATIC int SendMessage(LoopHandle loop, TaskHandle task, const char *message)
101 {
102 if (message == NULL) {
103 BEGET_LOGE("[control_fd] Invalid parameter");
104 return -1;
105 }
106 BufferHandle handle = NULL;
107 uint32_t bufferSize = strlen(message) + 1;
108 handle = LE_CreateBuffer(loop, bufferSize);
109 char *buff = (char *)LE_GetBufferInfo(handle, NULL, &bufferSize);
110 BEGET_ERROR_CHECK(buff != NULL, return -1, "[control_fd] Failed get buffer info");
111 int ret = memcpy_s(buff, bufferSize, message, strlen(message) + 1);
112 BEGET_ERROR_CHECK(ret == 0, LE_FreeBuffer(LE_GetDefaultLoop(), task, handle);
113 return -1, "[control_fd] Failed memcpy_s err=%d", errno);
114 LE_STATUS status = LE_Send(loop, task, handle, strlen(message) + 1);
115 BEGET_ERROR_CHECK(status == LE_SUCCESS, return -1, "[control_fd] Failed le send msg");
116 return 0;
117 }
118
CmdOnIncommingConnect(const LoopHandle loop,const TaskHandle server)119 CONTROL_FD_STATIC int CmdOnIncommingConnect(const LoopHandle loop, const TaskHandle server)
120 {
121 TaskHandle client = NULL;
122 LE_StreamInfo info = {};
123 #ifndef STARTUP_INIT_TEST
124 info.baseInfo.flags = TASK_STREAM | TASK_PIPE | TASK_CONNECT;
125 #else
126 info.baseInfo.flags = TASK_STREAM | TASK_PIPE | TASK_CONNECT | TASK_TEST;
127 #endif
128 info.baseInfo.close = OnClose;
129 info.baseInfo.userDataSize = sizeof(CmdTask);
130 info.disConnectComplete = NULL;
131 info.sendMessageComplete = NULL;
132 info.recvMessage = CmdOnRecvMessage;
133 int ret = LE_AcceptStreamClient(LE_GetDefaultLoop(), server, &client, &info);
134 BEGET_ERROR_CHECK(ret == 0, return -1, "[control_fd] Failed accept stream")
135 CmdTask *agent = (CmdTask *)LE_GetUserData(client);
136 BEGET_ERROR_CHECK(agent != NULL, return -1, "[control_fd] Invalid agent");
137 agent->task = client;
138 OH_ListInit(&agent->item);
139 ret = SendMessage(LE_GetDefaultLoop(), agent->task, "connect success.");
140 BEGET_ERROR_CHECK(ret == 0, return -1, "[control_fd] Failed send msg");
141 OH_ListAddTail(&g_cmdService.head, &agent->item);
142 return 0;
143 }
144
CmdServiceInit(const char * socketPath,CallbackControlFdProcess func)145 void CmdServiceInit(const char *socketPath, CallbackControlFdProcess func)
146 {
147 if ((socketPath == NULL) || (func == NULL)) {
148 BEGET_LOGE("[control_fd] Invalid parameter");
149 return;
150 }
151 OH_ListInit(&g_cmdService.head);
152 LE_StreamServerInfo info = {};
153 info.baseInfo.flags = TASK_STREAM | TASK_SERVER | TASK_PIPE;
154 info.server = (char *)socketPath;
155 info.socketId = -1;
156 info.baseInfo.close = NULL;
157 info.disConnectComplete = NULL;
158 info.incommingConnect = CmdOnIncommingConnect;
159 info.sendMessageComplete = NULL;
160 info.recvMessage = NULL;
161 g_controlFdFunc = func;
162 (void)LE_CreateStreamServer(LE_GetDefaultLoop(), &g_cmdService.serverTask, &info);
163 }
164
ClientTraversalProc(ListNode * node,void * data)165 static int ClientTraversalProc(ListNode *node, void *data)
166 {
167 CmdTask *info = ListEntry(node, CmdTask, item);
168 int pid = *(int *)data;
169 return pid - info->pid;
170 }
171
CmdServiceProcessDelClient(pid_t pid)172 void CmdServiceProcessDelClient(pid_t pid)
173 {
174 ListNode *node = OH_ListFind(&g_cmdService.head, (void *)&pid, ClientTraversalProc);
175 if (node != NULL) {
176 CmdTask *agent = ListEntry(node, CmdTask, item);
177 OH_ListRemove(&agent->item);
178 OH_ListInit(&agent->item);
179 LE_CloseTask(LE_GetDefaultLoop(), agent->task);
180 }
181 }
182