1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <unistd.h>
17
18 #include "dm_transport.h"
19 #include "dm_anonymous.h"
20 #include "dm_comm_tool.h"
21 #include "dm_constants.h"
22 #include "dm_log.h"
23 #include "dm_softbus_cache.h"
24 #include "dm_transport_msg.h"
25 #include "softbus_error_code.h"
26
27 namespace OHOS {
28 namespace DistributedHardware {
29 namespace {
30 // Dsoftbus sendBytes max message length: 4MB
31 constexpr uint32_t MAX_SEND_MSG_LENGTH = 4 * 1024 * 1024;
32 constexpr uint32_t INTERCEPT_STRING_LENGTH = 20;
33 constexpr uint32_t MAX_ROUND_SIZE = 1000;
34 const int32_t USLEEP_TIME_US_200000 = 200000; // 200ms
35 static QosTV g_qosInfo[] = {
36 { .qos = QOS_TYPE_MIN_BW, .value = 256 * 1024},
37 { .qos = QOS_TYPE_MAX_LATENCY, .value = 8000 },
38 { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 }
39 };
40 static uint32_t g_qosTvParamIndex = static_cast<uint32_t>(sizeof(g_qosInfo) / sizeof(g_qosInfo[0]));
41 static std::weak_ptr<DMCommTool> g_dmCommToolWPtr_;
42 }
43
DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr)44 DMTransport::DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr) : remoteDevSocketIds_({}), localServerSocket_(-1),
45 localSocketName_(""), isSocketSvrCreateFlag_(false), dmCommToolWPtr_(dmCommToolPtr)
46 {
47 LOGI("Ctor DMTransport");
48 g_dmCommToolWPtr_ = dmCommToolPtr;
49 }
50
OnSocketOpened(int32_t socketId,const PeerSocketInfo & info)51 int32_t DMTransport::OnSocketOpened(int32_t socketId, const PeerSocketInfo &info)
52 {
53 LOGI("OnSocketOpened, socket: %{public}d, peerSocketName: %{public}s, peerNetworkId: %{public}s, "
54 "peerPkgName: %{public}s", socketId, info.name, GetAnonyString(info.networkId).c_str(), info.pkgName);
55 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
56 if (remoteDevSocketIds_.find(info.networkId) == remoteDevSocketIds_.end()) {
57 std::set<int32_t> socketSet;
58 socketSet.insert(socketId);
59 remoteDevSocketIds_[info.networkId] = socketSet;
60 return DM_OK;
61 }
62 remoteDevSocketIds_.at(info.networkId).insert(socketId);
63 return DM_OK;
64 }
65
OnSocketClosed(int32_t socketId,ShutdownReason reason)66 void DMTransport::OnSocketClosed(int32_t socketId, ShutdownReason reason)
67 {
68 LOGI("OnSocketClosed, socket: %{public}d, reason: %{public}d", socketId, (int32_t)reason);
69 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
70 for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end();) {
71 iter->second.erase(socketId);
72 if (iter->second.empty()) {
73 iter = remoteDevSocketIds_.erase(iter);
74 } else {
75 ++iter;
76 }
77 }
78 sourceSocketIds_.erase(socketId);
79 }
80
OnBytesReceived(int32_t socketId,const void * data,uint32_t dataLen)81 void DMTransport::OnBytesReceived(int32_t socketId, const void *data, uint32_t dataLen)
82 {
83 if (socketId < 0 || data == nullptr || dataLen == 0 || dataLen > MAX_SEND_MSG_LENGTH) {
84 LOGE("OnBytesReceived param check failed");
85 return;
86 }
87
88 std::string remoteNeworkId = GetRemoteNetworkIdBySocketId(socketId);
89 if (remoteNeworkId.empty()) {
90 LOGE("Can not find the remote network id by socketId: %{public}d", socketId);
91 return;
92 }
93
94 uint8_t *buf = reinterpret_cast<uint8_t *>(calloc(dataLen + 1, sizeof(uint8_t)));
95 if (buf == nullptr) {
96 LOGE("OnBytesReceived: malloc memory failed");
97 return;
98 }
99
100 if (memcpy_s(buf, dataLen + 1, reinterpret_cast<const uint8_t *>(data), dataLen) != EOK) {
101 LOGE("OnBytesReceived: memcpy memory failed");
102 free(buf);
103 return;
104 }
105
106 std::string message(buf, buf + dataLen);
107 LOGI("Receive message size: %{public}" PRIu32, dataLen);
108 HandleReceiveMessage(socketId, message);
109 free(buf);
110 return;
111 }
112
HandleReceiveMessage(const int32_t socketId,const std::string & payload)113 void DMTransport::HandleReceiveMessage(const int32_t socketId, const std::string &payload)
114 {
115 std::string rmtNetworkId = GetRemoteNetworkIdBySocketId(socketId);
116 if (rmtNetworkId.empty()) {
117 LOGE("Can not find networkId by socketId: %{public}d", socketId);
118 return;
119 }
120 if (payload.empty() || payload.size() > MAX_SEND_MSG_LENGTH) {
121 LOGE("payload invalid");
122 return;
123 }
124 LOGI("Receive msg: %{public}s", GetAnonyString(payload).c_str());
125 cJSON *root = cJSON_Parse(payload.c_str());
126 if (root == NULL) {
127 LOGE("the msg is not json format");
128 return;
129 }
130 std::shared_ptr<CommMsg> commMsg = std::make_shared<CommMsg>();
131 FromJson(root, *commMsg);
132 cJSON_Delete(root);
133
134 std::shared_ptr<InnerCommMsg> innerMsg = std::make_shared<InnerCommMsg>(rmtNetworkId, commMsg, socketId);
135
136 LOGI("Receive DM msg, code: %{public}d, msg: %{public}s", commMsg->code, GetAnonyString(commMsg->msg).c_str());
137 AppExecFwk::InnerEvent::Pointer msgEvent = AppExecFwk::InnerEvent::Get(commMsg->code, innerMsg);
138 std::shared_ptr<DMCommTool> dmCommToolSPtr = dmCommToolWPtr_.lock();
139 if (dmCommToolSPtr == nullptr) {
140 LOGE("Can not get DMCommTool ptr");
141 return;
142 }
143 if (dmCommToolSPtr->GetEventHandler() == nullptr) {
144 LOGE("Can not get eventHandler");
145 return;
146 }
147 dmCommToolSPtr->GetEventHandler()->SendEvent(msgEvent, 0, AppExecFwk::EventQueue::Priority::IMMEDIATE);
148 }
149
150 //LCOV_EXCL_START
GetDMCommToolPtr()151 std::shared_ptr<DMCommTool> GetDMCommToolPtr()
152 {
153 if (g_dmCommToolWPtr_.expired()) {
154 LOGE("DMCommTool Weak ptr expired");
155 return nullptr;
156 }
157
158 std::shared_ptr<DMCommTool> dmCommToolSPtr = g_dmCommToolWPtr_.lock();
159 if (dmCommToolSPtr == nullptr) {
160 LOGE("Can not get DMCommTool ptr");
161 return nullptr;
162 }
163
164 return dmCommToolSPtr;
165 }
166 //LCOV_EXCL_STOP
167
OnBind(int32_t socket,PeerSocketInfo info)168 void OnBind(int32_t socket, PeerSocketInfo info)
169 {
170 std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
171 if (dmCommToolSPtr == nullptr) {
172 LOGE("Can not get DMCommTool ptr");
173 return;
174 }
175 dmCommToolSPtr->GetDMTransportPtr()->OnSocketOpened(socket, info);
176 }
177
OnShutdown(int32_t socket,ShutdownReason reason)178 void OnShutdown(int32_t socket, ShutdownReason reason)
179 {
180 std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
181 if (dmCommToolSPtr == nullptr) {
182 LOGE("Can not get DMCommTool ptr");
183 return;
184 }
185 dmCommToolSPtr->GetDMTransportPtr()->OnSocketClosed(socket, reason);
186 }
187
OnBytes(int32_t socket,const void * data,uint32_t dataLen)188 void OnBytes(int32_t socket, const void *data, uint32_t dataLen)
189 {
190 std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
191 if (dmCommToolSPtr == nullptr) {
192 LOGE("Can not get DMCommTool ptr");
193 return;
194 }
195 dmCommToolSPtr->GetDMTransportPtr()->OnBytesReceived(socket, data, dataLen);
196 }
197
OnMessage(int32_t socket,const void * data,uint32_t dataLen)198 void OnMessage(int32_t socket, const void *data, uint32_t dataLen)
199 {
200 (void)socket;
201 (void)data;
202 (void)dataLen;
203 LOGI("socket: %{public}d, dataLen:%{public}" PRIu32, socket, dataLen);
204 }
205
OnStream(int32_t socket,const StreamData * data,const StreamData * ext,const StreamFrameInfo * param)206 void OnStream(int32_t socket, const StreamData *data, const StreamData *ext,
207 const StreamFrameInfo *param)
208 {
209 (void)socket;
210 (void)data;
211 (void)ext;
212 (void)param;
213 LOGI("socket: %{public}d", socket);
214 }
215
OnFile(int32_t socket,FileEvent * event)216 void OnFile(int32_t socket, FileEvent *event)
217 {
218 (void)event;
219 LOGI("socket: %{public}d", socket);
220 }
221
OnQos(int32_t socket,QoSEvent eventId,const QosTV * qos,uint32_t qosCount)222 void OnQos(int32_t socket, QoSEvent eventId, const QosTV *qos, uint32_t qosCount)
223 {
224 if (qosCount == 0 || qosCount > MAX_ROUND_SIZE) {
225 LOGE("qosCount is invalid!");
226 return;
227 }
228 LOGI("OnQos, socket: %{public}d, QoSEvent: %{public}d, qosCount: %{public}" PRIu32,
229 socket, (int32_t)eventId, qosCount);
230 for (uint32_t idx = 0; idx < qosCount; idx++) {
231 LOGI("QosTV: type: %{public}d, value: %{public}d", (int32_t)qos[idx].qos, qos[idx].value);
232 }
233 }
234
235 ISocketListener iSocketListener = {
236 .OnBind = OnBind,
237 .OnShutdown = OnShutdown,
238 .OnBytes = OnBytes,
239 .OnMessage = OnMessage,
240 .OnStream = OnStream,
241 .OnFile = OnFile,
242 .OnQos = OnQos
243 };
244
CreateServerSocket()245 int32_t DMTransport::CreateServerSocket()
246 {
247 LOGI("CreateServerSocket start");
248 localSocketName_ = DM_SYNC_USERID_SESSION_NAME;
249 LOGI("CreateServerSocket , local socketName: %{public}s", localSocketName_.c_str());
250 std::string dmPkgName(DM_PKG_NAME);
251 SocketInfo info = {
252 .name = const_cast<char*>(localSocketName_.c_str()),
253 .pkgName = const_cast<char*>(dmPkgName.c_str()),
254 .dataType = DATA_TYPE_BYTES
255 };
256 int32_t socket = Socket(info);
257 LOGI("CreateServerSocket Finish, socket: %{public}d", socket);
258 return socket;
259 }
260
CreateClientSocket(const std::string & rmtNetworkId)261 int32_t DMTransport::CreateClientSocket(const std::string &rmtNetworkId)
262 {
263 if (!IsIdLengthValid(rmtNetworkId)) {
264 return ERR_DM_INPUT_PARA_INVALID;
265 }
266 LOGI("CreateClientSocket start, peerNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
267 std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
268 std::string dmPkgName(DM_PKG_NAME);
269 SocketInfo info = {
270 .name = const_cast<char*>(localSocketName_.c_str()),
271 .peerName = const_cast<char*>(peerSocketName.c_str()),
272 .peerNetworkId = const_cast<char*>(rmtNetworkId.c_str()),
273 .pkgName = const_cast<char*>(dmPkgName.c_str()),
274 .dataType = DATA_TYPE_BYTES
275 };
276 int32_t socket = Socket(info);
277 LOGI("Bind Socket server, socket: %{public}d, localSocketName: %{public}s, peerSocketName: %{public}s",
278 socket, localSocketName_.c_str(), peerSocketName.c_str());
279 return socket;
280 }
281
Init()282 int32_t DMTransport::Init()
283 {
284 LOGI("Init DMTransport");
285 if (isSocketSvrCreateFlag_.load()) {
286 LOGI("SocketServer already create success.");
287 return DM_OK;
288 }
289 int32_t socket = CreateServerSocket();
290 if (socket < DM_OK) {
291 LOGE("CreateSocketServer failed, ret: %{public}d", socket);
292 return ERR_DM_FAILED;
293 }
294
295 int32_t ret = Listen(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
296 if (ret != DM_OK) {
297 LOGE("Socket Listen failed, error code %{public}d.", ret);
298 return ERR_DM_FAILED;
299 }
300 isSocketSvrCreateFlag_.store(true);
301 localServerSocket_ = socket;
302 LOGI("Finish Init DSoftBus Server Socket, socket: %{public}d", socket);
303 return DM_OK;
304 }
305
UnInit()306 int32_t DMTransport::UnInit()
307 {
308 {
309 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
310 for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end(); ++iter) {
311 for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
312 LOGI("Shutdown client socket: %{public}d to remote dev: %{public}s", *iter1,
313 GetAnonyString(iter->first).c_str());
314 Shutdown(*iter1);
315 }
316 }
317 remoteDevSocketIds_.clear();
318 sourceSocketIds_.clear();
319 }
320
321 if (!isSocketSvrCreateFlag_.load()) {
322 LOGI("DSoftBus Server Socket already remove success.");
323 } else {
324 LOGI("Shutdown DSoftBus Server Socket, socket: %{public}d", localServerSocket_.load());
325 Shutdown(localServerSocket_.load());
326 localServerSocket_ = -1;
327 isSocketSvrCreateFlag_.store(false);
328 }
329 return DM_OK;
330 }
331
IsDeviceSessionOpened(const std::string & rmtNetworkId,int32_t & socketId)332 bool DMTransport::IsDeviceSessionOpened(const std::string &rmtNetworkId, int32_t &socketId)
333 {
334 if (!IsIdLengthValid(rmtNetworkId)) {
335 return false;
336 }
337 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
338 auto iter = remoteDevSocketIds_.find(rmtNetworkId);
339 if (iter == remoteDevSocketIds_.end()) {
340 return false;
341 }
342 for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
343 if (sourceSocketIds_.find(*iter1) != sourceSocketIds_.end()) {
344 socketId = *iter1;
345 return true;
346 }
347 }
348 return false;
349 }
350
GetRemoteNetworkIdBySocketId(int32_t socketId)351 std::string DMTransport::GetRemoteNetworkIdBySocketId(int32_t socketId)
352 {
353 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
354 std::string networkId = "";
355 for (auto const &item : remoteDevSocketIds_) {
356 if (item.second.find(socketId) != item.second.end()) {
357 networkId = item.first;
358 break;
359 }
360 }
361 return networkId;
362 }
363
ClearDeviceSocketOpened(const std::string & remoteDevId,int32_t socketId)364 void DMTransport::ClearDeviceSocketOpened(const std::string &remoteDevId, int32_t socketId)
365 {
366 if (!IsIdLengthValid(remoteDevId)) {
367 return;
368 }
369 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
370 auto iter = remoteDevSocketIds_.find(remoteDevId);
371 if (iter == remoteDevSocketIds_.end()) {
372 return;
373 }
374 iter->second.erase(socketId);
375 if (iter->second.empty()) {
376 remoteDevSocketIds_.erase(iter);
377 }
378 sourceSocketIds_.erase(socketId);
379 }
380
StartSocket(const std::string & rmtNetworkId,int32_t & socketId)381 int32_t DMTransport::StartSocket(const std::string &rmtNetworkId, int32_t &socketId)
382 {
383 int32_t errCode = ERR_DM_FAILED;
384 int32_t count = 0;
385 const int32_t maxCount = 10;
386
387 do {
388 errCode = StartSocketInner(rmtNetworkId, socketId);
389 if (errCode != ERR_DM_SOCKET_IN_USED) {
390 break;
391 }
392 count++;
393 usleep(USLEEP_TIME_US_200000);
394 } while (count < maxCount);
395
396 return errCode;
397 }
398
StartSocketInner(const std::string & rmtNetworkId,int32_t & socketId)399 int32_t DMTransport::StartSocketInner(const std::string &rmtNetworkId, int32_t &socketId)
400 {
401 if (!IsIdLengthValid(rmtNetworkId)) {
402 return ERR_DM_INPUT_PARA_INVALID;
403 }
404 if (IsDeviceSessionOpened(rmtNetworkId, socketId)) {
405 LOGE("Softbus session has already opened, deviceId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
406 return ERR_DM_SOCKET_IN_USED;
407 }
408
409 int32_t socket = CreateClientSocket(rmtNetworkId);
410 if (socket < DM_OK) {
411 LOGE("StartSocket failed, ret: %{public}d", socket);
412 return ERR_DM_FAILED;
413 }
414
415 int32_t ret = Bind(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
416 if (ret < DM_OK) {
417 if (ret == SOFTBUS_TRANS_SOCKET_IN_USE) {
418 LOGI("Softbus trans socket in use.");
419 return ERR_DM_SOCKET_IN_USED;
420 }
421 LOGE("OpenSession fail, rmtNetworkId: %{public}s, socket: %{public}d, ret: %{public}d",
422 GetAnonyString(rmtNetworkId).c_str(), socket, ret);
423 Shutdown(socket);
424 return ERR_DM_FAILED;
425 }
426
427 LOGI("Bind Socket success, rmtNetworkId:%{public}s, socketId: %{public}d",
428 GetAnonyString(rmtNetworkId).c_str(), socket);
429 std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
430 std::string dmPkgName(DM_PKG_NAME);
431 PeerSocketInfo peerSocketInfo = {
432 .name = const_cast<char*>(peerSocketName.c_str()),
433 .networkId = const_cast<char*>(rmtNetworkId.c_str()),
434 .pkgName = const_cast<char*>(dmPkgName.c_str()),
435 .dataType = DATA_TYPE_BYTES
436 };
437 OnSocketOpened(socket, peerSocketInfo);
438 sourceSocketIds_.insert(socket);
439 socketId = socket;
440 return DM_OK;
441 }
442
StopSocket(const std::string & rmtNetworkId)443 int32_t DMTransport::StopSocket(const std::string &rmtNetworkId)
444 {
445 if (!IsIdLengthValid(rmtNetworkId)) {
446 return ERR_DM_INPUT_PARA_INVALID;
447 }
448 int32_t socketId = -1;
449 if (!IsDeviceSessionOpened(rmtNetworkId, socketId)) {
450 LOGI("remote dev may be not opened, rmtNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
451 return ERR_DM_FAILED;
452 }
453
454 LOGI("StopSocket rmtNetworkId: %{public}s, socketId: %{public}d",
455 GetAnonyString(rmtNetworkId).c_str(), socketId);
456 Shutdown(socketId);
457 ClearDeviceSocketOpened(rmtNetworkId, socketId);
458 return DM_OK;
459 }
460
Send(const std::string & rmtNetworkId,const std::string & payload,int32_t socketId)461 int32_t DMTransport::Send(const std::string &rmtNetworkId, const std::string &payload, int32_t socketId)
462 {
463 if (!IsIdLengthValid(rmtNetworkId) || !IsMessageLengthValid(payload)) {
464 return ERR_DM_INPUT_PARA_INVALID;
465 }
466 if (socketId <= 0) {
467 LOGI("The session is not open, target networkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
468 return ERR_DM_FAILED;
469 }
470 uint32_t payLoadSize = payload.size();
471 LOGI("Send payload size: %{public}" PRIu32 ", target networkId: %{public}s, socketId: %{public}d",
472 static_cast<uint32_t>(payload.size()), GetAnonyString(rmtNetworkId).c_str(), socketId);
473
474 if (payLoadSize > MAX_SEND_MSG_LENGTH) {
475 LOGE("Send error: msg size: %{public}" PRIu32 " too long", payLoadSize);
476 return ERR_DM_FAILED;
477 }
478 uint8_t *buf = reinterpret_cast<uint8_t *>(calloc((payLoadSize), sizeof(uint8_t)));
479 if (buf == nullptr) {
480 LOGE("Send: malloc memory failed");
481 return ERR_DM_FAILED;
482 }
483
484 if (memcpy_s(buf, payLoadSize, reinterpret_cast<const uint8_t *>(payload.c_str()),
485 payLoadSize) != EOK) {
486 LOGE("Send: memcpy memory failed");
487 free(buf);
488 return ERR_DM_FAILED;
489 }
490
491 int32_t ret = SendBytes(socketId, buf, payLoadSize);
492 free(buf);
493 if (ret != DM_OK) {
494 LOGE("dsoftbus send error, ret: %{public}d", ret);
495 return ERR_DM_FAILED;
496 }
497 LOGI("Send payload success");
498 return DM_OK;
499 }
500 } // DistributedHardware
501 } // OHOS