• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dm_transport.h"
17 #include "dm_anonymous.h"
18 #include "dm_comm_tool.h"
19 #include "dm_constants.h"
20 #include "dm_log.h"
21 #include "dm_softbus_cache.h"
22 #include "dm_transport_msg.h"
23 
24 namespace OHOS {
25 namespace DistributedHardware {
26 namespace {
27 // Dsoftbus sendBytes max message length: 4MB
28 constexpr uint32_t MAX_SEND_MSG_LENGTH = 4 * 1024 * 1024;
29 constexpr uint32_t INTERCEPT_STRING_LENGTH = 20;
30 constexpr uint32_t MAX_ROUND_SIZE = 1000;
31 static QosTV g_qosInfo[] = {
32     { .qos = QOS_TYPE_MIN_BW, .value = 256 * 1024},
33     { .qos = QOS_TYPE_MAX_LATENCY, .value = 8000 },
34     { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 }
35 };
36 static uint32_t g_qosTvParamIndex = static_cast<uint32_t>(sizeof(g_qosInfo) / sizeof(g_qosInfo[0]));
37 static std::weak_ptr<DMCommTool> g_dmCommToolWPtr_;
38 }
39 
DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr)40 DMTransport::DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr) : remoteDevSocketIds_({}), localServerSocket_(-1),
41     localSocketName_(""), isSocketSvrCreateFlag_(false), dmCommToolWPtr_(dmCommToolPtr)
42 {
43     LOGI("Ctor DMTransport");
44     g_dmCommToolWPtr_ = dmCommToolPtr;
45 }
46 
OnSocketOpened(int32_t socketId,const PeerSocketInfo & info)47 int32_t DMTransport::OnSocketOpened(int32_t socketId, const PeerSocketInfo &info)
48 {
49     LOGI("OnSocketOpened, socket: %{public}d, peerSocketName: %{public}s, peerNetworkId: %{public}s, "
50         "peerPkgName: %{public}s", socketId, info.name, GetAnonyString(info.networkId).c_str(), info.pkgName);
51     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
52     if (remoteDevSocketIds_.find(info.networkId) == remoteDevSocketIds_.end()) {
53         std::set<int32_t> socketSet;
54         socketSet.insert(socketId);
55         remoteDevSocketIds_[info.networkId] = socketSet;
56         return DM_OK;
57     }
58     remoteDevSocketIds_.at(info.networkId).insert(socketId);
59     return DM_OK;
60 }
61 
OnSocketClosed(int32_t socketId,ShutdownReason reason)62 void DMTransport::OnSocketClosed(int32_t socketId, ShutdownReason reason)
63 {
64     LOGI("OnSocketClosed, socket: %{public}d, reason: %{public}d", socketId, (int32_t)reason);
65     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
66     for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end();) {
67         iter->second.erase(socketId);
68         if (iter->second.empty()) {
69             iter = remoteDevSocketIds_.erase(iter);
70         } else {
71             ++iter;
72         }
73     }
74     sourceSocketIds_.erase(socketId);
75 }
76 
OnBytesReceived(int32_t socketId,const void * data,uint32_t dataLen)77 void DMTransport::OnBytesReceived(int32_t socketId, const void *data, uint32_t dataLen)
78 {
79     if (socketId < 0 || data == nullptr || dataLen == 0 || dataLen > MAX_SEND_MSG_LENGTH) {
80         LOGE("OnBytesReceived param check failed");
81         return;
82     }
83 
84     std::string remoteNeworkId = GetRemoteNetworkIdBySocketId(socketId);
85     if (remoteNeworkId.empty()) {
86         LOGE("Can not find the remote network id by socketId: %{public}d", socketId);
87         return;
88     }
89 
90     uint8_t *buf = reinterpret_cast<uint8_t *>(calloc(dataLen + 1, sizeof(uint8_t)));
91     if (buf == nullptr) {
92         LOGE("OnBytesReceived: malloc memory failed");
93         return;
94     }
95 
96     if (memcpy_s(buf, dataLen + 1,  reinterpret_cast<const uint8_t *>(data), dataLen) != EOK) {
97         LOGE("OnBytesReceived: memcpy memory failed");
98         free(buf);
99         return;
100     }
101 
102     std::string message(buf, buf + dataLen);
103     LOGI("Receive message size: %{public}" PRIu32, dataLen);
104     HandleReceiveMessage(socketId, message);
105     free(buf);
106     return;
107 }
108 
HandleReceiveMessage(const int32_t socketId,const std::string & payload)109 void DMTransport::HandleReceiveMessage(const int32_t socketId, const std::string &payload)
110 {
111     std::string rmtNetworkId = GetRemoteNetworkIdBySocketId(socketId);
112     if (rmtNetworkId.empty()) {
113         LOGE("Can not find networkId by socketId: %{public}d", socketId);
114         return;
115     }
116     if (payload.empty() || payload.size() > MAX_SEND_MSG_LENGTH) {
117         LOGE("payload invalid");
118         return;
119     }
120     LOGI("Receive msg: %{public}s", GetAnonyString(payload).c_str());
121     cJSON *root = cJSON_Parse(payload.c_str());
122     if (root == NULL) {
123         LOGE("the msg is not json format");
124         return;
125     }
126     std::shared_ptr<CommMsg> commMsg = std::make_shared<CommMsg>();
127     FromJson(root, *commMsg);
128     cJSON_Delete(root);
129 
130     std::shared_ptr<InnerCommMsg> innerMsg = std::make_shared<InnerCommMsg>(rmtNetworkId, commMsg, socketId);
131 
132     LOGI("Receive DM msg, code: %{public}d, msg: %{public}s", commMsg->code, GetAnonyString(commMsg->msg).c_str());
133     AppExecFwk::InnerEvent::Pointer msgEvent = AppExecFwk::InnerEvent::Get(commMsg->code, innerMsg);
134     std::shared_ptr<DMCommTool> dmCommToolSPtr = dmCommToolWPtr_.lock();
135     if (dmCommToolSPtr == nullptr) {
136         LOGE("Can not get DMCommTool ptr");
137         return;
138     }
139     if (dmCommToolSPtr->GetEventHandler() == nullptr) {
140         LOGE("Can not get eventHandler");
141         return;
142     }
143     dmCommToolSPtr->GetEventHandler()->SendEvent(msgEvent, 0, AppExecFwk::EventQueue::Priority::IMMEDIATE);
144 }
145 
GetDMCommToolPtr()146 std::shared_ptr<DMCommTool> GetDMCommToolPtr()
147 {
148     if (g_dmCommToolWPtr_.expired()) {
149         LOGE("DMCommTool Weak ptr expired");
150         return nullptr;
151     }
152 
153     std::shared_ptr<DMCommTool> dmCommToolSPtr = g_dmCommToolWPtr_.lock();
154     if (dmCommToolSPtr == nullptr) {
155         LOGE("Can not get DMCommTool ptr");
156         return nullptr;
157     }
158 
159     return dmCommToolSPtr;
160 }
161 
OnBind(int32_t socket,PeerSocketInfo info)162 void OnBind(int32_t socket, PeerSocketInfo info)
163 {
164     std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
165     if (dmCommToolSPtr == nullptr) {
166         LOGE("Can not get DMCommTool ptr");
167         return;
168     }
169     dmCommToolSPtr->GetDMTransportPtr()->OnSocketOpened(socket, info);
170 }
171 
OnShutdown(int32_t socket,ShutdownReason reason)172 void OnShutdown(int32_t socket, ShutdownReason reason)
173 {
174     std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
175     if (dmCommToolSPtr == nullptr) {
176         LOGE("Can not get DMCommTool ptr");
177         return;
178     }
179     dmCommToolSPtr->GetDMTransportPtr()->OnSocketClosed(socket, reason);
180 }
181 
OnBytes(int32_t socket,const void * data,uint32_t dataLen)182 void OnBytes(int32_t socket, const void *data, uint32_t dataLen)
183 {
184     std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
185     if (dmCommToolSPtr == nullptr) {
186         LOGE("Can not get DMCommTool ptr");
187         return;
188     }
189     dmCommToolSPtr->GetDMTransportPtr()->OnBytesReceived(socket, data, dataLen);
190 }
191 
OnMessage(int32_t socket,const void * data,uint32_t dataLen)192 void OnMessage(int32_t socket, const void *data, uint32_t dataLen)
193 {
194     (void)socket;
195     (void)data;
196     (void)dataLen;
197     LOGI("socket: %{public}d, dataLen:%{public}" PRIu32, socket, dataLen);
198 }
199 
OnStream(int32_t socket,const StreamData * data,const StreamData * ext,const StreamFrameInfo * param)200 void OnStream(int32_t socket, const StreamData *data, const StreamData *ext,
201     const StreamFrameInfo *param)
202 {
203     (void)socket;
204     (void)data;
205     (void)ext;
206     (void)param;
207     LOGI("socket: %{public}d", socket);
208 }
209 
OnFile(int32_t socket,FileEvent * event)210 void OnFile(int32_t socket, FileEvent *event)
211 {
212     (void)event;
213     LOGI("socket: %{public}d", socket);
214 }
215 
OnQos(int32_t socket,QoSEvent eventId,const QosTV * qos,uint32_t qosCount)216 void OnQos(int32_t socket, QoSEvent eventId, const QosTV *qos, uint32_t qosCount)
217 {
218     if (qosCount == 0 || qosCount > MAX_ROUND_SIZE) {
219         LOGE("qosCount is invalid!");
220         return;
221     }
222     LOGI("OnQos, socket: %{public}d, QoSEvent: %{public}d, qosCount: %{public}" PRIu32,
223         socket, (int32_t)eventId, qosCount);
224     for (uint32_t idx = 0; idx < qosCount; idx++) {
225         LOGI("QosTV: type: %{public}d, value: %{public}d", (int32_t)qos[idx].qos, qos[idx].value);
226     }
227 }
228 
229 ISocketListener iSocketListener = {
230     .OnBind = OnBind,
231     .OnShutdown = OnShutdown,
232     .OnBytes = OnBytes,
233     .OnMessage = OnMessage,
234     .OnStream = OnStream,
235     .OnFile = OnFile,
236     .OnQos = OnQos
237 };
238 
CreateServerSocket()239 int32_t DMTransport::CreateServerSocket()
240 {
241     LOGI("CreateServerSocket start");
242     localSocketName_ = DM_SYNC_USERID_SESSION_NAME;
243     LOGI("CreateServerSocket , local socketName: %{public}s", localSocketName_.c_str());
244     std::string dmPkgName(DM_PKG_NAME);
245     SocketInfo info = {
246         .name = const_cast<char*>(localSocketName_.c_str()),
247         .pkgName = const_cast<char*>(dmPkgName.c_str()),
248         .dataType = DATA_TYPE_BYTES
249     };
250     int32_t socket = Socket(info);
251     LOGI("CreateServerSocket Finish, socket: %{public}d", socket);
252     return socket;
253 }
254 
CreateClientSocket(const std::string & rmtNetworkId)255 int32_t DMTransport::CreateClientSocket(const std::string &rmtNetworkId)
256 {
257     if (!IsIdLengthValid(rmtNetworkId)) {
258         return ERR_DM_INPUT_PARA_INVALID;
259     }
260     LOGI("CreateClientSocket start, peerNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
261     std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
262     std::string dmPkgName(DM_PKG_NAME);
263     SocketInfo info = {
264         .name = const_cast<char*>(localSocketName_.c_str()),
265         .peerName = const_cast<char*>(peerSocketName.c_str()),
266         .peerNetworkId = const_cast<char*>(rmtNetworkId.c_str()),
267         .pkgName = const_cast<char*>(dmPkgName.c_str()),
268         .dataType = DATA_TYPE_BYTES
269     };
270     int32_t socket = Socket(info);
271     LOGI("Bind Socket server, socket: %{public}d, localSocketName: %{public}s, peerSocketName: %{public}s",
272         socket, localSocketName_.c_str(), peerSocketName.c_str());
273     return socket;
274 }
275 
Init()276 int32_t DMTransport::Init()
277 {
278     LOGI("Init DMTransport");
279     if (isSocketSvrCreateFlag_.load()) {
280         LOGI("SocketServer already create success.");
281         return DM_OK;
282     }
283     int32_t socket = CreateServerSocket();
284     if (socket < DM_OK) {
285         LOGE("CreateSocketServer failed, ret: %{public}d", socket);
286         return ERR_DM_FAILED;
287     }
288 
289     int32_t ret = Listen(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
290     if (ret != DM_OK) {
291         LOGE("Socket Listen failed, error code %{public}d.", ret);
292         return ERR_DM_FAILED;
293     }
294     isSocketSvrCreateFlag_.store(true);
295     localServerSocket_ = socket;
296     LOGI("Finish Init DSoftBus Server Socket, socket: %{public}d", socket);
297     return DM_OK;
298 }
299 
UnInit()300 int32_t DMTransport::UnInit()
301 {
302     {
303         std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
304         for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end(); ++iter) {
305             for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
306                 LOGI("Shutdown client socket: %{public}d to remote dev: %{public}s", *iter1,
307                     GetAnonyString(iter->first).c_str());
308                 Shutdown(*iter1);
309             }
310         }
311         remoteDevSocketIds_.clear();
312         sourceSocketIds_.clear();
313     }
314 
315     if (!isSocketSvrCreateFlag_.load()) {
316         LOGI("DSoftBus Server Socket already remove success.");
317     } else {
318         LOGI("Shutdown DSoftBus Server Socket, socket: %{public}d", localServerSocket_.load());
319         Shutdown(localServerSocket_.load());
320         localServerSocket_ = -1;
321         isSocketSvrCreateFlag_.store(false);
322     }
323     return DM_OK;
324 }
325 
IsDeviceSessionOpened(const std::string & rmtNetworkId,int32_t & socketId)326 bool DMTransport::IsDeviceSessionOpened(const std::string &rmtNetworkId, int32_t &socketId)
327 {
328     if (!IsIdLengthValid(rmtNetworkId)) {
329         return false;
330     }
331     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
332     auto iter = remoteDevSocketIds_.find(rmtNetworkId);
333     if (iter == remoteDevSocketIds_.end()) {
334         return false;
335     }
336     for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
337         if (sourceSocketIds_.find(*iter1) != sourceSocketIds_.end()) {
338             socketId = *iter1;
339             return true;
340         }
341     }
342     return false;
343 }
344 
GetRemoteNetworkIdBySocketId(int32_t socketId)345 std::string DMTransport::GetRemoteNetworkIdBySocketId(int32_t socketId)
346 {
347     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
348     std::string networkId = "";
349     for (auto const &item : remoteDevSocketIds_) {
350         if (item.second.find(socketId) != item.second.end()) {
351             networkId = item.first;
352             break;
353         }
354     }
355     return networkId;
356 }
357 
ClearDeviceSocketOpened(const std::string & remoteDevId,int32_t socketId)358 void DMTransport::ClearDeviceSocketOpened(const std::string &remoteDevId, int32_t socketId)
359 {
360     if (!IsIdLengthValid(remoteDevId)) {
361         return;
362     }
363     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
364     auto iter = remoteDevSocketIds_.find(remoteDevId);
365     if (iter == remoteDevSocketIds_.end()) {
366         return;
367     }
368     iter->second.erase(socketId);
369     if (iter->second.empty()) {
370         remoteDevSocketIds_.erase(iter);
371     }
372     sourceSocketIds_.erase(socketId);
373 }
374 
StartSocket(const std::string & rmtNetworkId,int32_t & socketId)375 int32_t DMTransport::StartSocket(const std::string &rmtNetworkId, int32_t &socketId)
376 {
377     if (!IsIdLengthValid(rmtNetworkId)) {
378         return ERR_DM_INPUT_PARA_INVALID;
379     }
380     if (IsDeviceSessionOpened(rmtNetworkId, socketId)) {
381         LOGE("Softbus session has already opened, deviceId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
382         return DM_OK;
383     }
384 
385     int32_t socket = CreateClientSocket(rmtNetworkId);
386     if (socket < DM_OK) {
387         LOGE("StartSocket failed, ret: %{public}d", socket);
388         return ERR_DM_FAILED;
389     }
390 
391     int32_t ret = Bind(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
392     if (ret < DM_OK) {
393         LOGE("OpenSession fail, rmtNetworkId: %{public}s, socket: %{public}d, ret: %{public}d",
394             GetAnonyString(rmtNetworkId).c_str(), socket, ret);
395         Shutdown(socket);
396         return ERR_DM_FAILED;
397     }
398 
399     LOGI("Bind Socket success, rmtNetworkId:%{public}s, socketId: %{public}d",
400         GetAnonyString(rmtNetworkId).c_str(), socket);
401     std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
402     std::string dmPkgName(DM_PKG_NAME);
403     PeerSocketInfo peerSocketInfo = {
404         .name = const_cast<char*>(peerSocketName.c_str()),
405         .networkId = const_cast<char*>(rmtNetworkId.c_str()),
406         .pkgName = const_cast<char*>(dmPkgName.c_str()),
407         .dataType = DATA_TYPE_BYTES
408     };
409     OnSocketOpened(socket, peerSocketInfo);
410     sourceSocketIds_.insert(socket);
411     socketId = socket;
412     return DM_OK;
413 }
414 
StopSocket(const std::string & rmtNetworkId)415 int32_t DMTransport::StopSocket(const std::string &rmtNetworkId)
416 {
417     if (!IsIdLengthValid(rmtNetworkId)) {
418         return ERR_DM_INPUT_PARA_INVALID;
419     }
420     int32_t socketId = -1;
421     if (!IsDeviceSessionOpened(rmtNetworkId, socketId)) {
422         LOGI("remote dev may be not opened, rmtNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
423         return ERR_DM_FAILED;
424     }
425 
426     LOGI("StopSocket rmtNetworkId: %{public}s, socketId: %{public}d",
427         GetAnonyString(rmtNetworkId).c_str(), socketId);
428     Shutdown(socketId);
429     ClearDeviceSocketOpened(rmtNetworkId, socketId);
430     return DM_OK;
431 }
432 
Send(const std::string & rmtNetworkId,const std::string & payload,int32_t socketId)433 int32_t DMTransport::Send(const std::string &rmtNetworkId, const std::string &payload, int32_t socketId)
434 {
435     if (!IsIdLengthValid(rmtNetworkId) || !IsMessageLengthValid(payload)) {
436         return ERR_DM_INPUT_PARA_INVALID;
437     }
438     if (socketId <= 0) {
439         LOGI("The session is not open, target networkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
440         return ERR_DM_FAILED;
441     }
442     uint32_t payLoadSize = payload.size();
443     LOGI("Send payload size: %{public}" PRIu32 ", target networkId: %{public}s, socketId: %{public}d",
444         static_cast<uint32_t>(payload.size()), GetAnonyString(rmtNetworkId).c_str(), socketId);
445 
446     if (payLoadSize > MAX_SEND_MSG_LENGTH) {
447         LOGE("Send error: msg size: %{public}" PRIu32 " too long", payLoadSize);
448         return ERR_DM_FAILED;
449     }
450     uint8_t *buf = reinterpret_cast<uint8_t *>(calloc((payLoadSize), sizeof(uint8_t)));
451     if (buf == nullptr) {
452         LOGE("Send: malloc memory failed");
453         return ERR_DM_FAILED;
454     }
455 
456     if (memcpy_s(buf, payLoadSize, reinterpret_cast<const uint8_t *>(payload.c_str()),
457                  payLoadSize) != EOK) {
458         LOGE("Send: memcpy memory failed");
459         free(buf);
460         return ERR_DM_FAILED;
461     }
462 
463     int32_t ret = SendBytes(socketId, buf, payLoadSize);
464     free(buf);
465     if (ret != DM_OK) {
466         LOGE("dsoftbus send error, ret: %{public}d", ret);
467         return ERR_DM_FAILED;
468     }
469     LOGI("Send payload success");
470     return DM_OK;
471 }
472 } // DistributedHardware
473 } // OHOS