1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "dm_transport.h"
17 #include "dm_anonymous.h"
18 #include "dm_comm_tool.h"
19 #include "dm_constants.h"
20 #include "dm_log.h"
21 #include "dm_softbus_cache.h"
22 #include "dm_transport_msg.h"
23
24 namespace OHOS {
25 namespace DistributedHardware {
26 namespace {
27 // Dsoftbus sendBytes max message length: 4MB
28 constexpr uint32_t MAX_SEND_MSG_LENGTH = 4 * 1024 * 1024;
29 constexpr uint32_t INTERCEPT_STRING_LENGTH = 20;
30 constexpr uint32_t MAX_ROUND_SIZE = 1000;
31 static QosTV g_qosInfo[] = {
32 { .qos = QOS_TYPE_MIN_BW, .value = 256 * 1024},
33 { .qos = QOS_TYPE_MAX_LATENCY, .value = 8000 },
34 { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 }
35 };
36 static uint32_t g_qosTvParamIndex = static_cast<uint32_t>(sizeof(g_qosInfo) / sizeof(g_qosInfo[0]));
37 static std::weak_ptr<DMCommTool> g_dmCommToolWPtr_;
38 }
39
DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr)40 DMTransport::DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr) : remoteDevSocketIds_({}), localServerSocket_(-1),
41 localSocketName_(""), isSocketSvrCreateFlag_(false), dmCommToolWPtr_(dmCommToolPtr)
42 {
43 LOGI("Ctor DMTransport");
44 g_dmCommToolWPtr_ = dmCommToolPtr;
45 }
46
OnSocketOpened(int32_t socketId,const PeerSocketInfo & info)47 int32_t DMTransport::OnSocketOpened(int32_t socketId, const PeerSocketInfo &info)
48 {
49 LOGI("OnSocketOpened, socket: %{public}d, peerSocketName: %{public}s, peerNetworkId: %{public}s, "
50 "peerPkgName: %{public}s", socketId, info.name, GetAnonyString(info.networkId).c_str(), info.pkgName);
51 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
52 if (remoteDevSocketIds_.find(info.networkId) == remoteDevSocketIds_.end()) {
53 std::set<int32_t> socketSet;
54 socketSet.insert(socketId);
55 remoteDevSocketIds_[info.networkId] = socketSet;
56 return DM_OK;
57 }
58 remoteDevSocketIds_.at(info.networkId).insert(socketId);
59 return DM_OK;
60 }
61
OnSocketClosed(int32_t socketId,ShutdownReason reason)62 void DMTransport::OnSocketClosed(int32_t socketId, ShutdownReason reason)
63 {
64 LOGI("OnSocketClosed, socket: %{public}d, reason: %{public}d", socketId, (int32_t)reason);
65 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
66 for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end();) {
67 iter->second.erase(socketId);
68 if (iter->second.empty()) {
69 iter = remoteDevSocketIds_.erase(iter);
70 } else {
71 ++iter;
72 }
73 }
74 sourceSocketIds_.erase(socketId);
75 }
76
OnBytesReceived(int32_t socketId,const void * data,uint32_t dataLen)77 void DMTransport::OnBytesReceived(int32_t socketId, const void *data, uint32_t dataLen)
78 {
79 if (socketId < 0 || data == nullptr || dataLen == 0 || dataLen > MAX_SEND_MSG_LENGTH) {
80 LOGE("OnBytesReceived param check failed");
81 return;
82 }
83
84 std::string remoteNeworkId = GetRemoteNetworkIdBySocketId(socketId);
85 if (remoteNeworkId.empty()) {
86 LOGE("Can not find the remote network id by socketId: %{public}d", socketId);
87 return;
88 }
89
90 uint8_t *buf = reinterpret_cast<uint8_t *>(calloc(dataLen + 1, sizeof(uint8_t)));
91 if (buf == nullptr) {
92 LOGE("OnBytesReceived: malloc memory failed");
93 return;
94 }
95
96 if (memcpy_s(buf, dataLen + 1, reinterpret_cast<const uint8_t *>(data), dataLen) != EOK) {
97 LOGE("OnBytesReceived: memcpy memory failed");
98 free(buf);
99 return;
100 }
101
102 std::string message(buf, buf + dataLen);
103 LOGI("Receive message size: %{public}" PRIu32, dataLen);
104 HandleReceiveMessage(socketId, message);
105 free(buf);
106 return;
107 }
108
HandleReceiveMessage(const int32_t socketId,const std::string & payload)109 void DMTransport::HandleReceiveMessage(const int32_t socketId, const std::string &payload)
110 {
111 std::string rmtNetworkId = GetRemoteNetworkIdBySocketId(socketId);
112 if (rmtNetworkId.empty()) {
113 LOGE("Can not find networkId by socketId: %{public}d", socketId);
114 return;
115 }
116 if (payload.empty() || payload.size() > MAX_SEND_MSG_LENGTH) {
117 LOGE("payload invalid");
118 return;
119 }
120 LOGI("Receive msg: %{public}s", GetAnonyString(payload).c_str());
121 cJSON *root = cJSON_Parse(payload.c_str());
122 if (root == NULL) {
123 LOGE("the msg is not json format");
124 return;
125 }
126 std::shared_ptr<CommMsg> commMsg = std::make_shared<CommMsg>();
127 FromJson(root, *commMsg);
128 cJSON_Delete(root);
129
130 std::shared_ptr<InnerCommMsg> innerMsg = std::make_shared<InnerCommMsg>(rmtNetworkId, commMsg, socketId);
131
132 LOGI("Receive DM msg, code: %{public}d, msg: %{public}s", commMsg->code, GetAnonyString(commMsg->msg).c_str());
133 AppExecFwk::InnerEvent::Pointer msgEvent = AppExecFwk::InnerEvent::Get(commMsg->code, innerMsg);
134 std::shared_ptr<DMCommTool> dmCommToolSPtr = dmCommToolWPtr_.lock();
135 if (dmCommToolSPtr == nullptr) {
136 LOGE("Can not get DMCommTool ptr");
137 return;
138 }
139 if (dmCommToolSPtr->GetEventHandler() == nullptr) {
140 LOGE("Can not get eventHandler");
141 return;
142 }
143 dmCommToolSPtr->GetEventHandler()->SendEvent(msgEvent, 0, AppExecFwk::EventQueue::Priority::IMMEDIATE);
144 }
145
GetDMCommToolPtr()146 std::shared_ptr<DMCommTool> GetDMCommToolPtr()
147 {
148 if (g_dmCommToolWPtr_.expired()) {
149 LOGE("DMCommTool Weak ptr expired");
150 return nullptr;
151 }
152
153 std::shared_ptr<DMCommTool> dmCommToolSPtr = g_dmCommToolWPtr_.lock();
154 if (dmCommToolSPtr == nullptr) {
155 LOGE("Can not get DMCommTool ptr");
156 return nullptr;
157 }
158
159 return dmCommToolSPtr;
160 }
161
OnBind(int32_t socket,PeerSocketInfo info)162 void OnBind(int32_t socket, PeerSocketInfo info)
163 {
164 std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
165 if (dmCommToolSPtr == nullptr) {
166 LOGE("Can not get DMCommTool ptr");
167 return;
168 }
169 dmCommToolSPtr->GetDMTransportPtr()->OnSocketOpened(socket, info);
170 }
171
OnShutdown(int32_t socket,ShutdownReason reason)172 void OnShutdown(int32_t socket, ShutdownReason reason)
173 {
174 std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
175 if (dmCommToolSPtr == nullptr) {
176 LOGE("Can not get DMCommTool ptr");
177 return;
178 }
179 dmCommToolSPtr->GetDMTransportPtr()->OnSocketClosed(socket, reason);
180 }
181
OnBytes(int32_t socket,const void * data,uint32_t dataLen)182 void OnBytes(int32_t socket, const void *data, uint32_t dataLen)
183 {
184 std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
185 if (dmCommToolSPtr == nullptr) {
186 LOGE("Can not get DMCommTool ptr");
187 return;
188 }
189 dmCommToolSPtr->GetDMTransportPtr()->OnBytesReceived(socket, data, dataLen);
190 }
191
OnMessage(int32_t socket,const void * data,uint32_t dataLen)192 void OnMessage(int32_t socket, const void *data, uint32_t dataLen)
193 {
194 (void)socket;
195 (void)data;
196 (void)dataLen;
197 LOGI("socket: %{public}d, dataLen:%{public}" PRIu32, socket, dataLen);
198 }
199
OnStream(int32_t socket,const StreamData * data,const StreamData * ext,const StreamFrameInfo * param)200 void OnStream(int32_t socket, const StreamData *data, const StreamData *ext,
201 const StreamFrameInfo *param)
202 {
203 (void)socket;
204 (void)data;
205 (void)ext;
206 (void)param;
207 LOGI("socket: %{public}d", socket);
208 }
209
OnFile(int32_t socket,FileEvent * event)210 void OnFile(int32_t socket, FileEvent *event)
211 {
212 (void)event;
213 LOGI("socket: %{public}d", socket);
214 }
215
OnQos(int32_t socket,QoSEvent eventId,const QosTV * qos,uint32_t qosCount)216 void OnQos(int32_t socket, QoSEvent eventId, const QosTV *qos, uint32_t qosCount)
217 {
218 if (qosCount == 0 || qosCount > MAX_ROUND_SIZE) {
219 LOGE("qosCount is invalid!");
220 return;
221 }
222 LOGI("OnQos, socket: %{public}d, QoSEvent: %{public}d, qosCount: %{public}" PRIu32,
223 socket, (int32_t)eventId, qosCount);
224 for (uint32_t idx = 0; idx < qosCount; idx++) {
225 LOGI("QosTV: type: %{public}d, value: %{public}d", (int32_t)qos[idx].qos, qos[idx].value);
226 }
227 }
228
229 ISocketListener iSocketListener = {
230 .OnBind = OnBind,
231 .OnShutdown = OnShutdown,
232 .OnBytes = OnBytes,
233 .OnMessage = OnMessage,
234 .OnStream = OnStream,
235 .OnFile = OnFile,
236 .OnQos = OnQos
237 };
238
CreateServerSocket()239 int32_t DMTransport::CreateServerSocket()
240 {
241 LOGI("CreateServerSocket start");
242 localSocketName_ = DM_SYNC_USERID_SESSION_NAME;
243 LOGI("CreateServerSocket , local socketName: %{public}s", localSocketName_.c_str());
244 std::string dmPkgName(DM_PKG_NAME);
245 SocketInfo info = {
246 .name = const_cast<char*>(localSocketName_.c_str()),
247 .pkgName = const_cast<char*>(dmPkgName.c_str()),
248 .dataType = DATA_TYPE_BYTES
249 };
250 int32_t socket = Socket(info);
251 LOGI("CreateServerSocket Finish, socket: %{public}d", socket);
252 return socket;
253 }
254
CreateClientSocket(const std::string & rmtNetworkId)255 int32_t DMTransport::CreateClientSocket(const std::string &rmtNetworkId)
256 {
257 if (!IsIdLengthValid(rmtNetworkId)) {
258 return ERR_DM_INPUT_PARA_INVALID;
259 }
260 LOGI("CreateClientSocket start, peerNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
261 std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
262 std::string dmPkgName(DM_PKG_NAME);
263 SocketInfo info = {
264 .name = const_cast<char*>(localSocketName_.c_str()),
265 .peerName = const_cast<char*>(peerSocketName.c_str()),
266 .peerNetworkId = const_cast<char*>(rmtNetworkId.c_str()),
267 .pkgName = const_cast<char*>(dmPkgName.c_str()),
268 .dataType = DATA_TYPE_BYTES
269 };
270 int32_t socket = Socket(info);
271 LOGI("Bind Socket server, socket: %{public}d, localSocketName: %{public}s, peerSocketName: %{public}s",
272 socket, localSocketName_.c_str(), peerSocketName.c_str());
273 return socket;
274 }
275
Init()276 int32_t DMTransport::Init()
277 {
278 LOGI("Init DMTransport");
279 if (isSocketSvrCreateFlag_.load()) {
280 LOGI("SocketServer already create success.");
281 return DM_OK;
282 }
283 int32_t socket = CreateServerSocket();
284 if (socket < DM_OK) {
285 LOGE("CreateSocketServer failed, ret: %{public}d", socket);
286 return ERR_DM_FAILED;
287 }
288
289 int32_t ret = Listen(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
290 if (ret != DM_OK) {
291 LOGE("Socket Listen failed, error code %{public}d.", ret);
292 return ERR_DM_FAILED;
293 }
294 isSocketSvrCreateFlag_.store(true);
295 localServerSocket_ = socket;
296 LOGI("Finish Init DSoftBus Server Socket, socket: %{public}d", socket);
297 return DM_OK;
298 }
299
UnInit()300 int32_t DMTransport::UnInit()
301 {
302 {
303 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
304 for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end(); ++iter) {
305 for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
306 LOGI("Shutdown client socket: %{public}d to remote dev: %{public}s", *iter1,
307 GetAnonyString(iter->first).c_str());
308 Shutdown(*iter1);
309 }
310 }
311 remoteDevSocketIds_.clear();
312 sourceSocketIds_.clear();
313 }
314
315 if (!isSocketSvrCreateFlag_.load()) {
316 LOGI("DSoftBus Server Socket already remove success.");
317 } else {
318 LOGI("Shutdown DSoftBus Server Socket, socket: %{public}d", localServerSocket_.load());
319 Shutdown(localServerSocket_.load());
320 localServerSocket_ = -1;
321 isSocketSvrCreateFlag_.store(false);
322 }
323 return DM_OK;
324 }
325
IsDeviceSessionOpened(const std::string & rmtNetworkId,int32_t & socketId)326 bool DMTransport::IsDeviceSessionOpened(const std::string &rmtNetworkId, int32_t &socketId)
327 {
328 if (!IsIdLengthValid(rmtNetworkId)) {
329 return false;
330 }
331 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
332 auto iter = remoteDevSocketIds_.find(rmtNetworkId);
333 if (iter == remoteDevSocketIds_.end()) {
334 return false;
335 }
336 for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
337 if (sourceSocketIds_.find(*iter1) != sourceSocketIds_.end()) {
338 socketId = *iter1;
339 return true;
340 }
341 }
342 return false;
343 }
344
GetRemoteNetworkIdBySocketId(int32_t socketId)345 std::string DMTransport::GetRemoteNetworkIdBySocketId(int32_t socketId)
346 {
347 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
348 std::string networkId = "";
349 for (auto const &item : remoteDevSocketIds_) {
350 if (item.second.find(socketId) != item.second.end()) {
351 networkId = item.first;
352 break;
353 }
354 }
355 return networkId;
356 }
357
ClearDeviceSocketOpened(const std::string & remoteDevId,int32_t socketId)358 void DMTransport::ClearDeviceSocketOpened(const std::string &remoteDevId, int32_t socketId)
359 {
360 if (!IsIdLengthValid(remoteDevId)) {
361 return;
362 }
363 std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
364 auto iter = remoteDevSocketIds_.find(remoteDevId);
365 if (iter == remoteDevSocketIds_.end()) {
366 return;
367 }
368 iter->second.erase(socketId);
369 if (iter->second.empty()) {
370 remoteDevSocketIds_.erase(iter);
371 }
372 sourceSocketIds_.erase(socketId);
373 }
374
StartSocket(const std::string & rmtNetworkId,int32_t & socketId)375 int32_t DMTransport::StartSocket(const std::string &rmtNetworkId, int32_t &socketId)
376 {
377 if (!IsIdLengthValid(rmtNetworkId)) {
378 return ERR_DM_INPUT_PARA_INVALID;
379 }
380 if (IsDeviceSessionOpened(rmtNetworkId, socketId)) {
381 LOGE("Softbus session has already opened, deviceId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
382 return DM_OK;
383 }
384
385 int32_t socket = CreateClientSocket(rmtNetworkId);
386 if (socket < DM_OK) {
387 LOGE("StartSocket failed, ret: %{public}d", socket);
388 return ERR_DM_FAILED;
389 }
390
391 int32_t ret = Bind(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
392 if (ret < DM_OK) {
393 LOGE("OpenSession fail, rmtNetworkId: %{public}s, socket: %{public}d, ret: %{public}d",
394 GetAnonyString(rmtNetworkId).c_str(), socket, ret);
395 Shutdown(socket);
396 return ERR_DM_FAILED;
397 }
398
399 LOGI("Bind Socket success, rmtNetworkId:%{public}s, socketId: %{public}d",
400 GetAnonyString(rmtNetworkId).c_str(), socket);
401 std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
402 std::string dmPkgName(DM_PKG_NAME);
403 PeerSocketInfo peerSocketInfo = {
404 .name = const_cast<char*>(peerSocketName.c_str()),
405 .networkId = const_cast<char*>(rmtNetworkId.c_str()),
406 .pkgName = const_cast<char*>(dmPkgName.c_str()),
407 .dataType = DATA_TYPE_BYTES
408 };
409 OnSocketOpened(socket, peerSocketInfo);
410 sourceSocketIds_.insert(socket);
411 socketId = socket;
412 return DM_OK;
413 }
414
StopSocket(const std::string & rmtNetworkId)415 int32_t DMTransport::StopSocket(const std::string &rmtNetworkId)
416 {
417 if (!IsIdLengthValid(rmtNetworkId)) {
418 return ERR_DM_INPUT_PARA_INVALID;
419 }
420 int32_t socketId = -1;
421 if (!IsDeviceSessionOpened(rmtNetworkId, socketId)) {
422 LOGI("remote dev may be not opened, rmtNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
423 return ERR_DM_FAILED;
424 }
425
426 LOGI("StopSocket rmtNetworkId: %{public}s, socketId: %{public}d",
427 GetAnonyString(rmtNetworkId).c_str(), socketId);
428 Shutdown(socketId);
429 ClearDeviceSocketOpened(rmtNetworkId, socketId);
430 return DM_OK;
431 }
432
Send(const std::string & rmtNetworkId,const std::string & payload,int32_t socketId)433 int32_t DMTransport::Send(const std::string &rmtNetworkId, const std::string &payload, int32_t socketId)
434 {
435 if (!IsIdLengthValid(rmtNetworkId) || !IsMessageLengthValid(payload)) {
436 return ERR_DM_INPUT_PARA_INVALID;
437 }
438 if (socketId <= 0) {
439 LOGI("The session is not open, target networkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
440 return ERR_DM_FAILED;
441 }
442 uint32_t payLoadSize = payload.size();
443 LOGI("Send payload size: %{public}" PRIu32 ", target networkId: %{public}s, socketId: %{public}d",
444 static_cast<uint32_t>(payload.size()), GetAnonyString(rmtNetworkId).c_str(), socketId);
445
446 if (payLoadSize > MAX_SEND_MSG_LENGTH) {
447 LOGE("Send error: msg size: %{public}" PRIu32 " too long", payLoadSize);
448 return ERR_DM_FAILED;
449 }
450 uint8_t *buf = reinterpret_cast<uint8_t *>(calloc((payLoadSize), sizeof(uint8_t)));
451 if (buf == nullptr) {
452 LOGE("Send: malloc memory failed");
453 return ERR_DM_FAILED;
454 }
455
456 if (memcpy_s(buf, payLoadSize, reinterpret_cast<const uint8_t *>(payload.c_str()),
457 payLoadSize) != EOK) {
458 LOGE("Send: memcpy memory failed");
459 free(buf);
460 return ERR_DM_FAILED;
461 }
462
463 int32_t ret = SendBytes(socketId, buf, payLoadSize);
464 free(buf);
465 if (ret != DM_OK) {
466 LOGE("dsoftbus send error, ret: %{public}d", ret);
467 return ERR_DM_FAILED;
468 }
469 LOGI("Send payload success");
470 return DM_OK;
471 }
472 } // DistributedHardware
473 } // OHOS