1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "distributed/rpc/tcp/connection_pool.h"
18
19 namespace mindspore {
20 namespace distributed {
21 namespace rpc {
SetLinkPattern(bool linkPattern)22 void ConnectionPool::SetLinkPattern(bool linkPattern) { double_link_ = linkPattern; }
23
CloseConnection(Connection * conn)24 void ConnectionPool::CloseConnection(Connection *conn) {
25 if (conn == nullptr) {
26 return;
27 }
28
29 // Trigger Exit message note that this should be called before erasing link. Because we may chang deleted flag
30 // by to in this fun. And if deleted has been set to true, it means Exit message has been send before, do nothing.
31 if (!conn->deleted) {
32 DeleteConnInfo(conn);
33 }
34 conn->Close();
35 delete conn;
36 conn = nullptr;
37 }
38
FindConnection(const std::string & dst_url)39 Connection *ConnectionPool::FindConnection(const std::string &dst_url) {
40 std::lock_guard<std::mutex> lock(mutex_);
41 Connection *conn = nullptr;
42 auto iter = connections_.find(dst_url);
43 if (iter != connections_.end()) {
44 conn = iter->second;
45 }
46 return conn;
47 }
48
ResetAllConnMetrics()49 void ConnectionPool::ResetAllConnMetrics() {
50 for (const auto &iter : local_conns_) {
51 iter.second->send_metrics->Reset();
52 }
53 for (const auto &iter : remote_conns_) {
54 iter.second->send_metrics->Reset();
55 }
56 }
57
DeleteConnection(const std::string & dst_url)58 void ConnectionPool::DeleteConnection(const std::string &dst_url) {
59 Connection *conn = FindConnection(dst_url);
60 if (conn != nullptr) {
61 std::lock_guard<std::mutex> lock(mutex_);
62 (void)connections_.erase(dst_url);
63 CloseConnection(conn);
64 }
65 }
66
DeleteAllConnections(std::map<std::string,Connection * > * links) const67 void ConnectionPool::DeleteAllConnections(std::map<std::string, Connection *> *links) const {
68 if (links == nullptr) {
69 return;
70 }
71 auto iter = links->begin();
72 while (iter != links->end()) {
73 Connection *conn = iter->second;
74 if (conn == nullptr) {
75 continue;
76 }
77 // erase link
78 if (conn->recv_message != nullptr) {
79 delete conn->recv_message;
80 }
81 iter = links->erase(iter);
82 delete conn;
83 conn = nullptr;
84 }
85 }
86
AddConnection(Connection * conn)87 void ConnectionPool::AddConnection(Connection *conn) {
88 if (conn == nullptr) {
89 MS_LOG(ERROR) << "The connection is null";
90 return;
91 }
92 Connection *tmpConn = FindConnection(conn->destination);
93 if (tmpConn != nullptr) {
94 MS_LOG(INFO) << "unLink fd:" << tmpConn->socket_fd << ",to:" << tmpConn->destination.c_str();
95 CloseConnection(tmpConn);
96 }
97 std::lock_guard<std::mutex> lock(mutex_);
98 (void)connections_.emplace(conn->destination, conn);
99 }
100
DeleteConnInfo(int fd)101 void ConnectionPool::DeleteConnInfo(int fd) {
102 auto iter = conn_infos_.find(fd);
103 if (iter == conn_infos_.end()) {
104 return;
105 }
106 auto conn_infos = iter->second;
107 auto iter2 = conn_infos.begin();
108
109 while (iter2 != conn_infos.end()) {
110 auto linkInfo = *iter2;
111 if (linkInfo == nullptr) {
112 continue;
113 }
114 if (linkInfo->delete_callback) {
115 linkInfo->delete_callback(linkInfo->to, linkInfo->from);
116 }
117 iter2 = conn_infos.erase(iter2);
118 delete linkInfo;
119 }
120 (void)conn_infos_.erase(fd);
121 }
122
DeleteConnInfo(Connection * conn)123 void ConnectionPool::DeleteConnInfo(Connection *conn) {
124 if (conn == nullptr) {
125 return;
126 }
127 int fd = conn->socket_fd;
128 // If run in double link pattern, link fd and send fd must be the same, send Exit message bind on this fd
129 if (double_link_) {
130 DeleteConnInfo(fd);
131 return;
132 }
133
134 // If run in single link pattern, link fd and send fd may not be the same, we should send Exit message bind
135 // on link fd and remote link fd. Here 'deleted' flag should be set true to avoid duplicate Exit message with
136 // same aid.
137 conn->deleted = true;
138 DeleteConnInfo(conn->socket_fd);
139
140 if (conn->socket_fd != fd) {
141 MS_LOG(INFO) << "delete linker bind on link fd:" << conn->socket_fd << ",delete fd:" << fd;
142 }
143 }
144
DeleteAllConnInfos()145 void ConnectionPool::DeleteAllConnInfos() {
146 auto iter = conn_infos_.begin();
147 while (iter != conn_infos_.end()) {
148 auto conn_infos = iter->second;
149 auto iter2 = conn_infos.begin();
150
151 while (iter2 != conn_infos.end()) {
152 auto linkInfo = *iter2;
153 iter2 = conn_infos.erase(iter2);
154 delete linkInfo;
155 }
156 iter = conn_infos_.erase(iter);
157 }
158 }
159
FindConnInfo(int fd,const std::string & dst_url)160 ConnectionInfo *ConnectionPool::FindConnInfo(int fd, const std::string &dst_url) {
161 auto iter = conn_infos_.find(fd);
162 if (iter == conn_infos_.end()) {
163 return nullptr;
164 }
165 auto conn_infos = iter->second;
166 auto iter2 = conn_infos.begin();
167
168 while (iter2 != conn_infos.end()) {
169 auto linkInfo = *iter2;
170 if (linkInfo == nullptr) {
171 continue;
172 }
173 if (linkInfo->to == dst_url) {
174 return linkInfo;
175 }
176 ++iter2;
177 }
178 return nullptr;
179 }
180
AddConnInfo(int fd,const std::string & dst_url,DeleteCallBack callback)181 void ConnectionPool::AddConnInfo(int fd, const std::string &dst_url, DeleteCallBack callback) {
182 ConnectionInfo *linker = FindConnInfo(fd, dst_url);
183 if (linker != nullptr) {
184 return;
185 }
186 // This linker will be deleted in `DeleteConnInfo` or `DeleteAllConnInfos`.
187 linker = new (std::nothrow) ConnectionInfo();
188 if (linker == nullptr) {
189 MS_LOG(ERROR) << "new ConnectionInfo fail dAid:" << dst_url;
190 return;
191 }
192 linker->from = "";
193 linker->to = dst_url;
194 linker->socket_fd = fd;
195 linker->delete_callback = callback;
196 (void)conn_infos_[fd].insert(linker);
197 }
198
ReverseConnInfo(int fromFd,int toFd)199 bool ConnectionPool::ReverseConnInfo(int fromFd, int toFd) {
200 auto iter = conn_infos_.find(fromFd);
201 if (iter == conn_infos_.end()) {
202 return false;
203 }
204 auto conn_infos = iter->second;
205 (void)conn_infos_.erase(fromFd);
206 conn_infos_[toFd] = conn_infos;
207 return true;
208 }
209
Finalize()210 void ConnectionPool::Finalize() {
211 DeleteAllConnections(&local_conns_);
212 DeleteAllConnections(&remote_conns_);
213 DeleteAllConnInfos();
214 }
215 } // namespace rpc
216 } // namespace distributed
217 } // namespace mindspore
218