1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "ps/core/communicator/communicator_base.h" 24 #include "ps/core/communicator/tcp_communicator.h" 25 #include "ps/core/communicator/task_executor.h" 26 #include "ps/core/file_configuration.h" 27 #include "fl/server/common.h" 28 #include "fl/server/executor.h" 29 #include "fl/server/iteration.h" 30 #ifdef ENABLE_ARMOUR 31 #include "fl/armour/cipher/cipher_init.h" 32 #endif 33 34 namespace mindspore { 35 namespace fl { 36 namespace server { 37 // The sleeping time of the server thread before the networking is completed. 38 constexpr uint32_t kServerSleepTimeForNetworking = 1000; 39 40 // Class Server is the entrance of MindSpore's parameter server training mode and federated learning. 41 class Server { 42 public: GetInstance()43 static Server &GetInstance() { 44 static Server instance; 45 return instance; 46 } 47 48 void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config, 49 const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold); 50 51 // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be 52 // blocked until the server is finalized. 53 // func_graph is the frontend graph which will be parse in server's exector and aggregator. 54 void Run(); 55 56 void SwitchToSafeMode(); 57 void CancelSafeMode(); 58 bool IsSafeMode() const; 59 void WaitExitSafeMode() const; 60 61 // Whether the training job of the server is enabled. 62 InstanceState instance_state() const; 63 64 private: Server()65 Server() 66 : server_node_(nullptr), 67 task_executor_(nullptr), 68 use_tcp_(false), 69 use_http_(false), 70 http_port_(0), 71 func_graph_(nullptr), 72 executor_threshold_(0), 73 communicator_with_server_(nullptr), 74 communicators_with_worker_({}), 75 iteration_(nullptr), 76 safemode_(true), 77 scheduler_ip_(""), 78 scheduler_port_(0), 79 server_num_(0), 80 worker_num_(0), 81 fl_server_port_(0), 82 cipher_initial_client_cnt_(0), 83 cipher_exchange_keys_cnt_(0), 84 cipher_get_keys_cnt_(0), 85 cipher_share_secrets_cnt_(0), 86 cipher_get_secrets_cnt_(0), 87 cipher_get_clientlist_cnt_(0), 88 cipher_reconstruct_secrets_up_cnt_(0), 89 cipher_reconstruct_secrets_down_cnt_(0), 90 cipher_time_window_(0) {} 91 ~Server() = default; 92 Server(const Server &) = delete; 93 Server &operator=(const Server &) = delete; 94 95 // Load variables which is set by ps_context. 96 void InitServerContext(); 97 98 // Try to recover server config from persistent storage. 99 void Recovery(); 100 101 // Initialize the server cluster, server node and communicators. 102 void InitCluster(); 103 bool InitCommunicatorWithServer(); 104 bool InitCommunicatorWithWorker(); 105 106 // Initialize iteration with rounds. Which rounds to use could be set by ps_context as well. 107 void InitIteration(); 108 109 // Register all message and event callbacks for communicators(TCP and HTTP). This method must be called before 110 // communicators are started. 111 void RegisterCommCallbacks(); 112 113 // Register cluster exception callbacks. This method is called in RegisterCommCallbacks. 114 void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator); 115 116 // Register message callbacks. These messages are mainly from scheduler. 117 void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator); 118 119 // Initialize executor according to the server mode. 120 void InitExecutor(); 121 122 // Initialize cipher according to the public param. 123 void InitCipher(); 124 125 // Create round kernels and bind these kernels with corresponding Round. 126 void RegisterRoundKernel(); 127 128 void InitMetrics(); 129 130 // The communicators should be started after all initializations are completed. 131 void StartCommunicator(); 132 133 // The barriers before scaling operations. 134 void ProcessBeforeScalingOut(); 135 void ProcessBeforeScalingIn(); 136 137 // The handlers after scheduler's scaling operations are done. 138 void ProcessAfterScalingOut(); 139 void ProcessAfterScalingIn(); 140 141 // Handlers for enableFLS/disableFLS requests from the scheduler. 142 void HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 143 void HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 144 145 // Finish current instance and start a new one. FLPlan could be changed in this method. 146 void HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 147 148 // Query current instance information. 149 void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 150 151 // The server node is initialized in Server. 152 std::shared_ptr<ps::core::ServerNode> server_node_; 153 154 // The task executor of the communicators. This helps server to handle network message concurrently. The tasks 155 // submitted to this task executor is asynchronous. 156 std::shared_ptr<ps::core::TaskExecutor> task_executor_; 157 158 // Which protocol should communicators use. 159 bool use_tcp_; 160 bool use_http_; 161 uint16_t http_port_; 162 163 // The configure of all rounds. 164 std::vector<RoundConfig> rounds_config_; 165 CipherConfig cipher_config_; 166 167 // The graph passed by the frontend without backend optimizing. 168 FuncGraphPtr func_graph_; 169 170 // The threshold count for executor to do aggregation or optimizing. 171 size_t executor_threshold_; 172 173 // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective 174 // operations, etc. 175 std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_; 176 177 // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP. 178 // In some cases, both types should be supported in one distributed training job. So here we may have multiple 179 // communicators. 180 std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_; 181 182 // Mutex for scaling operations. We must wait server's initialization done before handle scaling events. 183 std::mutex scaling_mtx_; 184 185 // Iteration consists of multiple kinds of rounds. 186 Iteration *iteration_; 187 188 // The flag that represents whether server is in safemode. 189 // If true, the server is not available to workers and clients. 190 std::atomic_bool safemode_; 191 192 // Variables set by ps context. 193 #ifdef ENABLE_ARMOUR 194 armour::CipherInit *cipher_init_{nullptr}; 195 #endif 196 std::string scheduler_ip_; 197 uint16_t scheduler_port_; 198 uint32_t server_num_; 199 uint32_t worker_num_; 200 uint16_t fl_server_port_; 201 size_t cipher_initial_client_cnt_; 202 size_t cipher_exchange_keys_cnt_; 203 size_t cipher_get_keys_cnt_; 204 size_t cipher_share_secrets_cnt_; 205 size_t cipher_get_secrets_cnt_; 206 size_t cipher_get_clientlist_cnt_; 207 size_t cipher_reconstruct_secrets_up_cnt_; 208 size_t cipher_reconstruct_secrets_down_cnt_; 209 uint64_t cipher_time_window_; 210 }; 211 } // namespace server 212 } // namespace fl 213 } // namespace mindspore 214 #endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_ 215