1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "soft_bus_base_socket.h"
17
18 #include <cinttypes>
19
20 #include "remote_connect_listener_manager.h"
21
22 #define LOG_TAG "USER_AUTH_SA"
23 namespace OHOS {
24 namespace UserIam {
25 namespace UserAuth {
26 using namespace OHOS::DistributedHardware;
27 const std::string USERIAM_PACKAGE_NAME = "ohos.useriam";
28 static constexpr uint32_t REPLY_TIMER_LEN_MS = 5 * 1000; // 5s
29 static constexpr uint32_t INVALID_TIMER_ID = 0;
30 static std::recursive_mutex g_seqMutex;
31 static uint32_t g_messageSeq = 0;
32
BaseSocket(const int32_t socketId)33 BaseSocket::BaseSocket(const int32_t socketId)
34 : socketId_(socketId)
35 {
36 currTraceInfo_.msgType = -1;
37 currTraceInfo_.socketId = socketId;
38 IAM_LOGI("create socket id %{public}d.", socketId_);
39 }
40
~BaseSocket()41 BaseSocket::~BaseSocket()
42 {
43 Shutdown(socketId_);
44 IAM_LOGI("close socket id %{public}d.", socketId_);
45 }
46
GetSocketId()47 int32_t BaseSocket::GetSocketId()
48 {
49 return socketId_;
50 }
51
GetCurrTraceInfo()52 RemoteConnectFaultTrace BaseSocket::GetCurrTraceInfo()
53 {
54 return currTraceInfo_;
55 }
56
InsertMsgCallback(uint32_t messageSeq,const std::string & connectionName,const MsgCallback & callback,uint32_t timerId)57 void BaseSocket::InsertMsgCallback(uint32_t messageSeq, const std::string &connectionName,
58 const MsgCallback &callback, uint32_t timerId)
59 {
60 IAM_LOGD("start. messageSeq:%{public}u, timerId:%{public}u", messageSeq, timerId);
61 IF_FALSE_LOGE_AND_RETURN(callback != nullptr);
62
63 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
64 CallbackInfo callbackInfo = {
65 .connectionName = connectionName,
66 .msgCallback = callback,
67 .timerId = timerId,
68 .sendTime = std::chrono::steady_clock::now()
69 };
70 callbackMap_.insert(std::pair<int32_t, CallbackInfo>(messageSeq, callbackInfo));
71 }
72
RemoveMsgCallback(uint32_t messageSeq)73 void BaseSocket::RemoveMsgCallback(uint32_t messageSeq)
74 {
75 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
76 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
77 callbackMap_.erase(messageSeq);
78 }
79
GetConnectionName(uint32_t messageSeq)80 std::string BaseSocket::GetConnectionName(uint32_t messageSeq)
81 {
82 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
83 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
84 std::string connectionName;
85 auto iter = callbackMap_.find(messageSeq);
86 if (iter != callbackMap_.end()) {
87 connectionName = iter->second.connectionName;
88 }
89 return connectionName;
90 }
91
GetMsgCallback(uint32_t messageSeq)92 MsgCallback BaseSocket::GetMsgCallback(uint32_t messageSeq)
93 {
94 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
95 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
96 MsgCallback callback = nullptr;
97 auto iter = callbackMap_.find(messageSeq);
98 if (iter != callbackMap_.end()) {
99 callback = iter->second.msgCallback;
100 }
101 return callback;
102 }
103
PrintTransferDuration(uint32_t messageSeq)104 void BaseSocket::PrintTransferDuration(uint32_t messageSeq)
105 {
106 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
107 auto iter = callbackMap_.find(messageSeq);
108 if (iter == callbackMap_.end()) {
109 IAM_LOGE("message seq not found");
110 return;
111 }
112
113 auto receiveAckTime = std::chrono::steady_clock::now();
114 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(receiveAckTime - iter->second.sendTime);
115 IAM_LOGI("messageSeq:%{public}u MessageTransferDuration:%{public}" PRIu64 " ms", messageSeq,
116 static_cast<uint64_t>(duration.count()));
117 }
118
GetReplyTimer(uint32_t messageSeq)119 uint32_t BaseSocket::GetReplyTimer(uint32_t messageSeq)
120 {
121 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
122 std::lock_guard<std::recursive_mutex> lock(callbackMutex_);
123 uint32_t timerId = 0;
124 auto iter = callbackMap_.find(messageSeq);
125 if (iter != callbackMap_.end()) {
126 timerId = iter->second.timerId;
127 }
128 return timerId;
129 }
130
StartReplyTimer(uint32_t messageSeq)131 uint32_t BaseSocket::StartReplyTimer(uint32_t messageSeq)
132 {
133 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
134 uint32_t timerId = GetReplyTimer(messageSeq);
135 if (timerId != INVALID_TIMER_ID) {
136 IAM_LOGI("timer is already start");
137 return timerId;
138 }
139
140 timerId = RelativeTimer::GetInstance().Register(
141 [weakSelf = weak_from_this(), messageSeq, socketId = socketId_] {
142 auto self = weakSelf.lock();
143 if (self == nullptr) {
144 IAM_LOGE("socket %{public}d is released", socketId);
145 return;
146 }
147 self->ReplyTimerTimeOut(messageSeq);
148 },
149 REPLY_TIMER_LEN_MS);
150
151 return timerId;
152 }
153
StopReplyTimer(uint32_t messageSeq)154 void BaseSocket::StopReplyTimer(uint32_t messageSeq)
155 {
156 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
157 uint32_t timerId = GetReplyTimer(messageSeq);
158 if (timerId == INVALID_TIMER_ID) {
159 IAM_LOGI("timer is already stop");
160 return;
161 }
162
163 RelativeTimer::GetInstance().Unregister(timerId);
164 }
165
ReplyTimerTimeOut(uint32_t messageSeq)166 void BaseSocket::ReplyTimerTimeOut(uint32_t messageSeq)
167 {
168 IAM_LOGD("start. messageSeq:%{public}u", messageSeq);
169 std::string connectionName = GetConnectionName(messageSeq);
170 if (connectionName.empty()) {
171 IAM_LOGE("GetMsgCallback connectionName fail");
172 return;
173 }
174 currTraceInfo_.reason = "ack time out";
175 ReportConnectFaultTrace(currTraceInfo_);
176
177 RemoteConnectListenerManager::GetInstance().OnConnectionDown(connectionName);
178 RemoveMsgCallback(messageSeq);
179 IAM_LOGI("reply timer is timeout, messageSeq:%{public}u", messageSeq);
180 }
181
GetMessageSeq()182 int32_t BaseSocket::GetMessageSeq()
183 {
184 IAM_LOGD("start.");
185 std::lock_guard<std::recursive_mutex> lock(g_seqMutex);
186 g_messageSeq++;
187 return g_messageSeq;
188 }
189
SetDeviceNetworkId(const std::string networkId,std::shared_ptr<Attributes> & attributes)190 ResultCode BaseSocket::SetDeviceNetworkId(const std::string networkId, std::shared_ptr<Attributes> &attributes)
191 {
192 IAM_LOGD("start.");
193 IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
194
195 bool setDeviceNetworkIdRet = attributes->SetStringValue(Attributes::ATTR_COLLECTOR_NETWORK_ID, networkId);
196 if (setDeviceNetworkIdRet == false) {
197 IAM_LOGE("SetStringValue fail");
198 return GENERAL_ERROR;
199 }
200
201 return SUCCESS;
202 }
203
RefreshTraceInfo(const std::string & connectionName,int32_t msgType,bool ack,uint32_t messageSeq)204 void BaseSocket::RefreshTraceInfo(const std::string &connectionName, int32_t msgType, bool ack, uint32_t messageSeq)
205 {
206 currTraceInfo_.connectionName = connectionName;
207 currTraceInfo_.msgType = msgType;
208 currTraceInfo_.ack = ack;
209 currTraceInfo_.messageSeq = messageSeq;
210 }
211
SendRequest(const ConnectionInfo & connectionInfo)212 ResultCode BaseSocket::SendRequest(const ConnectionInfo &connectionInfo)
213 {
214 IAM_LOGD("start.");
215 IF_FALSE_LOGE_AND_RETURN_VAL(connectionInfo.attributes != nullptr, INVALID_PARAMETERS);
216 IF_FALSE_LOGE_AND_RETURN_VAL(connectionInfo.socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
217
218 int32_t messageSeq = GetMessageSeq();
219 int32_t msgType = -1;
220 // remote pin authentication may not contain the msgType parameter, no need to check the result
221 connectionInfo.attributes->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
222 RefreshTraceInfo(connectionInfo.connectionName, msgType, false, messageSeq);
223 std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
224 connectionInfo.connectionName, connectionInfo.srcEndPoint, connectionInfo.destEndPoint,
225 connectionInfo.attributes);
226 if (softBusMessage == nullptr) {
227 IAM_LOGE("softBusMessage is nullptr");
228 return GENERAL_ERROR;
229 }
230
231 std::shared_ptr<Attributes> request = softBusMessage->CreateMessage(false);
232 if (request == nullptr) {
233 IAM_LOGE("creatMessage fail");
234 return GENERAL_ERROR;
235 }
236
237 std::vector<uint8_t> data = request->Serialize();
238 int ret = SendBytes(connectionInfo.socketId, data.data(), data.size());
239 if (ret != SUCCESS) {
240 IAM_LOGE("fail to send message, result= %{public}d", ret);
241 return GENERAL_ERROR;
242 }
243
244 uint32_t timerId = StartReplyTimer(messageSeq);
245 if (timerId == INVALID_TIMER_ID) {
246 IAM_LOGE("create reply timer fail");
247 return GENERAL_ERROR;
248 }
249
250 InsertMsgCallback(messageSeq, connectionInfo.connectionName, connectionInfo.callback, timerId);
251 IAM_LOGI("SendRequest success.");
252 return SUCCESS;
253 }
254
SendResponse(const int32_t socketId,const std::string & connectionName,const std::string & srcEndPoint,const std::string & destEndPoint,const std::shared_ptr<Attributes> & attributes,uint32_t messageSeq)255 ResultCode BaseSocket::SendResponse(const int32_t socketId, const std::string &connectionName,
256 const std::string &srcEndPoint, const std::string &destEndPoint, const std::shared_ptr<Attributes> &attributes,
257 uint32_t messageSeq)
258 {
259 IAM_LOGD("start.");
260 IF_FALSE_LOGE_AND_RETURN_VAL(attributes != nullptr, INVALID_PARAMETERS);
261 IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
262 int32_t msgType = -1;
263 // remote pin authentication may not contain the msgType parameter, no need to check the result
264 attributes->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
265 RefreshTraceInfo(connectionName, msgType, true, messageSeq);
266
267 std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(messageSeq,
268 connectionName, srcEndPoint, destEndPoint, attributes);
269 if (softBusMessage == nullptr) {
270 IAM_LOGE("softBusMessage is nullptr");
271 return GENERAL_ERROR;
272 }
273
274 std::shared_ptr<Attributes> response = softBusMessage->CreateMessage(true);
275 if (response == nullptr) {
276 IAM_LOGE("creatMessage fail");
277 return GENERAL_ERROR;
278 }
279
280 std::vector<uint8_t> data = response->Serialize();
281 int ret = SendBytes(socketId, data.data(), data.size());
282 if (ret != SUCCESS) {
283 IAM_LOGE("fail to send message, result= %{public}d", ret);
284 return GENERAL_ERROR;
285 }
286
287 IAM_LOGI("SendResponse success.");
288 return SUCCESS;
289 }
290
ParseMessage(const std::string & networkId,void * message,uint32_t messageLen)291 std::shared_ptr<SoftBusMessage> BaseSocket::ParseMessage(const std::string &networkId,
292 void *message, uint32_t messageLen)
293 {
294 IAM_LOGD("start.");
295 IF_FALSE_LOGE_AND_RETURN_VAL(message != nullptr, nullptr);
296 IF_FALSE_LOGE_AND_RETURN_VAL(messageLen != 0, nullptr);
297
298 std::shared_ptr<SoftBusMessage> softBusMessage = Common::MakeShared<SoftBusMessage>(0, "", "", "", nullptr);
299 if (softBusMessage == nullptr) {
300 IAM_LOGE("softBusMessage is nullptr");
301 return nullptr;
302 }
303
304 std::shared_ptr<Attributes> attributes = softBusMessage->ParseMessage(message, messageLen);
305 if (attributes == nullptr) {
306 IAM_LOGE("parseMessage fail");
307 return nullptr;
308 }
309 int32_t msgType = -1;
310 // remote pin authentication may not contain the msgType parameter, no need to check the result
311 attributes->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
312 RefreshTraceInfo(softBusMessage->GetConnectionName(), msgType, softBusMessage->GetAckFlag(),
313 softBusMessage->GetMessageSeq());
314
315 int32_t ret = SetDeviceNetworkId(networkId, attributes);
316 if (ret != SUCCESS) {
317 IAM_LOGE("SetDeviceNetworkId fail");
318 return nullptr;
319 }
320
321 IAM_LOGD("ParseMessage success.");
322 return softBusMessage;
323 }
324
ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage,std::shared_ptr<Attributes> response)325 void BaseSocket::ProcessMessage(std::shared_ptr<SoftBusMessage> softBusMessage, std::shared_ptr<Attributes> response)
326 {
327 IF_FALSE_LOGE_AND_RETURN(softBusMessage != nullptr);
328 IF_FALSE_LOGE_AND_RETURN(response != nullptr);
329
330 bool setResultCode = response->SetInt32Value(Attributes::ATTR_RESULT_CODE, GENERAL_ERROR);
331 IF_FALSE_LOGE_AND_RETURN(setResultCode);
332
333 uint32_t messageVersion = softBusMessage->GetMessageVersion();
334 if (messageVersion != DEFAULT_MESSAGE_VERSION) {
335 IAM_LOGE("support message version %{public}u, receive message version %{public}u", DEFAULT_MESSAGE_VERSION,
336 messageVersion);
337 std::vector<uint32_t> supportedVersions = { DEFAULT_MESSAGE_VERSION };
338 bool setSupportedVersionsRet = response->SetUint32ArrayValue(Attributes::ATTR_SUPPORTED_MSG_VERSION,
339 supportedVersions);
340 IF_FALSE_LOGE_AND_RETURN(setSupportedVersionsRet);
341 return;
342 }
343
344 std::string connectionName = softBusMessage->GetConnectionName();
345 std::string destEndPoint = softBusMessage->GetDestEndPoint();
346
347 std::shared_ptr<ConnectionListener> connectionListener =
348 RemoteConnectListenerManager::GetInstance().FindListener(connectionName, destEndPoint);
349 if (connectionListener == nullptr) {
350 IAM_LOGE("connectionListener is nullptr");
351 return;
352 }
353
354 auto beginTime = std::chrono::steady_clock::now();
355 connectionListener->OnMessage(connectionName, destEndPoint, softBusMessage->GetAttributes(), response);
356 auto endTime = std::chrono::steady_clock::now();
357 auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - beginTime);
358 IAM_LOGI("messageSeq:%{public}u ProcessMessageDuration:%{public}" PRIu64 " ms", softBusMessage->GetMessageSeq(),
359 static_cast<uint64_t>(duration.count()));
360 }
361
ProcDataReceive(const int32_t socketId,std::shared_ptr<SoftBusMessage> & softBusMessage)362 ResultCode BaseSocket::ProcDataReceive(const int32_t socketId, std::shared_ptr<SoftBusMessage> &softBusMessage)
363 {
364 IAM_LOGD("start.");
365 IF_FALSE_LOGE_AND_RETURN_VAL(softBusMessage != nullptr, INVALID_PARAMETERS);
366 IF_FALSE_LOGE_AND_RETURN_VAL(socketId != INVALID_SOCKET_ID, INVALID_PARAMETERS);
367
368 std::shared_ptr<Attributes> request = softBusMessage->GetAttributes();
369 if (request == nullptr) {
370 IAM_LOGE("GetAttributes fail");
371 return GENERAL_ERROR;
372 }
373
374 uint32_t messageSeq = softBusMessage->GetMessageSeq();
375 bool ack = softBusMessage->GetAckFlag();
376 if (ack == true) {
377 PrintTransferDuration(messageSeq);
378 MsgCallback callback = GetMsgCallback(messageSeq);
379 if (callback == nullptr) {
380 IAM_LOGE("GetMsgCallback fail");
381 return GENERAL_ERROR;
382 }
383
384 callback(request);
385 StopReplyTimer(messageSeq);
386 RemoveMsgCallback(messageSeq);
387 } else {
388 std::string connectionName = softBusMessage->GetConnectionName();
389 std::string srcEndPoint = softBusMessage->GetSrcEndPoint();
390 std::string destEndPoint = softBusMessage->GetDestEndPoint();
391
392 std::shared_ptr<Attributes> response = Common::MakeShared<Attributes>();
393 if (response == nullptr) {
394 IAM_LOGE("create fail");
395 return GENERAL_ERROR;
396 }
397
398 ProcessMessage(softBusMessage, response);
399
400 SendResponse(socketId, connectionName, destEndPoint, srcEndPoint, response, messageSeq);
401 }
402
403 IAM_LOGI("ProcDataReceive success.");
404 return SUCCESS;
405 }
406 } // namespace UserAuth
407 } // namespace UserIam
408 } // namespace OHOS