1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ 19 20 #include <functional> 21 #include <map> 22 #include <queue> 23 #include <unordered_map> 24 #include <utility> 25 #include <vector> 26 #include <memory> 27 #include <string> 28 #include "include/backend/distributed/ps/constants.h" 29 #include "ps/core/communicator/communicator_base.h" 30 #include "ps/core/communicator/message.h" 31 #include "ps/core/communicator/task_executor.h" 32 #include "ps/core/node.h" 33 #include "ps/core/node_info.h" 34 #include "ps/core/recovery_base.h" 35 #include "utils/ms_exception.h" 36 37 namespace mindspore { 38 namespace ps { 39 namespace core { 40 class AbstractNode : public Node { 41 public: AbstractNode()42 AbstractNode() 43 : heart_beat_thread_(nullptr), 44 client_to_scheduler_thread_(nullptr), 45 client_to_scheduler_(nullptr), 46 client_to_server_(nullptr), 47 server_(nullptr), 48 server_thread_(nullptr), 49 worker_num_(0), 50 server_num_(0), 51 is_connected_to_scheduler_(false), 52 is_current_node_scale_in_(false), 53 node_recovery_(nullptr), 54 persistent_state_(PersistentState::NOT_ENABLE_PERSIST), 55 scheduler_ip_(""), 56 scheduler_port_(0), 57 is_recover(false) {} 58 ~AbstractNode() override; 59 60 typedef void (AbstractNode::*ResponseHandler)(const std::shared_ptr<MessageMeta> &meta, const void *data, 61 size_t size); 62 typedef void (AbstractNode::*ServerHandler)(const std::shared_ptr<TcpConnection> &conn, 63 const std::shared_ptr<MessageMeta> &meta, const Protos &protos, 64 const void *data, size_t size); 65 66 using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; 67 using RequestHandler = std::function<void(const std::shared_ptr<TcpConnection> &conn, 68 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size)>; 69 using CancelSafeModeFn = std::function<void()>; 70 71 bool Broadcast(const NodeRole &node_role, const std::string &message, int command, 72 const uint32_t &timeout = kCommTimeoutInSeconds); 73 74 // When the business layer finish scale out, it should call this function 75 void set_ready_for_scale_out(); 76 // When the business layer finish scale in, it should call this function 77 void set_ready_for_scale_in(); 78 79 // Send scale_out_done instructions to the scheduler. 80 void set_scale_out_done(); 81 82 // Send scale_in_done instructions to the scheduler. 83 void set_scale_in_done(); 84 85 // The worker/server sends the event to the scheduler, and then the scheduler broadcasts this event to all nodes. 86 void BroadcastEvent(const uint32_t &event); 87 88 // Set the callback corresponding to the event. 89 void RegisterEventCallback(const ClusterEvent &event, const EventCallback &event_cb); 90 // Set the callback corresponding to the custom event. 91 void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb); 92 93 bool Send(const NodeRole &node_role, const uint32_t &rank_id, const void *message, size_t len, int command, 94 VectorPtr *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds); 95 96 bool Send(const NodeRole &node_role, const uint32_t &rank_id, const std::string &msg, int command, 97 VectorPtr *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds); 98 bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &msgs, 99 int command, std::vector<VectorPtr> *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds); 100 101 // The interface that sends sync message to the scheduler. 102 bool SendToScheduler(const void *message, size_t len, NodeCommand command, VectorPtr *output = nullptr, 103 const uint32_t &timeout = kCommTimeoutInSeconds); 104 105 uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); 106 107 using CheckFailReturnFun = std::function<bool()>; 108 uint64_t FlCollectiveSendAsync(const CollectiveMessageMeta &collective_meta, const void *data, size_t size); 109 bool FlCollectiveWait(const CollectiveMessageMeta &expect_meta, size_t expect_size, VectorPtr *output, 110 const uint32_t &timeout = kCommTimeoutInSeconds); 111 112 std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id, 113 VectorPtr *output); 114 bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds); 115 116 PersistentState persistent_state() const; 117 void set_persistent_state(PersistentState persistent_state); 118 119 uint32_t worker_num() const; 120 uint32_t server_num() const; 121 122 void set_worker_num(const uint32_t &worker_num); 123 void set_server_num(const uint32_t &server_num); 124 125 std::string scheduler_ip() const; 126 void set_scheduler_ip(const std::string &scheduler_ip); 127 128 uint16_t scheduler_port() const; 129 void set_scheduler_port(const uint16_t &scheduler_port); 130 131 ClusterState cluster_state() const; 132 133 void set_handler(const RequestHandler &handler); 134 void Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, const void *data, 135 size_t size); 136 137 bool HasIterationFailed(uint32_t iteration_num) const; 138 // register cancel SafeMode function to node SetCancelSafeModeCallBack(const CancelSafeModeFn & fn)139 void SetCancelSafeModeCallBack(const CancelSafeModeFn &fn) { cancelSafeModeFn_ = fn; } 140 141 // server node and worker node send exception message to scheduler 142 void SendFailMessageToScheduler(const std::string &node_role, const std::string &event_info); 143 144 protected: 145 virtual void Register(const std::shared_ptr<TcpClient> &client); 146 bool Heartbeat(const std::shared_ptr<TcpClient> &client); 147 void FetchServers(const std::shared_ptr<TcpClient> &client); 148 149 void ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 150 void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 151 void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 152 153 // Process the response messages about actor route table service. 154 void ProcessReceiveSchedulerResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 155 156 void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 157 const Protos &protos, const void *data, size_t size); 158 void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 159 const Protos &protos, const void *data, size_t size); 160 161 void ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 162 const Protos &protos, const void *data, size_t size); 163 164 void ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 165 const Protos &protos, const void *data, size_t size); 166 167 // The worker/server processes the scale_out_done message from scheduelr 168 void ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 169 const Protos &protos, const void *data, size_t size); 170 // The worker/server processes the scale_in_done message from scheduelr 171 void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 172 const Protos &protos, const void *data, size_t size); 173 174 // The worker/server processes the scheduler recovery message from scheduelr 175 void ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 176 const Protos &, const void *data, size_t size); 177 178 // The worker/server processes the SEND_EVENT message from scheduelr 179 void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 180 const Protos &protos, const void *data, size_t size); 181 182 void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); 183 void UpdateSchedulerTime(); 184 bool CheckSchedulerTimeout() const; 185 bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); 186 bool WaitForDisconnect(const uint32_t &timeout); 187 virtual bool InitClientToScheduler(); 188 void InitClientToServer(); 189 const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id, 190 const NodeRole &role = NodeRole::SERVER); 191 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, 192 const uint32_t &timeout = kCommTimeoutInSeconds); 193 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta, 194 const Protos &, const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds); 195 uint64_t SendCollectiveMeta(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta, 196 const Protos &protos, const void *data, size_t size); 197 void ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 198 const Protos &protos, const void *data, size_t size); 199 void ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 200 const Protos &protos, const void *data, size_t size); 201 void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta); 202 void RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, 203 size_t size); 204 uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); 205 uint64_t NextActualRankRequestId(const uint32_t &rank_id); 206 void InitCommandHandler(); 207 void RegisterActorRouteTableRspHandler(); 208 void InitServerHandler(); 209 210 // Register collective communication initialization response methods. RegisterInitCollectCommResphandler()211 virtual void RegisterInitCollectCommResphandler() {} 212 213 // Register recovery response methods. RegisterRecoveryRespHandler()214 virtual void RegisterRecoveryRespHandler() {} 215 216 // when initializing the node, should initializing the node info. 217 void InitNodeInfo(const NodeRole &role); 218 // Initialize worker num and server num by cluster config. 219 void InitNodeNum(); 220 // Node recover by cluster config. 221 bool Recover(); 222 223 // Trigger the callback corresponding to the event. 224 void OnEventCallback(const ClusterEvent &event); 225 // Trigger the callback corresponding to the custom event. 226 void OnCustomEventCallback(const uint32_t &event); 227 228 bool IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info); 229 230 void CreateTcpServer(const std::pair<uint32_t, uint32_t> &port_range = {}); 231 232 void UpdateClusterState(const ClusterState &state); 233 234 void PersistMetaData(); 235 236 void ProcessPrepareBuildingNetwork(const std::shared_ptr<TcpConnection> &conn, 237 const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, 238 size_t size); 239 240 bool FlCollectiveWaitInner(const CollectiveMessageMeta &expect_meta, VectorPtr *output, const uint32_t &timeout); 241 void OnRecvCollectiveData(const MessageMeta &message_meta, const VectorPtr &data); 242 void ConnectToScheduler(); 243 244 void ProcessScaleOutRollback(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 245 const Protos &, const void *data, size_t size); 246 247 std::unique_ptr<std::thread> heart_beat_thread_; 248 std::unique_ptr<std::thread> client_to_scheduler_thread_; 249 std::shared_ptr<TcpClient> client_to_scheduler_; 250 std::shared_ptr<TcpClient> client_to_server_; 251 // the key is: <node_role,rank_id>, the value is: <ip, port> 252 std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; 253 // the map's key is: rank_id 254 std::map<std::pair<NodeRole, uint32_t>, std::shared_ptr<TcpClient>> connected_nodes_; 255 256 // the key is <rank_id, rank_request_id> 257 std::map<std::pair<uint32_t, uint64_t>, VectorPtr> received_data_; 258 std::mutex receive_callbacks_mutex_; 259 // the key is <rank_id, rank_request_id> 260 std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; 261 std::condition_variable receive_cond_; 262 263 // the key is rank_id, the value is rank_id's expected request_id 264 std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_; 265 // the key is rank_id, the value is rank_id's actual request_id 266 std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; 267 std::mutex rank_request_ids_mutex; 268 timeval scheduler_time_{0, 0}; 269 std::unordered_map<NodeCommand, ResponseHandler> handlers_; 270 std::unordered_map<NodeCommand, ServerHandler> server_handler_; 271 272 // send_rank_id, recv CollectiveMessageMeta and data 273 std::unordered_map<uint32_t, std::vector<std::pair<CollectiveMessageMeta, std::shared_ptr<std::vector<uint8_t>>>>> 274 fl_received_data_; 275 std::mutex fl_receive_mutex_; 276 std::condition_variable fl_receive_cond_; 277 278 // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA 279 std::shared_ptr<TcpServer> server_; 280 std::unique_ptr<std::thread> server_thread_; 281 std::unique_ptr<std::thread> message_callback_thread_; 282 283 uint32_t worker_num_; 284 uint32_t server_num_; 285 std::atomic<bool> is_connected_to_scheduler_; 286 // Identify whether the current node is a scale in node. 287 std::atomic<bool> is_current_node_scale_in_; 288 289 // Each ClusterEvent corresponds to a EventCallback to process the event. 290 std::map<ClusterEvent, EventCallback> event_to_callback_; 291 292 // Each custom event corresponds to a EventCallback to process the event. 293 // This event is sent to the scheduler, and then the scheduler broadcasts this event to all nodes. 294 // for example: 295 // In order to ensure the consistency of the cluster, the server broadcasts an iteration_end event to notify all other 296 // nodes to modify the iteration status 297 std::map<uint32_t, EventCallback> custom_event_to_callback_; 298 299 // Recovery for worker/server node. 300 std::unique_ptr<RecoveryBase> node_recovery_; 301 302 // The state of the persistent storage, such as ready to be persisted, in the process of being persisted, has 303 // completed the persistence, etc. 304 std::atomic<PersistentState> persistent_state_; 305 306 // The ip of scheduler. 307 std::string scheduler_ip_; 308 // The port of scheduler. 309 uint16_t scheduler_port_; 310 311 // Synchronize all node metadata from the scheduler. 312 std::unordered_map<std::string, NodeInfo> all_nodes_info_; 313 RequestHandler request_handler_; 314 315 std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_; 316 std::mutex communicator_mutex_; 317 std::mutex cluster_state_mutex_; 318 319 size_t failed_iteration_num_ = 0; 320 bool iteration_failed_ = false; 321 CancelSafeModeFn cancelSafeModeFn_; 322 323 std::atomic<bool> is_recover; 324 }; 325 using AbstractNodePtr = std::shared_ptr<AbstractNode>; 326 } // namespace core 327 } // namespace ps 328 } // namespace mindspore 329 #endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ 330