1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_NODE_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_NODE_H_ 19 20 #include <atomic> 21 #include <cstdlib> 22 #include <functional> 23 #include <iostream> 24 #include <memory> 25 #include <string> 26 #include <thread> 27 #include <unordered_map> 28 #include <vector> 29 #include <condition_variable> 30 #include <utility> 31 #include <tuple> 32 #include <map> 33 34 #include "ps/core/cluster_metadata.h" 35 #include "ps/core/cluster_config.h" 36 #include "include/backend/distributed/ps/ps_context.h" 37 #include "ps/core/node_info.h" 38 #include "ps/core/communicator/tcp_client.h" 39 #include "ps/core/communicator/tcp_server.h" 40 #include "ps/core/file_configuration.h" 41 42 namespace mindspore { 43 namespace ps { 44 namespace core { 45 constexpr int kTimeoutInSeconds = 30; 46 constexpr int kCommTimeoutInSeconds = 10; 47 constexpr int kCommTimeoutInThreeSeconds = 3; 48 class Node { 49 public: Node()50 Node() 51 : is_ready_(false), 52 is_finish_(false), 53 is_already_stopped_(true), 54 is_already_finished_(false), 55 next_request_id_(0), 56 current_cluster_state_(ClusterState::CLUSTER_STARTING) {} 57 virtual ~Node() = default; 58 59 using MessageCallback = std::function<void()>; 60 61 virtual bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) = 0; 62 virtual bool Stop() = 0; 63 virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; 64 65 std::string node_id() const; 66 uint32_t rank_id() const; 67 NodeRole role() const; 68 69 bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); 70 71 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &, const Protos &, 72 const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds); 73 74 // Whether to enable disaster recovery. 75 bool EnableRecovery() const; 76 77 protected: 78 bool WaitForStart(const uint32_t &timeout); 79 80 // Send data synchronously 81 bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, 82 const uint32_t &timeout = kCommTimeoutInSeconds); 83 // Send data asynchronously 84 bool SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta, 85 const Protos &protos, const void *data, size_t size); 86 87 uint64_t AddMessageTrack(const uint32_t &expected_response); 88 bool CheckMessageTrack(const uint64_t &request_id); 89 void NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta); 90 void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); 91 void ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, 92 size_t size); 93 void RunMessageCallback(const uint64_t &request_id); 94 95 NodeInfo node_info_; 96 std::atomic<bool> is_ready_; 97 std::atomic<bool> is_finish_; 98 99 std::atomic<bool> is_already_stopped_; 100 std::atomic<bool> is_already_finished_; 101 std::atomic_uint64_t next_request_id_; 102 103 std::mutex wait_start_mutex_; 104 std::condition_variable wait_start_cond_; 105 std::mutex wait_finish_mutex_; 106 std::condition_variable wait_finish_cond_; 107 std::mutex finish_mutex_; 108 109 // the key is: request_id, the value is: <expected responses, actual responses> 110 std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_; 111 std::mutex message_tracker_mutex_; 112 std::condition_variable message_tracker_cond_; 113 114 ClusterState current_cluster_state_; 115 116 // Configuration file,The format is as follows 117 //{ 118 // "recovery": { 119 // "storage_type": 1, 120 // "storge_file_path": "/home/cds/config.json" 121 // } 122 // } 123 std::unique_ptr<Configuration> config_; 124 // Used to synchronize the connected nodes 125 std::mutex client_mutex_; 126 127 // the key is: request_id 128 std::unordered_map<uint64_t, MessageCallback> message_callbacks_; 129 std::mutex message_callbacks_mutex_; 130 131 // the key is: request_id, the value is: <rank_id, RecvMessage> 132 std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_; 133 // the key is: request_id, the value is: <rank_id, RecvMessage> 134 std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> workder_receive_messages_; 135 std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_; 136 std::mutex receive_messages_mutex_; 137 138 // Message from the scheduler. The key is: request_id, the value is:RecvMessage. 139 std::unordered_map<uint64_t, VectorPtr> received_scheduler_messages_; 140 }; 141 } // namespace core 142 } // namespace ps 143 } // namespace mindspore 144 #endif // MINDSPORE_CCSRC_PS_CORE_NODE_H_ 145