1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/ps_context.h"
18 #include "utils/log_adapter.h"
19 #include "utils/ms_utils.h"
20 #include "backend/kernel_compiler/kernel.h"
21 #if ((defined ENABLE_CPU) && (!defined _WIN32))
22 #include "ps/ps_cache/ps_cache_manager.h"
23 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
24 #endif
25
26 namespace mindspore {
27 namespace ps {
instance()28 std::shared_ptr<PSContext> PSContext::instance() {
29 static std::shared_ptr<PSContext> ps_instance = nullptr;
30 if (ps_instance == nullptr) {
31 ps_instance.reset(new (std::nothrow) PSContext());
32 }
33 return ps_instance;
34 }
35
SetPSEnable(bool enabled)36 void PSContext::SetPSEnable(bool enabled) {
37 ps_enabled_ = enabled;
38 if (ps_enabled_) {
39 std::string ms_role = common::GetEnv(kEnvRole);
40 MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
41 if (ms_role == kEnvRoleOfWorker) {
42 is_worker_ = true;
43 } else if (ms_role == kEnvRoleOfPServer) {
44 is_pserver_ = true;
45 } else if (ms_role == kEnvRoleOfScheduler) {
46 is_sched_ = true;
47 } else {
48 MS_LOG(INFO) << "MS_ROLE is " << ms_role;
49 }
50
51 worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
52 server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
53 scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
54 if (scheduler_host_.length() > kLength) {
55 MS_LOG(EXCEPTION) << "The scheduler host's length can not exceed " << kLength;
56 }
57 scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
58 if (scheduler_port_ > kMaxPort) {
59 MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is illegal.";
60 }
61 scheduler_manage_port_ =
62 static_cast<uint16_t>((std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase)));
63 if (scheduler_manage_port_ > kMaxPort) {
64 MS_LOG(EXCEPTION) << "The port << " << scheduler_manage_port_ << " is illegal.";
65 }
66 cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
67 node_id_ = common::GetEnv(kEnvNodeId);
68 if (node_id_.length() > kLength) {
69 MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
70 }
71 } else {
72 MS_LOG(INFO) << "PS mode is disabled.";
73 is_worker_ = false;
74 is_pserver_ = false;
75 is_sched_ = false;
76 }
77 }
78
is_ps_mode() const79 bool PSContext::is_ps_mode() const {
80 if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
81 return true;
82 }
83 return ps_enabled_;
84 }
85
Reset()86 void PSContext::Reset() {
87 ps_enabled_ = false;
88 is_worker_ = false;
89 is_pserver_ = false;
90 is_sched_ = false;
91 #if ((defined ENABLE_CPU) && (!defined _WIN32))
92 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
93 ps_cache_instance.Finalize();
94 set_cache_enable(false);
95 }
96 #endif
97 }
98
ms_role() const99 std::string PSContext::ms_role() const {
100 if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
101 return role_;
102 }
103 if (is_worker_) {
104 return kEnvRoleOfWorker;
105 } else if (is_pserver_) {
106 return kEnvRoleOfPServer;
107 } else if (is_sched_) {
108 return kEnvRoleOfScheduler;
109 } else {
110 return kEnvRoleOfNotPS;
111 }
112 }
113
is_worker() const114 bool PSContext::is_worker() const {
115 if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
116 return role_ == kEnvRoleOfWorker;
117 }
118 return is_worker_;
119 }
120
is_server() const121 bool PSContext::is_server() const {
122 if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
123 return role_ == kEnvRoleOfServer;
124 }
125 return is_pserver_;
126 }
127
is_scheduler() const128 bool PSContext::is_scheduler() const {
129 if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
130 return role_ == kEnvRoleOfScheduler;
131 }
132 return is_sched_;
133 }
134
initial_worker_num() const135 uint32_t PSContext::initial_worker_num() const { return worker_num_; }
136
initial_server_num() const137 uint32_t PSContext::initial_server_num() const { return server_num_; }
138
scheduler_host() const139 std::string PSContext::scheduler_host() const { return scheduler_host_; }
140
SetPSRankId(uint32_t rank_id)141 void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
142
ps_rank_id() const143 uint32_t PSContext::ps_rank_id() const { return rank_id_; }
144
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size) const145 void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
146 size_t vocab_size) const {
147 #if ((defined ENABLE_CPU) && (!defined _WIN32))
148 ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
149 #endif
150 }
151
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name,size_t cache_vocab_size,size_t embedding_size) const152 void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
153 size_t cache_vocab_size, size_t embedding_size) const {
154 #if ((defined ENABLE_CPU) && (!defined _WIN32))
155 ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
156 #endif
157 }
158
InsertWeightInitInfo(const std::string & param_name,size_t global_seed,size_t op_seed) const159 void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const {
160 #if ((defined ENABLE_CPU) && (!defined _WIN32))
161 ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
162 #endif
163 }
164
InsertAccumuInitInfo(const std::string & param_name,float init_val) const165 void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
166 #if ((defined ENABLE_CPU) && (!defined _WIN32))
167 ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
168 #endif
169 }
170
CloneHashTable(const std::string & dest_param_name,const std::string & src_param_name) const171 void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
172 #if ((defined ENABLE_CPU) && (!defined _WIN32))
173 ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
174 #endif
175 }
176
set_cache_enable(bool cache_enable) const177 void PSContext::set_cache_enable(bool cache_enable) const {
178 #if ((defined ENABLE_CPU) && (!defined _WIN32))
179 PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
180 #endif
181 }
182
set_rank_id(uint32_t rank_id) const183 void PSContext::set_rank_id(uint32_t rank_id) const {
184 #if ((defined ENABLE_CPU) && (!defined _WIN32))
185 ps_cache_instance.set_rank_id(rank_id);
186 #endif
187 }
188
set_server_mode(const std::string & server_mode)189 void PSContext::set_server_mode(const std::string &server_mode) {
190 if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
191 MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
192 << " or " << kServerModeHybrid;
193 return;
194 }
195 MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
196 server_mode_ = server_mode;
197 }
198
server_mode() const199 const std::string &PSContext::server_mode() const { return server_mode_; }
200
set_encrypt_type(const std::string & encrypt_type)201 void PSContext::set_encrypt_type(const std::string &encrypt_type) {
202 if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
203 MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
204 << kDPEncryptType << " or " << kPWEncryptType;
205 return;
206 }
207 encrypt_type_ = encrypt_type;
208 }
encrypt_type() const209 const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
210
set_dp_eps(float dp_eps)211 void PSContext::set_dp_eps(float dp_eps) {
212 if (dp_eps > 0) {
213 dp_eps_ = dp_eps;
214 } else {
215 MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
216 return;
217 }
218 }
219
dp_eps() const220 float PSContext::dp_eps() const { return dp_eps_; }
221
set_dp_delta(float dp_delta)222 void PSContext::set_dp_delta(float dp_delta) {
223 if (dp_delta > 0 && dp_delta < 1) {
224 dp_delta_ = dp_delta;
225 } else {
226 MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
227 return;
228 }
229 }
dp_delta() const230 float PSContext::dp_delta() const { return dp_delta_; }
231
set_dp_norm_clip(float dp_norm_clip)232 void PSContext::set_dp_norm_clip(float dp_norm_clip) {
233 if (dp_norm_clip > 0) {
234 dp_norm_clip_ = dp_norm_clip;
235 } else {
236 MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
237 return;
238 }
239 }
dp_norm_clip() const240 float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
241
set_ms_role(const std::string & role)242 void PSContext::set_ms_role(const std::string &role) {
243 if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
244 MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
245 return;
246 }
247 if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
248 MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
249 return;
250 }
251 role_ = role;
252 }
253
set_worker_num(uint32_t worker_num)254 void PSContext::set_worker_num(uint32_t worker_num) {
255 // Hybrid training mode only supports one worker for now.
256 if (server_mode_ == kServerModeHybrid && worker_num != 1) {
257 MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
258 return;
259 }
260 worker_num_ = worker_num;
261 }
worker_num() const262 uint32_t PSContext::worker_num() const { return worker_num_; }
263
set_server_num(uint32_t server_num)264 void PSContext::set_server_num(uint32_t server_num) {
265 if (server_num == 0) {
266 MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
267 return;
268 }
269 server_num_ = server_num;
270 }
server_num() const271 uint32_t PSContext::server_num() const { return server_num_; }
272
set_scheduler_ip(const std::string & sched_ip)273 void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
274
scheduler_ip() const275 std::string PSContext::scheduler_ip() const { return scheduler_host_; }
276
set_scheduler_port(uint16_t sched_port)277 void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
278
scheduler_port() const279 uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
280
GenerateResetterRound()281 void PSContext::GenerateResetterRound() {
282 uint32_t binary_server_context = 0;
283 bool is_parameter_server_mode = false;
284 bool is_federated_learning_mode = false;
285 bool is_mixed_training_mode = false;
286 bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
287
288 if (server_mode_ == kServerModePS) {
289 is_parameter_server_mode = true;
290 } else if (server_mode_ == kServerModeFL) {
291 is_federated_learning_mode = true;
292 } else if (server_mode_ == kServerModeHybrid) {
293 is_mixed_training_mode = true;
294 } else {
295 MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
296 << " or " << kServerModeHybrid;
297 return;
298 }
299
300 binary_server_context = ((unsigned int)is_parameter_server_mode) | ((unsigned int)is_federated_learning_mode << 1) |
301 ((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3);
302 if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
303 resetter_round_ = ResetterRound::kNoNeedToReset;
304 } else {
305 resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
306 }
307 MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
308 return;
309 }
310
resetter_round() const311 ResetterRound PSContext::resetter_round() const { return resetter_round_; }
312
set_fl_server_port(uint16_t fl_server_port)313 void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
314
fl_server_port() const315 uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
316
set_fl_client_enable(bool enabled)317 void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
318
fl_client_enable() const319 bool PSContext::fl_client_enable() const { return fl_client_enable_; }
320
set_start_fl_job_threshold(uint64_t start_fl_job_threshold)321 void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
322 start_fl_job_threshold_ = start_fl_job_threshold;
323 }
324
start_fl_job_threshold() const325 uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
326
set_start_fl_job_time_window(uint64_t start_fl_job_time_window)327 void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
328 start_fl_job_time_window_ = start_fl_job_time_window;
329 }
330
start_fl_job_time_window() const331 uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
332
set_update_model_ratio(float update_model_ratio)333 void PSContext::set_update_model_ratio(float update_model_ratio) {
334 if (update_model_ratio > 1.0) {
335 MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
336 return;
337 }
338 update_model_ratio_ = update_model_ratio;
339 }
340
update_model_ratio() const341 float PSContext::update_model_ratio() const { return update_model_ratio_; }
342
set_update_model_time_window(uint64_t update_model_time_window)343 void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
344 update_model_time_window_ = update_model_time_window;
345 }
346
update_model_time_window() const347 uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
348
set_share_secrets_ratio(float share_secrets_ratio)349 void PSContext::set_share_secrets_ratio(float share_secrets_ratio) {
350 if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) {
351 share_secrets_ratio_ = share_secrets_ratio;
352 } else {
353 MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1].";
354 return;
355 }
356 }
357
share_secrets_ratio() const358 float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
359
set_cipher_time_window(uint64_t cipher_time_window)360 void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
361 if (cipher_time_window_ < 0) {
362 MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.";
363 return;
364 }
365 cipher_time_window_ = cipher_time_window;
366 }
367
cipher_time_window() const368 uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
369
set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold)370 void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
371 if (reconstruct_secrets_threshold == 0) {
372 MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
373 return;
374 }
375 reconstruct_secrets_threshold_ = reconstruct_secrets_threshold;
376 }
377
reconstruct_secrets_threshold() const378 uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; }
379
set_fl_name(const std::string & fl_name)380 void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
381
fl_name() const382 const std::string &PSContext::fl_name() const { return fl_name_; }
383
set_fl_iteration_num(uint64_t fl_iteration_num)384 void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
385
fl_iteration_num() const386 uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
387
set_client_epoch_num(uint64_t client_epoch_num)388 void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
389
client_epoch_num() const390 uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
391
set_client_batch_size(uint64_t client_batch_size)392 void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
393
client_batch_size() const394 uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
395
set_client_learning_rate(float client_learning_rate)396 void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
397
client_learning_rate() const398 float PSContext::client_learning_rate() const { return client_learning_rate_; }
399
set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration)400 void PSContext::set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration) {
401 worker_step_num_per_iteration_ = worker_step_num_per_iteration;
402 }
403
worker_step_num_per_iteration() const404 uint64_t PSContext::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
405
enable_ssl() const406 bool PSContext::enable_ssl() const { return enable_ssl_; }
407
set_enable_ssl(bool enabled)408 void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
409
cluster_config()410 core::ClusterConfig &PSContext::cluster_config() {
411 if (cluster_config_ == nullptr) {
412 MS_LOG(EXCEPTION) << "The cluster config is empty.";
413 }
414 return *cluster_config_;
415 }
416
set_scheduler_manage_port(uint16_t sched_port)417 void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
418
scheduler_manage_port() const419 uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
420
set_config_file_path(const std::string & path)421 void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
422
config_file_path() const423 std::string PSContext::config_file_path() const { return config_file_path_; }
424
set_node_id(const std::string & node_id)425 void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
426
node_id() const427 const std::string &PSContext::node_id() const { return node_id_; }
428
client_password() const429 std::string PSContext::client_password() const { return client_password_; }
set_client_password(const std::string & password)430 void PSContext::set_client_password(const std::string &password) { client_password_ = password; }
431
server_password() const432 std::string PSContext::server_password() const { return server_password_; }
set_server_password(const std::string & password)433 void PSContext::set_server_password(const std::string &password) { server_password_ = password; }
434 } // namespace ps
435 } // namespace mindspore
436