1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Context for parameter server training mode""" 16 17import os 18from mindspore._c_expression import PSContext 19from mindspore import context 20from mindspore import log as logger 21 22_ps_context = None 23 24 25def ps_context(): 26 """ 27 Get the global _ps_context, if it is not created, create a new one. 28 29 Returns: 30 _ps_context, the global parameter server training mode context. 31 """ 32 global _ps_context 33 if _ps_context is None: 34 _ps_context = PSContext.get_instance() 35 return _ps_context 36 37 38def _need_reset_device_target_for_ps(target): 39 ''' 40 For Ascend backend, the card can't be occupied by multiple processes in distributed traning, 41 so we need to reset the device target for some roles. 42 ''' 43 is_server = (os.getenv('MS_ROLE') in ["MS_PSERVER", "MS_SERVER", "MS_SCHED"]) 44 return is_server and target == "Ascend" 45 46 47def set_ps_enable(enable): 48 """ 49 Set ps enable flag. 50 """ 51 ps_context().set_ps_enable(enable) 52 # If this is Server or Scheduler and device target is Ascend, reset the target to CPU 53 if _need_reset_device_target_for_ps(context.get_context("device_target")): 54 logger.info("Reset device target to CPU when set_ps_enable.") 55 context.set_context(device_target="CPU") 56 57_set_ps_context_func_map = { 58 "server_mode": ps_context().set_server_mode, 59 "ms_role": ps_context().set_ms_role, 60 "enable_ps": set_ps_enable, 61 "worker_num": ps_context().set_worker_num, 62 "server_num": ps_context().set_server_num, 63 "scheduler_ip": ps_context().set_scheduler_ip, 64 "scheduler_port": ps_context().set_scheduler_port, 65 "enable_ssl": ps_context().set_enable_ssl, 66 "client_password": ps_context().set_client_password, 67 "server_password": ps_context().set_server_password, 68 "scheduler_manage_port": ps_context().set_scheduler_manage_port, 69 "config_file_path": ps_context().set_config_file_path, 70} 71 72_get_ps_context_func_map = { 73 "server_mode": ps_context().server_mode, 74 "ms_role": ps_context().ms_role, 75 "enable_ps": ps_context().is_ps_mode, 76 "worker_num": ps_context().worker_num, 77 "server_num": ps_context().server_num, 78 "scheduler_ip": ps_context().scheduler_ip, 79 "scheduler_port": ps_context().scheduler_port, 80 "enable_ssl": ps_context().enable_ssl, 81 "client_password": ps_context().client_password, 82 "server_password": ps_context().server_password, 83 "scheduler_manage_port": ps_context().scheduler_manage_port, 84 "config_file_path": ps_context().config_file_path, 85} 86 87def _get_ps_mode_rank(): 88 ps_rank = ps_context().ps_rank_id() 89 if ps_rank == -1: 90 raise RuntimeError("The parameter server mode training is not enabled yet.") 91 return ps_rank 92 93 94def _set_ps_context(**kwargs): 95 """ 96 Set parameter server training mode context. 97 98 Note: 99 Some other environment variables should also be set for parameter server training mode. 100 These environment variables are listed below: 101 102 .. code-block:: 103 104 MS_SERVER_NUM # Server number 105 MS_WORKER_NUM # Worker number 106 MS_SCHED_HOST # Scheduler IP address 107 MS_SCHED_PORT # Scheduler port 108 MS_ROLE # The role of this process: 109 # MS_SCHED represents the scheduler, 110 # MS_WORKER represents the worker, 111 # MS_PSERVER/MS_SERVER represents the Server 112 113 114 Args: 115 enable_ps (bool): Whether to enable parameter server training mode. 116 Only after enable_ps is set True, the environment variables will be effective. 117 Default: ``False``. 118 config_file_path (string): Configuration file path used by recovery. Default: ''. 119 scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202. 120 enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False``. 121 client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ''. 122 server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ''. 123 124 Raises: 125 ValueError: If input key is not the attribute in parameter server training mode context. 126 127 Examples: 128 >>> import mindspore as ms 129 >>> ms.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456') 130 """ 131 for key, value in kwargs.items(): 132 if key not in _set_ps_context_func_map: 133 raise ValueError("Set PS context keyword %s is not recognized!" % key) 134 set_func = _set_ps_context_func_map[key] 135 set_func(value) 136 137 138def _get_ps_context(attr_key): 139 """ 140 Get parameter server training mode context attribute value according to the key. 141 142 Args: 143 attr_key (str): The key of the attribute. 144 145 Returns: 146 Returns attribute value according to the key. 147 148 Raises: 149 ValueError: If input key is not attribute in auto parallel context. 150 """ 151 if attr_key not in _get_ps_context_func_map: 152 raise ValueError("Get PS context keyword %s is not recognized!" % attr_key) 153 get_func = _get_ps_context_func_map[attr_key] 154 value = get_func() 155 return value 156 157 158def _reset_ps_context(): 159 """ 160 Reset parameter server training mode context attributes to the default values: 161 162 - enable_ps: False. 163 """ 164 ps_context().reset() 165 166 167def _is_role_worker(): 168 return ps_context().is_worker() 169 170 171def _is_role_pserver(): 172 return ps_context().is_server() 173 174 175def _is_role_sched(): 176 return ps_context().is_scheduler() 177 178 179def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size, param_key=-1): 180 ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size, param_key) 181 182 183def _reinsert_hash_table_size(new_name, cur_name): 184 ps_context().reinsert_hash_table_size(new_name, cur_name) 185 186 187def _insert_accumu_init_info(name, init_val): 188 ps_context().insert_accumu_init_info(name, init_val) 189 190 191def _clone_hash_table(dest_param_name, dest_param_key, src_param_name, src_param_key): 192 ps_context().clone_hash_table(dest_param_name, dest_param_key, src_param_name, src_param_key) 193 194 195def _set_cache_enable(cache_enable): 196 # Environment variables are used to specify a maximum number of OpenBLAS threads: 197 # In ubuntu(GPU) environment, numpy will use too many threads for computing, 198 if cache_enable: 199 os.environ['OPENBLAS_NUM_THREADS'] = '2' 200 os.environ['GOTO_NUM_THREADS'] = '2' 201 os.environ['OMP_NUM_THREADS'] = '2' 202 ps_context().set_cache_enable(cache_enable) 203 204 205def _cache_enable(): 206 return ps_context().cache_enable() 207 208 209def _set_cache_size(cache_size): 210 ps_context().set_cache_size(cache_size) 211 212 213def _set_sparse_format(sparse_format): 214 ps_context().set_sparse_format(sparse_format) 215 216 217def _set_rank_id(rank_id): 218 ps_context().set_rank_id(rank_id) 219 220 221def _is_ps_mode(): 222 return _get_ps_context("server_mode") == "PARAMETER_SERVER" 223 224 225def _enable_distributed_mindrt(): 226 ''' 227 Whether the distributed MindRT is enabled. 228 This method is used to distinguish from old distributed training mode. 229 ''' 230 return ps_context().enable_distributed_mindrt() 231 232 233def _set_checkpoint_load_status(status): 234 return ps_context().set_checkpoint_load_status(status) 235 236 237def _store_warm_up_ptr_by_tensor(param_key, tensor): 238 return ps_context().store_warm_up_ptr_by_tensor(param_key, tensor) 239 240 241def _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor): 242 return ps_context().store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor) 243