• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16The configuration module provides various functions to set and get the supported
17configuration parameters, and read a configuration file.
18
19Common imported modules in corresponding API examples are as follows:
20
21.. code-block::
22
23    import mindspore.dataset as ds
24"""
25import os
26import platform
27import random
28import time
29import numpy
30import mindspore._c_dataengine as cde
31from mindspore import log as logger
32
33__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
34           'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval',
35           'get_monitor_sampling_interval', 'set_callback_timeout', 'get_callback_timeout',
36           'set_auto_num_workers', 'get_auto_num_workers', 'set_enable_shared_mem', 'get_enable_shared_mem',
37           'set_sending_batches', 'load', '_init_device_info']
38
39INT32_MAX = 2147483647
40UINT32_MAX = 4294967295
41
42_config = cde.GlobalContext.config_manager()
43
44
45def _init_device_info():
46    """
47    INTERNAL USE ONLY!
48    As rank_id need to pass into deep layer for numa and device_queue.
49    One process work with only one rank_id, In standalone scenario,
50    rank_id may come from env 'CUDA_VISIBLE_DEVICES', For distribute
51    scenario, rank_id come from _get_global_rank().
52    """
53    from mindspore import context
54    from mindspore.parallel._auto_parallel_context import auto_parallel_context
55    from mindspore.parallel._utils import _get_global_rank
56    numa_enable = False
57    numa_enable_env = os.getenv("DATASET_ENABLE_NUMA", None)
58    if numa_enable_env and numa_enable_env.strip() == 'True':
59        numa_enable = True
60    if context.get_context("device_target") == "GPU":
61        rank_id = _get_global_rank()
62        parallel_mode = auto_parallel_context().get_parallel_mode()
63        if parallel_mode == "stand_alone":
64            rank_id = context.get_context("device_id")
65        if numa_enable:
66            _config.set_numa_enable(True)
67        _config.set_rank_id(rank_id)
68    elif context.get_context("device_target") == "Ascend":
69        # Ascend is a special scenario, we'd better get rank info from env
70        env_rank_size = os.getenv("RANK_SIZE", None)
71        env_rank_id = os.getenv("RANK_ID", None)
72        rank_size = 0
73        rank_id = 0
74        if env_rank_size and env_rank_id:
75            try:
76                rank_size = int(env_rank_size.strip())
77                rank_id = int(env_rank_id.strip())
78            except ValueError:
79                raise ValueError("rank_size or rank_id is not int.")
80        if rank_size > 1:
81            if numa_enable:
82                _config.set_numa_enable(True)
83            _config.set_rank_id(rank_id)
84
85
86def set_seed(seed):
87    """
88    If the seed is set, the generated random number will be fixed, this helps to
89    produce deterministic results.
90
91    Note:
92        This set_seed function sets the seed in the Python random library and numpy.random library
93        for deterministic Python augmentations using randomness. This set_seed function should
94        be called with every iterator created to reset the random seed. In the pipeline, this
95        does not guarantee deterministic results with num_parallel_workers > 1.
96
97    Args:
98        seed(int): Random number seed. It is used to generate deterministic random numbers.
99
100    Raises:
101        ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32.
102
103    Examples:
104        >>> # Set a new global configuration value for the seed value.
105        >>> # Operations with randomness will use the seed value to generate random values.
106        >>> ds.config.set_seed(1000)
107    """
108    if seed < 0 or seed > UINT32_MAX:
109        raise ValueError("Seed given is not within the required range.")
110    _config.set_seed(seed)
111    random.seed(seed)
112    # numpy.random isn't thread safe
113    numpy.random.seed(seed)
114
115
116def get_seed():
117    """
118    Get random number seed. If the seed has been set, then will
119    return the set value, otherwise it will return the default seed value
120    which equals to std::mt19937::default_seed.
121
122    Returns:
123        int, random number seed.
124
125    Examples:
126        >>> # Get the global configuration of seed.
127        >>> # If set_seed() is never called before, the default value(std::mt19937::default_seed) will be returned.
128        >>> seed = ds.config.get_seed()
129    """
130    return _config.get_seed()
131
132
133def set_prefetch_size(size):
134    """
135    Set the queue capacity of the thread in pipeline.
136
137    Args:
138        size (int): The length of the cache queue.
139
140    Raises:
141        ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32.
142
143    Note:
144        Since total memory used for prefetch can grow very large with high number of workers,
145        when the number of workers is greater than 4, the per worker prefetch size will be reduced.
146        The actual prefetch size at runtime per-worker will be prefetchsize * (4 / num_parallel_workers).
147
148    Examples:
149        >>> # Set a new global configuration value for the prefetch size.
150        >>> ds.config.set_prefetch_size(1000)
151    """
152    if size <= 0 or size > INT32_MAX:
153        raise ValueError("Prefetch size given is not within the required range.")
154    _config.set_op_connector_size(size)
155
156
157def get_prefetch_size():
158    """
159    Get the prefetch size as for number of rows.
160
161    Returns:
162        int, total number of rows to be prefetched.
163
164    Examples:
165        >>> # Get the global configuration of prefetch size.
166        >>> # If set_prefetch_size() is never called before, the default value(20) will be returned.
167        >>> prefetch_size = ds.config.get_prefetch_size()
168    """
169    return _config.get_op_connector_size()
170
171
172def set_num_parallel_workers(num):
173    """
174    Set a new global configuration default value for the number of parallel workers.
175    This setting will affect the parallelism of all dataset operation.
176
177    Args:
178        num (int): Number of parallel workers to be used as a default for each operation.
179
180    Raises:
181        ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32.
182
183    Examples:
184        >>> # Set a new global configuration value for the number of parallel workers.
185        >>> # Now parallel dataset operators will run with 8 workers.
186        >>> ds.config.set_num_parallel_workers(8)
187    """
188    if num <= 0 or num > INT32_MAX:
189        raise ValueError("Number of parallel workers given is not within the required range.")
190    _config.set_num_parallel_workers(num)
191
192
193def get_num_parallel_workers():
194    """
195    Get the global configuration of number of parallel workers.
196    This is the DEFAULT num_parallel_workers value used for each operation, it is not related
197    to AutoNumWorker feature.
198
199    Returns:
200        int, number of parallel workers to be used as a default for each operation.
201
202    Examples:
203        >>> # Get the global configuration of parallel workers.
204        >>> # If set_num_parallel_workers() is never called before, the default value(8) will be returned.
205        >>> num_parallel_workers = ds.config.get_num_parallel_workers()
206    """
207    return _config.get_num_parallel_workers()
208
209
210def set_numa_enable(numa_enable):
211    """
212    Set the default state of numa enabled. If numa_enable is True, need to ensure numa library is installed.
213
214    Args:
215        numa_enable (bool): Whether to use numa bind feature.
216
217    Raises:
218        TypeError: If numa_enable is not a boolean data type.
219
220    Examples:
221        >>> # Set a new global configuration value for the state of numa enabled.
222        >>> # Now parallel dataset operators will run with numa bind function
223        >>> ds.config.set_numa_enable(True)
224    """
225    if not isinstance(numa_enable, bool):
226        raise TypeError("numa_enable must be a boolean dtype.")
227    _config.set_numa_enable(numa_enable)
228
229
230def get_numa_enable():
231    """
232    Get the state of numa to indicate enabled/disabled.
233    This is the DEFAULT numa enabled value used for the all process.
234
235    Returns:
236        bool, the default state of numa enabled.
237
238    Examples:
239        >>> # Get the global configuration of numa.
240        >>> numa_state = ds.config.get_numa_enable()
241    """
242    return _config.get_numa_enable()
243
244
245def set_monitor_sampling_interval(interval):
246    """
247    Set the default interval (in milliseconds) for monitor sampling.
248
249    Args:
250        interval (int): Interval (in milliseconds) to be used for performance monitor sampling.
251
252    Raises:
253        ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
254
255    Examples:
256        >>> # Set a new global configuration value for the monitor sampling interval.
257        >>> ds.config.set_monitor_sampling_interval(100)
258    """
259    if interval <= 0 or interval > INT32_MAX:
260        raise ValueError("Interval given is not within the required range.")
261    _config.set_monitor_sampling_interval(interval)
262
263
264def get_monitor_sampling_interval():
265    """
266    Get the global configuration of sampling interval of performance monitor.
267
268    Returns:
269        int, interval (in milliseconds) for performance monitor sampling.
270
271    Examples:
272        >>> # Get the global configuration of monitor sampling interval.
273        >>> # If set_monitor_sampling_interval() is never called before, the default value(1000) will be returned.
274        >>> sampling_interval = ds.config.get_monitor_sampling_interval()
275    """
276    return _config.get_monitor_sampling_interval()
277
278
279def set_auto_num_workers(enable):
280    """
281    Set num_parallel_workers for each op automatically(This feature is turned off by default).
282
283    If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the
284    num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by
285    ds.config.set_num_parallel_workers().
286
287    For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch).
288    This feature aims to provide a baseline for optimized num_workers assignment for each operation.
289    Operation whose num_parallel_workers is adjusted to a new value will be logged.
290
291    Args:
292        enable (bool): Whether to enable auto num_workers feature or not.
293
294    Raises:
295        TypeError: If enable is not of boolean type.
296
297    Examples:
298        >>> # Enable auto_num_worker feature, this might override the num_parallel_workers passed in by user
299        >>> ds.config.set_auto_num_workers(True)
300    """
301    if not isinstance(enable, bool):
302        raise TypeError("enable must be of type bool.")
303    _config.set_auto_num_workers(enable)
304
305
306def _set_auto_workers_config(option):
307    """
308    INTERNAL USE ONLY!
309    Select the weight profile of auto_num_workers. currently these 7 options are supported.
310    Option #0 leaf_num_workers:batch_num_workers:map_num_workers=1:1:1
311    Option #1 leaf_num_workers:batch_num_workers:map_num_workers=2:1:1
312    Option #2 leaf_num_workers:batch_num_workers:map_num_workers=1:2:1
313    Option #3 leaf_num_workers:batch_num_workers:map_num_workers=1:1:2
314    Option #4 leaf_num_workers:batch_num_workers:map_num_workers=2:2:1
315    Option #5 leaf_num_workers:batch_num_workers:map_num_workers=2:1:2
316    Option #6 leaf_num_workers:batch_num_workers:map_num_workers=1:2:2
317    Args:
318        option (int): The id of the profile to use.
319    Raises:
320        ValueError: If option is not int or not within the range of [0, 6]
321    """
322    if not isinstance(option, int):
323        raise ValueError("option isn't of type int.")
324    if option < 0 or option > 6:
325        raise ValueError("option isn't within the required range of [0, 6].")
326    _config.set_auto_worker_config(option)
327
328
329def get_auto_num_workers():
330    """
331    Get the setting (turned on or off) automatic number of workers.
332
333    Returns:
334        bool, whether auto number worker feature is turned on.
335
336    Examples:
337        >>> # Get the global configuration of auto number worker feature.
338        >>> num_workers = ds.config.get_auto_num_workers()
339    """
340    return _config.get_auto_num_workers()
341
342
343def set_callback_timeout(timeout):
344    """
345    Set the default timeout (in seconds) for DSWaitedCallback.
346    In case of a deadlock, the wait function will exit after the timeout period.
347
348    Args:
349        timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
350
351    Raises:
352        ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32.
353
354    Examples:
355        >>> # Set a new global configuration value for the timeout value.
356        >>> ds.config.set_callback_timeout(100)
357    """
358    if timeout <= 0 or timeout > INT32_MAX:
359        raise ValueError("Timeout given is not within the required range.")
360    _config.set_callback_timeout(timeout)
361
362
363def get_callback_timeout():
364    """
365    Get the default timeout for DSWaitedCallback.
366    In case of a deadlock, the wait function will exit after the timeout period.
367
368    Returns:
369        int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
370
371    Examples:
372        >>> # Get the global configuration of callback timeout.
373        >>> # If set_callback_timeout() is never called before, the default value(60) will be returned.
374        >>> callback_timeout = ds.config.get_callback_timeout()
375    """
376    return _config.get_callback_timeout()
377
378
379def __str__():
380    """
381    String representation of the configurations.
382
383    Returns:
384        str, configurations.
385    """
386    return str(_config)
387
388
389def load(file):
390    """
391    Load the project configuration from the file format.
392
393    Args:
394        file (str): Path of the configuration file to be loaded.
395
396    Raises:
397        RuntimeError: If file is invalid and parsing fails.
398
399    Examples:
400        >>> # Set new default configuration according to values in the configuration file.
401        >>> # example config file:
402        >>> # {
403        >>> #     "logFilePath": "/tmp",
404        >>> #     "numParallelWorkers": 4,
405        >>> #     "seed": 5489,
406        >>> #     "monitorSamplingInterval": 30
407        >>> # }
408        >>> config_file = "/path/to/config/file"
409        >>> ds.config.load(config_file)
410    """
411    _config.load(file)
412
413
414def _stop_dataset_profiler():
415    """
416    Mainly for stop dataset profiler.
417
418    Returns:
419        bool, whether the profiler file has generated.
420    """
421
422    while not _config.get_profiler_file_status():
423        _config.stop_dataset_profiler(True)
424        logger.warning("Profiling: waiting for dataset part profiling stop.")
425        time.sleep(1)
426
427
428def get_enable_shared_mem():
429    """
430    Get the default state of shared mem enabled variable.
431
432
433    Returns:
434        bool, the state of shared mem enabled variable (default=True).
435
436    Examples:
437        >>> # Get the flag of shared memory feature.
438        >>> shared_mem_flag = ds.config.get_enable_shared_mem()
439    """
440    # For windows we forbid shared mem function temporarily
441    if platform.system().lower() == 'windows':
442        logger.warning("For windows we forbid shared mem function temporarily.")
443        return False
444    return _config.get_enable_shared_mem()
445
446
447def set_enable_shared_mem(enable):
448    """
449    Set the default state of shared memory flag. If shared_mem_enable is True, will use shared memory queues
450    to pass data to processes that are created for operators that set python_multiprocessing=True.
451
452    Args:
453        enable (bool): Whether to use shared memory in operators when python_multiprocessing=True.
454
455    Raises:
456        TypeError: If enable is not a boolean data type.
457
458    Examples:
459        >>> # Enable shared memory feature to improve the performance of Python multiprocessing.
460        >>> ds.config.set_enable_shared_mem(True)
461    """
462    if not isinstance(enable, bool):
463        raise TypeError("enable must be of type bool.")
464    logger.warning("The shared memory is on, multiprocessing performance will be improved. "
465                   "Note: the required shared memory can't exceeds 80% of the available shared memory. "
466                   "You can reduce max_rowsize or reduce num_parallel_workers to reduce shared memory usage.")
467    _config.set_enable_shared_mem(enable)
468
469
470def set_sending_batches(batch_num):
471    """
472    Set the default sending batches when training with sink_mode=True in Ascend device.
473
474    Args:
475        batch_num (int): the total sending batches, when batch_num is set, it will wait unless sending batches
476         increase, default is 0 which means will send all batches in dataset.
477
478    Raises:
479        TypeError: If batch_num is not in int type.
480
481    Examples:
482        >>> # Set a new global configuration value for the sending batches
483        >>> ds.config.set_sending_batches(10)
484    """
485    if not isinstance(batch_num, int):
486        raise TypeError("batch_num must be an int dtype.")
487    _config.set_sending_batches(batch_num)
488