• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
18 #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
19 
20 #include <utility>
21 #include <string>
22 #include <memory>
23 #include <map>
24 #include <vector>
25 #include <unordered_map>
26 
27 #include "ps/core/node.h"
28 #include "ps/core/communicator/message.h"
29 #include "ps/core/follower_scaler.h"
30 #include "utils/ms_exception.h"
31 #include "ps/constants.h"
32 #include "ps/core/node_info.h"
33 #include "ps/core/recovery_base.h"
34 #include "ps/core/communicator/task_executor.h"
35 #include "ps/core/communicator/communicator_base.h"
36 
37 namespace mindspore {
38 namespace ps {
39 namespace core {
40 class FollowerScaler;
41 class AbstractNode : public Node {
42  public:
AbstractNode()43   AbstractNode()
44       : heart_beat_thread_(nullptr),
45         client_to_scheduler_thread_(nullptr),
46         client_to_scheduler_(nullptr),
47         server_(nullptr),
48         server_thread_(nullptr),
49         worker_num_(-1),
50         server_num_(-1),
51         is_current_node_scale_in_(false),
52         follower_scaler_(nullptr),
53         node_recovery_(nullptr),
54         scheduler_ip_(""),
55         scheduler_port_(0) {}
56   ~AbstractNode() override = default;
57 
58   typedef void (AbstractNode::*ResponseHandler)(const std::shared_ptr<MessageMeta> &meta, const void *data,
59                                                 size_t size);
60   typedef void (AbstractNode::*ServerHandler)(const std::shared_ptr<TcpConnection> &conn,
61                                               const std::shared_ptr<MessageMeta> &meta, const Protos &protos,
62                                               const void *data, size_t size);
63 
64   using DataPtr = std::shared_ptr<unsigned char[]>;
65   using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
66   using RequestHandler =
67     std::function<void(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
68                        const DataPtr &data, size_t size)>;
69 
70   bool Broadcast(const NodeRole &node_role, const DataPtr &message, size_t size, int command,
71                  const uint32_t &timeout = kCommTimeoutInSeconds);
72 
73   // When the business layer finish scale out, it should call this function
74   void set_ready_for_scale_out();
75   // When the business layer finish scale in, it should call this function
76   void set_ready_for_scale_in();
77 
78   // Send scale_out_done instructions to the scheduler.
79   void set_scale_out_done();
80 
81   // Send scale_in_done instructions to the scheduler.
82   void set_scale_in_done();
83 
84   // The worker/server sends the event to the scheduler, and then the scheduler broadcasts this event to all nodes.
85   void BroadcastEvent(const uint32_t &event);
86 
87   // Set the callback corresponding to the event.
88   void RegisterEventCallback(const ClusterEvent &event, const EventCallback &event_cb);
89   // Set the callback corresponding to the custom event.
90   void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb);
91 
92   bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
93             const uint32_t &timeout = kTimeoutInSeconds);
94   bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
95             const std::vector<size_t> &lens, int command, const uint32_t &timeout = kTimeoutInSeconds);
96   bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
97             VectorPtr *output, const uint32_t &timeout = kTimeoutInSeconds);
98   bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
99             const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
100             const uint32_t &timeout = kTimeoutInSeconds);
101 
102   uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
103   std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
104                                                        VectorPtr *output);
105   bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
106 
107   // Initialize the scaler for server to process before/after scaling operations.
108   bool InitFollowerScaler();
109 
110   // Register barriers before scaling operations for server.
111   void RegisterFollowerScalerBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier);
112   void RegisterFollowerScalerBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier);
113 
114   // Register handlers after scaling operations for server.
115   void RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler);
116   void RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler);
117 
118   int32_t worker_num() const;
119   int32_t server_num() const;
120 
121   void set_worker_num(const int32_t &worker_num);
122   void set_server_num(const int32_t &server_num);
123 
124   std::string scheduler_ip() const;
125   void set_scheduler_ip(const std::string &scheduler_ip);
126 
127   uint16_t scheduler_port() const;
128   void set_scheduler_port(const uint16_t &scheduler_port);
129 
130   ClusterState cluster_state() const;
131 
132   void set_handler(const RequestHandler &handler);
133   void Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, const void *data,
134                 size_t size);
135 
136   std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
137                                                         const std::shared_ptr<TaskExecutor> &task_executor);
138   std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port,
139                                                        uint32_t worker_num, uint32_t server_num,
140                                                        const std::shared_ptr<TaskExecutor> &task_executor);
141 
142  protected:
143   void Register(const std::shared_ptr<TcpClient> &client);
144   bool Heartbeat(const std::shared_ptr<TcpClient> &client);
145   void FetchServers(const std::shared_ptr<TcpClient> &client);
146 
147   void ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
148   void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
149   void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
150 
151   void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
152                            const Protos &protos, const void *data, size_t size);
153   void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
154                      const Protos &protos, const void *data, size_t size);
155 
156   void ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
157                        const Protos &protos, const void *data, size_t size);
158 
159   void ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
160                       const Protos &protos, const void *data, size_t size);
161 
162   // The worker/server processes the scale_out_done message from scheduelr
163   void ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
164                            const Protos &protos, const void *data, size_t size);
165   // The worker/server processes the scale_in_done message from scheduelr
166   void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
167                           const Protos &protos, const void *data, size_t size);
168 
169   // The worker/server processes the SEND_EVENT message from scheduelr
170   void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
171                     const Protos &protos, const void *data, size_t size);
172 
173   void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
174   void UpdateSchedulerTime();
175   bool CheckSchedulerTimeout() const;
176   bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
177   bool WaitForDisconnect(const uint32_t &timeout);
178   bool InitClientToScheduler();
179   const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id);
180   bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
181                        const uint32_t &timeout = kCommTimeoutInSeconds);
182   bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
183                        const Protos &, const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
184   uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
185                             const Protos &protos, const void *data, size_t size);
186   void ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
187                                  const void *data, size_t size);
188   void ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
189                        const Protos &protos, const void *data, size_t size);
190   void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta);
191   void RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
192                           size_t size);
193   uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
194   uint64_t NextActualRankRequestId(const uint32_t &rank_id);
195   void InitCommandHandler();
196   void InitServerHandler();
197 
198   // when initializing the node, should initializing the node info.
199   void InitNodeInfo(const NodeRole &role);
200   // Initialize worker num and server num by cluster config.
201   void InitNodeNum();
202   // Node recover by cluster config.
203   bool Recover();
204 
205   // Trigger the callback corresponding to the event.
206   void OnEventCallback(const ClusterEvent &event);
207   // Trigger the callback corresponding to the custom event.
208   void OnCustomEventCallback(const uint32_t &event);
209 
210   bool IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info);
211 
212   void CreateTcpServer();
213 
214   std::unique_ptr<std::thread> heart_beat_thread_;
215   std::unique_ptr<std::thread> client_to_scheduler_thread_;
216   std::shared_ptr<TcpClient> client_to_scheduler_;
217 
218   // the key is: <node_role,rank_id>, the value is: <ip, port>
219   std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
220   // the map's key is: rank_id
221   std::unordered_map<uint32_t, std::shared_ptr<TcpClient>> connected_nodes_;
222 
223   // the key is <rank_id, rank_request_id>
224   std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_;
225   std::mutex receive_callbacks_mutex_;
226   // the key is <rank_id, rank_request_id>
227   std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;
228   std::condition_variable receive_cond_;
229 
230   // the key is rank_id, the value is rank_id's expected request_id
231   std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_;
232   // the key is rank_id, the value is rank_id's actual request_id
233   std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
234   std::mutex rank_request_ids_mutex;
235   timeval scheduler_time_{0, 0};
236   std::unordered_map<NodeCommand, ResponseHandler> handlers_;
237   std::unordered_map<NodeCommand, ServerHandler> server_handler_;
238 
239   // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
240   std::shared_ptr<TcpServer> server_;
241   std::unique_ptr<std::thread> server_thread_;
242 
243   int32_t worker_num_;
244   int32_t server_num_;
245 
246   // Identify whether the current node is a scale in node.
247   std::atomic<bool> is_current_node_scale_in_;
248 
249   // Each ClusterEvent corresponds to a EventCallback to process the event.
250   std::map<ClusterEvent, EventCallback> event_to_callback_;
251 
252   // Each custom event corresponds to a EventCallback to process the event.
253   // This event is sent to the scheduler, and then the scheduler broadcasts this event to all nodes.
254   // for example:
255   // In order to ensure the consistency of the cluster, the server broadcasts an iteration_end event to notify all other
256   // nodes to modify the iteration status
257   std::map<uint32_t, EventCallback> custom_event_to_callback_;
258 
259   // Scaler for worker/server node.
260   std::unique_ptr<FollowerScaler> follower_scaler_;
261 
262   // Recovery for worker/server node.
263   std::unique_ptr<RecoveryBase> node_recovery_;
264 
265   // The ip of scheduler.
266   std::string scheduler_ip_;
267   // The port of scheduler.
268   uint16_t scheduler_port_;
269 
270   // Synchronize all node metadata from the scheduler.
271   std::unordered_map<std::string, NodeInfo> all_nodes_info_;
272   RequestHandler request_handler_;
273 
274   std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
275   std::mutex communicator_mutex_;
276 };
277 }  // namespace core
278 }  // namespace ps
279 }  // namespace mindspore
280 #endif  // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_
281