• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "soft_bus_base_socket.h"
17 
18 #include <cinttypes>
19 
20 #include "remote_connect_listener_manager.h"
21 
22 #define LOG_TAG "USER_AUTH_SA"
23 namespace OHOS {
24 namespace UserIam {
25 namespace UserAuth {
26 using namespace OHOS::DistributedHardware;
27 const std::string USERIAM_PACKAGE_NAME = "ohos.useriam";
28 static constexpr uint32_t REPLY_TIMER_LEN_MS = 5 * 1000; // 5s
29 static constexpr uint32_t INVALID_TIMER_ID = 0;
30 static std::recursive_mutex g_seqMutex;
31 static uint32_t g_messageSeq = 0;
32 
BaseSocket(const int32_t socketId)33 BaseSocket::BaseSocket(const int32_t socketId)
34     : socketId_(socketId)
35 {
36     currTraceInfo_.msgType = -1;
37     currTraceInfo_.socketId = socketId;
38     IAM_LOGI("create socket id %{public}d.", socketId_);
39 }
40 
~BaseSocket()41 BaseSocket::~BaseSocket()
42 {
43     Shutdown(socketId_);
44     IAM_LOGI("close socket id %{public}d.", socketId_);
45 }
46 
GetSocketId()47 int32_t BaseSocket::GetSocketId()
48 {
49     return socketId_;
50 }
51 
GetCurrTraceInfo()52 RemoteConnectFaultTrace BaseSocket::GetCurrTraceInfo()
53 {
54     return currTraceInfo_;
55 }
56 
InsertMsgCallback(uint32_t messageSeq,const std::string & connectionName,const MsgCallback & callback,uint32_t timerId)57 void BaseSocket::InsertMsgCallback(uint32_t messageSeq, const std::string &connectionName,
58     const MsgCallback &callback, uint32_t timerId)
59 {
60     IAM_LOGD("start. messageSeq:%{public}u, timerId:%{public}u", messageSeq, timerId);
61     IF_FALSE_LOGE_AND_RETURN(callback != nullptr);
62 
63     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
64     CallbackInfo callbackInfo = {
65         .connectionName = connectionName,
66         .msgCallback = callback,
67         .timerId = timerId,
68         .sendTime = std::chrono::steady_clock::now()
69     };
70     callbackMap_.insert(std::pair<int32_t, CallbackInfo>(messageSeq, callbackInfo));
71 }
72 
RemoveMsgCallback(uint32_t messageSeq)73 void BaseSocket::RemoveMsgCallback(uint32_t messageSeq)
74 {
75     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
76     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
77     callbackMap_.erase(messageSeq);
78 }
79 
GetConnectionName(uint32_t messageSeq)80 std::string BaseSocket::GetConnectionName(uint32_t messageSeq)
81 {
82     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
83     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
84     std::string connectionName;
85     auto iter = callbackMap_.find(messageSeq);
86     if (iter != callbackMap_.end()) {
87         connectionName = iter->second.connectionName;
88     }
89     return connectionName;
90 }
91 
GetMsgCallback(uint32_t messageSeq)92 MsgCallback BaseSocket::GetMsgCallback(uint32_t messageSeq)
93 {
94     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
95     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
96     MsgCallback callback = nullptr;
97     auto iter = callbackMap_.find(messageSeq);
98     if (iter != callbackMap_.end()) {
99         callback = iter->second.msgCallback;
100     }
101     return callback;
102 }
103 
PrintTransferDuration(uint32_t messageSeq)104 void BaseSocket::PrintTransferDuration(uint32_t messageSeq)
105 {
106     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
107     auto iter = callbackMap_.find(messageSeq);
108     if (iter == callbackMap_.end()) {
109         IAM_LOGE("message seq not found");
110         return;
111     }
112 
113     auto receiveAckTime = std::chrono::steady_clock::now();
114     auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(receiveAckTime - iter->second.sendTime);
115     IAM_LOGI("messageSeq:%{public}u MessageTransferDuration:%{public}" PRIu64 " ms", messageSeq,
116         static_cast<uint64_t>(duration.count()));
117 }
118 
GetReplyTimer(uint32_t messageSeq)119 uint32_t BaseSocket::GetReplyTimer(uint32_t messageSeq)
120 {
121     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
122     std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
123     uint32_t timerId = 0;
124     auto iter = callbackMap_.find(messageSeq);
125     if (iter != callbackMap_.end()) {
126         timerId = iter->second.timerId;
127     }
128     return timerId;
129 }
130 
StartReplyTimer(uint32_t messageSeq)131 uint32_t BaseSocket::StartReplyTimer(uint32_t messageSeq)
132 {
133     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
134     uint32_t timerId = GetReplyTimer(messageSeq);
135     if (timerId != INVALID_TIMER_ID) {
136         IAM_LOGI("timer is already start");
137         return timerId;
138     }
139 
140     timerId = RelativeTimer::GetInstance().Register(
141         [weakSelf = weak_from_this(), messageSeq, socketId = socketId_] {
142             auto self = weakSelf.lock();
143             if (self == nullptr) {
144                 IAM_LOGE("socket %{public}d is released", socketId);
145                 return;
146             }
147             self->ReplyTimerTimeOut(messageSeq);
148         },
149         REPLY_TIMER_LEN_MS);
150 
151     return timerId;
152 }
153 
StopReplyTimer(uint32_t messageSeq)154 void BaseSocket::StopReplyTimer(uint32_t messageSeq)
155 {
156     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
157     uint32_t timerId = GetReplyTimer(messageSeq);
158     if (timerId == INVALID_TIMER_ID) {
159         IAM_LOGI("timer is already stop");
160         return;
161     }
162 
163     RelativeTimer::GetInstance().Unregister(timerId);
164 }
165 
ReplyTimerTimeOut(uint32_t messageSeq)166 void BaseSocket::ReplyTimerTimeOut(uint32_t messageSeq)
167 {
168     IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
169     std::string connectionName = GetConnectionName(messageSeq);
170     if (connectionName.empty()) {
171         IAM_LOGE("GetMsgCallback connectionName fail");
172         return;
173     }
174     currTraceInfo_.reason = "ack time out";
175     ReportConnectFaultTrace(currTraceInfo_);
176 
177     RemoteConnectListenerManager::GetInstance().OnConnectionDown(connectionName);
178     RemoveMsgCallback(messageSeq);
179     IAM_LOGI("reply timer is timeout, messageSeq:%{public}u", messageSeq);
180 }
181 
GetMessageSeq()182 int32_t BaseSocket::GetMessageSeq()
183 {
184     IAM_LOGD("start.");
185     std::lock_guard<std::recursive_mutex> lock(g_seqMutex);
186     g_messageSeq++;
187     return g_messageSeq;
188 }
189 
SetDeviceNetworkId(const std::string networkId,std::shared_ptr<Attributes> & attributes)190 ResultCode BaseSocket::SetDeviceNetworkId(const std::string networkId, std::shared_ptr<Attributes> &attributes)
191 {
192     IAM_LOGD("start.");
193     IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
194 
195     bool setDeviceNetworkIdRet = attributes->SetStringValue(Attributes::ATTR_COLLECTOR_NETWORK_ID, networkId);
196     if (setDeviceNetworkIdRet == false) {
197         IAM_LOGE("SetStringValue fail");
198         return GENERAL_ERROR;
199     }
200 
201     return SUCCESS;
202 }
203 
RefreshTraceInfo(const std::string & connectionName,int32_t msgType,bool ack,uint32_t messageSeq)204 void BaseSocket::RefreshTraceInfo(const std::string &connectionName, int32_t msgType, bool ack, uint32_t messageSeq)
205 {
206     currTraceInfo_.connectionName = connectionName;
207     currTraceInfo_.msgType = msgType;
208     currTraceInfo_.ack = ack;
209     currTraceInfo_.messageSeq = messageSeq;
210 }
211 
SendRequest(const ConnectionInfo & connectionInfo)212 ResultCode BaseSocket::SendRequest(const ConnectionInfo &connectionInfo)
213 {
214     IAM_LOGD("start.");
215     IF_FALSE_LOGE_AND_RETURN_VAL(connectionInfo.attributes != nullptr, INVALID_PARAMETERS);
216     IF_FALSE_LOGE_AND_RETURN_VAL(connectionInfo.socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
217 
218     int32_t messageSeq = GetMessageSeq();
219     int32_t msgType = -1;
220     // remote pin authentication may not contain the msgType parameter, no need to check the result
221     connectionInfo.attributes->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
222     RefreshTraceInfo(connectionInfo.connectionName, msgType, false, messageSeq);
223     std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
224         connectionInfo.connectionName, connectionInfo.srcEndPoint, connectionInfo.destEndPoint,
225         connectionInfo.attributes);
226     if (softBusMessage == nullptr) {
227         IAM_LOGE("softBusMessage is nullptr");
228         return GENERAL_ERROR;
229     }
230 
231     std::shared_ptr<Attributes> request = softBusMessage->CreateMessage(false);
232     if (request == nullptr) {
233         IAM_LOGE("creatMessage fail");
234         return GENERAL_ERROR;
235     }
236 
237     std::vector<uint8_t> data = request->Serialize();
238     int ret = SendBytes(connectionInfo.socketId, data.data(), data.size());
239     if (ret != SUCCESS) {
240         IAM_LOGE("fail to send message, result= %{public}d", ret);
241         return GENERAL_ERROR;
242     }
243 
244     uint32_t timerId = StartReplyTimer(messageSeq);
245     if (timerId == INVALID_TIMER_ID) {
246         IAM_LOGE("create reply timer fail");
247         return GENERAL_ERROR;
248     }
249 
250     InsertMsgCallback(messageSeq, connectionInfo.connectionName, connectionInfo.callback, timerId);
251     IAM_LOGI("SendRequest success.");
252     return SUCCESS;
253 }
254 
SendResponse(const int32_t socketId,const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,uint32_t messageSeq)255 ResultCode BaseSocket::SendResponse(const int32_t socketId, const std::string &connectionName,
256     const std::string &srcEndPoint, const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes,
257     uint32_t messageSeq)
258 {
259     IAM_LOGD("start.");
260     IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
261     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
262     int32_t msgType = -1;
263     // remote pin authentication may not contain the msgType parameter, no need to check the result
264     attributes->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
265     RefreshTraceInfo(connectionName, msgType, true, messageSeq);
266 
267     std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
268         connectionName, srcEndPoint, destEndPoint, attributes);
269     if (softBusMessage == nullptr) {
270         IAM_LOGE("softBusMessage is nullptr");
271         return GENERAL_ERROR;
272     }
273 
274     std::shared_ptr<Attributes> response = softBusMessage->CreateMessage(true);
275     if (response == nullptr) {
276         IAM_LOGE("creatMessage fail");
277         return GENERAL_ERROR;
278     }
279 
280     std::vector<uint8_t> data = response->Serialize();
281     int ret = SendBytes(socketId, data.data(), data.size());
282     if (ret != SUCCESS) {
283         IAM_LOGE("fail to send message, result= %{public}d", ret);
284         return GENERAL_ERROR;
285     }
286 
287     IAM_LOGI("SendResponse success.");
288     return SUCCESS;
289 }
290 
ParseMessage(const std::string & networkId,void * message,uint32_t messageLen)291 std::shared_ptr<SoftBusMessage> BaseSocket::ParseMessage(const std::string &networkId,
292     void *message, uint32_t messageLen)
293 {
294     IAM_LOGD("start.");
295     IF_FALSE_LOGE_AND_RETURN_VAL(message != nullptr, nullptr);
296     IF_FALSE_LOGE_AND_RETURN_VAL(messageLen != 0, nullptr);
297 
298     std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(0, "", "", "", nullptr);
299     if (softBusMessage == nullptr) {
300         IAM_LOGE("softBusMessage is nullptr");
301         return nullptr;
302     }
303 
304     std::shared_ptr<Attributes> attributes = softBusMessage->ParseMessage(message, messageLen);
305     if (attributes == nullptr) {
306         IAM_LOGE("parseMessage fail");
307         return nullptr;
308     }
309     int32_t msgType = -1;
310     // remote pin authentication may not contain the msgType parameter, no need to check the result
311     attributes->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
312     RefreshTraceInfo(softBusMessage->GetConnectionName(), msgType, softBusMessage->GetAckFlag(),
313         softBusMessage->GetMessageSeq());
314 
315     int32_t ret = SetDeviceNetworkId(networkId, attributes);
316     if (ret != SUCCESS) {
317         IAM_LOGE("SetDeviceNetworkId fail");
318         return nullptr;
319     }
320 
321     IAM_LOGD("ParseMessage success.");
322     return softBusMessage;
323 }
324 
ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage,std::shared_ptr<Attributes> response)325 void BaseSocket::ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage, std::shared_ptr<Attributes> response)
326 {
327     IF_FALSE_LOGE_AND_RETURN(softBusMessage != nullptr);
328     IF_FALSE_LOGE_AND_RETURN(response != nullptr);
329 
330     bool setResultCode = response->SetInt32Value(Attributes::ATTR_RESULT_CODE, GENERAL_ERROR);
331     IF_FALSE_LOGE_AND_RETURN(setResultCode);
332 
333     uint32_t messageVersion = softBusMessage->GetMessageVersion();
334     if (messageVersion != DEFAULT_MESSAGE_VERSION) {
335         IAM_LOGE("support message version %{public}u, receive message version %{public}u", DEFAULT_MESSAGE_VERSION,
336             messageVersion);
337         std::vector<uint32_t> supportedVersions = { DEFAULT_MESSAGE_VERSION };
338         bool setSupportedVersionsRet = response->SetUint32ArrayValue(Attributes::ATTR_SUPPORTED_MSG_VERSION,
339             supportedVersions);
340         IF_FALSE_LOGE_AND_RETURN(setSupportedVersionsRet);
341         return;
342     }
343 
344     std::string connectionName = softBusMessage->GetConnectionName();
345     std::string destEndPoint = softBusMessage->GetDestEndPoint();
346 
347     std::shared_ptr<ConnectionListener> connectionListener =
348         RemoteConnectListenerManager::GetInstance().FindListener(connectionName, destEndPoint);
349     if (connectionListener == nullptr) {
350         IAM_LOGE("connectionListener is nullptr");
351         return;
352     }
353 
354     auto beginTime = std::chrono::steady_clock::now();
355     connectionListener->OnMessage(connectionName, destEndPoint, softBusMessage->GetAttributes(), response);
356     auto endTime = std::chrono::steady_clock::now();
357     auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - beginTime);
358     IAM_LOGI("messageSeq:%{public}u ProcessMessageDuration:%{public}" PRIu64 " ms", softBusMessage->GetMessageSeq(),
359         static_cast<uint64_t>(duration.count()));
360 }
361 
ProcDataReceive(const int32_t socketId,std::shared_ptr<SoftBusMessage> & softBusMessage)362 ResultCode BaseSocket::ProcDataReceive(const int32_t socketId, std::shared_ptr<SoftBusMessage> &softBusMessage)
363 {
364     IAM_LOGD("start.");
365     IF_FALSE_LOGE_AND_RETURN_VAL(softBusMessage != nullptr, INVALID_PARAMETERS);
366     IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
367 
368     std::shared_ptr<Attributes> request = softBusMessage->GetAttributes();
369     if (request == nullptr) {
370         IAM_LOGE("GetAttributes fail");
371         return GENERAL_ERROR;
372     }
373 
374     uint32_t messageSeq = softBusMessage->GetMessageSeq();
375     bool ack = softBusMessage->GetAckFlag();
376     if (ack == true) {
377         PrintTransferDuration(messageSeq);
378         MsgCallback callback = GetMsgCallback(messageSeq);
379         if (callback == nullptr) {
380             IAM_LOGE("GetMsgCallback fail");
381             return GENERAL_ERROR;
382         }
383 
384         callback(request);
385         StopReplyTimer(messageSeq);
386         RemoveMsgCallback(messageSeq);
387     } else {
388         std::string connectionName = softBusMessage->GetConnectionName();
389         std::string srcEndPoint = softBusMessage->GetSrcEndPoint();
390         std::string destEndPoint = softBusMessage->GetDestEndPoint();
391 
392         std::shared_ptr<Attributes> response = Common::MakeShared<Attributes>();
393         if (response == nullptr) {
394             IAM_LOGE("create fail");
395             return GENERAL_ERROR;
396         }
397 
398         ProcessMessage(softBusMessage, response);
399 
400         SendResponse(socketId, connectionName, destEndPoint, srcEndPoint, response, messageSeq);
401     }
402 
403     IAM_LOGI("ProcDataReceive success.");
404     return SUCCESS;
405 }
406 } // namespace UserAuth
407 } // namespace UserIam
408 } // namespace OHOS