• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Context for parameter server training mode"""
16
17import os
18from mindspore._checkparam import Validator
19from mindspore._c_expression import PSContext
20
21_ps_context = None
22
23_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
24                            "start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window",
25                            "fl_iteration_num", "client_epoch_num", "client_batch_size", "scheduler_manage_port",
26                            "cipher_time_window", "reconstruct_secrets_threshold"]
27
28_check_non_negative_int_keys = ["worker_num"]
29
30_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
31
32_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"]
33
34
35def ps_context():
36    """
37    Get the global _ps_context, if it is not created, create a new one.
38
39    Returns:
40        _ps_context, the global parameter server training mode context.
41    """
42    global _ps_context
43    if _ps_context is None:
44        _ps_context = PSContext.get_instance()
45    return _ps_context
46
47_set_ps_context_func_map = {
48    "server_mode": ps_context().set_server_mode,
49    "ms_role": ps_context().set_ms_role,
50    "enable_ps": ps_context().set_ps_enable,
51    "enable_fl": ps_context().set_ps_enable,
52    "worker_num": ps_context().set_worker_num,
53    "server_num": ps_context().set_server_num,
54    "scheduler_ip": ps_context().set_scheduler_ip,
55    "scheduler_port": ps_context().set_scheduler_port,
56    "fl_server_port": ps_context().set_fl_server_port,
57    "enable_fl_client": ps_context().set_fl_client_enable,
58    "start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
59    "start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
60    "update_model_ratio": ps_context().set_update_model_ratio,
61    "update_model_time_window": ps_context().set_update_model_time_window,
62    "share_secrets_ratio": ps_context().set_share_secrets_ratio,
63    "cipher_time_window": ps_context().set_cipher_time_window,
64    "reconstruct_secrets_threshold": ps_context().set_reconstruct_secrets_threshold,
65    "fl_name": ps_context().set_fl_name,
66    "fl_iteration_num": ps_context().set_fl_iteration_num,
67    "client_epoch_num": ps_context().set_client_epoch_num,
68    "client_batch_size": ps_context().set_client_batch_size,
69    "client_learning_rate": ps_context().set_client_learning_rate,
70    "worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration,
71    "enable_ssl": ps_context().set_enable_ssl,
72    "client_password": ps_context().set_client_password,
73    "server_password": ps_context().set_server_password,
74    "scheduler_manage_port": ps_context().set_scheduler_manage_port,
75    "config_file_path": ps_context().set_config_file_path,
76    "dp_eps": ps_context().set_dp_eps,
77    "dp_delta": ps_context().set_dp_delta,
78    "dp_norm_clip": ps_context().set_dp_norm_clip,
79    "encrypt_type": ps_context().set_encrypt_type
80}
81
82_get_ps_context_func_map = {
83    "server_mode": ps_context().server_mode,
84    "ms_role": ps_context().ms_role,
85    "enable_ps": ps_context().is_ps_mode,
86    "enable_fl": ps_context().is_ps_mode,
87    "worker_num": ps_context().worker_num,
88    "server_num": ps_context().server_num,
89    "scheduler_ip": ps_context().scheduler_ip,
90    "scheduler_port": ps_context().scheduler_port,
91    "fl_server_port": ps_context().fl_server_port,
92    "enable_fl_client": ps_context().fl_client_enable,
93    "start_fl_job_threshold": ps_context().start_fl_job_threshold,
94    "start_fl_job_time_window": ps_context().start_fl_job_time_window,
95    "update_model_ratio": ps_context().update_model_ratio,
96    "update_model_time_window": ps_context().update_model_time_window,
97    "share_secrets_ratio": ps_context().share_secrets_ratio,
98    "cipher_time_window": ps_context().set_cipher_time_window,
99    "reconstruct_secrets_threshold": ps_context().reconstruct_secrets_threshold,
100    "fl_name": ps_context().fl_name,
101    "fl_iteration_num": ps_context().fl_iteration_num,
102    "client_epoch_num": ps_context().client_epoch_num,
103    "client_batch_size": ps_context().client_batch_size,
104    "client_learning_rate": ps_context().client_learning_rate,
105    "worker_step_num_per_iteration": ps_context().worker_step_num_per_iteration,
106    "enable_ssl": ps_context().enable_ssl,
107    "client_password": ps_context().client_password,
108    "server_password": ps_context().server_password,
109    "scheduler_manage_port": ps_context().scheduler_manage_port,
110    "config_file_path": ps_context().config_file_path
111}
112
113
114def _get_ps_mode_rank():
115    ps_rank = ps_context().ps_rank_id()
116    if ps_rank == -1:
117        raise RuntimeError("The parameter server mode training is not enabled yet.")
118    return ps_rank
119
120
121def _set_ps_context(**kwargs):
122    """
123    Set parameter server training mode context.
124
125    Note:
126        Some other environment variables should also be set for parameter server training mode.
127        These environment variables are listed below:
128
129        .. code-block::
130
131            MS_SERVER_NUM  # Server number
132            MS_WORKER_NUM  # Worker number
133            MS_SCHED_HOST  # Scheduler IP address
134            MS_SCHED_PORT  # Scheduler port
135            MS_ROLE        # The role of this process:
136                           # MS_SCHED represents the scheduler,
137                           # MS_WORKER represents the worker,
138                           # MS_PSERVER represents the Server
139
140
141    Args:
142        enable_ps (bool): Whether to enable parameter server training mode.
143                          Only after enable_ps is set True, the environment variables will be effective.
144                          Default: False.
145
146    Raises:
147        ValueError: If input key is not the attribute in parameter server training mode context.
148
149    Examples:
150        >>> context.set_ps_context(enable_ps=True)
151    """
152    for key, value in kwargs.items():
153        if key not in _set_ps_context_func_map:
154            raise ValueError("Set PS context keyword %s is not recognized!" % key)
155        _check_value(key, value)
156        set_func = _set_ps_context_func_map[key]
157        set_func(value)
158
159
160def _get_ps_context(attr_key):
161    """
162    Get parameter server training mode context attribute value according to the key.
163
164    Args:
165        attr_key (str): The key of the attribute.
166
167    Returns:
168        Returns attribute value according to the key.
169
170    Raises:
171        ValueError: If input key is not attribute in auto parallel context.
172    """
173    if attr_key not in _get_ps_context_func_map:
174        raise ValueError("Get PS context keyword %s is not recognized!" % attr_key)
175    get_func = _get_ps_context_func_map[attr_key]
176    value = get_func()
177    return value
178
179
180def _reset_ps_context():
181    """
182    Reset parameter server training mode context attributes to the default values:
183
184    - enable_ps: False.
185    """
186    ps_context().reset()
187
188
189def _is_role_worker():
190    return ps_context().is_worker()
191
192
193def _is_role_pserver():
194    return ps_context().is_server()
195
196
197def _is_role_sched():
198    return ps_context().is_scheduler()
199
200
201def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
202    ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
203
204
205def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
206    ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
207
208
209def _insert_weight_init_info(name, global_seed, op_seed):
210    ps_context().insert_weight_init_info(name, global_seed, op_seed)
211
212
213def _insert_accumu_init_info(name, init_val):
214    ps_context().insert_accumu_init_info(name, init_val)
215
216
217def _clone_hash_table(dest_param_name, src_param_name):
218    ps_context().clone_hash_table(dest_param_name, src_param_name)
219
220
221def _set_cache_enable(cache_enable):
222    # Environment variables are used to specify a maximum number of OpenBLAS threads:
223    # In ubuntu(GPU) environment, numpy will use too many threads for computing,
224    if cache_enable:
225        os.environ['OPENBLAS_NUM_THREADS'] = '2'
226        os.environ['GOTO_NUM_THREADS'] = '2'
227        os.environ['OMP_NUM_THREADS'] = '2'
228    ps_context().set_cache_enable(cache_enable)
229
230
231def _set_rank_id(rank_id):
232    ps_context().set_rank_id(rank_id)
233
234
235def _check_value(key, value):
236    """
237    Validate the value for parameter server context keys.
238    """
239    if key in _check_positive_int_keys:
240        Validator.check_positive_int(value, key)
241
242    if key in _check_non_negative_int_keys:
243        Validator.check_non_negative_int(value, key)
244
245    if key in _check_positive_float_keys:
246        Validator.check_positive_float(value, key)
247
248    if key in _check_port_keys:
249        if value < 1 or value > 65535:
250            raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))
251