• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "soft_bus_socket_listener.h"
17 
18 #include "accesstoken_common_log.h"
19 #include "constant.h"
20 #include "remote_command_manager.h"
21 #include "socket.h"
22 #include "soft_bus_manager.h"
23 
24 namespace OHOS {
25 namespace Security {
26 namespace AccessToken {
27 namespace {
28 static const int32_t MAX_ONBYTES_RECEIVED_DATA_LEN = 1024 * 1024 * 10;
29 static const std::string TOKEN_SYNC_PACKAGE_NAME = "ohos.security.distributed_access_token";
30 static const std::string TOKEN_SYNC_SOCKET_NAME = "ohos.security.atm_channel.";
31 } // namespace
32 
33 std::mutex SoftBusSocketListener::socketMutex_;
34 std::map<int32_t, std::string> SoftBusSocketListener::socketBindMap_;
35 
OnBind(int32_t socket,PeerSocketInfo info)36 void SoftBusSocketListener::OnBind(int32_t socket, PeerSocketInfo info)
37 {
38     LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd is %{public}d.", socket);
39 
40     if (socket <= Constant::INVALID_SOCKET_FD) {
41         LOGE(ATM_DOMAIN, ATM_TAG, "Socket fd invalid.");
42         return;
43     }
44     std::string peerSessionName(info.name);
45     if (peerSessionName.find(TOKEN_SYNC_SOCKET_NAME) != 0) {
46         LOGE(ATM_DOMAIN, ATM_TAG, "Peer session name(%{public}s) is invalid.", info.name);
47         return;
48     }
49     std::string packageName(info.pkgName);
50     if (packageName != TOKEN_SYNC_PACKAGE_NAME) {
51         LOGE(ATM_DOMAIN, ATM_TAG, "Peer pkgname(%{public}s) is invalid.", info.pkgName);
52         return;
53     }
54 
55     std::string peerNetworkId(info.networkId);
56     std::lock_guard<std::mutex> guard(socketMutex_);
57     auto iter = socketBindMap_.find(socket);
58     if (iter == socketBindMap_.end()) {
59         socketBindMap_.insert(std::pair<int32_t, std::string>(socket, peerNetworkId));
60     } else {
61         iter->second = peerNetworkId;
62     }
63 }
64 
OnShutdown(int32_t socket,ShutdownReason reason)65 void SoftBusSocketListener::OnShutdown(int32_t socket, ShutdownReason reason)
66 {
67     LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd %{public}d shutdown because %{public}u.", socket, reason);
68 
69     if (socket <= Constant::INVALID_SOCKET_FD) {
70         LOGE(ATM_DOMAIN, ATM_TAG, "Socket fd invalid.");
71         return;
72     }
73 
74     // clear sessionId state
75     std::lock_guard<std::mutex> guard(socketMutex_);
76     auto iter = socketBindMap_.find(socket);
77     if (iter != socketBindMap_.end()) {
78         socketBindMap_.erase(iter);
79     }
80 }
81 
GetNetworkIdBySocket(const int32_t socket,std::string & networkId)82 bool SoftBusSocketListener::GetNetworkIdBySocket(const int32_t socket, std::string& networkId)
83 {
84     if (socket <= Constant::INVALID_SOCKET_FD) {
85         LOGE(ATM_DOMAIN, ATM_TAG, "Socket fd invalid.");
86         return false;
87     }
88 
89     std::lock_guard<std::mutex> guard(socketMutex_);
90     auto iter = socketBindMap_.find(socket);
91     if (iter != socketBindMap_.end()) {
92         networkId = iter->second;
93         return true;
94     }
95     return false;
96 }
97 
OnClientBytes(int32_t socket,const void * data,uint32_t dataLen)98 void SoftBusSocketListener::OnClientBytes(int32_t socket, const void* data, uint32_t dataLen)
99 {
100     LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd %{public}d, recv len %{public}d.", socket, dataLen);
101 
102     if ((socket <= Constant::INVALID_SOCKET_FD) || (data == nullptr) ||
103         (dataLen == 0) || (dataLen > MAX_ONBYTES_RECEIVED_DATA_LEN)) {
104         LOGE(ATM_DOMAIN, ATM_TAG, "Params invalid.");
105         return;
106     }
107 
108     std::string networkId;
109     if (!GetNetworkIdBySocket(socket, networkId)) {
110         LOGE(ATM_DOMAIN, ATM_TAG, "Socket invalid, bind service first.");
111         return;
112     }
113 
114     // channel create in SoftBusDeviceConnectionListener::OnDeviceOnline->RemoteCommandManager::NotifyDeviceOnline
115     auto channel = RemoteCommandManager::GetInstance().GetExecutorChannel(networkId);
116     if (channel == nullptr) {
117         LOGE(ATM_DOMAIN, ATM_TAG, "GetExecutorChannel failed");
118         return;
119     }
120     channel->HandleDataReceived(socket, static_cast<unsigned char*>(const_cast<void*>(data)), dataLen);
121 }
122 
OnServiceBytes(int32_t socket,const void * data,uint32_t dataLen)123 void SoftBusSocketListener::OnServiceBytes(int32_t socket, const void* data, uint32_t dataLen)
124 {
125     LOGI(ATM_DOMAIN, ATM_TAG, "Socket fd %{public}d, recv len %{public}d.", socket, dataLen);
126 
127     if ((socket <= Constant::INVALID_SOCKET_FD) || (data == nullptr) ||
128         (dataLen == 0) || (dataLen > MAX_ONBYTES_RECEIVED_DATA_LEN)) {
129         LOGE(ATM_DOMAIN, ATM_TAG, "Params invalid.");
130         return;
131     }
132 
133     std::string networkId;
134     if (SoftBusManager::GetInstance().GetNetworkIdBySocket(socket, networkId)) {
135         // channel create in SoftBusDeviceConnectionListener::OnDeviceOnline->RemoteCommandManager::NotifyDeviceOnline
136         auto channel = RemoteCommandManager::GetInstance().GetExecutorChannel(networkId);
137         if (channel == nullptr) {
138             LOGE(ATM_DOMAIN, ATM_TAG, "GetExecutorChannel failed");
139             return;
140         }
141         channel->HandleDataReceived(socket, static_cast<unsigned char*>(const_cast<void*>(data)), dataLen);
142     } else {
143         LOGE(ATM_DOMAIN, ATM_TAG, "Unkonow socket.");
144     }
145 }
146 
CleanUpAllBindSocket()147 void SoftBusSocketListener::CleanUpAllBindSocket()
148 {
149     std::lock_guard<std::mutex> guard(socketMutex_);
150     for (auto it = socketBindMap_.begin(); it != socketBindMap_.end();) {
151         ::Shutdown(it->first);
152         it = socketBindMap_.erase(it);
153     }
154 }
155 } // namespace AccessToken
156 } // namespace Security
157 } // namespace OHOS
158