• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
18 #define MINDSPORE_CCSRC_PS_CONTEXT_H_
19 
20 #include <map>
21 #include <string>
22 #include <memory>
23 #include "ps/constants.h"
24 #include "ps/core/cluster_metadata.h"
25 #include "ps/core/cluster_config.h"
26 
27 namespace mindspore {
28 namespace ps {
29 constexpr char kServerModePS[] = "PARAMETER_SERVER";
30 constexpr char kServerModeFL[] = "FEDERATED_LEARNING";
31 constexpr char kServerModeHybrid[] = "HYBRID_TRAINING";
32 constexpr char kEnvRole[] = "MS_ROLE";
33 constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
34 constexpr char kEnvRoleOfServer[] = "MS_SERVER";
35 constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
36 constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
37 constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
38 constexpr char kDPEncryptType[] = "DP_ENCRYPT";
39 constexpr char kPWEncryptType[] = "PW_ENCRYPT";
40 constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
41 
42 // Use binary data to represent federated learning server's context so that we can judge which round resets the
43 // iteration. From right to left, each bit stands for:
44 // 0: Server is in parameter server mode.
45 // 1: Server is in federated learning mode.
46 // 2: Server is in mixed training mode.
47 // 3: Server enables pairwise encrypt algorithm.
48 // For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled.
49 enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight, kPushMetrics };
50 const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
51                                                                          {0b1010, ResetterRound::kReconstructSeccrets},
52                                                                          {0b1100, ResetterRound::kPushMetrics},
53                                                                          {0b0100, ResetterRound::kPushMetrics}};
54 
55 class PSContext {
56  public:
57   ~PSContext() = default;
58   PSContext(PSContext const &) = delete;
59   PSContext &operator=(const PSContext &) = delete;
60   static std::shared_ptr<PSContext> instance();
61 
62   void SetPSEnable(bool enabled);
63   bool is_ps_mode() const;
64   void Reset();
65   std::string ms_role() const;
66   bool is_worker() const;
67   bool is_server() const;
68   bool is_scheduler() const;
69   uint32_t initial_worker_num() const;
70   uint32_t initial_server_num() const;
71   std::string scheduler_host() const;
72   void SetPSRankId(uint32_t rank_id);
73   uint32_t ps_rank_id() const;
74   void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
75                            size_t vocab_size) const;
76   void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
77                              size_t cache_vocab_size, size_t embedding_size) const;
78   void InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const;
79   void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
80   void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
81   void set_cache_enable(bool cache_enable) const;
82   void set_rank_id(uint32_t rank_id) const;
83   bool enable_ssl() const;
84   void set_enable_ssl(bool enabled);
85 
86   std::string client_password() const;
87   void set_client_password(const std::string &password);
88   std::string server_password() const;
89   void set_server_password(const std::string &password);
90 
91   // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
92   // by ps_context.
93   void set_server_mode(const std::string &server_mode);
94   const std::string &server_mode() const;
95 
96   void set_ms_role(const std::string &role);
97 
98   void set_worker_num(uint32_t worker_num);
99   uint32_t worker_num() const;
100 
101   void set_server_num(uint32_t server_num);
102   uint32_t server_num() const;
103 
104   void set_scheduler_ip(const std::string &sched_ip);
105   std::string scheduler_ip() const;
106 
107   void set_scheduler_port(uint16_t sched_port);
108   uint16_t scheduler_port() const;
109 
110   // Methods federated learning.
111 
112   // Generate which round should reset the iteration.
113   void GenerateResetterRound();
114   ResetterRound resetter_round() const;
115 
116   void set_fl_server_port(uint16_t fl_server_port);
117   uint16_t fl_server_port() const;
118 
119   // Set true if this process is a federated learning worker in cross-silo scenario.
120   void set_fl_client_enable(bool enabled);
121   bool fl_client_enable() const;
122 
123   void set_start_fl_job_threshold(uint64_t start_fl_job_threshold);
124   uint64_t start_fl_job_threshold() const;
125 
126   void set_start_fl_job_time_window(uint64_t start_fl_job_time_window);
127   uint64_t start_fl_job_time_window() const;
128 
129   void set_update_model_ratio(float update_model_ratio);
130   float update_model_ratio() const;
131 
132   void set_update_model_time_window(uint64_t update_model_time_window);
133   uint64_t update_model_time_window() const;
134 
135   void set_share_secrets_ratio(float share_secrets_ratio);
136   float share_secrets_ratio() const;
137 
138   void set_cipher_time_window(uint64_t cipher_time_window);
139   uint64_t cipher_time_window() const;
140 
141   void set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold);
142   uint64_t reconstruct_secrets_threshold() const;
143 
144   void set_fl_name(const std::string &fl_name);
145   const std::string &fl_name() const;
146 
147   // Set the iteration number of the federated learning.
148   void set_fl_iteration_num(uint64_t fl_iteration_num);
149   uint64_t fl_iteration_num() const;
150 
151   // Set the training epoch number of the client.
152   void set_client_epoch_num(uint64_t client_epoch_num);
153   uint64_t client_epoch_num() const;
154 
155   // Set the data batch size of the client.
156   void set_client_batch_size(uint64_t client_batch_size);
157   uint64_t client_batch_size() const;
158 
159   void set_client_learning_rate(float client_learning_rate);
160   float client_learning_rate() const;
161 
162   void set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration);
163   uint64_t worker_step_num_per_iteration() const;
164 
165   core::ClusterConfig &cluster_config();
166 
167   void set_scheduler_manage_port(uint16_t sched_port);
168   uint16_t scheduler_manage_port() const;
169 
170   void set_config_file_path(const std::string &path);
171   std::string config_file_path() const;
172 
173   void set_dp_eps(float dp_eps);
174   float dp_eps() const;
175 
176   void set_dp_delta(float dp_delta);
177   float dp_delta() const;
178 
179   void set_dp_norm_clip(float dp_norm_clip);
180   float dp_norm_clip() const;
181 
182   void set_encrypt_type(const std::string &encrypt_type);
183   const std::string &encrypt_type() const;
184 
185   void set_node_id(const std::string &node_id);
186   const std::string &node_id() const;
187 
188  private:
PSContext()189   PSContext()
190       : ps_enabled_(false),
191         is_worker_(false),
192         is_pserver_(false),
193         is_sched_(false),
194         enable_ssl_(false),
195         rank_id_(0),
196         worker_num_(0),
197         server_num_(0),
198         scheduler_host_("0.0.0.0"),
199         scheduler_port_(6667),
200         role_(kEnvRoleOfNotPS),
201         server_mode_(""),
202         resetter_round_(ResetterRound::kNoNeedToReset),
203         fl_server_port_(6668),
204         fl_client_enable_(false),
205         fl_name_(""),
206         start_fl_job_threshold_(0),
207         start_fl_job_time_window_(3000),
208         update_model_ratio_(1.0),
209         update_model_time_window_(3000),
210         share_secrets_ratio_(1.0),
211         cipher_time_window_(300000),
212         reconstruct_secrets_threshold_(2000),
213         fl_iteration_num_(20),
214         client_epoch_num_(25),
215         client_batch_size_(32),
216         client_learning_rate_(0.001),
217         worker_step_num_per_iteration_(65),
218         secure_aggregation_(false),
219         cluster_config_(nullptr),
220         scheduler_manage_port_(11202),
221         config_file_path_(""),
222         dp_eps_(50),
223         dp_delta_(0.01),
224         dp_norm_clip_(1.0),
225         encrypt_type_(kNotEncryptType),
226         node_id_(""),
227         client_password_(""),
228         server_password_("") {}
229   bool ps_enabled_;
230   bool is_worker_;
231   bool is_pserver_;
232   bool is_sched_;
233   bool enable_ssl_;
234   uint32_t rank_id_;
235   uint32_t worker_num_;
236   uint32_t server_num_;
237   std::string scheduler_host_;
238   uint16_t scheduler_port_;
239 
240   // The server process's role.
241   std::string role_;
242 
243   // Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode.
244   std::string server_mode_;
245 
246   // The round which will reset the iteration. Used in federated learning for now.
247   ResetterRound resetter_round_;
248 
249   // Http port of federated learning server.
250   uint16_t fl_server_port_;
251 
252   // Whether this process is the federated client. Used in cross-silo scenario of federated learning.
253   bool fl_client_enable_;
254 
255   // Federated learning job name.
256   std::string fl_name_;
257 
258   // The threshold count of startFLJob round. Used in federated learning for now.
259   uint64_t start_fl_job_threshold_;
260 
261   // The time window of startFLJob round in millisecond.
262   uint64_t start_fl_job_time_window_;
263 
264   // Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_.
265   float update_model_ratio_;
266 
267   // The time window of updateModel round in millisecond.
268   uint64_t update_model_time_window_;
269 
270   // Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_.
271   float share_secrets_ratio_;
272 
273   // The time window of each cipher round in millisecond.
274   uint64_t cipher_time_window_;
275 
276   // The threshold count of reconstruct secrets round. Used in federated learning for now.
277   uint64_t reconstruct_secrets_threshold_;
278 
279   // Iteration number of federeated learning, which is the number of interactions between client and server.
280   uint64_t fl_iteration_num_;
281 
282   // Client training epoch number. Used in federated learning for now.
283   uint64_t client_epoch_num_;
284 
285   // Client training data batch size. Used in federated learning for now.
286   uint64_t client_batch_size_;
287 
288   // Client training learning rate. Used in federated learning for now.
289   float client_learning_rate_;
290 
291   // The worker standalone training step number before communicating with server.
292   uint64_t worker_step_num_per_iteration_;
293 
294   // Whether to use secure aggregation algorithm. Used in federated learning for now.
295   bool secure_aggregation_;
296 
297   // The cluster config read through environment variables, the value does not change.
298   std::unique_ptr<core::ClusterConfig> cluster_config_;
299 
300   // The port used by scheduler to receive http requests for scale out or scale in.
301   uint16_t scheduler_manage_port_;
302 
303   // The path of the configuration file, used to configure the certification path and persistent storage type, etc.
304   std::string config_file_path_;
305 
306   // Epsilon budget of differential privacy mechanism. Used in federated learning for now.
307   float dp_eps_;
308 
309   // Delta budget of differential privacy mechanism. Used in federated learning for now.
310   float dp_delta_;
311 
312   // Norm clip factor of differential privacy mechanism. Used in federated learning for now.
313   float dp_norm_clip_;
314 
315   // Secure mechanism for federated learning. Used in federated learning for now.
316   std::string encrypt_type_;
317 
318   // Unique id of the node
319   std::string node_id_;
320 
321   // Password used to decode p12 file.
322   std::string client_password_;
323   // Password used to decode p12 file.
324   std::string server_password_;
325 };
326 }  // namespace ps
327 }  // namespace mindspore
328 #endif  // MINDSPORE_CCSRC_PS_CONTEXT_H_
329