1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include "soft_bus_channel.h"
16
17 #include <securec.h>
18
19 #include "constant_common.h"
20 #include "device_info_manager.h"
21 #ifdef EVENTHANDLER_ENABLE
22 #include "access_event_handler.h"
23 #endif
24 #include "token_sync_manager_service.h"
25 #include "singleton.h"
26 #include "soft_bus_manager.h"
27
28 namespace OHOS {
29 namespace Security {
30 namespace AccessToken {
31 namespace {
32 static const std::string REQUEST_TYPE = "request";
33 static const std::string RESPONSE_TYPE = "response";
34 static const std::string TASK_NAME_CLOSE_SESSION = "atm_soft_bus_channel_close_session";
35 static const int32_t EXECUTE_COMMAND_TIME_OUT = 3000;
36 static const int32_t WAIT_SESSION_CLOSE_MILLISECONDS = 5 * 1000;
37 // send buf size for header
38 static const int RPC_TRANSFER_HEAD_BYTES_LENGTH = 1024 * 256;
39 // decompress buf size
40 static const int RPC_TRANSFER_BYTES_MAX_LENGTH = 1024 * 1024;
41 } // namespace
SoftBusChannel(const std::string & deviceId)42 SoftBusChannel::SoftBusChannel(const std::string &deviceId)
43 : deviceId_(deviceId), mutex_(), callbacks_(), responseResult_(""), loadedCond_()
44 {
45 LOGD(ATM_DOMAIN, ATM_TAG, "SoftBusChannel(deviceId)");
46 isDelayClosing_ = false;
47 socketFd_ = Constant::INVALID_SOCKET_FD;
48 isSocketUsing_ = false;
49 }
50
~SoftBusChannel()51 SoftBusChannel::~SoftBusChannel()
52 {
53 LOGD(ATM_DOMAIN, ATM_TAG, "~SoftBusChannel()");
54 }
55
BuildConnection()56 int SoftBusChannel::BuildConnection()
57 {
58 CancelCloseConnectionIfNeeded();
59
60 std::unique_lock<std::mutex> lock(socketMutex_);
61 if (socketFd_ != Constant::INVALID_SOCKET_FD) {
62 LOGI(ATM_DOMAIN, ATM_TAG, "Socket is exist, no need open again.");
63 return Constant::SUCCESS;
64 }
65
66 if (socketFd_ == Constant::INVALID_SOCKET_FD) {
67 LOGI(ATM_DOMAIN, ATM_TAG, "Bind service with device: %{public}s",
68 ConstantCommon::EncryptDevId(deviceId_).c_str());
69 int socket = SoftBusManager::GetInstance().BindService(deviceId_);
70 if (socket == Constant::INVALID_SOCKET_FD) {
71 LOGE(ATM_DOMAIN, ATM_TAG, "Bind service failed.");
72 return Constant::FAILURE;
73 }
74 socketFd_ = socket;
75 }
76 return Constant::SUCCESS;
77 }
78
79 #ifdef EVENTHANDLER_ENABLE
GetSendEventHandler(std::shared_ptr<AccessEventHandler> & handler)80 static bool GetSendEventHandler(std::shared_ptr<AccessEventHandler>& handler)
81 {
82 auto tokenSyncManagerService = DelayedSingleton<TokenSyncManagerService>::GetInstance();
83 if (tokenSyncManagerService == nullptr) {
84 LOGE(ATM_DOMAIN, ATM_TAG, "TokenSyncManagerService is null.");
85 return false;
86 }
87 handler = tokenSyncManagerService->GetSendEventHandler();
88 if (handler == nullptr) {
89 LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
90 return false;
91 }
92
93 return true;
94 }
95 #endif
96
CloseConnection()97 void SoftBusChannel::CloseConnection()
98 {
99 LOGD(ATM_DOMAIN, ATM_TAG, "Close connection");
100 std::unique_lock<std::mutex> lock(mutex_);
101 if (isDelayClosing_) {
102 return;
103 }
104
105 #ifdef EVENTHANDLER_ENABLE
106 std::shared_ptr<AccessEventHandler> handler = nullptr;
107 if (!GetSendEventHandler(handler)) {
108 LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
109 return;
110 }
111 #endif
112 std::weak_ptr<SoftBusChannel> weakPtr = shared_from_this();
113 std::function<void()> delayed = ([weakPtr]() {
114 auto self = weakPtr.lock();
115 if (self == nullptr) {
116 LOGE(ATM_DOMAIN, ATM_TAG, "SoftBusChannel is nullptr");
117 return;
118 }
119 std::unique_lock<std::mutex> lock(self->socketMutex_);
120 if (self->isSocketUsing_) {
121 LOGD(ATM_DOMAIN, ATM_TAG, "Socket is in using, cancel close socket");
122 } else {
123 SoftBusManager::GetInstance().CloseSocket(self->socketFd_);
124 self->socketFd_ = Constant::INVALID_SESSION;
125 LOGI(ATM_DOMAIN, ATM_TAG, "Close socket for device: %{public}s",
126 ConstantCommon::EncryptDevId(self->deviceId_).c_str());
127 }
128 self->isDelayClosing_ = false;
129 });
130
131 LOGD(ATM_DOMAIN, ATM_TAG, "Close socket after %{public}d ms", WAIT_SESSION_CLOSE_MILLISECONDS);
132 #ifdef EVENTHANDLER_ENABLE
133 handler->ProxyPostTask(delayed, TASK_NAME_CLOSE_SESSION, WAIT_SESSION_CLOSE_MILLISECONDS);
134 #endif
135
136 isDelayClosing_ = true;
137 }
138
Release()139 void SoftBusChannel::Release()
140 {
141 #ifdef EVENTHANDLER_ENABLE
142 std::shared_ptr<AccessEventHandler> handler = nullptr;
143 if (!GetSendEventHandler(handler)) {
144 LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
145 return;
146 }
147 handler->ProxyRemoveTask(TASK_NAME_CLOSE_SESSION);
148 #endif
149 }
150
GetUuid()151 std::string SoftBusChannel::GetUuid()
152 {
153 // to use a lib like libuuid
154 int uuidStrLen = 37; // 32+4+1
155 char uuidbuf[uuidStrLen];
156 RandomUuid(uuidbuf, uuidStrLen);
157 std::string uuid(uuidbuf);
158 LOGD(ATM_DOMAIN, ATM_TAG, "Generated message uuid: %{public}s", ConstantCommon::EncryptDevId(uuid).c_str());
159
160 return uuid;
161 }
162
InsertCallback(int result,std::string & uuid)163 void SoftBusChannel::InsertCallback(int result, std::string &uuid)
164 {
165 std::unique_lock<std::mutex> lock(socketMutex_);
166 std::function<void(const std::string &)> callback = [this](const std::string &result) {
167 responseResult_ = std::string(result);
168 loadedCond_.notify_all();
169 LOGD(ATM_DOMAIN, ATM_TAG, "OnResponse called end");
170 };
171 callbacks_.insert(std::pair<std::string, std::function<void(std::string)>>(uuid, callback));
172
173 isSocketUsing_ = true;
174 lock.unlock();
175 }
176
ExecuteCommand(const std::string & commandName,const std::string & jsonPayload)177 std::string SoftBusChannel::ExecuteCommand(const std::string &commandName, const std::string &jsonPayload)
178 {
179 if (commandName.empty() || jsonPayload.empty()) {
180 LOGE(ATM_DOMAIN, ATM_TAG, "Invalid params, commandName: %{public}s", commandName.c_str());
181 return "";
182 }
183
184 std::string uuid = GetUuid();
185
186 int len = static_cast<int32_t>(RPC_TRANSFER_HEAD_BYTES_LENGTH + jsonPayload.length());
187 unsigned char* buf = new (std::nothrow) unsigned char[len + 1];
188 if (buf == nullptr) {
189 LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory: %{public}d", len);
190 return "";
191 }
192 (void)memset_s(buf, len + 1, 0, len + 1);
193 BytesInfo info;
194 info.bytes = buf;
195 info.bytesLength = len;
196 int result = PrepareBytes(REQUEST_TYPE, uuid, commandName, jsonPayload, info);
197 if (result != Constant::SUCCESS) {
198 delete[] buf;
199 return "";
200 }
201 InsertCallback(result, uuid);
202 int retCode = SendRequestBytes(buf, info.bytesLength);
203 delete[] buf;
204
205 std::unique_lock<std::mutex> lock2(socketMutex_);
206 if (retCode != Constant::SUCCESS) {
207 LOGE(ATM_DOMAIN, ATM_TAG, "Send request data failed: %{public}d ", retCode);
208 callbacks_.erase(uuid);
209 isSocketUsing_ = false;
210 return "";
211 }
212
213 LOGD(ATM_DOMAIN, ATM_TAG, "Wait command response");
214 if (loadedCond_.wait_for(lock2, std::chrono::milliseconds(EXECUTE_COMMAND_TIME_OUT)) == std::cv_status::timeout) {
215 LOGW(ATM_DOMAIN, ATM_TAG, "Time out to wait response.");
216 callbacks_.erase(uuid);
217 isSocketUsing_ = false;
218 return "";
219 }
220
221 isSocketUsing_ = false;
222 return responseResult_;
223 }
224
HandleDataReceived(int socket,const unsigned char * bytes,int length)225 void SoftBusChannel::HandleDataReceived(int socket, const unsigned char* bytes, int length)
226 {
227 LOGD(ATM_DOMAIN, ATM_TAG, "HandleDataReceived");
228 #ifdef DEBUG_API_PERFORMANCE
229 LOGI(ATM_DOMAIN, ATM_TAG, "Api_performance:recieve message from softbus");
230 #endif
231 if (socket <= 0 || length <= 0) {
232 LOGE(ATM_DOMAIN, ATM_TAG, "Invalid params: socket: %{public}d, data length: %{public}d", socket, length);
233 return;
234 }
235 std::string receiveData = Decompress(bytes, length);
236 if (receiveData.empty()) {
237 LOGE(ATM_DOMAIN, ATM_TAG, "Invalid parameter bytes");
238 return;
239 }
240 std::shared_ptr<SoftBusMessage> message = SoftBusMessage::FromJson(receiveData);
241 if (message == nullptr) {
242 LOGD(ATM_DOMAIN, ATM_TAG, "Invalid json string");
243 return;
244 }
245 if (!message->IsValid()) {
246 LOGD(ATM_DOMAIN, ATM_TAG, "Invalid data, has empty field");
247 return;
248 }
249
250 std::string type = message->GetType();
251 if (REQUEST_TYPE == (type)) {
252 std::function<void()> delayed = ([weak = weak_from_this(), socket, message]() {
253 auto self = weak.lock();
254 if (self == nullptr) {
255 LOGE(ATM_DOMAIN, ATM_TAG, "SoftBusChannel is nullptr");
256 return;
257 }
258 self->HandleRequest(socket, message->GetId(), message->GetCommandName(), message->GetJsonPayload());
259 });
260
261 #ifdef EVENTHANDLER_ENABLE
262 std::shared_ptr<AccessEventHandler> handler = nullptr;
263 if (!GetSendEventHandler(handler)) {
264 LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
265 return;
266 }
267 handler->ProxyPostTask(delayed, "HandleDataReceived_HandleRequest");
268 #endif
269 } else if (RESPONSE_TYPE == (type)) {
270 HandleResponse(message->GetId(), message->GetJsonPayload());
271 } else {
272 LOGE(ATM_DOMAIN, ATM_TAG, "Invalid type: %{public}s ", type.c_str());
273 }
274 }
275
PrepareBytes(const std::string & type,const std::string & id,const std::string & commandName,const std::string & jsonPayload,BytesInfo & info)276 int SoftBusChannel::PrepareBytes(const std::string &type, const std::string &id, const std::string &commandName,
277 const std::string &jsonPayload, BytesInfo &info)
278 {
279 SoftBusMessage messageEntity(type, id, commandName, jsonPayload);
280 std::string json = messageEntity.ToJson();
281 return Compress(json, info.bytes, info.bytesLength);
282 }
283
Compress(const std::string & json,const unsigned char * compressedBytes,int & compressedLength)284 int SoftBusChannel::Compress(const std::string &json, const unsigned char* compressedBytes, int &compressedLength)
285 {
286 uLong len = compressBound(json.size());
287 // length will not so that long
288 if (compressedLength > 0 && static_cast<int32_t>(len) > compressedLength) {
289 LOGE(ATM_DOMAIN, ATM_TAG,
290 "compress error. data length overflow, bound length: %{public}d, buffer length: %{public}d",
291 static_cast<int32_t>(len), compressedLength);
292 return Constant::FAILURE;
293 }
294
295 int result = compress(const_cast<Byte*>(compressedBytes), &len,
296 reinterpret_cast<unsigned char*>(const_cast<char*>(json.c_str())), json.size() + 1);
297 if (result != Z_OK) {
298 LOGE(ATM_DOMAIN, ATM_TAG, "Compress failed! error code: %{public}d", result);
299 return result;
300 }
301 LOGD(ATM_DOMAIN, ATM_TAG, "Compress complete. compress %{public}d bytes to %{public}d", compressedLength,
302 static_cast<int32_t>(len));
303 compressedLength = static_cast<int32_t>(len);
304 return Constant::SUCCESS;
305 }
306
Decompress(const unsigned char * bytes,const int length)307 std::string SoftBusChannel::Decompress(const unsigned char* bytes, const int length)
308 {
309 LOGD(ATM_DOMAIN, ATM_TAG, "Input length: %{public}d", length);
310 uLong len = RPC_TRANSFER_BYTES_MAX_LENGTH;
311 unsigned char* buf = new (std::nothrow) unsigned char[len + 1];
312 if (buf == nullptr) {
313 LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory!");
314 return "";
315 }
316 (void)memset_s(buf, len + 1, 0, len + 1);
317 int result = uncompress(buf, &len, const_cast<unsigned char*>(bytes), length);
318 if (result != Z_OK) {
319 LOGE(ATM_DOMAIN, ATM_TAG,
320 "uncompress failed, error code: %{public}d, bound length: %{public}d, buffer length: %{public}d", result,
321 static_cast<int32_t>(len), length);
322 delete[] buf;
323 return "";
324 }
325 buf[len] = '\0';
326 std::string str(reinterpret_cast<char*>(buf));
327 delete[] buf;
328 return str;
329 }
330
SendRequestBytes(const unsigned char * bytes,const int bytesLength)331 int SoftBusChannel::SendRequestBytes(const unsigned char* bytes, const int bytesLength)
332 {
333 if (bytesLength == 0) {
334 LOGE(ATM_DOMAIN, ATM_TAG, "Bytes data is invalid.");
335 return Constant::FAILURE;
336 }
337
338 std::unique_lock<std::mutex> lock(socketMutex_);
339 if (CheckSessionMayReopenLocked() != Constant::SUCCESS) {
340 LOGE(ATM_DOMAIN, ATM_TAG, "Socket invalid and reopen failed!");
341 return Constant::FAILURE;
342 }
343
344 LOGD(ATM_DOMAIN, ATM_TAG, "Send len (after compress len)= %{public}d", bytesLength);
345 #ifdef DEBUG_API_PERFORMANCE
346 LOGI(ATM_DOMAIN, ATM_TAG, "Api_performance:send command to softbus");
347 #endif
348 int result = ::SendBytes(socketFd_, bytes, bytesLength);
349 if (result != Constant::SUCCESS) {
350 LOGE(ATM_DOMAIN, ATM_TAG, "Fail to send! result= %{public}d", result);
351 return Constant::FAILURE;
352 }
353 LOGD(ATM_DOMAIN, ATM_TAG, "Send successfully.");
354 return Constant::SUCCESS;
355 }
356
CheckSessionMayReopenLocked()357 int SoftBusChannel::CheckSessionMayReopenLocked()
358 {
359 // when socket is opened, we got a valid sessionid, when socket closed, we will reset sessionid.
360 if (IsSessionAvailable()) {
361 return Constant::SUCCESS;
362 }
363 int socket = SoftBusManager::GetInstance().BindService(deviceId_);
364 if (socket != Constant::INVALID_SESSION) {
365 socketFd_ = socket;
366 return Constant::SUCCESS;
367 }
368 return Constant::FAILURE;
369 }
370
IsSessionAvailable()371 bool SoftBusChannel::IsSessionAvailable()
372 {
373 return socketFd_ > Constant::INVALID_SESSION;
374 }
375
CancelCloseConnectionIfNeeded()376 void SoftBusChannel::CancelCloseConnectionIfNeeded()
377 {
378 std::unique_lock<std::mutex> lock(mutex_);
379 if (!isDelayClosing_) {
380 return;
381 }
382 LOGD(ATM_DOMAIN, ATM_TAG, "Cancel close connection");
383
384 Release();
385 isDelayClosing_ = false;
386 }
387
HandleRequest(int socket,const std::string & id,const std::string & commandName,const std::string & jsonPayload)388 void SoftBusChannel::HandleRequest(int socket, const std::string &id, const std::string &commandName,
389 const std::string &jsonPayload)
390 {
391 std::shared_ptr<BaseRemoteCommand> command =
392 RemoteCommandFactory::GetInstance().NewRemoteCommandFromJson(commandName, jsonPayload);
393 if (command == nullptr) {
394 // send result back directly
395 LOGW(ATM_DOMAIN, ATM_TAG, "Command %{public}s cannot get from json", commandName.c_str());
396
397 int sendlen = static_cast<int32_t>(RPC_TRANSFER_HEAD_BYTES_LENGTH + jsonPayload.length());
398 unsigned char* sendbuf = new (std::nothrow) unsigned char[sendlen + 1];
399 if (sendbuf == nullptr) {
400 LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory: %{public}d", sendlen);
401 return;
402 }
403 (void)memset_s(sendbuf, sendlen + 1, 0, sendlen + 1);
404 BytesInfo info;
405 info.bytes = sendbuf;
406 info.bytesLength = sendlen;
407 int sendResult = PrepareBytes(RESPONSE_TYPE, id, commandName, jsonPayload, info);
408 if (sendResult != Constant::SUCCESS) {
409 delete[] sendbuf;
410 return;
411 }
412 int sendResultCode = SendResponseBytes(socket, sendbuf, info.bytesLength);
413 delete[] sendbuf;
414 LOGD(ATM_DOMAIN, ATM_TAG, "Send response result= %{public}d ", sendResultCode);
415 return;
416 }
417
418 // execute command
419 command->Execute();
420 LOGD(ATM_DOMAIN, ATM_TAG, "Command uniqueId: %{public}s, finish with status: %{public}d, message: %{public}s",
421 ConstantCommon::EncryptDevId(command->remoteProtocol_.uniqueId).c_str(), command->remoteProtocol_.statusCode,
422 command->remoteProtocol_.message.c_str());
423
424 // send result back
425 std::string resultJsonPayload = command->ToJsonPayload();
426 int len = static_cast<int32_t>(RPC_TRANSFER_HEAD_BYTES_LENGTH + resultJsonPayload.length());
427 unsigned char* buf = new (std::nothrow) unsigned char[len + 1];
428 if (buf == nullptr) {
429 LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory: %{public}d", len);
430 return;
431 }
432 (void)memset_s(buf, len + 1, 0, len + 1);
433 BytesInfo info;
434 info.bytes = buf;
435 info.bytesLength = len;
436 int result = PrepareBytes(RESPONSE_TYPE, id, commandName, resultJsonPayload, info);
437 if (result != Constant::SUCCESS) {
438 delete[] buf;
439 return;
440 }
441 int retCode = SendResponseBytes(socket, buf, info.bytesLength);
442 delete[] buf;
443 LOGD(ATM_DOMAIN, ATM_TAG, "Send response result= %{public}d", retCode);
444 }
445
HandleResponse(const std::string & id,const std::string & jsonPayload)446 void SoftBusChannel::HandleResponse(const std::string &id, const std::string &jsonPayload)
447 {
448 std::unique_lock<std::mutex> lock(socketMutex_);
449 auto callback = callbacks_.find(id);
450 if (callback != callbacks_.end()) {
451 (callback->second)(jsonPayload);
452 callbacks_.erase(callback);
453 }
454 }
455
SendResponseBytes(int socket,const unsigned char * bytes,const int bytesLength)456 int SoftBusChannel::SendResponseBytes(int socket, const unsigned char* bytes, const int bytesLength)
457 {
458 LOGD(ATM_DOMAIN, ATM_TAG, "Send len (after compress len)= %{public}d", bytesLength);
459 int result = ::SendBytes(socket, bytes, bytesLength);
460 if (result != Constant::SUCCESS) {
461 LOGE(ATM_DOMAIN, ATM_TAG, "Fail to send! result= %{public}d", result);
462 return Constant::FAILURE;
463 }
464 LOGD(ATM_DOMAIN, ATM_TAG, "Send successfully.");
465 return Constant::SUCCESS;
466 }
467
FromJson(const std::string & jsonString)468 std::shared_ptr<SoftBusMessage> SoftBusMessage::FromJson(const std::string &jsonString)
469 {
470 CJsonUnique json = CreateJsonFromString(jsonString);
471 if (json == nullptr || cJSON_IsObject(json.get()) == false) {
472 LOGE(ATM_DOMAIN, ATM_TAG, "Failed to parse jsonString");
473 return nullptr;
474 }
475
476 std::string type;
477 std::string id;
478 std::string commandName;
479 std::string jsonPayload;
480 GetStringFromJson(json.get(), "type", type);
481 GetStringFromJson(json.get(), "id", id);
482 GetStringFromJson(json.get(), "commandName", commandName);
483 GetStringFromJson(json.get(), "jsonPayload", jsonPayload);
484 std::shared_ptr<SoftBusMessage> message = std::make_shared<SoftBusMessage>(type, id, commandName, jsonPayload);
485 return message;
486 }
487 } // namespace AccessToken
488 } // namespace Security
489 } // namespace OHOS
490