1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ 19 20 #include <utility> 21 #include <string> 22 #include <memory> 23 #include <map> 24 #include <vector> 25 #include <unordered_map> 26 27 #include "ps/core/node.h" 28 #include "ps/core/communicator/message.h" 29 #include "ps/core/follower_scaler.h" 30 #include "utils/ms_exception.h" 31 #include "ps/constants.h" 32 #include "ps/core/node_info.h" 33 #include "ps/core/recovery_base.h" 34 #include "ps/core/communicator/task_executor.h" 35 #include "ps/core/communicator/communicator_base.h" 36 37 namespace mindspore { 38 namespace ps { 39 namespace core { 40 class FollowerScaler; 41 class AbstractNode : public Node { 42 public: AbstractNode()43 AbstractNode() 44 : heart_beat_thread_(nullptr), 45 client_to_scheduler_thread_(nullptr), 46 client_to_scheduler_(nullptr), 47 server_(nullptr), 48 server_thread_(nullptr), 49 worker_num_(-1), 50 server_num_(-1), 51 is_current_node_scale_in_(false), 52 follower_scaler_(nullptr), 53 node_recovery_(nullptr), 54 scheduler_ip_(""), 55 scheduler_port_(0) {} 56 ~AbstractNode() override = default; 57 58 typedef void (AbstractNode::*ResponseHandler)(const std::shared_ptr<MessageMeta> &meta, const void *data, 59 size_t size); 60 typedef void (AbstractNode::*ServerHandler)(const std::shared_ptr<TcpConnection> &conn, 61 const std::shared_ptr<MessageMeta> &meta, const Protos &protos, 62 const void *data, size_t size); 63 64 using DataPtr = std::shared_ptr<unsigned char[]>; 65 using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; 66 using RequestHandler = 67 std::function<void(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 68 const DataPtr &data, size_t size)>; 69 70 bool Broadcast(const NodeRole &node_role, const DataPtr &message, size_t size, int command, 71 const uint32_t &timeout = kCommTimeoutInSeconds); 72 73 // When the business layer finish scale out, it should call this function 74 void set_ready_for_scale_out(); 75 // When the business layer finish scale in, it should call this function 76 void set_ready_for_scale_in(); 77 78 // Send scale_out_done instructions to the scheduler. 79 void set_scale_out_done(); 80 81 // Send scale_in_done instructions to the scheduler. 82 void set_scale_in_done(); 83 84 // The worker/server sends the event to the scheduler, and then the scheduler broadcasts this event to all nodes. 85 void BroadcastEvent(const uint32_t &event); 86 87 // Set the callback corresponding to the event. 88 void RegisterEventCallback(const ClusterEvent &event, const EventCallback &event_cb); 89 // Set the callback corresponding to the custom event. 90 void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb); 91 92 bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, 93 const uint32_t &timeout = kTimeoutInSeconds); 94 bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data, 95 const std::vector<size_t> &lens, int command, const uint32_t &timeout = kTimeoutInSeconds); 96 bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command, 97 VectorPtr *output, const uint32_t &timeout = kTimeoutInSeconds); 98 bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data, 99 const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output, 100 const uint32_t &timeout = kTimeoutInSeconds); 101 102 uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); 103 std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id, 104 VectorPtr *output); 105 bool CollectiveWait(const std::pair<uint32_t, uint64_t> &request_id, const uint32_t &timeout = kCommTimeoutInSeconds); 106 107 // Initialize the scaler for server to process before/after scaling operations. 108 bool InitFollowerScaler(); 109 110 // Register barriers before scaling operations for server. 111 void RegisterFollowerScalerBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier); 112 void RegisterFollowerScalerBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier); 113 114 // Register handlers after scaling operations for server. 115 void RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler); 116 void RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler); 117 118 int32_t worker_num() const; 119 int32_t server_num() const; 120 121 void set_worker_num(const int32_t &worker_num); 122 void set_server_num(const int32_t &server_num); 123 124 std::string scheduler_ip() const; 125 void set_scheduler_ip(const std::string &scheduler_ip); 126 127 uint16_t scheduler_port() const; 128 void set_scheduler_port(const uint16_t &scheduler_port); 129 130 ClusterState cluster_state() const; 131 132 void set_handler(const RequestHandler &handler); 133 void Response(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, const void *data, 134 size_t size); 135 136 std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port, 137 const std::shared_ptr<TaskExecutor> &task_executor); 138 std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port, 139 uint32_t worker_num, uint32_t server_num, 140 const std::shared_ptr<TaskExecutor> &task_executor); 141 142 protected: 143 void Register(const std::shared_ptr<TcpClient> &client); 144 bool Heartbeat(const std::shared_ptr<TcpClient> &client); 145 void FetchServers(const std::shared_ptr<TcpClient> &client); 146 147 void ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 148 void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 149 void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); 150 151 void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 152 const Protos &protos, const void *data, size_t size); 153 void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 154 const Protos &protos, const void *data, size_t size); 155 156 void ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 157 const Protos &protos, const void *data, size_t size); 158 159 void ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 160 const Protos &protos, const void *data, size_t size); 161 162 // The worker/server processes the scale_out_done message from scheduelr 163 void ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 164 const Protos &protos, const void *data, size_t size); 165 // The worker/server processes the scale_in_done message from scheduelr 166 void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 167 const Protos &protos, const void *data, size_t size); 168 169 // The worker/server processes the SEND_EVENT message from scheduelr 170 void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 171 const Protos &protos, const void *data, size_t size); 172 173 void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client); 174 void UpdateSchedulerTime(); 175 bool CheckSchedulerTimeout() const; 176 bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); 177 bool WaitForDisconnect(const uint32_t &timeout); 178 bool InitClientToScheduler(); 179 const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id); 180 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, 181 const uint32_t &timeout = kCommTimeoutInSeconds); 182 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta, 183 const Protos &, const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds); 184 uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta, 185 const Protos &protos, const void *data, size_t size); 186 void ProcessCollectiveSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 187 const void *data, size_t size); 188 void ProcessSendData(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, 189 const Protos &protos, const void *data, size_t size); 190 void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta); 191 void RunReceiveCallback(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, 192 size_t size); 193 uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); 194 uint64_t NextActualRankRequestId(const uint32_t &rank_id); 195 void InitCommandHandler(); 196 void InitServerHandler(); 197 198 // when initializing the node, should initializing the node info. 199 void InitNodeInfo(const NodeRole &role); 200 // Initialize worker num and server num by cluster config. 201 void InitNodeNum(); 202 // Node recover by cluster config. 203 bool Recover(); 204 205 // Trigger the callback corresponding to the event. 206 void OnEventCallback(const ClusterEvent &event); 207 // Trigger the callback corresponding to the custom event. 208 void OnCustomEventCallback(const uint32_t &event); 209 210 bool IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info); 211 212 void CreateTcpServer(); 213 214 std::unique_ptr<std::thread> heart_beat_thread_; 215 std::unique_ptr<std::thread> client_to_scheduler_thread_; 216 std::shared_ptr<TcpClient> client_to_scheduler_; 217 218 // the key is: <node_role,rank_id>, the value is: <ip, port> 219 std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; 220 // the map's key is: rank_id 221 std::unordered_map<uint32_t, std::shared_ptr<TcpClient>> connected_nodes_; 222 223 // the key is <rank_id, rank_request_id> 224 std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_; 225 std::mutex receive_callbacks_mutex_; 226 // the key is <rank_id, rank_request_id> 227 std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; 228 std::condition_variable receive_cond_; 229 230 // the key is rank_id, the value is rank_id's expected request_id 231 std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_; 232 // the key is rank_id, the value is rank_id's actual request_id 233 std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_; 234 std::mutex rank_request_ids_mutex; 235 timeval scheduler_time_{0, 0}; 236 std::unordered_map<NodeCommand, ResponseHandler> handlers_; 237 std::unordered_map<NodeCommand, ServerHandler> server_handler_; 238 239 // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA 240 std::shared_ptr<TcpServer> server_; 241 std::unique_ptr<std::thread> server_thread_; 242 243 int32_t worker_num_; 244 int32_t server_num_; 245 246 // Identify whether the current node is a scale in node. 247 std::atomic<bool> is_current_node_scale_in_; 248 249 // Each ClusterEvent corresponds to a EventCallback to process the event. 250 std::map<ClusterEvent, EventCallback> event_to_callback_; 251 252 // Each custom event corresponds to a EventCallback to process the event. 253 // This event is sent to the scheduler, and then the scheduler broadcasts this event to all nodes. 254 // for example: 255 // In order to ensure the consistency of the cluster, the server broadcasts an iteration_end event to notify all other 256 // nodes to modify the iteration status 257 std::map<uint32_t, EventCallback> custom_event_to_callback_; 258 259 // Scaler for worker/server node. 260 std::unique_ptr<FollowerScaler> follower_scaler_; 261 262 // Recovery for worker/server node. 263 std::unique_ptr<RecoveryBase> node_recovery_; 264 265 // The ip of scheduler. 266 std::string scheduler_ip_; 267 // The port of scheduler. 268 uint16_t scheduler_port_; 269 270 // Synchronize all node metadata from the scheduler. 271 std::unordered_map<std::string, NodeInfo> all_nodes_info_; 272 RequestHandler request_handler_; 273 274 std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_; 275 std::mutex communicator_mutex_; 276 }; 277 } // namespace core 278 } // namespace ps 279 } // namespace mindspore 280 #endif // MINDSPORE_CCSRC_PS_CORE_ABSTRACT_NODE_H_ 281