1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_ 18 #define MINDSPORE_CCSRC_PS_CONTEXT_H_ 19 20 #include <map> 21 #include <string> 22 #include <memory> 23 #include "ps/constants.h" 24 #include "ps/core/cluster_metadata.h" 25 #include "ps/core/cluster_config.h" 26 27 namespace mindspore { 28 namespace ps { 29 constexpr char kServerModePS[] = "PARAMETER_SERVER"; 30 constexpr char kServerModeFL[] = "FEDERATED_LEARNING"; 31 constexpr char kServerModeHybrid[] = "HYBRID_TRAINING"; 32 constexpr char kEnvRole[] = "MS_ROLE"; 33 constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; 34 constexpr char kEnvRoleOfServer[] = "MS_SERVER"; 35 constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; 36 constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; 37 constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; 38 constexpr char kDPEncryptType[] = "DP_ENCRYPT"; 39 constexpr char kPWEncryptType[] = "PW_ENCRYPT"; 40 constexpr char kNotEncryptType[] = "NOT_ENCRYPT"; 41 42 // Use binary data to represent federated learning server's context so that we can judge which round resets the 43 // iteration. From right to left, each bit stands for: 44 // 0: Server is in parameter server mode. 45 // 1: Server is in federated learning mode. 46 // 2: Server is in mixed training mode. 47 // 3: Server enables pairwise encrypt algorithm. 48 // For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled. 49 enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight, kPushMetrics }; 50 const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel}, 51 {0b1010, ResetterRound::kReconstructSeccrets}, 52 {0b1100, ResetterRound::kPushMetrics}, 53 {0b0100, ResetterRound::kPushMetrics}}; 54 55 class PSContext { 56 public: 57 ~PSContext() = default; 58 PSContext(PSContext const &) = delete; 59 PSContext &operator=(const PSContext &) = delete; 60 static std::shared_ptr<PSContext> instance(); 61 62 void SetPSEnable(bool enabled); 63 bool is_ps_mode() const; 64 void Reset(); 65 std::string ms_role() const; 66 bool is_worker() const; 67 bool is_server() const; 68 bool is_scheduler() const; 69 uint32_t initial_worker_num() const; 70 uint32_t initial_server_num() const; 71 std::string scheduler_host() const; 72 void SetPSRankId(uint32_t rank_id); 73 uint32_t ps_rank_id() const; 74 void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, 75 size_t vocab_size) const; 76 void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, 77 size_t cache_vocab_size, size_t embedding_size) const; 78 void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const; 79 void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; 80 void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; 81 void set_cache_enable(bool cache_enable) const; 82 void set_rank_id(uint32_t rank_id) const; 83 bool enable_ssl() const; 84 void set_enable_ssl(bool enabled); 85 86 std::string client_password() const; 87 void set_client_password(const std::string &password); 88 std::string server_password() const; 89 void set_server_password(const std::string &password); 90 91 // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set 92 // by ps_context. 93 void set_server_mode(const std::string &server_mode); 94 const std::string &server_mode() const; 95 96 void set_ms_role(const std::string &role); 97 98 void set_worker_num(uint32_t worker_num); 99 uint32_t worker_num() const; 100 101 void set_server_num(uint32_t server_num); 102 uint32_t server_num() const; 103 104 void set_scheduler_ip(const std::string &sched_ip); 105 std::string scheduler_ip() const; 106 107 void set_scheduler_port(uint16_t sched_port); 108 uint16_t scheduler_port() const; 109 110 // Methods federated learning. 111 112 // Generate which round should reset the iteration. 113 void GenerateResetterRound(); 114 ResetterRound resetter_round() const; 115 116 void set_fl_server_port(uint16_t fl_server_port); 117 uint16_t fl_server_port() const; 118 119 // Set true if this process is a federated learning worker in cross-silo scenario. 120 void set_fl_client_enable(bool enabled); 121 bool fl_client_enable() const; 122 123 void set_start_fl_job_threshold(uint64_t start_fl_job_threshold); 124 uint64_t start_fl_job_threshold() const; 125 126 void set_start_fl_job_time_window(uint64_t start_fl_job_time_window); 127 uint64_t start_fl_job_time_window() const; 128 129 void set_update_model_ratio(float update_model_ratio); 130 float update_model_ratio() const; 131 132 void set_update_model_time_window(uint64_t update_model_time_window); 133 uint64_t update_model_time_window() const; 134 135 void set_share_secrets_ratio(float share_secrets_ratio); 136 float share_secrets_ratio() const; 137 138 void set_cipher_time_window(uint64_t cipher_time_window); 139 uint64_t cipher_time_window() const; 140 141 void set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold); 142 uint64_t reconstruct_secrets_threshold() const; 143 144 void set_fl_name(const std::string &fl_name); 145 const std::string &fl_name() const; 146 147 // Set the iteration number of the federated learning. 148 void set_fl_iteration_num(uint64_t fl_iteration_num); 149 uint64_t fl_iteration_num() const; 150 151 // Set the training epoch number of the client. 152 void set_client_epoch_num(uint64_t client_epoch_num); 153 uint64_t client_epoch_num() const; 154 155 // Set the data batch size of the client. 156 void set_client_batch_size(uint64_t client_batch_size); 157 uint64_t client_batch_size() const; 158 159 void set_client_learning_rate(float client_learning_rate); 160 float client_learning_rate() const; 161 162 void set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration); 163 uint64_t worker_step_num_per_iteration() const; 164 165 core::ClusterConfig &cluster_config(); 166 167 void set_scheduler_manage_port(uint16_t sched_port); 168 uint16_t scheduler_manage_port() const; 169 170 void set_config_file_path(const std::string &path); 171 std::string config_file_path() const; 172 173 void set_dp_eps(float dp_eps); 174 float dp_eps() const; 175 176 void set_dp_delta(float dp_delta); 177 float dp_delta() const; 178 179 void set_dp_norm_clip(float dp_norm_clip); 180 float dp_norm_clip() const; 181 182 void set_encrypt_type(const std::string &encrypt_type); 183 const std::string &encrypt_type() const; 184 185 void set_node_id(const std::string &node_id); 186 const std::string &node_id() const; 187 188 private: PSContext()189 PSContext() 190 : ps_enabled_(false), 191 is_worker_(false), 192 is_pserver_(false), 193 is_sched_(false), 194 enable_ssl_(false), 195 rank_id_(0), 196 worker_num_(0), 197 server_num_(0), 198 scheduler_host_("0.0.0.0"), 199 scheduler_port_(6667), 200 role_(kEnvRoleOfNotPS), 201 server_mode_(""), 202 resetter_round_(ResetterRound::kNoNeedToReset), 203 fl_server_port_(6668), 204 fl_client_enable_(false), 205 fl_name_(""), 206 start_fl_job_threshold_(0), 207 start_fl_job_time_window_(3000), 208 update_model_ratio_(1.0), 209 update_model_time_window_(3000), 210 share_secrets_ratio_(1.0), 211 cipher_time_window_(300000), 212 reconstruct_secrets_threshold_(2000), 213 fl_iteration_num_(20), 214 client_epoch_num_(25), 215 client_batch_size_(32), 216 client_learning_rate_(0.001), 217 worker_step_num_per_iteration_(65), 218 secure_aggregation_(false), 219 cluster_config_(nullptr), 220 scheduler_manage_port_(11202), 221 config_file_path_(""), 222 dp_eps_(50), 223 dp_delta_(0.01), 224 dp_norm_clip_(1.0), 225 encrypt_type_(kNotEncryptType), 226 node_id_(""), 227 client_password_(""), 228 server_password_("") {} 229 bool ps_enabled_; 230 bool is_worker_; 231 bool is_pserver_; 232 bool is_sched_; 233 bool enable_ssl_; 234 uint32_t rank_id_; 235 uint32_t worker_num_; 236 uint32_t server_num_; 237 std::string scheduler_host_; 238 uint16_t scheduler_port_; 239 240 // The server process's role. 241 std::string role_; 242 243 // Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode. 244 std::string server_mode_; 245 246 // The round which will reset the iteration. Used in federated learning for now. 247 ResetterRound resetter_round_; 248 249 // Http port of federated learning server. 250 uint16_t fl_server_port_; 251 252 // Whether this process is the federated client. Used in cross-silo scenario of federated learning. 253 bool fl_client_enable_; 254 255 // Federated learning job name. 256 std::string fl_name_; 257 258 // The threshold count of startFLJob round. Used in federated learning for now. 259 uint64_t start_fl_job_threshold_; 260 261 // The time window of startFLJob round in millisecond. 262 uint64_t start_fl_job_time_window_; 263 264 // Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_. 265 float update_model_ratio_; 266 267 // The time window of updateModel round in millisecond. 268 uint64_t update_model_time_window_; 269 270 // Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_. 271 float share_secrets_ratio_; 272 273 // The time window of each cipher round in millisecond. 274 uint64_t cipher_time_window_; 275 276 // The threshold count of reconstruct secrets round. Used in federated learning for now. 277 uint64_t reconstruct_secrets_threshold_; 278 279 // Iteration number of federeated learning, which is the number of interactions between client and server. 280 uint64_t fl_iteration_num_; 281 282 // Client training epoch number. Used in federated learning for now. 283 uint64_t client_epoch_num_; 284 285 // Client training data batch size. Used in federated learning for now. 286 uint64_t client_batch_size_; 287 288 // Client training learning rate. Used in federated learning for now. 289 float client_learning_rate_; 290 291 // The worker standalone training step number before communicating with server. 292 uint64_t worker_step_num_per_iteration_; 293 294 // Whether to use secure aggregation algorithm. Used in federated learning for now. 295 bool secure_aggregation_; 296 297 // The cluster config read through environment variables, the value does not change. 298 std::unique_ptr<core::ClusterConfig> cluster_config_; 299 300 // The port used by scheduler to receive http requests for scale out or scale in. 301 uint16_t scheduler_manage_port_; 302 303 // The path of the configuration file, used to configure the certification path and persistent storage type, etc. 304 std::string config_file_path_; 305 306 // Epsilon budget of differential privacy mechanism. Used in federated learning for now. 307 float dp_eps_; 308 309 // Delta budget of differential privacy mechanism. Used in federated learning for now. 310 float dp_delta_; 311 312 // Norm clip factor of differential privacy mechanism. Used in federated learning for now. 313 float dp_norm_clip_; 314 315 // Secure mechanism for federated learning. Used in federated learning for now. 316 std::string encrypt_type_; 317 318 // Unique id of the node 319 std::string node_id_; 320 321 // Password used to decode p12 file. 322 std::string client_password_; 323 // Password used to decode p12 file. 324 std::string server_password_; 325 }; 326 } // namespace ps 327 } // namespace mindspore 328 #endif // MINDSPORE_CCSRC_PS_CONTEXT_H_ 329