1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "distributed/rpc/tcp/connection.h"
18
19 #include <memory>
20 #include <utility>
21
22 #include "distributed/rpc/tcp/tcp_socket_operation.h"
23 #include "distributed/rpc/tcp/ssl_socket_operation.h"
24 #include "distributed/rpc/tcp/connection_pool.h"
25
26 namespace mindspore {
27 namespace distributed {
28 namespace rpc {
29 // Print error message every 1000 times and sleep for 5ms in case the log file is too large.
30 static size_t kPrintCount = 0;
31 static std::mutex kPrintCountMutex;
32 const size_t kPrintCountInterval = 1000;
33 const int kPrintTimeInterval = 50000;
34
35 // Handle socket events like read/write.
SocketEventHandler(int fd,uint32_t events,void * context)36 void SocketEventHandler(int fd, uint32_t events, void *context) {
37 Connection *conn = reinterpret_cast<Connection *>(context);
38 if (conn == nullptr) {
39 return;
40 }
41
42 if (fd != conn->socket_fd) {
43 MS_LOG(ERROR) << "Failed to reuse connection, delete and close fd: " << fd << ", connfd: " << conn->socket_fd
44 << ", event: " << events;
45 if (conn->recv_event_loop->DeleteEpollEvent(fd) != RPC_OK) {
46 MS_LOG(ERROR) << "Failed to delete epoll event for fd: " << fd;
47 }
48 conn->state = ConnectionState::kDisconnecting;
49 if (conn->event_callback != nullptr) {
50 conn->event_callback(conn);
51 } else {
52 MS_LOG(ERROR) << "No event_callback found for fd: " << fd << ", events: " << events;
53 }
54 return;
55 }
56 // Handle write event.
57 if ((events & EPOLLOUT) > 0) {
58 (void)conn->recv_event_loop->UpdateEpollEvent(fd, EPOLLIN | EPOLLHUP | EPOLLERR);
59 if (conn->write_callback != nullptr) {
60 conn->write_callback(conn);
61 }
62 }
63 // Handle read event.
64 if (events & EPOLLIN) {
65 if (conn->read_callback != nullptr) {
66 conn->read_callback(conn);
67 }
68 }
69
70 std::lock_guard<std::mutex> conn_lock(conn->conn_owned_mutex_);
71 // Handle disconnect event.
72 if (conn->state == ConnectionState::kDisconnecting || (events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR))) {
73 {
74 std::lock_guard<std::mutex> count_lock(kPrintCountMutex);
75 if (kPrintCount++ % kPrintCountInterval == 0) {
76 MS_LOG(WARNING) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
77 << ", errcode: " << conn->error_code << ", errno: " << errno << " " << strerror(errno)
78 << ", to: " << conn->destination.c_str() << ", type:" << conn->recv_message_type
79 << ", remote: " << conn->is_remote << ", count: " << kPrintCount << ", from: " << conn->source;
80 (void)usleep(kPrintTimeInterval);
81 }
82 }
83 conn->state = ConnectionState::kDisconnecting;
84 if (conn->event_callback != nullptr) {
85 conn->event_callback(conn);
86 } else {
87 MS_LOG(ERROR) << "No event_callback found for fd: " << fd << ", events: " << events;
88 }
89 }
90 }
91
92 // Handle new connect event.
NewConnectEventHandler(int fd,uint32_t events,void * context)93 void NewConnectEventHandler(int fd, uint32_t events, void *context) {
94 int retval = 0;
95 Connection *conn = reinterpret_cast<Connection *>(context);
96 if (conn == nullptr) {
97 return;
98 }
99 conn->socket_operation->NewConnEventHandler(fd, events, context);
100
101 if (conn->state == ConnectionState::kDisconnecting) {
102 conn->Disconnect(fd);
103 return;
104 } else if (conn->state != ConnectionState::kConnected) {
105 // The handshake is not complete
106 return;
107 }
108
109 retval = conn->recv_event_loop->DeleteEpollEvent(fd);
110 if (retval > 0) {
111 MS_LOG(ERROR) << "Failed to remove epoll remove connect handler for fd: " << fd;
112 return;
113 }
114
115 retval = conn->recv_event_loop->SetEventHandler(conn->socket_fd, EPOLLIN | EPOLLHUP | EPOLLRDHUP | EPOLLERR,
116 SocketEventHandler, reinterpret_cast<void *>(conn));
117 if (retval != RPC_OK) {
118 MS_LOG(ERROR) << "Failed to add socket event handler for fd: " << fd << ", events: " << events;
119 conn->Disconnect(fd);
120 return;
121 }
122
123 conn->write_callback(conn);
124 SocketEventHandler(fd, events, context);
125 return;
126 }
127
Connection()128 Connection::Connection()
129 : socket_fd(-1),
130 deleted(false),
131 is_remote(false),
132 type(kTcp),
133 socket_operation(nullptr),
134 state(kInit),
135 send_event_loop(nullptr),
136 recv_event_loop(nullptr),
137 send_metrics(nullptr),
138 send_message(nullptr),
139 recv_message(nullptr),
140 recv_state(kMsgHeader),
141 total_recv_len(0),
142 total_send_len(0),
143 recv_len(0),
144 event_callback(nullptr),
145 succ_callback(nullptr),
146 write_callback(nullptr),
147 read_callback(nullptr),
148 output_buffer_size(0),
149 error_code(0) {
150 // Initialize the recv kernel message structure.
151 recv_kernel_msg.msg_control = nullptr;
152 recv_kernel_msg.msg_controllen = 0;
153 recv_kernel_msg.msg_flags = 0;
154 recv_kernel_msg.msg_name = nullptr;
155 recv_kernel_msg.msg_namelen = 0;
156 recv_kernel_msg.msg_iov = recv_io_vec;
157 recv_kernel_msg.msg_iovlen = RECV_MSG_IO_VEC_LEN;
158
159 // Initialize the send message header.
160 // This variable will be deleted in the `Close` method.
161 send_metrics = new SendMetrics();
162 for (unsigned int i = 0; i < MAGICID_LEN; i++) {
163 if (i < sizeof(RPC_MAGICID) - 1) {
164 send_msg_header.magic[i] = RPC_MAGICID[i];
165 } else {
166 send_msg_header.magic[i] = '\0';
167 }
168 }
169
170 // Initialize the send kernel message structure.
171 send_kernel_msg.msg_control = nullptr;
172 send_kernel_msg.msg_controllen = 0;
173 send_kernel_msg.msg_flags = 0;
174 send_kernel_msg.msg_name = nullptr;
175 send_kernel_msg.msg_namelen = 0;
176 send_kernel_msg.msg_iov = send_io_vec;
177 send_kernel_msg.msg_iovlen = SEND_MSG_IO_VEC_LEN;
178 }
179
Initialize()180 int Connection::Initialize() {
181 InitSocketOperation();
182 return AddConnnectEventHandler();
183 }
184
InitSocketOperation()185 void Connection::InitSocketOperation() {
186 if (socket_operation != nullptr) {
187 return;
188 }
189 // This variable will be deleted in the `Close` method.
190 if (!enable_ssl) {
191 socket_operation = new (std::nothrow) TCPSocketOperation();
192 } else {
193 socket_operation = new (std::nothrow) SSLSocketOperation();
194 }
195 MS_EXCEPTION_IF_NULL(socket_operation);
196 if (!socket_operation->Initialize()) {
197 MS_LOG(EXCEPTION) << "Failed to initialize the socket operation.";
198 }
199 }
200
ReconnectSourceSocket(int fd,uint32_t events,int * soError,uint32_t error)201 bool Connection::ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error) {
202 if (soError == nullptr) {
203 return false;
204 }
205 MS_EXCEPTION_IF_NULL(recv_event_loop);
206 socklen_t len = sizeof(*soError);
207
208 int retval = recv_event_loop->DeleteEpollEvent(fd);
209 if (retval > 0) {
210 MS_LOG(ERROR) << "Failed to delete event for fd: " << fd << ", event: " << events;
211 return false;
212 }
213
214 retval = getsockopt(fd, SOL_SOCKET, SO_ERROR, soError, &len);
215 if (retval > 0) {
216 *soError = errno;
217 }
218 if (*soError > 0 || error > 0) {
219 MS_LOG(INFO) << "Connection from " << source << " to " << destination << " is not available yet, errno: " << errno
220 << " " << strerror(errno) << ", events: " << events;
221 return false;
222 }
223 retval = recv_event_loop->SetEventHandler(socket_fd, EPOLLIN | EPOLLHUP | EPOLLRDHUP | EPOLLERR, SocketEventHandler,
224 reinterpret_cast<void *>(this));
225 if (retval != RPC_OK) {
226 MS_LOG(ERROR) << "Failed to add socket event handler for fd: " << fd << ", events: " << events;
227 return false;
228 }
229 return true;
230 }
231
Disconnect(int fd)232 void Connection::Disconnect(int fd) {
233 if (LOG_CHECK_EVERY_N()) {
234 MS_LOG(INFO) << "New connection fail fd: " << fd << ", state: " << state << ", errno: " << errno
235 << ", to: " << destination.c_str() << ", type: " << recv_message_type;
236 }
237 state = ConnectionState::kDisconnecting;
238 event_callback(this);
239 return;
240 }
241
Close()242 void Connection::Close() {
243 if (recv_event_loop != nullptr) {
244 if (recv_event_loop->DeleteEpollEvent(socket_fd) == RPC_ERROR) {
245 MS_LOG(WARNING) << "Failed to delete epoll event " << socket_fd;
246 }
247 }
248
249 // There's no need to release the recv_message because the lifecycle of this data is passed to the caller.
250
251 if (total_send_len != 0 && send_message != nullptr) {
252 delete send_message;
253 send_message = nullptr;
254 }
255
256 MessageBase *tmpMsg = nullptr;
257 while (!send_message_queue.empty()) {
258 tmpMsg = send_message_queue.front();
259 send_message_queue.pop();
260 delete tmpMsg;
261 tmpMsg = nullptr;
262 }
263
264 if (socket_operation != nullptr) {
265 socket_operation->Close(this);
266 delete socket_operation;
267 socket_operation = nullptr;
268 }
269
270 if (send_metrics != nullptr) {
271 delete send_metrics;
272 send_metrics = nullptr;
273 }
274 }
275
ReceiveMessage()276 int Connection::ReceiveMessage() {
277 bool ok = ParseMessage();
278 // If no message parsed, wait for next read
279 if (!ok) {
280 if (state == ConnectionState::kDisconnecting) {
281 return -1;
282 }
283 return 0;
284 }
285
286 // Call msg handler if set
287 if (message_handler) {
288 auto result = message_handler(recv_message);
289 if (result != rpc::NULL_MSG) {
290 // Send the result message back to the tcp client if any.
291 FillSendMessage(result, "", false);
292 (void)Flush();
293 }
294 } else {
295 MS_LOG(INFO) << "Message handler was not found";
296 }
297 return 1;
298 }
299
CheckMessageType()300 void Connection::CheckMessageType() {
301 if (recv_message_type != ParseType::kUnknown) {
302 return;
303 }
304
305 std::string magic_id = "";
306 magic_id.resize(sizeof(RPC_MAGICID) - 1);
307 char *buf = const_cast<char *>(magic_id.data());
308
309 ssize_t size = socket_operation->ReceivePeek(this, buf, sizeof(RPC_MAGICID) - 1);
310 if (size < static_cast<int>(sizeof(RPC_MAGICID) - 1)) {
311 if (size == 0) {
312 MS_LOG(INFO) << "Set connection disconnecting for fd: " << socket_fd << ", size: " << size
313 << ", magic size: " << static_cast<int>(sizeof(RPC_MAGICID) - 1) << ", errno: " << errno;
314 state = ConnectionState::kDisconnecting;
315 }
316 return;
317 }
318 if (strncmp(RPC_MAGICID, magic_id.c_str(), sizeof(RPC_MAGICID) - 1) == 0) {
319 recv_state = State::kMsgHeader;
320 recv_message_type = ParseType::kTcpMsg;
321 }
322 return;
323 }
324
GenerateHttpMessage(MessageBase * msg)325 std::string Connection::GenerateHttpMessage(MessageBase *msg) {
326 if (msg == nullptr) {
327 return "";
328 }
329 static const std::string postLineBegin = std::string() + "POST /";
330 static const std::string postLineEnd = std::string() + " HTTP/1.1\r\n";
331 static const std::string userAgentLineBegin = std::string() + "User-Agent: libprocess/";
332 static const std::string fromLineBegin = std::string() + "Libprocess-From: ";
333 static const std::string connectLine = std::string() + "Connection: Keep-Alive\r\n";
334 static const std::string hostLine = std::string() + "Host: \r\n";
335 static const std::string chunkedBeginLine = std::string() + "Transfer-Encoding: chunked\r\n\r\n";
336 static const std::string chunkedEndLine = std::string() + "\r\n" + "0\r\n" + "\r\n";
337 static const std::string commonEndLine = std::string() + "\r\n";
338
339 std::string postLine;
340 if (msg->To().Name() != "") {
341 postLine = postLineBegin + msg->To().Name() + "/" + msg->Name() + postLineEnd;
342 } else {
343 postLine = postLineBegin + msg->Name() + postLineEnd;
344 }
345
346 std::string userAgentLine = userAgentLineBegin + msg->From().Name() + "@" + advertise_addr_ + commonEndLine;
347 std::string fromLine = fromLineBegin + msg->From().Name() + "@" + advertise_addr_ + commonEndLine;
348
349 if (msg->Body().size() > 0) {
350 std::ostringstream bodyLine;
351 bodyLine << std::hex << msg->Body().size() << "\r\n";
352 (void)bodyLine.write(msg->Body().data(), msg->Body().size());
353 return postLine + userAgentLine + fromLine + connectLine + hostLine + chunkedBeginLine + bodyLine.str() +
354 chunkedEndLine;
355 }
356 return postLine + userAgentLine + fromLine + connectLine + hostLine + commonEndLine;
357 }
358
FillSendMessage(MessageBase * msg,const std::string & advertiseUrl,bool isHttpKmsg)359 void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg) {
360 if (msg == nullptr || send_metrics == nullptr) {
361 return;
362 }
363 if (msg->type == MessageBase::Type::KMSG) {
364 // The total len of array variable `send_io_vec` is `SEND_MSG_IO_VEC_LEN` whose value is 5 currently.
365 size_t index = 0;
366 if (!isHttpKmsg) {
367 send_to = msg->to;
368 send_from = msg->from;
369 FillMessageHeader(*msg, &send_msg_header);
370
371 send_io_vec[index].iov_base = &send_msg_header;
372 send_io_vec[index].iov_len = sizeof(send_msg_header);
373 ++index;
374 send_io_vec[index].iov_base = const_cast<char *>(msg->name.data());
375 send_io_vec[index].iov_len = msg->name.size();
376 ++index;
377 send_io_vec[index].iov_base = const_cast<char *>(send_to.data());
378 send_io_vec[index].iov_len = send_to.size();
379 ++index;
380 send_io_vec[index].iov_base = const_cast<char *>(send_from.data());
381 send_io_vec[index].iov_len = send_from.size();
382 ++index;
383 send_io_vec[index].iov_base = GetMessageBaseRealData(msg);
384 // The real size of the data body.
385 size_t real_data_size = GetMessageBaseRealDataSize(msg);
386 send_io_vec[index].iov_len = real_data_size;
387 ++index;
388 send_kernel_msg.msg_iov = send_io_vec;
389 send_kernel_msg.msg_iovlen = index;
390 total_send_len =
391 UlongToUint(sizeof(send_msg_header)) + msg->name.size() + send_to.size() + send_from.size() + real_data_size;
392 send_message = msg;
393
394 // update metrics
395 send_metrics->UpdateMax(real_data_size);
396 send_metrics->last_send_msg_name = msg->name;
397 return;
398 } else {
399 if (advertise_addr_.empty()) {
400 size_t idx = advertiseUrl.find(URL_PROTOCOL_IP_SEPARATOR);
401 if (idx == std::string::npos) {
402 advertise_addr_ = advertiseUrl;
403 } else {
404 advertise_addr_ = advertiseUrl.substr(idx + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
405 }
406 }
407 msg->body = GenerateHttpMessage(msg);
408 }
409
410 send_io_vec[index].iov_base = GetMessageBaseRealData(msg);
411 size_t real_data_size = GetMessageBaseRealDataSize(msg);
412 send_io_vec[index].iov_len = real_data_size;
413 ++index;
414 send_kernel_msg.msg_iov = send_io_vec;
415 send_kernel_msg.msg_iovlen = index;
416 total_send_len = UlongToUint(real_data_size);
417 send_message = msg;
418
419 // update metrics
420 send_metrics->UpdateMax(real_data_size);
421 send_metrics->last_send_msg_name = msg->name;
422 }
423 }
424
FillRecvMessage()425 void Connection::FillRecvMessage() {
426 size_t recvNameLen = static_cast<size_t>(recv_msg_header.name_len);
427 size_t recvToLen = static_cast<size_t>(recv_msg_header.to_len);
428 size_t recvFromLen = static_cast<size_t>(recv_msg_header.from_len);
429 size_t recvBodyLen = static_cast<size_t>(recv_msg_header.body_len);
430 if (recvNameLen > MAX_KMSG_NAME_LEN || recvToLen > MAX_KMSG_TO_LEN || recvFromLen > MAX_KMSG_FROM_LEN ||
431 recvBodyLen > MAX_KMSG_BODY_LEN) {
432 MS_LOG(ERROR) << "Drop invalid tcp data.";
433 state = ConnectionState::kDisconnecting;
434 return;
435 }
436
437 // The total len of array variable `recv_io_vec` is `RECV_MSG_IO_VEC_LEN` whose value is 4 currently.
438 int i = 0;
439
440 // This new message will be assigned to `recv_message` later.
441 MessageBase *msg = new (std::nothrow) MessageBase();
442 MS_EXCEPTION_IF_NULL(msg);
443
444 msg->name.resize(recvNameLen);
445 recv_to.resize(recvToLen);
446 recv_from.resize(recvFromLen);
447
448 if (allocate_cb_) {
449 void *allocated_mem = allocate_cb_(recvBodyLen);
450 msg->data = allocated_mem;
451 msg->size = recvBodyLen;
452 } else {
453 msg->body.resize(recvBodyLen);
454 }
455
456 recv_io_vec[i].iov_base = const_cast<char *>(msg->name.data());
457 recv_io_vec[i].iov_len = msg->name.size();
458 ++i;
459 recv_io_vec[i].iov_base = const_cast<char *>(recv_to.data());
460 recv_io_vec[i].iov_len = recv_to.size();
461 ++i;
462 recv_io_vec[i].iov_base = const_cast<char *>(recv_from.data());
463 recv_io_vec[i].iov_len = recv_from.size();
464 ++i;
465 recv_io_vec[i].iov_base = GetMessageBaseRealData(msg);
466 // The real size of the data body.
467 size_t real_data_size = GetMessageBaseRealDataSize(msg);
468 recv_io_vec[i].iov_len = real_data_size;
469 ++i;
470
471 recv_kernel_msg.msg_iov = recv_io_vec;
472 recv_kernel_msg.msg_iovlen = IntToSize(i);
473 total_recv_len = msg->name.size() + recv_to.size() + recv_from.size() + real_data_size;
474
475 // There is no need to delete recv_message first because the recv_message has already been returned to the caller and
476 // it's the caller's responsibility to release the received message after using it.
477 // The real data raw pointer is allocated by callback set by the caller. So the caller should be responsible for its
478 // releasing as well.
479 recv_message = msg;
480 }
481
Flush()482 size_t Connection::Flush() {
483 size_t total_send_bytes = 0;
484 while (!send_message_queue.empty() || total_send_len != 0) {
485 if (total_send_len == 0) {
486 FillSendMessage(send_message_queue.front(), source, false);
487 send_message_queue.pop();
488 }
489 size_t sendLen = 0;
490 int retval = socket_operation->SendMessage(this, &send_kernel_msg, total_send_len, &sendLen);
491 if (retval == IO_RW_OK && sendLen > 0) {
492 total_send_len -= sendLen;
493 if (total_send_len == 0) {
494 // update metrics
495 send_metrics->UpdateError(false);
496
497 size_t real_data_size = GetMessageBaseRealDataSize(send_message);
498 output_buffer_size -= real_data_size;
499 total_send_bytes += real_data_size;
500
501 if (!FreeMessageMemory(send_message)) {
502 MS_LOG(ERROR) << "Failed to free memory of the send message.";
503 }
504 delete send_message;
505 send_message = nullptr;
506 break;
507 }
508 } else if (retval == IO_RW_OK && sendLen == 0) {
509 // EAGAIN
510 MS_LOG(ERROR) << "Failed to send message and update the epoll event";
511 (void)recv_event_loop->UpdateEpollEvent(socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR);
512 continue;
513 } else {
514 // update metrics
515 send_metrics->UpdateError(true, error_code);
516 MS_LOG(WARNING) << "Failed to send data, change connection state to disconnecting, errno: " << errno << " "
517 << strerror(errno);
518 state = ConnectionState::kDisconnecting;
519 break;
520 }
521 }
522 return total_send_bytes;
523 }
524
AddConnnectEventHandler()525 int Connection::AddConnnectEventHandler() {
526 return recv_event_loop->SetEventHandler(socket_fd, EPOLLIN | EPOLLHUP | EPOLLERR, NewConnectEventHandler,
527 reinterpret_cast<void *>(this));
528 }
529
ParseMessage()530 bool Connection::ParseMessage() {
531 int retval = 0;
532 size_t recvLen = 0;
533 char *recvBuf = nullptr;
534
535 switch (recv_state) {
536 // Parse message header.
537 case State::kMsgHeader:
538 recvBuf = reinterpret_cast<char *>(&recv_msg_header) + recv_len;
539 retval = socket_operation->Receive(this, recvBuf, sizeof(MessageHeader) - recv_len, &recvLen);
540 if (retval != IO_RW_OK) {
541 state = ConnectionState::kDisconnecting;
542 recv_len += recvLen;
543 return false;
544 }
545 if ((recvLen + recv_len) != sizeof(MessageHeader)) {
546 recv_len += recvLen;
547 return false;
548 }
549 recv_len = 0;
550
551 if (strncmp(recv_msg_header.magic, RPC_MAGICID, sizeof(RPC_MAGICID) - 1) != 0) {
552 MS_LOG(ERROR) << "Failed to check magicid, RPC_MAGICID: " << RPC_MAGICID
553 << ", recv magic_id: " << recv_msg_header.magic;
554 state = ConnectionState::kDisconnecting;
555 return false;
556 }
557 ReorderHeader(&recv_msg_header);
558 FillRecvMessage();
559 if (state == ConnectionState::kDisconnecting) {
560 return false;
561 }
562 recv_state = State::kBody;
563
564 // Parse message body.
565 case State::kBody:
566 recvLen = 0;
567 retval = socket_operation->ReceiveMessage(this, &recv_kernel_msg, total_recv_len, &recvLen);
568 if (recvLen != total_recv_len) {
569 if (retval != IO_RW_OK) {
570 state = ConnectionState::kDisconnecting;
571 return false;
572 }
573 total_recv_len -= recvLen;
574 return false;
575 }
576 if (!SetUrlForRecvMessage()) {
577 MS_LOG(ERROR) << "Set url info for recv message failed.";
578 return false;
579 }
580 recv_state = State::kMsgHeader;
581 break;
582 default:
583 return false;
584 }
585 return true;
586 }
587
SetUrlForRecvMessage()588 bool Connection::SetUrlForRecvMessage() {
589 auto recv_from_separator_pos = recv_from.find('@');
590 auto recv_to_separator_pos = recv_to.find('@');
591 if (recv_from_separator_pos == std::string::npos && recv_to_separator_pos == std::string::npos) {
592 MS_LOG(ERROR) << "Invalid message format, can not find separator '@'";
593 return false;
594 }
595
596 std::string from_name = recv_from.substr(0, recv_from_separator_pos);
597 std::string from_url = recv_from.substr(recv_from_separator_pos + 1);
598 std::string to_name = recv_to.substr(0, recv_to_separator_pos);
599 std::string to_url = recv_to.substr(recv_to_separator_pos + 1);
600 recv_message->from = AID(from_name, from_url);
601 recv_message->to = AID(to_name, to_url);
602
603 return true;
604 }
605
ReorderHeader(MessageHeader * header) const606 void Connection::ReorderHeader(MessageHeader *header) const {
607 header->name_len = ntohl(header->name_len);
608 header->to_len = ntohl(header->to_len);
609 header->from_len = ntohl(header->from_len);
610 header->body_len = ntohl(header->body_len);
611 }
612
FreeMessageMemory(MessageBase * msg)613 bool Connection::FreeMessageMemory(MessageBase *msg) {
614 if (msg == nullptr) {
615 MS_LOG(ERROR) << "The message is nullptr.";
616 return false;
617 }
618 if (msg->data == nullptr) {
619 MS_LOG(DEBUG) << "No need to free the raw pointer of message.";
620 return true;
621 }
622
623 // Use callback to release the real memory of the data.
624 if (!free_cb_) {
625 MS_LOG(ERROR) << "The free memory callback is not set. Can't free the data in message.";
626 return false;
627 }
628 bool free_result = free_cb_(msg->data);
629 if (!free_result) {
630 MS_LOG(ERROR) << "Failed to free message data memory.";
631 return false;
632 }
633 return true;
634 }
635
GetMessageBaseRealData(const MessageBase * msg) const636 void *Connection::GetMessageBaseRealData(const MessageBase *msg) const {
637 MS_ERROR_IF_NULL_W_RET_VAL(msg, nullptr);
638 // The 'data' attribute is preferred.
639 if (msg->data != nullptr) {
640 return msg->data;
641 }
642
643 // Parse 'body' attribute if 'data' is empty.
644 if (!msg->body.empty()) {
645 return const_cast<char *>(msg->body.data());
646 }
647
648 MS_LOG(ERROR) << "The message object has neither 'data' nor 'body' attributes.";
649 return nullptr;
650 }
651
GetMessageBaseRealDataSize(const MessageBase * msg) const652 size_t Connection::GetMessageBaseRealDataSize(const MessageBase *msg) const {
653 MS_ERROR_IF_NULL_W_RET_VAL(msg, 0);
654 // The 'size' attribute is preferred.
655 if (msg->data != nullptr) {
656 return msg->size;
657 }
658
659 // Parse 'body' attribute if 'data' is empty.
660 if (!msg->body.empty()) {
661 return msg->body.size();
662 }
663
664 MS_LOG(ERROR) << "The message object has neither 'data' nor 'body' attributes.";
665 return 0;
666 }
667 } // namespace rpc
668 } // namespace distributed
669 } // namespace mindspore
670