• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_H_
19 
20 #include <queue>
21 #include <string>
22 #include <mutex>
23 #include <memory>
24 
25 #include "actor/msg.h"
26 #include "include/backend/distributed/rpc/tcp/constants.h"
27 #include "distributed/rpc/tcp/event_loop.h"
28 #include "distributed/rpc/tcp/socket_operation.h"
29 
30 namespace mindspore {
31 namespace distributed {
32 namespace rpc {
33 /*
34  * The SendMetrics is responsible for collecting metrics when sending data through a connection.
35  */
36 struct SendMetrics {
37   // Records the message number and max body size.
UpdateMaxSendMetrics38   void UpdateMax(size_t size) {
39     accum_msg_count++;
40     if (size > max_msg_size) {
41       max_msg_size = size;
42     }
43   }
44 
45   // Records the latest error message.
46   void UpdateError(bool fail, int err = 0) {
47     if (fail) {
48       last_fail_msg_name = last_send_msg_name;
49       error_code = err;
50     } else {
51       last_succ_msg_name = last_send_msg_name;
52     }
53   }
54 
55   // Reset all the metrics info.
ResetSendMetrics56   void Reset() {
57     accum_msg_count = 0;
58     max_msg_size = 0;
59     error_code = 0;
60     last_succ_msg_name = "";
61     last_fail_msg_name = "";
62     last_send_msg_name = "";
63   }
64 
65   // The total number of bytes sent already.
66   size_t accum_msg_count{0};
67 
68   // The max message body size sent in bytes.
69   size_t max_msg_size{0};
70   int error_code{0};
71 
72   std::string last_succ_msg_name;
73   std::string last_fail_msg_name;
74   std::string last_send_msg_name;
75 };
76 
77 /*
78  * Represents a TCP or SSL connection.
79  */
80 struct Connection {
81  public:
82   Connection();
83   ~Connection() = default;
84 
85   // Initialize the connection(eg. add some socket event handlers).
86   int Initialize();
87 
88   // Create a new socket operation if needed.
89   void InitSocketOperation();
90 
91   // Delete this socket fd(source client socket) and add back to the connection.
92   bool ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error);
93 
94   // Disconnect the socket fd from source.
95   void Disconnect(int fd);
96 
97   // Close this connection.
98   void Close();
99 
100   int ReceiveMessage();
101   void CheckMessageType();
102 
103   // Fill the message to be sent based on the input message.
104   void FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg);
105 
106   void FillRecvMessage();
107 
IsSameConnection108   bool IsSame(const Connection *that) {
109     return !(that != nullptr && that->destination == destination && that->is_remote == is_remote);
110   }
111 
112   // Send all the messages in the message queue.
113   size_t Flush();
114 
115   /**
116    * @description: Set callback to allocate memory for this connection when receiving message from the remote.
117    * @param {MemAllocateCallback} &allocate_cb: The allocating memory callback.
118    * @return {void}
119    */
SetAllocateCallbackConnection120   void SetAllocateCallback(const MemAllocateCallback &allocate_cb) { allocate_cb_ = allocate_cb; }
121 
122   /**
123    * @description: Set callback to free message for this connection.
124    * @param {MemFreeCallback} &free_cb: The callback which frees the real memory after message is sent to peer.
125    * @return {void}
126    */
SetMessageFreeCallbackConnection127   void SetMessageFreeCallback(const MemFreeCallback &free_cb) { free_cb_ = free_cb; }
128 
129   /**
130    * @description: Returns the free callback.
131    * @return {const MemFreeCallback &}
132    */
free_cbConnection133   const MemFreeCallback &free_cb() const { return free_cb_; }
134 
135   /**
136    * @description: Free the real data of the message using free callback.
137    * @param {MessageBase} *msg: The MessageBase object.
138    * @return {bool}: Whether successfully freeing the real data.
139    */
140   bool FreeMessageMemory(MessageBase *msg);
141 
142   // The socket used by this connection.
143   int socket_fd;
144 
145   // Indicates whether this connection is deleted from link manager.
146   bool deleted;
147 
148   // Indicates the priority of this connection.
149   ConnectionPriority priority{ConnectionPriority::kPriorityHigh};
150 
151   // Indicates whether this connection is connected from remote client.
152   // A connection is remote only when the connection is created by the `OnAccept` callback.
153   bool is_remote;
154 
155   // TCP or SSL.
156   ConnectionType type;
157 
158   // The socket address(ip:port) of client and server of this connection.
159   std::string source;
160   std::string destination;
161 
162   // Peer address.
163   std::string peer;
164 
165   // Specific operations for the socket in this connection.
166   SocketOperation *socket_operation;
167 
168   bool enable_ssl{false};
169 
170   // The state of this connection(eg. kInit/kConnecting/..)
171   ConnectionState state{kInit};
172 
173   // The threads for handling the receive and send requsets on this connection.
174   EventLoop *send_event_loop;
175   EventLoop *recv_event_loop;
176 
177   // Collects data sending metrics.
178   SendMetrics *send_metrics;
179 
180   // The message data waiting to be sent and receive through this connection..
181   MessageBase *send_message;
182   MessageBase *recv_message;
183 
184   // Owned by the tcp_comm.
185   std::shared_ptr<std::mutex> conn_mutex;
186 
187   // Owned by connection itself.
188   std::mutex conn_owned_mutex_;
189 
190   State recv_state;
191 
192   // Total length of received and sent messages.
193   size_t total_recv_len;
194   size_t total_send_len;
195   size_t recv_len;
196 
197   std::string send_to;
198   std::string send_from;
199   std::string recv_to;
200   std::string recv_from;
201 
202   // Message header.
203   MessageHeader send_msg_header;
204   MessageHeader recv_msg_header;
205 
206   // The message structure of kernel.
207   struct msghdr send_kernel_msg;
208   struct msghdr recv_kernel_msg;
209 
210   struct iovec recv_io_vec[RECV_MSG_IO_VEC_LEN];
211   struct iovec send_io_vec[SEND_MSG_IO_VEC_LEN];
212 
213   ParseType recv_message_type{kTcpMsg};
214 
215   // Callbacks for io events
216   ConnectionCallBack event_callback;
217   ConnectionCallBack succ_callback;
218   ConnectionCallBack write_callback;
219   ConnectionCallBack read_callback;
220 
221   // Function for handling received messages.
222   MessageHandler message_handler;
223 
224   // Buffer for messages to be sent.
225   std::queue<MessageBase *> send_message_queue;
226 
227   uint64_t output_buffer_size;
228 
229   // The error code when sending or receiving messages.
230   int error_code;
231 
232   // The method used to allocate memory when server receiving message from the remote.
233   MemAllocateCallback allocate_cb_;
234 
235   // The method used to free the memory after client sending data to the remote.
236   MemFreeCallback free_cb_;
237 
238  private:
239   // Add handler for socket connect event.
240   int AddConnnectEventHandler();
241 
242   // Parse message from socket recv buffer.
243   bool ParseMessage();
244 
245   // After ParseMessage, set from url and to url into recv message.
246   bool SetUrlForRecvMessage();
247 
248   // Make a http message based on given input message.
249   std::string GenerateHttpMessage(MessageBase *msg);
250 
251   // Change the header body from network byte order to host byte order.
252   void ReorderHeader(MessageHeader *header) const;
253 
254   /**
255    * @description: Get the real data pointer of the message.
256    * @param {MessageBase} *msg: The MessageBase object.
257    * @return {void *}: The pointer to the memory of the real data.
258    */
259   void *GetMessageBaseRealData(const MessageBase *msg) const;
260 
261   /**
262    * @description: Get size of the real data size.
263    * @param {MessageBase} *msg: The MessageBase object.
264    * @return {size_t}: The size of the real data.
265    */
266   size_t GetMessageBaseRealDataSize(const MessageBase *msg) const;
267 
268   std::string advertise_addr_;
269 };
270 }  // namespace rpc
271 }  // namespace distributed
272 }  // namespace mindspore
273 
274 #endif
275