1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16The configuration module provides various functions to set and get the supported 17configuration parameters, and read a configuration file. 18 19Common imported modules in corresponding API examples are as follows: 20 21.. code-block:: 22 23 import mindspore.dataset as ds 24""" 25import os 26import platform 27import random 28import time 29import numpy 30import mindspore._c_dataengine as cde 31from mindspore import log as logger 32 33__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', 34 'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval', 35 'get_monitor_sampling_interval', 'set_callback_timeout', 'get_callback_timeout', 36 'set_auto_num_workers', 'get_auto_num_workers', 'set_enable_shared_mem', 'get_enable_shared_mem', 37 'set_sending_batches', 'load', '_init_device_info'] 38 39INT32_MAX = 2147483647 40UINT32_MAX = 4294967295 41 42_config = cde.GlobalContext.config_manager() 43 44 45def _init_device_info(): 46 """ 47 INTERNAL USE ONLY! 48 As rank_id need to pass into deep layer for numa and device_queue. 49 One process work with only one rank_id, In standalone scenario, 50 rank_id may come from env 'CUDA_VISIBLE_DEVICES', For distribute 51 scenario, rank_id come from _get_global_rank(). 52 """ 53 from mindspore import context 54 from mindspore.parallel._auto_parallel_context import auto_parallel_context 55 from mindspore.parallel._utils import _get_global_rank 56 numa_enable = False 57 numa_enable_env = os.getenv("DATASET_ENABLE_NUMA", None) 58 if numa_enable_env and numa_enable_env.strip() == 'True': 59 numa_enable = True 60 if context.get_context("device_target") == "GPU": 61 rank_id = _get_global_rank() 62 parallel_mode = auto_parallel_context().get_parallel_mode() 63 if parallel_mode == "stand_alone": 64 rank_id = context.get_context("device_id") 65 if numa_enable: 66 _config.set_numa_enable(True) 67 _config.set_rank_id(rank_id) 68 elif context.get_context("device_target") == "Ascend": 69 # Ascend is a special scenario, we'd better get rank info from env 70 env_rank_size = os.getenv("RANK_SIZE", None) 71 env_rank_id = os.getenv("RANK_ID", None) 72 rank_size = 0 73 rank_id = 0 74 if env_rank_size and env_rank_id: 75 try: 76 rank_size = int(env_rank_size.strip()) 77 rank_id = int(env_rank_id.strip()) 78 except ValueError: 79 raise ValueError("rank_size or rank_id is not int.") 80 if rank_size > 1: 81 if numa_enable: 82 _config.set_numa_enable(True) 83 _config.set_rank_id(rank_id) 84 85 86def set_seed(seed): 87 """ 88 If the seed is set, the generated random number will be fixed, this helps to 89 produce deterministic results. 90 91 Note: 92 This set_seed function sets the seed in the Python random library and numpy.random library 93 for deterministic Python augmentations using randomness. This set_seed function should 94 be called with every iterator created to reset the random seed. In the pipeline, this 95 does not guarantee deterministic results with num_parallel_workers > 1. 96 97 Args: 98 seed(int): Random number seed. It is used to generate deterministic random numbers. 99 100 Raises: 101 ValueError: If seed is invalid when seed < 0 or seed > MAX_UINT_32. 102 103 Examples: 104 >>> # Set a new global configuration value for the seed value. 105 >>> # Operations with randomness will use the seed value to generate random values. 106 >>> ds.config.set_seed(1000) 107 """ 108 if seed < 0 or seed > UINT32_MAX: 109 raise ValueError("Seed given is not within the required range.") 110 _config.set_seed(seed) 111 random.seed(seed) 112 # numpy.random isn't thread safe 113 numpy.random.seed(seed) 114 115 116def get_seed(): 117 """ 118 Get random number seed. If the seed has been set, then will 119 return the set value, otherwise it will return the default seed value 120 which equals to std::mt19937::default_seed. 121 122 Returns: 123 int, random number seed. 124 125 Examples: 126 >>> # Get the global configuration of seed. 127 >>> # If set_seed() is never called before, the default value(std::mt19937::default_seed) will be returned. 128 >>> seed = ds.config.get_seed() 129 """ 130 return _config.get_seed() 131 132 133def set_prefetch_size(size): 134 """ 135 Set the queue capacity of the thread in pipeline. 136 137 Args: 138 size (int): The length of the cache queue. 139 140 Raises: 141 ValueError: If the queue capacity of the thread is invalid when size <= 0 or size > MAX_INT_32. 142 143 Note: 144 Since total memory used for prefetch can grow very large with high number of workers, 145 when the number of workers is greater than 4, the per worker prefetch size will be reduced. 146 The actual prefetch size at runtime per-worker will be prefetchsize * (4 / num_parallel_workers). 147 148 Examples: 149 >>> # Set a new global configuration value for the prefetch size. 150 >>> ds.config.set_prefetch_size(1000) 151 """ 152 if size <= 0 or size > INT32_MAX: 153 raise ValueError("Prefetch size given is not within the required range.") 154 _config.set_op_connector_size(size) 155 156 157def get_prefetch_size(): 158 """ 159 Get the prefetch size as for number of rows. 160 161 Returns: 162 int, total number of rows to be prefetched. 163 164 Examples: 165 >>> # Get the global configuration of prefetch size. 166 >>> # If set_prefetch_size() is never called before, the default value(20) will be returned. 167 >>> prefetch_size = ds.config.get_prefetch_size() 168 """ 169 return _config.get_op_connector_size() 170 171 172def set_num_parallel_workers(num): 173 """ 174 Set a new global configuration default value for the number of parallel workers. 175 This setting will affect the parallelism of all dataset operation. 176 177 Args: 178 num (int): Number of parallel workers to be used as a default for each operation. 179 180 Raises: 181 ValueError: If num_parallel_workers is invalid when num <= 0 or num > MAX_INT_32. 182 183 Examples: 184 >>> # Set a new global configuration value for the number of parallel workers. 185 >>> # Now parallel dataset operators will run with 8 workers. 186 >>> ds.config.set_num_parallel_workers(8) 187 """ 188 if num <= 0 or num > INT32_MAX: 189 raise ValueError("Number of parallel workers given is not within the required range.") 190 _config.set_num_parallel_workers(num) 191 192 193def get_num_parallel_workers(): 194 """ 195 Get the global configuration of number of parallel workers. 196 This is the DEFAULT num_parallel_workers value used for each operation, it is not related 197 to AutoNumWorker feature. 198 199 Returns: 200 int, number of parallel workers to be used as a default for each operation. 201 202 Examples: 203 >>> # Get the global configuration of parallel workers. 204 >>> # If set_num_parallel_workers() is never called before, the default value(8) will be returned. 205 >>> num_parallel_workers = ds.config.get_num_parallel_workers() 206 """ 207 return _config.get_num_parallel_workers() 208 209 210def set_numa_enable(numa_enable): 211 """ 212 Set the default state of numa enabled. If numa_enable is True, need to ensure numa library is installed. 213 214 Args: 215 numa_enable (bool): Whether to use numa bind feature. 216 217 Raises: 218 TypeError: If numa_enable is not a boolean data type. 219 220 Examples: 221 >>> # Set a new global configuration value for the state of numa enabled. 222 >>> # Now parallel dataset operators will run with numa bind function 223 >>> ds.config.set_numa_enable(True) 224 """ 225 if not isinstance(numa_enable, bool): 226 raise TypeError("numa_enable must be a boolean dtype.") 227 _config.set_numa_enable(numa_enable) 228 229 230def get_numa_enable(): 231 """ 232 Get the state of numa to indicate enabled/disabled. 233 This is the DEFAULT numa enabled value used for the all process. 234 235 Returns: 236 bool, the default state of numa enabled. 237 238 Examples: 239 >>> # Get the global configuration of numa. 240 >>> numa_state = ds.config.get_numa_enable() 241 """ 242 return _config.get_numa_enable() 243 244 245def set_monitor_sampling_interval(interval): 246 """ 247 Set the default interval (in milliseconds) for monitor sampling. 248 249 Args: 250 interval (int): Interval (in milliseconds) to be used for performance monitor sampling. 251 252 Raises: 253 ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32. 254 255 Examples: 256 >>> # Set a new global configuration value for the monitor sampling interval. 257 >>> ds.config.set_monitor_sampling_interval(100) 258 """ 259 if interval <= 0 or interval > INT32_MAX: 260 raise ValueError("Interval given is not within the required range.") 261 _config.set_monitor_sampling_interval(interval) 262 263 264def get_monitor_sampling_interval(): 265 """ 266 Get the global configuration of sampling interval of performance monitor. 267 268 Returns: 269 int, interval (in milliseconds) for performance monitor sampling. 270 271 Examples: 272 >>> # Get the global configuration of monitor sampling interval. 273 >>> # If set_monitor_sampling_interval() is never called before, the default value(1000) will be returned. 274 >>> sampling_interval = ds.config.get_monitor_sampling_interval() 275 """ 276 return _config.get_monitor_sampling_interval() 277 278 279def set_auto_num_workers(enable): 280 """ 281 Set num_parallel_workers for each op automatically(This feature is turned off by default). 282 283 If turned on, the num_parallel_workers in each op will be adjusted automatically, possibly overwriting the 284 num_parallel_workers passed in by user or the default value (if user doesn't pass anything) set by 285 ds.config.set_num_parallel_workers(). 286 287 For now, this function is only optimized for YoloV3 dataset with per_batch_map (running map in batch). 288 This feature aims to provide a baseline for optimized num_workers assignment for each operation. 289 Operation whose num_parallel_workers is adjusted to a new value will be logged. 290 291 Args: 292 enable (bool): Whether to enable auto num_workers feature or not. 293 294 Raises: 295 TypeError: If enable is not of boolean type. 296 297 Examples: 298 >>> # Enable auto_num_worker feature, this might override the num_parallel_workers passed in by user 299 >>> ds.config.set_auto_num_workers(True) 300 """ 301 if not isinstance(enable, bool): 302 raise TypeError("enable must be of type bool.") 303 _config.set_auto_num_workers(enable) 304 305 306def _set_auto_workers_config(option): 307 """ 308 INTERNAL USE ONLY! 309 Select the weight profile of auto_num_workers. currently these 7 options are supported. 310 Option #0 leaf_num_workers:batch_num_workers:map_num_workers=1:1:1 311 Option #1 leaf_num_workers:batch_num_workers:map_num_workers=2:1:1 312 Option #2 leaf_num_workers:batch_num_workers:map_num_workers=1:2:1 313 Option #3 leaf_num_workers:batch_num_workers:map_num_workers=1:1:2 314 Option #4 leaf_num_workers:batch_num_workers:map_num_workers=2:2:1 315 Option #5 leaf_num_workers:batch_num_workers:map_num_workers=2:1:2 316 Option #6 leaf_num_workers:batch_num_workers:map_num_workers=1:2:2 317 Args: 318 option (int): The id of the profile to use. 319 Raises: 320 ValueError: If option is not int or not within the range of [0, 6] 321 """ 322 if not isinstance(option, int): 323 raise ValueError("option isn't of type int.") 324 if option < 0 or option > 6: 325 raise ValueError("option isn't within the required range of [0, 6].") 326 _config.set_auto_worker_config(option) 327 328 329def get_auto_num_workers(): 330 """ 331 Get the setting (turned on or off) automatic number of workers. 332 333 Returns: 334 bool, whether auto number worker feature is turned on. 335 336 Examples: 337 >>> # Get the global configuration of auto number worker feature. 338 >>> num_workers = ds.config.get_auto_num_workers() 339 """ 340 return _config.get_auto_num_workers() 341 342 343def set_callback_timeout(timeout): 344 """ 345 Set the default timeout (in seconds) for DSWaitedCallback. 346 In case of a deadlock, the wait function will exit after the timeout period. 347 348 Args: 349 timeout (int): Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. 350 351 Raises: 352 ValueError: If timeout is invalid when timeout <= 0 or timeout > MAX_INT_32. 353 354 Examples: 355 >>> # Set a new global configuration value for the timeout value. 356 >>> ds.config.set_callback_timeout(100) 357 """ 358 if timeout <= 0 or timeout > INT32_MAX: 359 raise ValueError("Timeout given is not within the required range.") 360 _config.set_callback_timeout(timeout) 361 362 363def get_callback_timeout(): 364 """ 365 Get the default timeout for DSWaitedCallback. 366 In case of a deadlock, the wait function will exit after the timeout period. 367 368 Returns: 369 int, Timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock. 370 371 Examples: 372 >>> # Get the global configuration of callback timeout. 373 >>> # If set_callback_timeout() is never called before, the default value(60) will be returned. 374 >>> callback_timeout = ds.config.get_callback_timeout() 375 """ 376 return _config.get_callback_timeout() 377 378 379def __str__(): 380 """ 381 String representation of the configurations. 382 383 Returns: 384 str, configurations. 385 """ 386 return str(_config) 387 388 389def load(file): 390 """ 391 Load the project configuration from the file format. 392 393 Args: 394 file (str): Path of the configuration file to be loaded. 395 396 Raises: 397 RuntimeError: If file is invalid and parsing fails. 398 399 Examples: 400 >>> # Set new default configuration according to values in the configuration file. 401 >>> # example config file: 402 >>> # { 403 >>> # "logFilePath": "/tmp", 404 >>> # "numParallelWorkers": 4, 405 >>> # "seed": 5489, 406 >>> # "monitorSamplingInterval": 30 407 >>> # } 408 >>> config_file = "/path/to/config/file" 409 >>> ds.config.load(config_file) 410 """ 411 _config.load(file) 412 413 414def _stop_dataset_profiler(): 415 """ 416 Mainly for stop dataset profiler. 417 418 Returns: 419 bool, whether the profiler file has generated. 420 """ 421 422 while not _config.get_profiler_file_status(): 423 _config.stop_dataset_profiler(True) 424 logger.warning("Profiling: waiting for dataset part profiling stop.") 425 time.sleep(1) 426 427 428def get_enable_shared_mem(): 429 """ 430 Get the default state of shared mem enabled variable. 431 432 433 Returns: 434 bool, the state of shared mem enabled variable (default=True). 435 436 Examples: 437 >>> # Get the flag of shared memory feature. 438 >>> shared_mem_flag = ds.config.get_enable_shared_mem() 439 """ 440 # For windows we forbid shared mem function temporarily 441 if platform.system().lower() == 'windows': 442 logger.warning("For windows we forbid shared mem function temporarily.") 443 return False 444 return _config.get_enable_shared_mem() 445 446 447def set_enable_shared_mem(enable): 448 """ 449 Set the default state of shared memory flag. If shared_mem_enable is True, will use shared memory queues 450 to pass data to processes that are created for operators that set python_multiprocessing=True. 451 452 Args: 453 enable (bool): Whether to use shared memory in operators when python_multiprocessing=True. 454 455 Raises: 456 TypeError: If enable is not a boolean data type. 457 458 Examples: 459 >>> # Enable shared memory feature to improve the performance of Python multiprocessing. 460 >>> ds.config.set_enable_shared_mem(True) 461 """ 462 if not isinstance(enable, bool): 463 raise TypeError("enable must be of type bool.") 464 logger.warning("The shared memory is on, multiprocessing performance will be improved. " 465 "Note: the required shared memory can't exceeds 80% of the available shared memory. " 466 "You can reduce max_rowsize or reduce num_parallel_workers to reduce shared memory usage.") 467 _config.set_enable_shared_mem(enable) 468 469 470def set_sending_batches(batch_num): 471 """ 472 Set the default sending batches when training with sink_mode=True in Ascend device. 473 474 Args: 475 batch_num (int): the total sending batches, when batch_num is set, it will wait unless sending batches 476 increase, default is 0 which means will send all batches in dataset. 477 478 Raises: 479 TypeError: If batch_num is not in int type. 480 481 Examples: 482 >>> # Set a new global configuration value for the sending batches 483 >>> ds.config.set_sending_batches(10) 484 """ 485 if not isinstance(batch_num, int): 486 raise TypeError("batch_num must be an int dtype.") 487 _config.set_sending_batches(batch_num) 488