1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "rs_profiler_socket.h"
17
18 #include <fcntl.h>
19 #include <netinet/in.h>
20 #include <netinet/tcp.h>
21 #include <securec.h>
22 #include <sys/select.h>
23 #include <sys/ioctl.h>
24 #include <sys/un.h>
25 #include <unistd.h>
26 #include <poll.h>
27
28 #include "rs_profiler_log.h"
29 #include "rs_profiler_utils.h"
30
31 namespace OHOS::Rosen {
32
GetTimeoutDesc(uint32_t milliseconds)33 static timeval GetTimeoutDesc(uint32_t milliseconds)
34 {
35 const uint32_t millisecondsInSecond = 1000u;
36
37 timeval timeout = {};
38 timeout.tv_sec = milliseconds / millisecondsInSecond;
39 timeout.tv_usec = (milliseconds % millisecondsInSecond) * millisecondsInSecond;
40 return timeout;
41 }
42
GetTimeout(int32_t socket)43 static timeval GetTimeout(int32_t socket)
44 {
45 timeval timeout = {};
46 socklen_t size = sizeof(timeout);
47 getsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<char*>(&timeout), &size);
48 return timeout;
49 }
50
SetTimeout(int32_t socket,const timeval & timeout)51 static void SetTimeout(int32_t socket, const timeval& timeout)
52 {
53 setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<const char*>(&timeout), sizeof(timeout));
54 }
55
SetTimeout(int32_t socket,uint32_t milliseconds)56 static void SetTimeout(int32_t socket, uint32_t milliseconds)
57 {
58 SetTimeout(socket, GetTimeoutDesc(milliseconds));
59 }
60
ToggleFlag(uint32_t flags,uint32_t flag,bool enable)61 static int32_t ToggleFlag(uint32_t flags, uint32_t flag, bool enable)
62 {
63 return enable ? (flags | flag) : (flags & ~flag);
64 }
65
SetBlocking(int32_t socket,bool enable)66 static void SetBlocking(int32_t socket, bool enable)
67 {
68 fcntl(socket, F_SETFL, ToggleFlag(fcntl(socket, F_GETFL, 0), O_NONBLOCK, !enable));
69 }
70
SetCloseOnExec(int32_t socket,bool enable)71 static void SetCloseOnExec(int32_t socket, bool enable)
72 {
73 fcntl(socket, F_SETFD, ToggleFlag(fcntl(socket, F_GETFD, 0), FD_CLOEXEC, enable));
74 }
75
~Socket()76 Socket::~Socket()
77 {
78 Shutdown();
79 }
80
Connected() const81 bool Socket::Connected() const
82 {
83 return (socket_ != -1) && (client_ != -1) && (state_ == SocketState::CONNECTED);
84 }
85
GetState() const86 SocketState Socket::GetState() const
87 {
88 return state_;
89 }
90
Shutdown()91 void Socket::Shutdown()
92 {
93 shutdown(socket_, SHUT_RDWR);
94 fdsan_close_with_tag(socket_, LOG_DOMAIN);
95 socket_ = -1;
96
97 shutdown(client_, SHUT_RDWR);
98 fdsan_close_with_tag(client_, LOG_DOMAIN);
99 client_ = -1;
100
101 state_ = SocketState::SHUTDOWN;
102 }
103
Open(uint16_t port)104 void Socket::Open(uint16_t port)
105 {
106 socket_ = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
107 if (socket_ == -1) {
108 Shutdown();
109 return;
110 }
111 fdsan_exchange_owner_tag(socket_, 0, LOG_DOMAIN);
112
113 const std::string socketName = "render_service_" + std::to_string(port);
114 sockaddr_un address {};
115 address.sun_family = AF_UNIX;
116 address.sun_path[0] = 0;
117 ::memmove_s(address.sun_path + 1, sizeof(address.sun_path) - 1, socketName.data(), socketName.size());
118
119 const size_t addressSize = offsetof(sockaddr_un, sun_path) + socketName.size() + 1;
120 if (bind(socket_, reinterpret_cast<sockaddr*>(&address), addressSize) == -1) {
121 Shutdown();
122 return;
123 }
124
125 const int32_t maxConnections = 5;
126 if (listen(socket_, maxConnections) != 0) {
127 Shutdown();
128 return;
129 }
130
131 SetBlocking(socket_, false);
132 SetCloseOnExec(socket_, true);
133
134 state_ = SocketState::CREATE;
135 }
136
AcceptClient()137 void Socket::AcceptClient()
138 {
139 client_ = accept4(socket_, nullptr, nullptr, SOCK_CLOEXEC);
140 if (client_ == -1) {
141 if ((errno != EWOULDBLOCK) && (errno != EAGAIN) && (errno != EINTR)) {
142 Shutdown();
143 }
144 return;
145 }
146 fdsan_exchange_owner_tag(client_, 0, LOG_DOMAIN);
147
148 SetBlocking(client_, false);
149 SetCloseOnExec(client_, true);
150
151 int32_t nodelay = 1;
152 setsockopt(client_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&nodelay), sizeof(nodelay));
153
154 state_ = SocketState::CONNECTED;
155 }
156
Available()157 size_t Socket::Available()
158 {
159 int32_t size = 0;
160 const auto result = ioctl(client_, FIONREAD, &size);
161 if (result == -1) {
162 HRPE("Socket: Available failed: %d", errno);
163 return 0u;
164 }
165 return static_cast<size_t>(size);
166 }
167
SendWhenReady(const void * data,size_t size)168 bool Socket::SendWhenReady(const void* data, size_t size)
169 {
170 if (!data || (size == 0)) {
171 return true;
172 }
173
174 SetBlocking(client_, true);
175
176 const timeval previousTimeout = GetTimeout(client_);
177
178 const uint32_t timeoutMilliseconds = 40;
179 SetTimeout(client_, timeoutMilliseconds);
180
181 const char* bytes = reinterpret_cast<const char*>(data);
182 size_t sent = 0;
183 while (sent < size) {
184 if (PollSend(1) == 0) {
185 // wait for 1ms in worst case to have socket ready for sending
186 continue;
187 }
188 const ssize_t sentBytes = send(client_, bytes, size - sent, 0);
189 if ((sentBytes <= 0) && (errno != EINTR)) {
190 HRPE("Socket: SendWhenReady: Invoke shutdown: %d", errno);
191 Shutdown();
192 return false;
193 }
194 auto actualSentBytes = static_cast<size_t>(sentBytes);
195 sent += actualSentBytes;
196 bytes += actualSentBytes;
197 }
198
199 SetTimeout(client_, previousTimeout);
200 SetBlocking(client_, false);
201 return true;
202 }
203
Receive(void * data,size_t & size)204 bool Socket::Receive(void* data, size_t& size)
205 {
206 if (!data || (size == 0)) {
207 return true;
208 }
209
210 SetBlocking(client_, false);
211
212 const ssize_t receivedBytes = recv(client_, static_cast<char*>(data), size, 0);
213 if (receivedBytes > 0) {
214 size = static_cast<size_t>(receivedBytes);
215 } else {
216 size = 0;
217 if ((errno == EWOULDBLOCK) || (errno == EAGAIN) || (errno == EINTR)) {
218 return true;
219 }
220 HRPE("Socket: Receive: Invoke shutdown: %d", errno);
221 Shutdown();
222 return false;
223 }
224 return true;
225 }
226
ReceiveWhenReady(void * data,size_t size)227 bool Socket::ReceiveWhenReady(void* data, size_t size)
228 {
229 if (!data || (size == 0)) {
230 return true;
231 }
232
233 const timeval previousTimeout = GetTimeout(client_);
234 const uint32_t bandwitdth = 10000; // KB/ms
235 const uint32_t timeoutPad = 100;
236 const uint32_t timeout = size / bandwitdth + timeoutPad;
237
238 SetBlocking(client_, true);
239 SetTimeout(client_, timeout);
240
241 size_t received = 0;
242 char* bytes = static_cast<char*>(data);
243 while (received < size) {
244 // receivedBytes can only be -1 or [0, size - received] (from recv man)
245 const ssize_t receivedBytes = recv(client_, bytes, size - received, 0);
246 if ((receivedBytes == -1) && (errno != EINTR)) {
247 HRPE("Socket: ReceiveWhenReady: Invoke shutdown: %d", errno);
248 Shutdown();
249 return false;
250 }
251
252 // so receivedBytes here always [0, size - received]
253 // then received can't be > `size` and it can't be overflowed
254 auto actualReceivedBytes = static_cast<size_t>(receivedBytes);
255 received += actualReceivedBytes;
256 bytes += actualReceivedBytes;
257 }
258
259 SetTimeout(client_, previousTimeout);
260 SetBlocking(client_, false);
261 return true;
262 }
263
PollReceive(int timeout)264 int Socket::PollReceive(int timeout)
265 {
266 struct pollfd pollFd = {0};
267 pollFd.fd = client_;
268 pollFd.events = POLLIN;
269 return poll(&pollFd, 1, timeout);
270 }
271
PollSend(int timeout)272 int Socket::PollSend(int timeout)
273 {
274 struct pollfd pollFd = {0};
275 pollFd.fd = client_;
276 pollFd.events = POLLOUT;
277 return poll(&pollFd, 1, timeout);
278 }
279
280 } // namespace OHOS::Rosen