1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Context for parameter server training mode""" 16 17import os 18from mindspore._checkparam import Validator 19from mindspore._c_expression import PSContext 20 21_ps_context = None 22 23_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", 24 "start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window", 25 "fl_iteration_num", "client_epoch_num", "client_batch_size", "scheduler_manage_port", 26 "cipher_time_window", "reconstruct_secrets_threshold"] 27 28_check_non_negative_int_keys = ["worker_num"] 29 30_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] 31 32_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"] 33 34 35def ps_context(): 36 """ 37 Get the global _ps_context, if it is not created, create a new one. 38 39 Returns: 40 _ps_context, the global parameter server training mode context. 41 """ 42 global _ps_context 43 if _ps_context is None: 44 _ps_context = PSContext.get_instance() 45 return _ps_context 46 47_set_ps_context_func_map = { 48 "server_mode": ps_context().set_server_mode, 49 "ms_role": ps_context().set_ms_role, 50 "enable_ps": ps_context().set_ps_enable, 51 "enable_fl": ps_context().set_ps_enable, 52 "worker_num": ps_context().set_worker_num, 53 "server_num": ps_context().set_server_num, 54 "scheduler_ip": ps_context().set_scheduler_ip, 55 "scheduler_port": ps_context().set_scheduler_port, 56 "fl_server_port": ps_context().set_fl_server_port, 57 "enable_fl_client": ps_context().set_fl_client_enable, 58 "start_fl_job_threshold": ps_context().set_start_fl_job_threshold, 59 "start_fl_job_time_window": ps_context().set_start_fl_job_time_window, 60 "update_model_ratio": ps_context().set_update_model_ratio, 61 "update_model_time_window": ps_context().set_update_model_time_window, 62 "share_secrets_ratio": ps_context().set_share_secrets_ratio, 63 "cipher_time_window": ps_context().set_cipher_time_window, 64 "reconstruct_secrets_threshold": ps_context().set_reconstruct_secrets_threshold, 65 "fl_name": ps_context().set_fl_name, 66 "fl_iteration_num": ps_context().set_fl_iteration_num, 67 "client_epoch_num": ps_context().set_client_epoch_num, 68 "client_batch_size": ps_context().set_client_batch_size, 69 "client_learning_rate": ps_context().set_client_learning_rate, 70 "worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration, 71 "enable_ssl": ps_context().set_enable_ssl, 72 "client_password": ps_context().set_client_password, 73 "server_password": ps_context().set_server_password, 74 "scheduler_manage_port": ps_context().set_scheduler_manage_port, 75 "config_file_path": ps_context().set_config_file_path, 76 "dp_eps": ps_context().set_dp_eps, 77 "dp_delta": ps_context().set_dp_delta, 78 "dp_norm_clip": ps_context().set_dp_norm_clip, 79 "encrypt_type": ps_context().set_encrypt_type 80} 81 82_get_ps_context_func_map = { 83 "server_mode": ps_context().server_mode, 84 "ms_role": ps_context().ms_role, 85 "enable_ps": ps_context().is_ps_mode, 86 "enable_fl": ps_context().is_ps_mode, 87 "worker_num": ps_context().worker_num, 88 "server_num": ps_context().server_num, 89 "scheduler_ip": ps_context().scheduler_ip, 90 "scheduler_port": ps_context().scheduler_port, 91 "fl_server_port": ps_context().fl_server_port, 92 "enable_fl_client": ps_context().fl_client_enable, 93 "start_fl_job_threshold": ps_context().start_fl_job_threshold, 94 "start_fl_job_time_window": ps_context().start_fl_job_time_window, 95 "update_model_ratio": ps_context().update_model_ratio, 96 "update_model_time_window": ps_context().update_model_time_window, 97 "share_secrets_ratio": ps_context().share_secrets_ratio, 98 "cipher_time_window": ps_context().set_cipher_time_window, 99 "reconstruct_secrets_threshold": ps_context().reconstruct_secrets_threshold, 100 "fl_name": ps_context().fl_name, 101 "fl_iteration_num": ps_context().fl_iteration_num, 102 "client_epoch_num": ps_context().client_epoch_num, 103 "client_batch_size": ps_context().client_batch_size, 104 "client_learning_rate": ps_context().client_learning_rate, 105 "worker_step_num_per_iteration": ps_context().worker_step_num_per_iteration, 106 "enable_ssl": ps_context().enable_ssl, 107 "client_password": ps_context().client_password, 108 "server_password": ps_context().server_password, 109 "scheduler_manage_port": ps_context().scheduler_manage_port, 110 "config_file_path": ps_context().config_file_path 111} 112 113 114def _get_ps_mode_rank(): 115 ps_rank = ps_context().ps_rank_id() 116 if ps_rank == -1: 117 raise RuntimeError("The parameter server mode training is not enabled yet.") 118 return ps_rank 119 120 121def _set_ps_context(**kwargs): 122 """ 123 Set parameter server training mode context. 124 125 Note: 126 Some other environment variables should also be set for parameter server training mode. 127 These environment variables are listed below: 128 129 .. code-block:: 130 131 MS_SERVER_NUM # Server number 132 MS_WORKER_NUM # Worker number 133 MS_SCHED_HOST # Scheduler IP address 134 MS_SCHED_PORT # Scheduler port 135 MS_ROLE # The role of this process: 136 # MS_SCHED represents the scheduler, 137 # MS_WORKER represents the worker, 138 # MS_PSERVER represents the Server 139 140 141 Args: 142 enable_ps (bool): Whether to enable parameter server training mode. 143 Only after enable_ps is set True, the environment variables will be effective. 144 Default: False. 145 146 Raises: 147 ValueError: If input key is not the attribute in parameter server training mode context. 148 149 Examples: 150 >>> context.set_ps_context(enable_ps=True) 151 """ 152 for key, value in kwargs.items(): 153 if key not in _set_ps_context_func_map: 154 raise ValueError("Set PS context keyword %s is not recognized!" % key) 155 _check_value(key, value) 156 set_func = _set_ps_context_func_map[key] 157 set_func(value) 158 159 160def _get_ps_context(attr_key): 161 """ 162 Get parameter server training mode context attribute value according to the key. 163 164 Args: 165 attr_key (str): The key of the attribute. 166 167 Returns: 168 Returns attribute value according to the key. 169 170 Raises: 171 ValueError: If input key is not attribute in auto parallel context. 172 """ 173 if attr_key not in _get_ps_context_func_map: 174 raise ValueError("Get PS context keyword %s is not recognized!" % attr_key) 175 get_func = _get_ps_context_func_map[attr_key] 176 value = get_func() 177 return value 178 179 180def _reset_ps_context(): 181 """ 182 Reset parameter server training mode context attributes to the default values: 183 184 - enable_ps: False. 185 """ 186 ps_context().reset() 187 188 189def _is_role_worker(): 190 return ps_context().is_worker() 191 192 193def _is_role_pserver(): 194 return ps_context().is_server() 195 196 197def _is_role_sched(): 198 return ps_context().is_scheduler() 199 200 201def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size): 202 ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size) 203 204 205def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size): 206 ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size) 207 208 209def _insert_weight_init_info(name, global_seed, op_seed): 210 ps_context().insert_weight_init_info(name, global_seed, op_seed) 211 212 213def _insert_accumu_init_info(name, init_val): 214 ps_context().insert_accumu_init_info(name, init_val) 215 216 217def _clone_hash_table(dest_param_name, src_param_name): 218 ps_context().clone_hash_table(dest_param_name, src_param_name) 219 220 221def _set_cache_enable(cache_enable): 222 # Environment variables are used to specify a maximum number of OpenBLAS threads: 223 # In ubuntu(GPU) environment, numpy will use too many threads for computing, 224 if cache_enable: 225 os.environ['OPENBLAS_NUM_THREADS'] = '2' 226 os.environ['GOTO_NUM_THREADS'] = '2' 227 os.environ['OMP_NUM_THREADS'] = '2' 228 ps_context().set_cache_enable(cache_enable) 229 230 231def _set_rank_id(rank_id): 232 ps_context().set_rank_id(rank_id) 233 234 235def _check_value(key, value): 236 """ 237 Validate the value for parameter server context keys. 238 """ 239 if key in _check_positive_int_keys: 240 Validator.check_positive_int(value, key) 241 242 if key in _check_non_negative_int_keys: 243 Validator.check_non_negative_int(value, key) 244 245 if key in _check_positive_float_keys: 246 Validator.check_positive_float(value, key) 247 248 if key in _check_port_keys: 249 if value < 1 or value > 65535: 250 raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value)) 251