• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include "soft_bus_channel.h"
16 
17 #include <securec.h>
18 
19 #include "constant_common.h"
20 #include "device_info_manager.h"
21 #ifdef EVENTHANDLER_ENABLE
22 #include "access_event_handler.h"
23 #endif
24 #include "token_sync_manager_service.h"
25 #include "singleton.h"
26 #include "soft_bus_manager.h"
27 
28 namespace OHOS {
29 namespace Security {
30 namespace AccessToken {
31 namespace {
32 static const std::string REQUEST_TYPE = "request";
33 static const std::string RESPONSE_TYPE = "response";
34 static const std::string TASK_NAME_CLOSE_SESSION = "atm_soft_bus_channel_close_session";
35 static const int32_t EXECUTE_COMMAND_TIME_OUT = 3000;
36 static const int32_t WAIT_SESSION_CLOSE_MILLISECONDS = 5 * 1000;
37 // send buf size for header
38 static const int RPC_TRANSFER_HEAD_BYTES_LENGTH = 1024 * 256;
39 // decompress buf size
40 static const int RPC_TRANSFER_BYTES_MAX_LENGTH = 1024 * 1024;
41 } // namespace
SoftBusChannel(const std::string & deviceId)42 SoftBusChannel::SoftBusChannel(const std::string &deviceId)
43     : deviceId_(deviceId), mutex_(), callbacks_(), responseResult_(""), loadedCond_()
44 {
45     LOGD(ATM_DOMAIN, ATM_TAG, "SoftBusChannel(deviceId)");
46     isDelayClosing_ = false;
47     socketFd_ = Constant::INVALID_SOCKET_FD;
48     isSocketUsing_ = false;
49 }
50 
~SoftBusChannel()51 SoftBusChannel::~SoftBusChannel()
52 {
53     LOGD(ATM_DOMAIN, ATM_TAG, "~SoftBusChannel()");
54 }
55 
BuildConnection()56 int SoftBusChannel::BuildConnection()
57 {
58     CancelCloseConnectionIfNeeded();
59 
60     std::unique_lock<std::mutex> lock(socketMutex_);
61     if (socketFd_ != Constant::INVALID_SOCKET_FD) {
62         LOGI(ATM_DOMAIN, ATM_TAG, "Socket is exist, no need open again.");
63         return Constant::SUCCESS;
64     }
65 
66     if (socketFd_ == Constant::INVALID_SOCKET_FD) {
67         LOGI(ATM_DOMAIN, ATM_TAG, "Bind service with device: %{public}s",
68             ConstantCommon::EncryptDevId(deviceId_).c_str());
69         int socket = SoftBusManager::GetInstance().BindService(deviceId_);
70         if (socket == Constant::INVALID_SOCKET_FD) {
71             LOGE(ATM_DOMAIN, ATM_TAG, "Bind service failed.");
72             return Constant::FAILURE;
73         }
74         socketFd_ = socket;
75     }
76     return Constant::SUCCESS;
77 }
78 
79 #ifdef EVENTHANDLER_ENABLE
GetSendEventHandler(std::shared_ptr<AccessEventHandler> & handler)80 static bool GetSendEventHandler(std::shared_ptr<AccessEventHandler>& handler)
81 {
82     auto tokenSyncManagerService = DelayedSingleton<TokenSyncManagerService>::GetInstance();
83     if (tokenSyncManagerService == nullptr) {
84         LOGE(ATM_DOMAIN, ATM_TAG, "TokenSyncManagerService is null.");
85         return false;
86     }
87     handler = tokenSyncManagerService->GetSendEventHandler();
88     if (handler == nullptr) {
89         LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
90         return false;
91     }
92 
93     return true;
94 }
95 #endif
96 
CloseConnection()97 void SoftBusChannel::CloseConnection()
98 {
99     LOGD(ATM_DOMAIN, ATM_TAG, "Close connection");
100     std::unique_lock<std::mutex> lock(mutex_);
101     if (isDelayClosing_) {
102         return;
103     }
104 
105 #ifdef EVENTHANDLER_ENABLE
106     std::shared_ptr<AccessEventHandler> handler = nullptr;
107     if (!GetSendEventHandler(handler)) {
108         LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
109         return;
110     }
111 #endif
112     std::weak_ptr<SoftBusChannel> weakPtr = shared_from_this();
113     std::function<void()> delayed = ([weakPtr]() {
114         auto self = weakPtr.lock();
115         if (self == nullptr) {
116             LOGE(ATM_DOMAIN, ATM_TAG, "SoftBusChannel is nullptr");
117             return;
118         }
119         std::unique_lock<std::mutex> lock(self->socketMutex_);
120         if (self->isSocketUsing_) {
121             LOGD(ATM_DOMAIN, ATM_TAG, "Socket is in using, cancel close socket");
122         } else {
123             SoftBusManager::GetInstance().CloseSocket(self->socketFd_);
124             self->socketFd_ = Constant::INVALID_SESSION;
125             LOGI(ATM_DOMAIN, ATM_TAG, "Close socket for device: %{public}s",
126                 ConstantCommon::EncryptDevId(self->deviceId_).c_str());
127         }
128         self->isDelayClosing_ = false;
129     });
130 
131     LOGD(ATM_DOMAIN, ATM_TAG, "Close socket after %{public}d ms", WAIT_SESSION_CLOSE_MILLISECONDS);
132 #ifdef EVENTHANDLER_ENABLE
133     handler->ProxyPostTask(delayed, TASK_NAME_CLOSE_SESSION, WAIT_SESSION_CLOSE_MILLISECONDS);
134 #endif
135 
136     isDelayClosing_ = true;
137 }
138 
Release()139 void SoftBusChannel::Release()
140 {
141 #ifdef EVENTHANDLER_ENABLE
142     std::shared_ptr<AccessEventHandler> handler = nullptr;
143     if (!GetSendEventHandler(handler)) {
144         LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
145         return;
146     }
147     handler->ProxyRemoveTask(TASK_NAME_CLOSE_SESSION);
148 #endif
149 }
150 
GetUuid()151 std::string SoftBusChannel::GetUuid()
152 {
153     // to use a lib like libuuid
154     int uuidStrLen = 37; // 32+4+1
155     char uuidbuf[uuidStrLen];
156     RandomUuid(uuidbuf, uuidStrLen);
157     std::string uuid(uuidbuf);
158     LOGD(ATM_DOMAIN, ATM_TAG, "Generated message uuid: %{public}s", ConstantCommon::EncryptDevId(uuid).c_str());
159 
160     return uuid;
161 }
162 
InsertCallback(int result,std::string & uuid)163 void SoftBusChannel::InsertCallback(int result, std::string &uuid)
164 {
165     std::unique_lock<std::mutex> lock(socketMutex_);
166     std::function<void(const std::string &)> callback = [this](const std::string &result) {
167         responseResult_ = std::string(result);
168         loadedCond_.notify_all();
169         LOGD(ATM_DOMAIN, ATM_TAG, "OnResponse called end");
170     };
171     callbacks_.insert(std::pair<std::string, std::function<void(std::string)>>(uuid, callback));
172 
173     isSocketUsing_ = true;
174     lock.unlock();
175 }
176 
ExecuteCommand(const std::string & commandName,const std::string & jsonPayload)177 std::string SoftBusChannel::ExecuteCommand(const std::string &commandName, const std::string &jsonPayload)
178 {
179     if (commandName.empty() || jsonPayload.empty()) {
180         LOGE(ATM_DOMAIN, ATM_TAG, "Invalid params, commandName: %{public}s", commandName.c_str());
181         return "";
182     }
183 
184     std::string uuid = GetUuid();
185 
186     int len = static_cast<int32_t>(RPC_TRANSFER_HEAD_BYTES_LENGTH + jsonPayload.length());
187     unsigned char* buf = new (std::nothrow) unsigned char[len + 1];
188     if (buf == nullptr) {
189         LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory: %{public}d", len);
190         return "";
191     }
192     (void)memset_s(buf, len + 1, 0, len + 1);
193     BytesInfo info;
194     info.bytes = buf;
195     info.bytesLength = len;
196     int result = PrepareBytes(REQUEST_TYPE, uuid, commandName, jsonPayload, info);
197     if (result != Constant::SUCCESS) {
198         delete[] buf;
199         return "";
200     }
201     InsertCallback(result, uuid);
202     int retCode = SendRequestBytes(buf, info.bytesLength);
203     delete[] buf;
204 
205     std::unique_lock<std::mutex> lock2(socketMutex_);
206     if (retCode != Constant::SUCCESS) {
207         LOGE(ATM_DOMAIN, ATM_TAG, "Send request data failed: %{public}d ", retCode);
208         callbacks_.erase(uuid);
209         isSocketUsing_ = false;
210         return "";
211     }
212 
213     LOGD(ATM_DOMAIN, ATM_TAG, "Wait command response");
214     if (loadedCond_.wait_for(lock2, std::chrono::milliseconds(EXECUTE_COMMAND_TIME_OUT)) == std::cv_status::timeout) {
215         LOGW(ATM_DOMAIN, ATM_TAG, "Time out to wait response.");
216         callbacks_.erase(uuid);
217         isSocketUsing_ = false;
218         return "";
219     }
220 
221     isSocketUsing_ = false;
222     return responseResult_;
223 }
224 
HandleDataReceived(int socket,const unsigned char * bytes,int length)225 void SoftBusChannel::HandleDataReceived(int socket, const unsigned char* bytes, int length)
226 {
227     LOGD(ATM_DOMAIN, ATM_TAG, "HandleDataReceived");
228 #ifdef DEBUG_API_PERFORMANCE
229     LOGI(ATM_DOMAIN, ATM_TAG, "Api_performance:recieve message from softbus");
230 #endif
231     if (socket <= 0 || length <= 0) {
232         LOGE(ATM_DOMAIN, ATM_TAG, "Invalid params: socket: %{public}d, data length: %{public}d", socket, length);
233         return;
234     }
235     std::string receiveData = Decompress(bytes, length);
236     if (receiveData.empty()) {
237         LOGE(ATM_DOMAIN, ATM_TAG, "Invalid parameter bytes");
238         return;
239     }
240     std::shared_ptr<SoftBusMessage> message = SoftBusMessage::FromJson(receiveData);
241     if (message == nullptr) {
242         LOGD(ATM_DOMAIN, ATM_TAG, "Invalid json string");
243         return;
244     }
245     if (!message->IsValid()) {
246         LOGD(ATM_DOMAIN, ATM_TAG, "Invalid data, has empty field");
247         return;
248     }
249 
250     std::string type = message->GetType();
251     if (REQUEST_TYPE == (type)) {
252         std::function<void()> delayed = ([weak = weak_from_this(), socket, message]() {
253             auto self = weak.lock();
254             if (self == nullptr) {
255                 LOGE(ATM_DOMAIN, ATM_TAG, "SoftBusChannel is nullptr");
256                 return;
257             }
258             self->HandleRequest(socket, message->GetId(), message->GetCommandName(), message->GetJsonPayload());
259         });
260 
261 #ifdef EVENTHANDLER_ENABLE
262         std::shared_ptr<AccessEventHandler> handler = nullptr;
263         if (!GetSendEventHandler(handler)) {
264             LOGE(ATM_DOMAIN, ATM_TAG, "Fail to get EventHandler");
265             return;
266         }
267         handler->ProxyPostTask(delayed, "HandleDataReceived_HandleRequest");
268 #endif
269     } else if (RESPONSE_TYPE == (type)) {
270         HandleResponse(message->GetId(), message->GetJsonPayload());
271     } else {
272         LOGE(ATM_DOMAIN, ATM_TAG, "Invalid type: %{public}s ", type.c_str());
273     }
274 }
275 
PrepareBytes(const std::string & type,const std::string & id,const std::string & commandName,const std::string & jsonPayload,BytesInfo & info)276 int SoftBusChannel::PrepareBytes(const std::string &type, const std::string &id, const std::string &commandName,
277     const std::string &jsonPayload, BytesInfo &info)
278 {
279     SoftBusMessage messageEntity(type, id, commandName, jsonPayload);
280     std::string json = messageEntity.ToJson();
281     return Compress(json, info.bytes, info.bytesLength);
282 }
283 
Compress(const std::string & json,const unsigned char * compressedBytes,int & compressedLength)284 int SoftBusChannel::Compress(const std::string &json, const unsigned char* compressedBytes, int &compressedLength)
285 {
286     uLong len = compressBound(json.size());
287     // length will not so that long
288     if (compressedLength > 0 && static_cast<int32_t>(len) > compressedLength) {
289         LOGE(ATM_DOMAIN, ATM_TAG,
290             "compress error. data length overflow, bound length: %{public}d, buffer length: %{public}d",
291             static_cast<int32_t>(len), compressedLength);
292         return Constant::FAILURE;
293     }
294 
295     int result = compress(const_cast<Byte*>(compressedBytes), &len,
296         reinterpret_cast<unsigned char*>(const_cast<char*>(json.c_str())), json.size() + 1);
297     if (result != Z_OK) {
298         LOGE(ATM_DOMAIN, ATM_TAG, "Compress failed! error code: %{public}d", result);
299         return result;
300     }
301     LOGD(ATM_DOMAIN, ATM_TAG, "Compress complete. compress %{public}d bytes to %{public}d", compressedLength,
302         static_cast<int32_t>(len));
303     compressedLength = static_cast<int32_t>(len);
304     return Constant::SUCCESS;
305 }
306 
Decompress(const unsigned char * bytes,const int length)307 std::string SoftBusChannel::Decompress(const unsigned char* bytes, const int length)
308 {
309     LOGD(ATM_DOMAIN, ATM_TAG, "Input length: %{public}d", length);
310     uLong len = RPC_TRANSFER_BYTES_MAX_LENGTH;
311     unsigned char* buf = new (std::nothrow) unsigned char[len + 1];
312     if (buf == nullptr) {
313         LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory!");
314         return "";
315     }
316     (void)memset_s(buf, len + 1, 0, len + 1);
317     int result = uncompress(buf, &len, const_cast<unsigned char*>(bytes), length);
318     if (result != Z_OK) {
319         LOGE(ATM_DOMAIN, ATM_TAG,
320             "uncompress failed, error code: %{public}d, bound length: %{public}d, buffer length: %{public}d", result,
321             static_cast<int32_t>(len), length);
322         delete[] buf;
323         return "";
324     }
325     buf[len] = '\0';
326     std::string str(reinterpret_cast<char*>(buf));
327     delete[] buf;
328     return str;
329 }
330 
SendRequestBytes(const unsigned char * bytes,const int bytesLength)331 int SoftBusChannel::SendRequestBytes(const unsigned char* bytes, const int bytesLength)
332 {
333     if (bytesLength == 0) {
334         LOGE(ATM_DOMAIN, ATM_TAG, "Bytes data is invalid.");
335         return Constant::FAILURE;
336     }
337 
338     std::unique_lock<std::mutex> lock(socketMutex_);
339     if (CheckSessionMayReopenLocked() != Constant::SUCCESS) {
340         LOGE(ATM_DOMAIN, ATM_TAG, "Socket invalid and reopen failed!");
341         return Constant::FAILURE;
342     }
343 
344     LOGD(ATM_DOMAIN, ATM_TAG, "Send len (after compress len)= %{public}d", bytesLength);
345 #ifdef DEBUG_API_PERFORMANCE
346     LOGI(ATM_DOMAIN, ATM_TAG, "Api_performance:send command to softbus");
347 #endif
348     int result = ::SendBytes(socketFd_, bytes, bytesLength);
349     if (result != Constant::SUCCESS) {
350         LOGE(ATM_DOMAIN, ATM_TAG, "Fail to send! result= %{public}d", result);
351         return Constant::FAILURE;
352     }
353     LOGD(ATM_DOMAIN, ATM_TAG, "Send successfully.");
354     return Constant::SUCCESS;
355 }
356 
CheckSessionMayReopenLocked()357 int SoftBusChannel::CheckSessionMayReopenLocked()
358 {
359     // when socket is opened, we got a valid sessionid, when socket closed, we will reset sessionid.
360     if (IsSessionAvailable()) {
361         return Constant::SUCCESS;
362     }
363     int socket = SoftBusManager::GetInstance().BindService(deviceId_);
364     if (socket != Constant::INVALID_SESSION) {
365         socketFd_ = socket;
366         return Constant::SUCCESS;
367     }
368     return Constant::FAILURE;
369 }
370 
IsSessionAvailable()371 bool SoftBusChannel::IsSessionAvailable()
372 {
373     return socketFd_ > Constant::INVALID_SESSION;
374 }
375 
CancelCloseConnectionIfNeeded()376 void SoftBusChannel::CancelCloseConnectionIfNeeded()
377 {
378     std::unique_lock<std::mutex> lock(mutex_);
379     if (!isDelayClosing_) {
380         return;
381     }
382     LOGD(ATM_DOMAIN, ATM_TAG, "Cancel close connection");
383 
384     Release();
385     isDelayClosing_ = false;
386 }
387 
HandleRequest(int socket,const std::string & id,const std::string & commandName,const std::string & jsonPayload)388 void SoftBusChannel::HandleRequest(int socket, const std::string &id, const std::string &commandName,
389     const std::string &jsonPayload)
390 {
391     std::shared_ptr<BaseRemoteCommand> command =
392         RemoteCommandFactory::GetInstance().NewRemoteCommandFromJson(commandName, jsonPayload);
393     if (command == nullptr) {
394         // send result back directly
395         LOGW(ATM_DOMAIN, ATM_TAG, "Command %{public}s cannot get from json", commandName.c_str());
396 
397         int sendlen = static_cast<int32_t>(RPC_TRANSFER_HEAD_BYTES_LENGTH + jsonPayload.length());
398         unsigned char* sendbuf = new (std::nothrow) unsigned char[sendlen + 1];
399         if (sendbuf == nullptr) {
400             LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory: %{public}d", sendlen);
401             return;
402         }
403         (void)memset_s(sendbuf, sendlen + 1, 0, sendlen + 1);
404         BytesInfo info;
405         info.bytes = sendbuf;
406         info.bytesLength = sendlen;
407         int sendResult = PrepareBytes(RESPONSE_TYPE, id, commandName, jsonPayload, info);
408         if (sendResult != Constant::SUCCESS) {
409             delete[] sendbuf;
410             return;
411         }
412         int sendResultCode = SendResponseBytes(socket, sendbuf, info.bytesLength);
413         delete[] sendbuf;
414         LOGD(ATM_DOMAIN, ATM_TAG, "Send response result= %{public}d ", sendResultCode);
415         return;
416     }
417 
418     // execute command
419     command->Execute();
420     LOGD(ATM_DOMAIN, ATM_TAG, "Command uniqueId: %{public}s, finish with status: %{public}d, message: %{public}s",
421         ConstantCommon::EncryptDevId(command->remoteProtocol_.uniqueId).c_str(), command->remoteProtocol_.statusCode,
422         command->remoteProtocol_.message.c_str());
423 
424     // send result back
425     std::string resultJsonPayload = command->ToJsonPayload();
426     int len = static_cast<int32_t>(RPC_TRANSFER_HEAD_BYTES_LENGTH + resultJsonPayload.length());
427     unsigned char* buf = new (std::nothrow) unsigned char[len + 1];
428     if (buf == nullptr) {
429         LOGE(ATM_DOMAIN, ATM_TAG, "No enough memory: %{public}d", len);
430         return;
431     }
432     (void)memset_s(buf, len + 1, 0, len + 1);
433     BytesInfo info;
434     info.bytes = buf;
435     info.bytesLength = len;
436     int result = PrepareBytes(RESPONSE_TYPE, id, commandName, resultJsonPayload, info);
437     if (result != Constant::SUCCESS) {
438         delete[] buf;
439         return;
440     }
441     int retCode = SendResponseBytes(socket, buf, info.bytesLength);
442     delete[] buf;
443     LOGD(ATM_DOMAIN, ATM_TAG, "Send response result= %{public}d", retCode);
444 }
445 
HandleResponse(const std::string & id,const std::string & jsonPayload)446 void SoftBusChannel::HandleResponse(const std::string &id, const std::string &jsonPayload)
447 {
448     std::unique_lock<std::mutex> lock(socketMutex_);
449     auto callback = callbacks_.find(id);
450     if (callback != callbacks_.end()) {
451         (callback->second)(jsonPayload);
452         callbacks_.erase(callback);
453     }
454 }
455 
SendResponseBytes(int socket,const unsigned char * bytes,const int bytesLength)456 int SoftBusChannel::SendResponseBytes(int socket, const unsigned char* bytes, const int bytesLength)
457 {
458     LOGD(ATM_DOMAIN, ATM_TAG, "Send len (after compress len)= %{public}d", bytesLength);
459     int result = ::SendBytes(socket, bytes, bytesLength);
460     if (result != Constant::SUCCESS) {
461         LOGE(ATM_DOMAIN, ATM_TAG, "Fail to send! result= %{public}d", result);
462         return Constant::FAILURE;
463     }
464     LOGD(ATM_DOMAIN, ATM_TAG, "Send successfully.");
465     return Constant::SUCCESS;
466 }
467 
FromJson(const std::string & jsonString)468 std::shared_ptr<SoftBusMessage> SoftBusMessage::FromJson(const std::string &jsonString)
469 {
470     CJsonUnique json = CreateJsonFromString(jsonString);
471     if (json == nullptr || cJSON_IsObject(json.get()) == false) {
472         LOGE(ATM_DOMAIN, ATM_TAG, "Failed to parse jsonString");
473         return nullptr;
474     }
475 
476     std::string type;
477     std::string id;
478     std::string commandName;
479     std::string jsonPayload;
480     GetStringFromJson(json.get(), "type", type);
481     GetStringFromJson(json.get(), "id", id);
482     GetStringFromJson(json.get(), "commandName", commandName);
483     GetStringFromJson(json.get(), "jsonPayload", jsonPayload);
484     std::shared_ptr<SoftBusMessage> message = std::make_shared<SoftBusMessage>(type, id, commandName, jsonPayload);
485     return message;
486 }
487 } // namespace AccessToken
488 } // namespace Security
489 } // namespace OHOS
490