1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "soft_bus_socket_listener.h"
17
18 #include "accesstoken_common_log.h"
19 #include "constant.h"
20 #include "remote_command_manager.h"
21 #include "socket.h"
22 #include "soft_bus_manager.h"
23
24 namespace OHOS {
25 namespace Security {
26 namespace AccessToken {
27 namespace {
28 static const int32_t MAX_ONBYTES_RECEIVED_DATA_LEN = 1024 * 1024 * 10;
29 static const std::string TOKEN_SYNC_PACKAGE_NAME = "ohos.security.distributed_access_token";
30 static const std::string TOKEN_SYNC_SOCKET_NAME = "ohos.security.atm_channel.";
31 } // namespace
32
33 std::mutex SoftBusSocketListener::socketMutex_;
34 std::map<int32_t, std::string> SoftBusSocketListener::socketBindMap_;
35
OnBind(int32_t socket,PeerSocketInfo info)36 void SoftBusSocketListener::OnBind(int32_t socket, PeerSocketInfo info)
37 {
38 LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd is %{public}d.", socket);
39
40 if (socket <= Constant::INVALID_SOCKET_FD) {
41 LOGE(ATM_DOMAIN, ATM_TAG, "Socket fd invalid.");
42 return;
43 }
44 std::string peerSessionName(info.name);
45 if (peerSessionName.find(TOKEN_SYNC_SOCKET_NAME) != 0) {
46 LOGE(ATM_DOMAIN, ATM_TAG, "Peer session name(%{public}s) is invalid.", info.name);
47 return;
48 }
49 std::string packageName(info.pkgName);
50 if (packageName != TOKEN_SYNC_PACKAGE_NAME) {
51 LOGE(ATM_DOMAIN, ATM_TAG, "Peer pkgname(%{public}s) is invalid.", info.pkgName);
52 return;
53 }
54
55 std::string peerNetworkId(info.networkId);
56 std::lock_guard<std::mutex> guard(socketMutex_);
57 auto iter = socketBindMap_.find(socket);
58 if (iter == socketBindMap_.end()) {
59 socketBindMap_.insert(std::pair<int32_t, std::string>(socket, peerNetworkId));
60 } else {
61 iter->second = peerNetworkId;
62 }
63 }
64
OnShutdown(int32_t socket,ShutdownReason reason)65 void SoftBusSocketListener::OnShutdown(int32_t socket, ShutdownReason reason)
66 {
67 LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd %{public}d shutdown because %{public}u.", socket, reason);
68
69 if (socket <= Constant::INVALID_SOCKET_FD) {
70 LOGE(ATM_DOMAIN, ATM_TAG, "Socket fd invalid.");
71 return;
72 }
73
74 // clear sessionId state
75 std::lock_guard<std::mutex> guard(socketMutex_);
76 auto iter = socketBindMap_.find(socket);
77 if (iter != socketBindMap_.end()) {
78 socketBindMap_.erase(iter);
79 }
80 }
81
GetNetworkIdBySocket(const int32_t socket,std::string & networkId)82 bool SoftBusSocketListener::GetNetworkIdBySocket(const int32_t socket, std::string& networkId)
83 {
84 if (socket <= Constant::INVALID_SOCKET_FD) {
85 LOGE(ATM_DOMAIN, ATM_TAG, "Socket fd invalid.");
86 return false;
87 }
88
89 std::lock_guard<std::mutex> guard(socketMutex_);
90 auto iter = socketBindMap_.find(socket);
91 if (iter != socketBindMap_.end()) {
92 networkId = iter->second;
93 return true;
94 }
95 return false;
96 }
97
OnClientBytes(int32_t socket,const void * data,uint32_t dataLen)98 void SoftBusSocketListener::OnClientBytes(int32_t socket, const void* data, uint32_t dataLen)
99 {
100 LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd %{public}d, recv len %{public}d.", socket, dataLen);
101
102 if ((socket <= Constant::INVALID_SOCKET_FD) || (data == nullptr) ||
103 (dataLen == 0) || (dataLen > MAX_ONBYTES_RECEIVED_DATA_LEN)) {
104 LOGE(ATM_DOMAIN, ATM_TAG, "Params invalid.");
105 return;
106 }
107
108 std::string networkId;
109 if (!GetNetworkIdBySocket(socket, networkId)) {
110 LOGE(ATM_DOMAIN, ATM_TAG, "Socket invalid, bind service first.");
111 return;
112 }
113
114 // channel create in SoftBusDeviceConnectionListener::OnDeviceOnline->RemoteCommandManager::NotifyDeviceOnline
115 auto channel = RemoteCommandManager::GetInstance().GetExecutorChannel(networkId);
116 if (channel == nullptr) {
117 LOGE(ATM_DOMAIN, ATM_TAG, "GetExecutorChannel failed");
118 return;
119 }
120 channel->HandleDataReceived(socket, static_cast<unsigned char*>(const_cast<void*>(data)), dataLen);
121 }
122
OnServiceBytes(int32_t socket,const void * data,uint32_t dataLen)123 void SoftBusSocketListener::OnServiceBytes(int32_t socket, const void* data, uint32_t dataLen)
124 {
125 LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd %{public}d, recv len %{public}d.", socket, dataLen);
126
127 if ((socket <= Constant::INVALID_SOCKET_FD) || (data == nullptr) ||
128 (dataLen == 0) || (dataLen > MAX_ONBYTES_RECEIVED_DATA_LEN)) {
129 LOGE(ATM_DOMAIN, ATM_TAG, "Params invalid.");
130 return;
131 }
132
133 std::string networkId;
134 if (SoftBusManager::GetInstance().GetNetworkIdBySocket(socket, networkId)) {
135 // channel create in SoftBusDeviceConnectionListener::OnDeviceOnline->RemoteCommandManager::NotifyDeviceOnline
136 auto channel = RemoteCommandManager::GetInstance().GetExecutorChannel(networkId);
137 if (channel == nullptr) {
138 LOGE(ATM_DOMAIN, ATM_TAG, "GetExecutorChannel failed");
139 return;
140 }
141 channel->HandleDataReceived(socket, static_cast<unsigned char*>(const_cast<void*>(data)), dataLen);
142 } else {
143 LOGE(ATM_DOMAIN, ATM_TAG, "Unkonow socket.");
144 }
145 }
146
CleanUpAllBindSocket()147 void SoftBusSocketListener::CleanUpAllBindSocket()
148 {
149 std::lock_guard<std::mutex> guard(socketMutex_);
150 for (auto it = socketBindMap_.begin(); it != socketBindMap_.end();) {
151 ::Shutdown(it->first);
152 it = socketBindMap_.erase(it);
153 }
154 }
155 } // namespace AccessToken
156 } // namespace Security
157 } // namespace OHOS
158