• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/ps_context.h"
18 #include "utils/log_adapter.h"
19 #include "utils/ms_utils.h"
20 #include "backend/kernel_compiler/kernel.h"
21 #if ((defined ENABLE_CPU) && (!defined _WIN32))
22 #include "ps/ps_cache/ps_cache_manager.h"
23 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
24 #endif
25 
26 namespace mindspore {
27 namespace ps {
instance()28 std::shared_ptr<PSContext> PSContext::instance() {
29   static std::shared_ptr<PSContext> ps_instance = nullptr;
30   if (ps_instance == nullptr) {
31     ps_instance.reset(new (std::nothrow) PSContext());
32   }
33   return ps_instance;
34 }
35 
SetPSEnable(bool enabled)36 void PSContext::SetPSEnable(bool enabled) {
37   ps_enabled_ = enabled;
38   if (ps_enabled_) {
39     std::string ms_role = common::GetEnv(kEnvRole);
40     MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
41     if (ms_role == kEnvRoleOfWorker) {
42       is_worker_ = true;
43     } else if (ms_role == kEnvRoleOfPServer) {
44       is_pserver_ = true;
45     } else if (ms_role == kEnvRoleOfScheduler) {
46       is_sched_ = true;
47     } else {
48       MS_LOG(INFO) << "MS_ROLE is " << ms_role;
49     }
50 
51     worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
52     server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
53     scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
54     if (scheduler_host_.length() > kLength) {
55       MS_LOG(EXCEPTION) << "The scheduler host's length can not exceed " << kLength;
56     }
57     scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
58     if (scheduler_port_ > kMaxPort) {
59       MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is illegal.";
60     }
61     scheduler_manage_port_ =
62       static_cast<uint16_t>((std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase)));
63     if (scheduler_manage_port_ > kMaxPort) {
64       MS_LOG(EXCEPTION) << "The port << " << scheduler_manage_port_ << " is illegal.";
65     }
66     cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
67     node_id_ = common::GetEnv(kEnvNodeId);
68     if (node_id_.length() > kLength) {
69       MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
70     }
71   } else {
72     MS_LOG(INFO) << "PS mode is disabled.";
73     is_worker_ = false;
74     is_pserver_ = false;
75     is_sched_ = false;
76   }
77 }
78 
is_ps_mode() const79 bool PSContext::is_ps_mode() const {
80   if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
81     return true;
82   }
83   return ps_enabled_;
84 }
85 
Reset()86 void PSContext::Reset() {
87   ps_enabled_ = false;
88   is_worker_ = false;
89   is_pserver_ = false;
90   is_sched_ = false;
91 #if ((defined ENABLE_CPU) && (!defined _WIN32))
92   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
93     ps_cache_instance.Finalize();
94     set_cache_enable(false);
95   }
96 #endif
97 }
98 
ms_role() const99 std::string PSContext::ms_role() const {
100   if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
101     return role_;
102   }
103   if (is_worker_) {
104     return kEnvRoleOfWorker;
105   } else if (is_pserver_) {
106     return kEnvRoleOfPServer;
107   } else if (is_sched_) {
108     return kEnvRoleOfScheduler;
109   } else {
110     return kEnvRoleOfNotPS;
111   }
112 }
113 
is_worker() const114 bool PSContext::is_worker() const {
115   if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
116     return role_ == kEnvRoleOfWorker;
117   }
118   return is_worker_;
119 }
120 
is_server() const121 bool PSContext::is_server() const {
122   if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
123     return role_ == kEnvRoleOfServer;
124   }
125   return is_pserver_;
126 }
127 
is_scheduler() const128 bool PSContext::is_scheduler() const {
129   if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
130     return role_ == kEnvRoleOfScheduler;
131   }
132   return is_sched_;
133 }
134 
initial_worker_num() const135 uint32_t PSContext::initial_worker_num() const { return worker_num_; }
136 
initial_server_num() const137 uint32_t PSContext::initial_server_num() const { return server_num_; }
138 
scheduler_host() const139 std::string PSContext::scheduler_host() const { return scheduler_host_; }
140 
SetPSRankId(uint32_t rank_id)141 void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
142 
ps_rank_id() const143 uint32_t PSContext::ps_rank_id() const { return rank_id_; }
144 
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size) const145 void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
146                                     size_t vocab_size) const {
147 #if ((defined ENABLE_CPU) && (!defined _WIN32))
148   ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
149 #endif
150 }
151 
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name,size_t cache_vocab_size,size_t embedding_size) const152 void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
153                                       size_t cache_vocab_size, size_t embedding_size) const {
154 #if ((defined ENABLE_CPU) && (!defined _WIN32))
155   ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
156 #endif
157 }
158 
InsertWeightInitInfo(const std::string & param_name,size_t global_seed,size_t op_seed) const159 void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
160 #if ((defined ENABLE_CPU) && (!defined _WIN32))
161   ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
162 #endif
163 }
164 
InsertAccumuInitInfo(const std::string & param_name,float init_val) const165 void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
166 #if ((defined ENABLE_CPU) && (!defined _WIN32))
167   ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
168 #endif
169 }
170 
CloneHashTable(const std::string & dest_param_name,const std::string & src_param_name) const171 void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
172 #if ((defined ENABLE_CPU) && (!defined _WIN32))
173   ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
174 #endif
175 }
176 
set_cache_enable(bool cache_enable) const177 void PSContext::set_cache_enable(bool cache_enable) const {
178 #if ((defined ENABLE_CPU) && (!defined _WIN32))
179   PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
180 #endif
181 }
182 
set_rank_id(uint32_t rank_id) const183 void PSContext::set_rank_id(uint32_t rank_id) const {
184 #if ((defined ENABLE_CPU) && (!defined _WIN32))
185   ps_cache_instance.set_rank_id(rank_id);
186 #endif
187 }
188 
set_server_mode(const std::string & server_mode)189 void PSContext::set_server_mode(const std::string &server_mode) {
190   if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
191     MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
192                       << " or " << kServerModeHybrid;
193     return;
194   }
195   MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
196   server_mode_ = server_mode;
197 }
198 
server_mode() const199 const std::string &PSContext::server_mode() const { return server_mode_; }
200 
set_encrypt_type(const std::string & encrypt_type)201 void PSContext::set_encrypt_type(const std::string &encrypt_type) {
202   if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
203     MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
204                       << kDPEncryptType << " or " << kPWEncryptType;
205     return;
206   }
207   encrypt_type_ = encrypt_type;
208 }
encrypt_type() const209 const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
210 
set_dp_eps(float dp_eps)211 void PSContext::set_dp_eps(float dp_eps) {
212   if (dp_eps > 0) {
213     dp_eps_ = dp_eps;
214   } else {
215     MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
216     return;
217   }
218 }
219 
dp_eps() const220 float PSContext::dp_eps() const { return dp_eps_; }
221 
set_dp_delta(float dp_delta)222 void PSContext::set_dp_delta(float dp_delta) {
223   if (dp_delta > 0 && dp_delta < 1) {
224     dp_delta_ = dp_delta;
225   } else {
226     MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
227     return;
228   }
229 }
dp_delta() const230 float PSContext::dp_delta() const { return dp_delta_; }
231 
set_dp_norm_clip(float dp_norm_clip)232 void PSContext::set_dp_norm_clip(float dp_norm_clip) {
233   if (dp_norm_clip > 0) {
234     dp_norm_clip_ = dp_norm_clip;
235   } else {
236     MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
237     return;
238   }
239 }
dp_norm_clip() const240 float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
241 
set_ms_role(const std::string & role)242 void PSContext::set_ms_role(const std::string &role) {
243   if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
244     MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
245     return;
246   }
247   if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
248     MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
249     return;
250   }
251   role_ = role;
252 }
253 
set_worker_num(uint32_t worker_num)254 void PSContext::set_worker_num(uint32_t worker_num) {
255   // Hybrid training mode only supports one worker for now.
256   if (server_mode_ == kServerModeHybrid && worker_num != 1) {
257     MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
258     return;
259   }
260   worker_num_ = worker_num;
261 }
worker_num() const262 uint32_t PSContext::worker_num() const { return worker_num_; }
263 
set_server_num(uint32_t server_num)264 void PSContext::set_server_num(uint32_t server_num) {
265   if (server_num == 0) {
266     MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
267     return;
268   }
269   server_num_ = server_num;
270 }
server_num() const271 uint32_t PSContext::server_num() const { return server_num_; }
272 
set_scheduler_ip(const std::string & sched_ip)273 void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
274 
scheduler_ip() const275 std::string PSContext::scheduler_ip() const { return scheduler_host_; }
276 
set_scheduler_port(uint16_t sched_port)277 void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
278 
scheduler_port() const279 uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
280 
GenerateResetterRound()281 void PSContext::GenerateResetterRound() {
282   uint32_t binary_server_context = 0;
283   bool is_parameter_server_mode = false;
284   bool is_federated_learning_mode = false;
285   bool is_mixed_training_mode = false;
286   bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
287 
288   if (server_mode_ == kServerModePS) {
289     is_parameter_server_mode = true;
290   } else if (server_mode_ == kServerModeFL) {
291     is_federated_learning_mode = true;
292   } else if (server_mode_ == kServerModeHybrid) {
293     is_mixed_training_mode = true;
294   } else {
295     MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
296                       << " or " << kServerModeHybrid;
297     return;
298   }
299 
300   binary_server_context = ((unsigned int)is_parameter_server_mode) | ((unsigned int)is_federated_learning_mode << 1) |
301                           ((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3);
302   if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
303     resetter_round_ = ResetterRound::kNoNeedToReset;
304   } else {
305     resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
306   }
307   MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
308   return;
309 }
310 
resetter_round() const311 ResetterRound PSContext::resetter_round() const { return resetter_round_; }
312 
set_fl_server_port(uint16_t fl_server_port)313 void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
314 
fl_server_port() const315 uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
316 
set_fl_client_enable(bool enabled)317 void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
318 
fl_client_enable() const319 bool PSContext::fl_client_enable() const { return fl_client_enable_; }
320 
set_start_fl_job_threshold(uint64_t start_fl_job_threshold)321 void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
322   start_fl_job_threshold_ = start_fl_job_threshold;
323 }
324 
start_fl_job_threshold() const325 uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
326 
set_start_fl_job_time_window(uint64_t start_fl_job_time_window)327 void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
328   start_fl_job_time_window_ = start_fl_job_time_window;
329 }
330 
start_fl_job_time_window() const331 uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
332 
set_update_model_ratio(float update_model_ratio)333 void PSContext::set_update_model_ratio(float update_model_ratio) {
334   if (update_model_ratio > 1.0) {
335     MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
336     return;
337   }
338   update_model_ratio_ = update_model_ratio;
339 }
340 
update_model_ratio() const341 float PSContext::update_model_ratio() const { return update_model_ratio_; }
342 
set_update_model_time_window(uint64_t update_model_time_window)343 void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
344   update_model_time_window_ = update_model_time_window;
345 }
346 
update_model_time_window() const347 uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
348 
set_share_secrets_ratio(float share_secrets_ratio)349 void PSContext::set_share_secrets_ratio(float share_secrets_ratio) {
350   if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) {
351     share_secrets_ratio_ = share_secrets_ratio;
352   } else {
353     MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1].";
354     return;
355   }
356 }
357 
share_secrets_ratio() const358 float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
359 
set_cipher_time_window(uint64_t cipher_time_window)360 void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
361   if (cipher_time_window_ < 0) {
362     MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.";
363     return;
364   }
365   cipher_time_window_ = cipher_time_window;
366 }
367 
cipher_time_window() const368 uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
369 
set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold)370 void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
371   if (reconstruct_secrets_threshold == 0) {
372     MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
373     return;
374   }
375   reconstruct_secrets_threshold_ = reconstruct_secrets_threshold;
376 }
377 
reconstruct_secrets_threshold() const378 uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; }
379 
set_fl_name(const std::string & fl_name)380 void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
381 
fl_name() const382 const std::string &PSContext::fl_name() const { return fl_name_; }
383 
set_fl_iteration_num(uint64_t fl_iteration_num)384 void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
385 
fl_iteration_num() const386 uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
387 
set_client_epoch_num(uint64_t client_epoch_num)388 void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
389 
client_epoch_num() const390 uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
391 
set_client_batch_size(uint64_t client_batch_size)392 void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
393 
client_batch_size() const394 uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
395 
set_client_learning_rate(float client_learning_rate)396 void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
397 
client_learning_rate() const398 float PSContext::client_learning_rate() const { return client_learning_rate_; }
399 
set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration)400 void PSContext::set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration) {
401   worker_step_num_per_iteration_ = worker_step_num_per_iteration;
402 }
403 
worker_step_num_per_iteration() const404 uint64_t PSContext::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
405 
enable_ssl() const406 bool PSContext::enable_ssl() const { return enable_ssl_; }
407 
set_enable_ssl(bool enabled)408 void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
409 
cluster_config()410 core::ClusterConfig &PSContext::cluster_config() {
411   if (cluster_config_ == nullptr) {
412     MS_LOG(EXCEPTION) << "The cluster config is empty.";
413   }
414   return *cluster_config_;
415 }
416 
set_scheduler_manage_port(uint16_t sched_port)417 void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
418 
scheduler_manage_port() const419 uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
420 
set_config_file_path(const std::string & path)421 void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
422 
config_file_path() const423 std::string PSContext::config_file_path() const { return config_file_path_; }
424 
set_node_id(const std::string & node_id)425 void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
426 
node_id() const427 const std::string &PSContext::node_id() const { return node_id_; }
428 
client_password() const429 std::string PSContext::client_password() const { return client_password_; }
set_client_password(const std::string & password)430 void PSContext::set_client_password(const std::string &password) { client_password_ = password; }
431 
server_password() const432 std::string PSContext::server_password() const { return server_password_; }
set_server_password(const std::string & password)433 void PSContext::set_server_password(const std::string &password) { server_password_ = password; }
434 }  // namespace ps
435 }  // namespace mindspore
436