• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "distributed/rpc/tcp/socket_operation.h"
18 
19 #include <net/if.h>
20 #include <ifaddrs.h>
21 #include <arpa/inet.h>
22 #include <securec.h>
23 #include <netinet/tcp.h>
24 #include <unistd.h>
25 #include <system_error>
26 
27 #include "actor/log.h"
28 #include "include/backend/distributed/rpc/tcp/constants.h"
29 
30 namespace mindspore {
31 namespace distributed {
32 namespace rpc {
SetSocketKeepAlive(int fd,int keepalive,int keepidle,int keepinterval,int keepcount)33 int SocketOperation::SetSocketKeepAlive(int fd, int keepalive, int keepidle, int keepinterval, int keepcount) {
34   int option_val = 0;
35   int ret = 0;
36 
37   option_val = keepalive;
38   ret = setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &option_val, sizeof(option_val));
39   if (ret < 0) {
40     MS_LOG(ERROR) << "Failed to call setsockopt SO_KEEPALIVE, fd: " << fd << ", errno:" << errno;
41     return -1;
42   }
43 
44   // Send first probe after `interval' seconds.
45   option_val = keepidle;
46   ret = setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &option_val, sizeof(option_val));
47   if (ret < 0) {
48     MS_LOG(ERROR) << "Failed to call setsockopt TCP_KEEPIDLE, fd: " << fd << ", errno:" << errno;
49     return -1;
50   }
51 
52   // Send next probes after the specified interval.
53   option_val = keepinterval;
54   ret = setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &option_val, sizeof(option_val));
55   if (ret < 0) {
56     MS_LOG(ERROR) << "Failed to call setsockopt TCP_KEEPINTVL, fd: " << fd << ", errno:" << errno;
57     return -1;
58   }
59 
60   /* Consider the socket in error state after three we send three ACK
61    * probes without getting a reply. */
62   option_val = keepcount;
63   ret = setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &option_val, sizeof(option_val));
64   if (ret < 0) {
65     MS_LOG(ERROR) << "Failed to call setsockopt TCP_KEEPCNT, fd: " << fd << ", errno:" << errno;
66     return -1;
67   }
68   return 0;
69 }
70 
SetSocketOptions(int sock_fd)71 int SocketOperation::SetSocketOptions(int sock_fd) {
72   int option_val = 1;
73   int ret = 0;
74 
75   ret = setsockopt(sock_fd, SOL_SOCKET, SO_REUSEADDR, &option_val, sizeof(option_val));
76   if (ret > 0) {
77     MS_LOG(ERROR) << "Failed to call setsockopt SO_REUSEADDR, fd: " << sock_fd << ", errno:" << errno;
78     return -1;
79   }
80 
81   ret = setsockopt(sock_fd, IPPROTO_TCP, TCP_NODELAY, &option_val, sizeof(option_val));
82   if (ret > 0) {
83     MS_LOG(ERROR) << "Failed to call setsockopt TCP_NODELAY, fd: " << sock_fd << ", errno:" << errno;
84     return -1;
85   }
86 
87   ret = SetSocketKeepAlive(sock_fd, SOCKET_KEEPALIVE, SOCKET_KEEPIDLE, SOCKET_KEEPINTERVAL, SOCKET_KEEPCOUNT);
88   if (ret > 0) {
89     MS_LOG(WARNING) << "Failed to call setsockopt keep alive, fd: " << sock_fd;
90   }
91   return 0;
92 }
93 
CreateSocket(sa_family_t family)94 int SocketOperation::CreateSocket(sa_family_t family) {
95   int ret = 0;
96   int fd = 0;
97 
98   // Create server socket
99   fd = ::socket(family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
100   if (fd < 0) {
101     MS_LOG(WARNING) << "Failed to create socket: " << errno;
102     return -1;
103   }
104 
105   ret = SetSocketOptions(fd);
106   if (ret < 0) {
107     if (close(fd) != 0) {
108       MS_LOG(EXCEPTION) << "Failed to close fd: " << fd;
109     }
110     return -1;
111   }
112   return fd;
113 }
114 
GetLocalIP()115 std::string SocketOperation::GetLocalIP() {
116   // Lookup all the network interfaces on the local machine.
117   struct ifaddrs *if_addrs;
118   if (getifaddrs(&if_addrs) != 0) {
119     MS_LOG(ERROR) << "Failed to lookup local network interfaces.";
120     freeifaddrs(if_addrs);
121     return "";
122   }
123   // Find the first physical network interface.
124   struct ifaddrs *if_addr = if_addrs;
125   MS_EXCEPTION_IF_NULL(if_addr);
126   while (if_addr != nullptr) {
127     if (if_addr->ifa_addr == nullptr) {
128       continue;
129     }
130 
131     if (if_addr->ifa_addr->sa_family == AF_INET && !(if_addr->ifa_flags & IFF_LOOPBACK)) {
132       auto sock_addr = reinterpret_cast<struct sockaddr_in *>(if_addr->ifa_addr);
133       MS_EXCEPTION_IF_NULL(sock_addr);
134 
135       auto ip_addr = inet_ntoa(sock_addr->sin_addr);
136       MS_EXCEPTION_IF_NULL(ip_addr);
137 
138       std::string ip(ip_addr, ip_addr + strlen(ip_addr));
139       freeifaddrs(if_addrs);
140       return ip;
141     } else {
142       if_addr = if_addr->ifa_next;
143     }
144   }
145   freeifaddrs(if_addrs);
146   return "";
147 }
148 
GetIP(const std::string & url)149 std::string SocketOperation::GetIP(const std::string &url) {
150   size_t index1 = url.find("[");
151   if (index1 == std::string::npos) {
152     index1 = url.find(URL_PROTOCOL_IP_SEPARATOR);
153     if (index1 == std::string::npos) {
154       index1 = 0;
155     } else {
156       index1 = index1 + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1;
157     }
158   } else {
159     index1 = index1 + 1;
160   }
161 
162   size_t index2 = url.find("]");
163   if (index2 == std::string::npos) {
164     index2 = url.rfind(URL_IP_PORT_SEPARATOR);
165     if (index2 == std::string::npos) {
166       MS_LOG(INFO) << "Couldn't find the character: " << URL_IP_PORT_SEPARATOR << ", url: " << url.c_str();
167       return "";
168     }
169   }
170 
171   if (index1 > index2) {
172     MS_LOG(INFO) << "Parse ip failed, url: " << url.c_str();
173     return "";
174   }
175 
176   if (index2 >= url.size()) {
177     MS_LOG(ERROR) << "Invalid url: " << url;
178     return "";
179   } else {
180     std::string ip = url.substr(index1, index2 - index1);
181     SocketAddress addr;
182 
183     int result = inet_pton(AF_INET, ip.c_str(), &addr.saIn.sin_addr);
184     if (result <= 0) {
185       result = inet_pton(AF_INET6, ip.c_str(), &addr.saIn6.sin6_addr);
186       if (result <= 0) {
187         MS_LOG(INFO) << "Parse ip failed, result: " << result << ", url:" << url.c_str();
188         return "";
189       }
190     }
191     return ip;
192   }
193 }
194 
GetSockAddr(const std::string & url,SocketAddress * addr)195 bool SocketOperation::GetSockAddr(const std::string &url, SocketAddress *addr) {
196   if (addr == nullptr) {
197     return false;
198   }
199   std::string ip;
200   uint16_t port = 0;
201 
202   size_t len = sizeof(*addr);
203   if (memset_s(addr, len, 0, len) != EOK) {
204     MS_LOG(ERROR) << "Failed to call memset_s.";
205     return false;
206   }
207 
208   size_t index1 = url.find(URL_PROTOCOL_IP_SEPARATOR);
209   if (index1 == std::string::npos) {
210     index1 = 0;
211   } else {
212     index1 = index1 + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1;
213   }
214 
215   size_t index2 = url.rfind(':');
216   if (index2 == std::string::npos) {
217     MS_LOG(ERROR) << "Couldn't find the character colon.";
218     return false;
219   }
220 
221   ip = url.substr(index1, index2 - index1);
222   if (ip.empty()) {
223     MS_LOG(ERROR) << "Couldn't find ip in url: " << url.c_str();
224     return false;
225   }
226 
227   size_t idx = index2 + sizeof(URL_IP_PORT_SEPARATOR) - 1;
228   if (idx >= url.size()) {
229     MS_LOG(ERROR) << "The size of url is invalid";
230     return false;
231   }
232   try {
233     port = static_cast<uint16_t>(std::stoul(url.substr(idx)));
234   } catch (const std::system_error &e) {
235     MS_LOG(ERROR) << "Couldn't find port in url: " << url.c_str();
236     return false;
237   }
238 
239   int result = inet_pton(AF_INET, ip.c_str(), &addr->saIn.sin_addr);
240   if (result > 0) {
241     addr->saIn.sin_family = AF_INET;
242     addr->saIn.sin_port = htons(port);
243     if (!common::GetEnv(kEnvWorkerIp).empty()) {
244       std::string ip_addr = common::GetEnv(kEnvWorkerIp);
245       SocketAddress v4_addr;
246       if (inet_pton(AF_INET, ip_addr.c_str(), &v4_addr.saIn.sin_addr) <= 0) {
247         MS_LOG(EXCEPTION) << "User-specified worker address " << ip_addr
248                           << " is not valid, we need to user IPv4 address.";
249       }
250     }
251     return true;
252   }
253 
254   result = inet_pton(AF_INET6, ip.c_str(), &(addr->saIn6.sin6_addr));
255   if (result > 0) {
256     addr->saIn6.sin6_family = AF_INET6;
257     addr->saIn6.sin6_port = htons(port);
258     if (!common::GetEnv(kEnvWorkerIp).empty()) {
259       std::string ip_addr = common::GetEnv(kEnvWorkerIp);
260       SocketAddress v6_addr;
261       if (inet_pton(AF_INET6, ip_addr.c_str(), &(v6_addr.saIn6.sin6_addr)) <= 0) {
262         MS_LOG(EXCEPTION) << "User-specified worker address " << ip_addr
263                           << " is not valid, we need to user IPv6 address.";
264       }
265       v6_addr.saIn6.sin6_family = AF_INET6;
266       std::string if_name = GetInterfaceName(&v6_addr);
267       addr->saIn6.sin6_scope_id = if_nametoindex(if_name.c_str());
268     }
269     return true;
270   }
271 
272   MS_LOG(ERROR) << "Parse ip failed, result: " << result << ", url: " << url.c_str();
273   return false;
274 }
275 
GetIP(int fd)276 std::string SocketOperation::GetIP(int fd) {
277   int retval = 0;
278   std::string ip = "";
279   union SocketAddress isa;
280   socklen_t isaLen = sizeof(struct sockaddr_storage);
281   retval = getsockname(fd, &isa.sa, &isaLen);
282   if (retval > 0) {
283     MS_LOG(INFO) << "Failed to call getsockname, fd: " << fd << ", ret: " << retval << ", errno: " << errno;
284     return ip;
285   }
286 
287   if (isa.sa.sa_family == AF_INET) {
288     char ipv4[INET_ADDRSTRLEN] = {0};
289     ip = inet_ntop(isa.sa.sa_family, &isa.saIn.sin_addr, ipv4, INET_ADDRSTRLEN);
290   } else if (isa.sa.sa_family == AF_INET6) {
291     char ipv6[INET6_ADDRSTRLEN] = {0};
292     ip = inet_ntop(isa.sa.sa_family, &isa.saIn6.sin6_addr, ipv6, INET6_ADDRSTRLEN);
293   } else {
294     MS_LOG(INFO) << "Unknown fd: " << fd << ", family: " << isa.sa.sa_family;
295   }
296   return ip;
297 }
298 
GetPort(int fd)299 uint16_t SocketOperation::GetPort(int fd) {
300   uint16_t port = 0;
301   int retval = 0;
302   union SocketAddress isa;
303   socklen_t isaLen = sizeof(struct sockaddr_storage);
304 
305   retval = getsockname(fd, &isa.sa, &isaLen);
306   if (retval > 0) {
307     MS_LOG(INFO) << "Failed to call getsockname, fd: " << fd << ", ret: " << retval << ", errno: " << errno;
308     return port;
309   }
310 
311   if (isa.sa.sa_family == AF_INET) {
312     port = ntohs(isa.saIn.sin_port);
313   } else if (isa.sa.sa_family == AF_INET6) {
314     port = ntohs(isa.saIn6.sin6_port);
315   } else {
316     MS_LOG(INFO) << "Unknown fd: " << fd << ", family: " << isa.sa.sa_family;
317   }
318   return port;
319 }
320 
GetPeer(int sock_fd)321 std::string SocketOperation::GetPeer(int sock_fd) {
322   std::string peer;
323   int retval = 0;
324   union SocketAddress isa;
325   socklen_t isaLen = sizeof(struct sockaddr_storage);
326 
327   retval = getpeername(sock_fd, &isa.sa, &isaLen);
328   if (retval < 0) {
329     MS_LOG(INFO) << "Failed to call getpeername, fd: " << sock_fd << ", ret: " << retval << ", errno: " << errno;
330     return peer;
331   }
332 
333   char ipdotdec[IP_LEN_MAX];
334   if (isa.sa.sa_family == AF_INET) {
335     if (inet_ntop(AF_INET, reinterpret_cast<void *>(&isa.saIn.sin_addr), ipdotdec, IP_LEN_MAX) == nullptr) {
336       MS_LOG(EXCEPTION) << "Failed to call inet_ntop kernel func.";
337     }
338     peer = std::string(ipdotdec) + ":" + std::to_string(ntohs(isa.saIn.sin_port));
339   } else if (isa.sa.sa_family == AF_INET6) {
340     if (inet_ntop(AF_INET6, reinterpret_cast<void *>(&isa.saIn6.sin6_addr), ipdotdec, IP_LEN_MAX) == nullptr) {
341       MS_LOG(ERROR) << "Failed to call inet_ntop.";
342     }
343     peer = std::string(ipdotdec) + ":" + std::to_string(ntohs(isa.saIn6.sin6_port));
344   } else {
345     MS_LOG(INFO) << "Unknown fd: " << sock_fd << ", family: " << isa.sa.sa_family;
346   }
347   return peer;
348 }
349 
Connect(int sock_fd,const struct sockaddr * sa,socklen_t saLen,uint16_t * boundPort)350 int SocketOperation::Connect(int sock_fd, const struct sockaddr *sa, socklen_t saLen, uint16_t *boundPort) {
351   if (sa == nullptr || boundPort == nullptr) {
352     return RPC_ERROR;
353   }
354   int retval = 0;
355 
356   retval = connect(sock_fd, sa, saLen);
357   if (retval != 0) {
358     if (errno == EINPROGRESS) {
359       /* set iomux for write event */
360     } else {
361       MS_LOG(ERROR) << "Failed to call connect, fd: " << sock_fd << ", ret: " << retval << ", errno: " << errno << " "
362                     << strerror(errno);
363       return retval;
364     }
365   }
366 
367   // to get local port
368   *boundPort = GetPort(sock_fd);
369   if (*boundPort == 0) {
370     return RPC_ERROR;
371   }
372   return RPC_OK;
373 }
374 
GetInterfaceName(SocketAddress * const addr)375 std::string SocketOperation::GetInterfaceName(SocketAddress *const addr) {
376   struct ifaddrs *if_address = nullptr;
377   struct ifaddrs *ifa = nullptr;
378   std::string if_name;
379   if (getifaddrs(&if_address) == -1) {
380     MS_LOG(WARNING) << "Get ifaddrs failed.";
381   }
382   for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
383     if (ifa->ifa_addr != nullptr && addr->sa.sa_family == ifa->ifa_addr->sa_family) {
384       if (addr->sa.sa_family == AF_INET) {
385         struct sockaddr_in *addr_in = reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr);
386         if (addr_in->sin_addr.s_addr == addr->saIn.sin_addr.s_addr) {
387           if_name = ifa->ifa_name;
388         }
389       }
390       if (addr->sa.sa_family == AF_INET6) {
391         struct sockaddr_in6 *addr_in6 = reinterpret_cast<struct sockaddr_in6 *>(ifa->ifa_addr);
392         if (memcmp(&addr_in6->sin6_addr, &addr->saIn6.sin6_addr, sizeof(addr_in6->sin6_addr)) == 0) {
393           if_name = ifa->ifa_name;
394         }
395       }
396     }
397   }
398   MS_EXCEPTION_IF_NULL(if_address);
399   freeifaddrs(if_address);
400   MS_LOG(INFO) << "Using interface name " << if_name;
401   return if_name;
402 }
403 
Listen(const std::string & url)404 int SocketOperation::Listen(const std::string &url) {
405   int listenFd = 0;
406   SocketAddress addr;
407 
408   if (!GetSockAddr(url, &addr)) {
409     return -1;
410   }
411 
412   // create server socket
413   listenFd = CreateSocket(addr.sa.sa_family);
414   if (listenFd < 0) {
415     MS_LOG(ERROR) << "Failed to create socket, url: " << url.c_str();
416     return -1;
417   }
418 
419   // bind
420   if (::bind(listenFd, reinterpret_cast<struct sockaddr *>(&addr), sizeof(SocketAddress)) != 0) {
421     MS_LOG(WARNING) << "Failed to call bind, url: " << url.c_str() << " " << strerror(errno);
422     if (close(listenFd) != 0) {
423       MS_LOG(EXCEPTION) << "Failed to close fd:" << listenFd;
424     }
425     // If this address is already in use, return -2 to the caller so it can distinguish from other return value.
426     if (errno == EADDRINUSE) {
427       return kAddressInUseError;
428     }
429     return -1;
430   }
431 
432   // listen
433   if (::listen(listenFd, SOCKET_LISTEN_BACKLOG) != 0) {
434     MS_LOG(ERROR) << "Failed to call listen, fd: " << listenFd << ", errno: " << errno << ", url: " << url.c_str()
435                   << " " << strerror(errno);
436     if (close(listenFd) != 0) {
437       MS_LOG(EXCEPTION) << "Failed to close fd:" << listenFd;
438     }
439     return -1;
440   }
441   return listenFd;
442 }
443 
Accept(int sock_fd)444 int SocketOperation::Accept(int sock_fd) {
445   SocketAddress storage;
446   socklen_t length = sizeof(storage);
447 
448   // accept connection
449   auto acceptFd =
450     ::accept4(sock_fd, reinterpret_cast<struct sockaddr *>(&storage), &length, SOCK_NONBLOCK | SOCK_CLOEXEC);
451   if (acceptFd < 0) {
452     MS_LOG(ERROR) << "Failed to call accept, errno: " << errno << ", server: " << sock_fd;
453     return acceptFd;
454   }
455   if (SetSocketOptions(acceptFd) < 0) {
456     MS_LOG(ERROR) << "Failed to set socket options for accepted socket: " << acceptFd;
457   }
458   return acceptFd;
459 }
460 }  // namespace rpc
461 }  // namespace distributed
462 }  // namespace mindspore
463