1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_NODE_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_NODE_H_ 19 20 #include <atomic> 21 #include <cstdlib> 22 #include <functional> 23 #include <iostream> 24 #include <memory> 25 #include <string> 26 #include <thread> 27 #include <unordered_map> 28 #include <vector> 29 #include <condition_variable> 30 #include <utility> 31 #include <tuple> 32 #include <map> 33 34 #include "ps/core/cluster_metadata.h" 35 #include "ps/core/cluster_config.h" 36 #include "ps/ps_context.h" 37 #include "ps/core/node_info.h" 38 #include "ps/core/communicator/tcp_client.h" 39 #include "ps/core/communicator/tcp_server.h" 40 #include "ps/core/file_configuration.h" 41 42 namespace mindspore { 43 namespace ps { 44 namespace core { 45 constexpr int kTimeoutInSeconds = 30; 46 constexpr int kCommTimeoutInSeconds = 3; 47 class Node { 48 public: Node()49 Node() 50 : is_ready_(false), 51 is_finish_(false), 52 is_already_stopped_(true), 53 is_already_finished_(false), 54 next_request_id_(0), 55 current_node_state_(NodeState::NODE_STARTING), 56 current_cluster_state_(ClusterState::ClUSTER_STARTING) {} 57 virtual ~Node() = default; 58 59 using MessageCallback = std::function<void()>; 60 61 virtual bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) = 0; 62 virtual bool Stop() = 0; 63 virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; 64 65 std::string node_id() const; 66 uint32_t rank_id() const; 67 NodeRole role() const; 68 uint16_t BoundPort() const; 69 std::string BoundIp() const; 70 71 bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); 72 73 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &, const Protos &, 74 const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds); 75 76 protected: 77 bool WaitForStart(const uint32_t &timeout); 78 79 // Send data synchronously 80 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, 81 const uint32_t &timeout = kCommTimeoutInSeconds); 82 // Send data asynchronously 83 uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta, 84 const Protos &protos, const void *data, size_t size); 85 86 uint64_t AddMessageTrack(const uint32_t &expected_response); 87 bool CheckMessageTrack(const uint64_t &request_id); 88 void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta); 89 void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); 90 void ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, 91 size_t size); 92 void RunMessageCallback(const uint64_t &request_id); 93 94 NodeInfo node_info_; 95 // Whether the cluster is ready 96 std::atomic<bool> is_ready_; 97 // Whether the cluster is finished. 98 std::atomic<bool> is_finish_; 99 100 std::atomic<bool> is_already_stopped_; 101 std::atomic<bool> is_already_finished_; 102 std::atomic_uint64_t next_request_id_; 103 104 std::mutex wait_start_mutex_; 105 std::condition_variable wait_start_cond_; 106 std::mutex wait_finish_mutex_; 107 std::condition_variable wait_finish_cond_; 108 std::mutex finish_mutex_; 109 110 // the key is: request_id, the value is: <expected responses, actual responses> 111 std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; 112 std::mutex message_tracker_mutex_; 113 std::condition_variable message_tracker_cond_; 114 115 // Worker and server receive the node state and cluster state from the scheduler. 116 NodeState current_node_state_; 117 ClusterState current_cluster_state_; 118 119 // Configuration file,The format is as follows 120 //{ 121 // "recovery": { 122 // "storage_type": 1, 123 // "storge_file_path": "/home/cds/config.json" 124 // } 125 // } 126 std::unique_ptr<Configuration> config_; 127 // Used to synchronize the connected nodes 128 std::mutex client_mutex_; 129 130 // the key is: request_id 131 std::unordered_map<uint64_t, MessageCallback> message_callbacks_; 132 std::mutex message_callbacks_mutex_; 133 134 // the key is: request_id, the value is: <rank_id, RecvMessage> 135 std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_; 136 // the key is: request_id, the value is: <rank_id, RecvMessage> 137 std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> workder_receive_messages_; 138 std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_; 139 std::mutex receive_messages_mutex_; 140 }; 141 } // namespace core 142 } // namespace ps 143 } // namespace mindspore 144 #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ 145