• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
18 #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
19 
20 #include <functional>
21 #include <map>
22 #include <queue>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 #include <memory>
27 #include <string>
28 #include "include/backend/distributed/ps/constants.h"
29 #include "ps/core/communicator/communicator_base.h"
30 #include "ps/core/communicator/message.h"
31 #include "ps/core/communicator/task_executor.h"
32 #include "ps/core/node.h"
33 #include "ps/core/node_info.h"
34 #include "ps/core/recovery_base.h"
35 #include "utils/ms_exception.h"
36 
37 namespace mindspore {
38 namespace ps {
39 namespace core {
40 class AbstractNode : public Node {
41  public:
AbstractNode()42   AbstractNode()
43       : heart_beat_thread_(nullptr),
44         client_to_scheduler_thread_(nullptr),
45         client_to_scheduler_(nullptr),
46         client_to_server_(nullptr),
47         server_(nullptr),
48         server_thread_(nullptr),
49         worker_num_(0),
50         server_num_(0),
51         is_connected_to_scheduler_(false),
52         is_current_node_scale_in_(false),
53         node_recovery_(nullptr),
54         persistent_state_(PersistentState::NOT_ENABLE_PERSIST),
55         scheduler_ip_(""),
56         scheduler_port_(0),
57         is_recover(false) {}
58   ~AbstractNode() override;
59 
60   typedef void (AbstractNode::*ResponseHandler)(const std::shared_ptr<MessageMeta> &meta, const void *data,
61                                                 size_t size);
62   typedef void (AbstractNode::*ServerHandler)(const std::shared_ptr<TcpConnection> &conn,
63                                               const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
64                                               const void *data, size_t size);
65 
66   using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
67   using RequestHandler = std::function<void(const std::shared_ptr<TcpConnection> &conn,
68                                             const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size)>;
69   using CancelSafeModeFn = std::function<void()>;
70 
71   bool Broadcast(const NodeRole &node_role, const std::string &message, int command,
72                  const uint32_t &timeout = kCommTimeoutInSeconds);
73 
74   // When the business layer finish scale out, it should call this function
75   void set_ready_for_scale_out();
76   // When the business layer finish scale in, it should call this function
77   void set_ready_for_scale_in();
78 
79   // Send scale_out_done instructions to the scheduler.
80   void set_scale_out_done();
81 
82   // Send scale_in_done instructions to the scheduler.
83   void set_scale_in_done();
84 
85   // The worker/server sends the event to the scheduler, and then the scheduler broadcasts this event to all nodes.
86   void BroadcastEvent(const uint32_t &event);
87 
88   // Set the callback corresponding to the event.
89   void RegisterEventCallback(const ClusterEvent &event, const EventCallback &event_cb);
90   // Set the callback corresponding to the custom event.
91   void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb);
92 
93   bool Send(const NodeRole &node_role, const uint32_t &rank_id, const void *message, size_t len, int command,
94             VectorPtr *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds);
95 
96   bool Send(const NodeRole &node_role, const uint32_t &rank_id, const std::string &msg, int command,
97             VectorPtr *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds);
98   bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &msgs,
99             int command, std::vector<VectorPtr> *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds);
100 
101   // The interface that sends sync message to the scheduler.
102   bool SendToScheduler(const void *message, size_t len, NodeCommand command, VectorPtr *output = nullptr,
103                        const uint32_t &timeout = kCommTimeoutInSeconds);
104 
105   uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
106 
107   using CheckFailReturnFun = std::function<bool()>;
108   uint64_t FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data, size_t size);
109   bool FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output,
110                         const uint32_t &timeout = kCommTimeoutInSeconds);
111 
112   std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
113                                                        VectorPtr *output);
114   bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
115 
116   PersistentState persistent_state() const;
117   void set_persistent_state(PersistentState persistent_state);
118 
119   uint32_t worker_num() const;
120   uint32_t server_num() const;
121 
122   void set_worker_num(const uint32_t &worker_num);
123   void set_server_num(const uint32_t &server_num);
124 
125   std::string scheduler_ip() const;
126   void set_scheduler_ip(const std::string &scheduler_ip);
127 
128   uint16_t scheduler_port() const;
129   void set_scheduler_port(const uint16_t &scheduler_port);
130 
131   ClusterState cluster_state() const;
132 
133   void set_handler(const RequestHandler &handler);
134   void Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, const void *data,
135                 size_t size);
136 
137   bool HasIterationFailed(uint32_t iteration_num) const;
138   // register cancel SafeMode function to node
SetCancelSafeModeCallBack(const CancelSafeModeFn & fn)139   void SetCancelSafeModeCallBack(const CancelSafeModeFn &fn) { cancelSafeModeFn_ = fn; }
140 
141   // server node and worker node send exception message to scheduler
142   void SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info);
143 
144  protected:
145   virtual void Register(const std::shared_ptr<TcpClient> &client);
146   bool Heartbeat(const std::shared_ptr<TcpClient> &client);
147   void FetchServers(const std::shared_ptr<TcpClient> &client);
148 
149   void ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
150   void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
151   void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
152 
153   // Process the response messages about actor route table service.
154   void ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
155 
156   void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
157                            const Protos &protos, const void *data, size_t size);
158   void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
159                      const Protos &protos, const void *data, size_t size);
160 
161   void ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
162                        const Protos &protos, const void *data, size_t size);
163 
164   void ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
165                       const Protos &protos, const void *data, size_t size);
166 
167   // The worker/server processes the scale_out_done message from scheduelr
168   void ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
169                            const Protos &protos, const void *data, size_t size);
170   // The worker/server processes the scale_in_done message from scheduelr
171   void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
172                           const Protos &protos, const void *data, size_t size);
173 
174   // The worker/server processes the scheduler recovery message from scheduelr
175   void ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
176                                 const Protos &, const void *data, size_t size);
177 
178   // The worker/server processes the SEND_EVENT message from scheduelr
179   void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
180                     const Protos &protos, const void *data, size_t size);
181 
182   void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
183   void UpdateSchedulerTime();
184   bool CheckSchedulerTimeout() const;
185   bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
186   bool WaitForDisconnect(const uint32_t &timeout);
187   virtual bool InitClientToScheduler();
188   void InitClientToServer();
189   const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id,
190                                                          const NodeRole &role = NodeRole::SERVER);
191   bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
192                        const uint32_t &timeout = kCommTimeoutInSeconds);
193   bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
194                        const Protos &, const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
195   uint64_t SendCollectiveMeta(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
196                               const Protos &protos, const void *data, size_t size);
197   void ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
198                                  const Protos &protos, const void *data, size_t size);
199   void ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
200                        const Protos &protos, const void *data, size_t size);
201   void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta);
202   void RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
203                           size_t size);
204   uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
205   uint64_t NextActualRankRequestId(const uint32_t &rank_id);
206   void InitCommandHandler();
207   void RegisterActorRouteTableRspHandler();
208   void InitServerHandler();
209 
210   // Register collective communication initialization response methods.
RegisterInitCollectCommResphandler()211   virtual void RegisterInitCollectCommResphandler() {}
212 
213   // Register recovery response methods.
RegisterRecoveryRespHandler()214   virtual void RegisterRecoveryRespHandler() {}
215 
216   // when initializing the node, should initializing the node info.
217   void InitNodeInfo(const NodeRole &role);
218   // Initialize worker num and server num by cluster config.
219   void InitNodeNum();
220   // Node recover by cluster config.
221   bool Recover();
222 
223   // Trigger the callback corresponding to the event.
224   void OnEventCallback(const ClusterEvent &event);
225   // Trigger the callback corresponding to the custom event.
226   void OnCustomEventCallback(const uint32_t &event);
227 
228   bool IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info);
229 
230   void CreateTcpServer(const std::pair<uint32_t, uint32_t> &port_range = {});
231 
232   void UpdateClusterState(const ClusterState &state);
233 
234   void PersistMetaData();
235 
236   void ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnection> &conn,
237                                      const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
238                                      size_t size);
239 
240   bool FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output, const uint32_t &timeout);
241   void OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data);
242   void ConnectToScheduler();
243 
244   void ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
245                                const Protos &, const void *data, size_t size);
246 
247   std::unique_ptr<std::thread> heart_beat_thread_;
248   std::unique_ptr<std::thread> client_to_scheduler_thread_;
249   std::shared_ptr<TcpClient> client_to_scheduler_;
250   std::shared_ptr<TcpClient> client_to_server_;
251   // the key is: <node_role,rank_id>, the value is: <ip, port>
252   std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
253   // the map's key is: rank_id
254   std::map<std::pair<NodeRole, uint32_t>, std::shared_ptr<TcpClient>> connected_nodes_;
255 
256   // the key is <rank_id, rank_request_id>
257   std::map<std::pair<uint32_t, uint64_t>, VectorPtr> received_data_;
258   std::mutex receive_callbacks_mutex_;
259   // the key is <rank_id, rank_request_id>
260   std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;
261   std::condition_variable receive_cond_;
262 
263   // the key is rank_id, the value is rank_id's expected request_id
264   std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_;
265   // the key is rank_id, the value is rank_id's actual request_id
266   std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
267   std::mutex rank_request_ids_mutex;
268   timeval scheduler_time_{0, 0};
269   std::unordered_map<NodeCommand, ResponseHandler> handlers_;
270   std::unordered_map<NodeCommand, ServerHandler> server_handler_;
271 
272   // send_rank_id, recv CollectiveMessageMeta and data
273   std::unordered_map<uint32_t, std::vector<std::pair<CollectiveMessageMeta, std::shared_ptr<std::vector<uint8_t>>>>>
274     fl_received_data_;
275   std::mutex fl_receive_mutex_;
276   std::condition_variable fl_receive_cond_;
277 
278   // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
279   std::shared_ptr<TcpServer> server_;
280   std::unique_ptr<std::thread> server_thread_;
281   std::unique_ptr<std::thread> message_callback_thread_;
282 
283   uint32_t worker_num_;
284   uint32_t server_num_;
285   std::atomic<bool> is_connected_to_scheduler_;
286   // Identify whether the current node is a scale in node.
287   std::atomic<bool> is_current_node_scale_in_;
288 
289   // Each ClusterEvent corresponds to a EventCallback to process the event.
290   std::map<ClusterEvent, EventCallback> event_to_callback_;
291 
292   // Each custom event corresponds to a EventCallback to process the event.
293   // This event is sent to the scheduler, and then the scheduler broadcasts this event to all nodes.
294   // for example:
295   // In order to ensure the consistency of the cluster, the server broadcasts an iteration_end event to notify all other
296   // nodes to modify the iteration status
297   std::map<uint32_t, EventCallback> custom_event_to_callback_;
298 
299   // Recovery for worker/server node.
300   std::unique_ptr<RecoveryBase> node_recovery_;
301 
302   // The state of the persistent storage, such as ready to be persisted, in the process of being persisted, has
303   // completed the persistence, etc.
304   std::atomic<PersistentState> persistent_state_;
305 
306   // The ip of scheduler.
307   std::string scheduler_ip_;
308   // The port of scheduler.
309   uint16_t scheduler_port_;
310 
311   // Synchronize all node metadata from the scheduler.
312   std::unordered_map<std::string, NodeInfo> all_nodes_info_;
313   RequestHandler request_handler_;
314 
315   std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
316   std::mutex communicator_mutex_;
317   std::mutex cluster_state_mutex_;
318 
319   size_t failed_iteration_num_ = 0;
320   bool iteration_failed_ = false;
321   CancelSafeModeFn cancelSafeModeFn_;
322 
323   std::atomic<bool> is_recover;
324 };
325 using AbstractNodePtr = std::shared_ptr<AbstractNode>;
326 }  // namespace core
327 }  // namespace ps
328 }  // namespace mindspore
329 #endif  // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
330