1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Context of cost_model in auto_parallel""" 16import threading 17from mindspore._c_expression import CostModelContext 18from mindspore._checkparam import args_type_check 19 20 21class _CostModelContext: 22 """ 23 _CostModelContext is the environment in which operations are executed 24 25 Note: 26 Creating a context through instantiating Context object is not recommended. 27 Use cost_model_context() to get the context since Context is singleton. 28 """ 29 _instance = None 30 _instance_lock = threading.Lock() 31 32 def __init__(self): 33 self._context_handle = CostModelContext.get_instance() 34 35 def __new__(cls): 36 if cls._instance is None: 37 cls._instance_lock.acquire() 38 cls._instance = object.__new__(cls) 39 cls._instance_lock.release() 40 return cls._instance 41 42 def set_device_memory_capacity(self, dev_mem_cap): 43 """ 44 Set device memory capacity. 45 46 Args: 47 dev_mem_cap (float): The memory capacity for each device. 48 49 Raises: 50 ValueError: If context handle is none. 51 """ 52 if self._context_handle is None: 53 raise ValueError("Context handle is none in context!!!") 54 self._context_handle.set_device_memory_capacity(dev_mem_cap) 55 56 def get_device_memory_capacity(self): 57 """ 58 Get device memory capacity. 59 60 Raises: 61 ValueError: If context handle is none. 62 """ 63 if self._context_handle is None: 64 raise ValueError("Context handle is none in context!!!") 65 return self._context_handle.get_device_memory_capacity() 66 67 def set_costmodel_alpha(self, alpha): 68 """ 69 Set costmodel alpha. 70 71 Args: 72 alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm. 73 74 Raises: 75 ValueError: If context handle is none. 76 """ 77 if self._context_handle is None: 78 raise ValueError("Context handle is none in context!!!") 79 self._context_handle.set_costmodel_alpha(alpha) 80 81 def get_costmodel_alpha(self): 82 """ 83 Get costmodel alpha. 84 85 Raises: 86 ValueError: If context handle is none. 87 """ 88 if self._context_handle is None: 89 raise ValueError("Context handle is none in context!!!") 90 return self._context_handle.get_costmodel_alpha() 91 92 def set_costmodel_beta(self, beta): 93 """ 94 Set costmodel beta. 95 96 Args: 97 beta (float): The parameter costmodel_beta used in strategy-searching algorithm. 98 99 Raises: 100 ValueError: If context handle is none. 101 """ 102 if self._context_handle is None: 103 raise ValueError("Context handle is none in context!!!") 104 self._context_handle.set_costmodel_beta(beta) 105 106 def get_costmodel_beta(self): 107 """ 108 Get costmodel beta. 109 110 Raises: 111 ValueError: If context handle is none. 112 """ 113 if self._context_handle is None: 114 raise ValueError("Context handle is none in context!!!") 115 return self._context_handle.get_costmodel_beta() 116 117 def set_costmodel_gamma(self, gamma): 118 """ 119 Set costmodel gamma. 120 121 Args: 122 gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm. 123 124 Raises: 125 ValueError: If context handle is none. 126 """ 127 if self._context_handle is None: 128 raise ValueError("Context handle is none in context!!!") 129 self._context_handle.set_costmodel_gamma(gamma) 130 131 def get_costmodel_gamma(self): 132 """ 133 Get costmodel gamma. 134 135 Raises: 136 ValueError: If context handle is none. 137 """ 138 if self._context_handle is None: 139 raise ValueError("Context handle is none in context!!!") 140 return self._context_handle.get_costmodel_gamma() 141 142 def set_costmodel_communi_threshold(self, threshold): 143 """ 144 Set costmodel communication threshold. 145 146 Args: 147 threshold (float): A parameter used in adjusting communication calculation for practice. 148 149 Raises: 150 ValueError: If context handle is none. 151 """ 152 if self._context_handle is None: 153 raise ValueError("Context handle is none in context!!!") 154 self._context_handle.set_costmodel_communi_threshold(threshold) 155 156 def get_costmodel_communi_threshold(self): 157 """ 158 Get costmodel communication threshold. 159 160 Raises: 161 ValueError: If context handle is none. 162 """ 163 if self._context_handle is None: 164 raise ValueError("Context handle is none in context!!!") 165 return self._context_handle.get_costmodel_communi_threshold() 166 167 def set_costmodel_communi_const(self, communi_const): 168 """ 169 Set costmodel communication const. 170 171 Args: 172 const (float): A parameter used in adjusting communication calculation for practice. 173 174 Raises: 175 ValueError: If context handle is none. 176 """ 177 if self._context_handle is None: 178 raise ValueError("Context handle is none in context!!!") 179 self._context_handle.set_costmodel_communi_const(communi_const) 180 181 def get_costmodel_communi_const(self): 182 """ 183 Get costmodel communication const. 184 185 Raises: 186 ValueError: If context handle is none. 187 """ 188 if self._context_handle is None: 189 raise ValueError("Context handle is none in context!!!") 190 return self._context_handle.get_costmodel_communi_const() 191 192 def set_costmodel_communi_bias(self, communi_bias): 193 """ 194 Set costmodel communication bias. 195 196 Args: 197 communi_bias (float): A parameter used in adjusting communication calculation for practice. 198 199 Raises: 200 ValueError: If context handle is none. 201 """ 202 if self._context_handle is None: 203 raise ValueError("Context handle is none in context!!!") 204 self._context_handle.set_costmodel_communi_bias(communi_bias) 205 206 def get_costmodel_communi_bias(self): 207 """ 208 Get costmodel communication bias. 209 210 Raises: 211 ValueError: If context handle is none. 212 """ 213 if self._context_handle is None: 214 raise ValueError("Context handle is none in context!!!") 215 return self._context_handle.get_costmodel_communi_bias() 216 217 def set_multi_subgraphs(self, multi_subgraph): 218 """ 219 Set the flag of ANF graph containing multiple subgraphs. 220 221 Args: 222 multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag. 223 224 Raises: 225 ValueError: If context handle is none. 226 """ 227 if self._context_handle is None: 228 raise ValueError("Context handle is none in context!!!") 229 self._context_handle.set_multi_subgraphs(multi_subgraph) 230 231 def get_multi_subgraphs(self): 232 """ 233 Get the flag of ANF graph containing multiple subgraphs. 234 235 Raises: 236 ValueError: If context handle is none. 237 """ 238 if self._context_handle is None: 239 raise ValueError("Context handle is none in context!!!") 240 return self._context_handle.get_multi_subgraphs() 241 242 def set_run_phase(self, phase): 243 """ 244 Set the flag of running phase: training (0) or inference (1) 245 246 Args: 247 phase (int): A parameter indicating which phase is running. 248 249 Raises: 250 ValueError: If context handle is none, or phase is not in {0, 1}. 251 """ 252 if not isinstance(phase, int) or isinstance(phase, bool): 253 raise TypeError(f"The type of communi_const must be int, but got {type(phase)}.") 254 if self._context_handle is None: 255 raise ValueError("Context handle is none in context!!!") 256 if phase not in (0, 1): 257 raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase)) 258 self._context_handle.set_run_phase(phase) 259 260 def get_run_phase(self): 261 """ 262 Get the flag of running phase. 263 264 Raises: 265 ValueError: If context handle is none. 266 """ 267 if self._context_handle is None: 268 raise ValueError("Context handle is none in context!!!") 269 return self._context_handle.get_run_phase() 270 271 def set_dp_algo_single_loop(self, single_loop): 272 """ 273 Set the flag of generating a single suite of OperatorInfos in for-loop. 274 275 Args: 276 single_loop (bool): The parameter for the single loop flag. 277 278 Raises: 279 ValueError: If context handle is none. 280 """ 281 if not isinstance(single_loop, bool): 282 raise TypeError(f"The type of single_loop must be bool, but got {type(single_loop)}.") 283 if self._context_handle is None: 284 raise ValueError("Context handle is none in context!!!") 285 self._context_handle.set_dp_algo_single_loop(single_loop) 286 287 def get_dp_algo_single_loop(self): 288 """ 289 Get the flag of whether or not generating a single suite of OperatorInfos in for-loop. 290 291 Raises: 292 ValueError: If context handle is none. 293 """ 294 if self._context_handle is None: 295 raise ValueError("Context handle is none in context!!!") 296 return self._context_handle.get_dp_algo_single_loop() 297 298 def set_costmodel_allreduce_fusion_algorithm(self, algorithm): 299 """ 300 Set costmodel allreduce fusion algorithm. 301 302 Args: 303 algorithm (int): The AllReduce fusion algorithm of parameter gradients. 304 305 Raises: 306 ValueError: If context handle is none. 307 """ 308 if self._context_handle is None: 309 raise ValueError("Context handle is none in context!!!") 310 self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm) 311 312 def get_costmodel_allreduce_fusion_algorithm(self): 313 """ 314 Get costmodel allreduce fusion algorithm. 315 316 Raises: 317 ValueError: If context handle is none. 318 """ 319 if self._context_handle is None: 320 raise ValueError("Context handle is none in context!!!") 321 return self._context_handle.get_costmodel_allreduce_fusion_algorithm() 322 323 def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times): 324 """ 325 Set costmodel allreduce fusion times. 326 327 Args: 328 allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients. 329 330 Raises: 331 ValueError: If context handle is none. 332 """ 333 if self._context_handle is None: 334 raise ValueError("Context handle is none in context!!!") 335 self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times) 336 337 def get_costmodel_allreduce_fusion_times(self): 338 """ 339 Get costmodel allreduce fusion times. 340 341 Raises: 342 ValueError: If context handle is none. 343 """ 344 if self._context_handle is None: 345 raise ValueError("Context handle is none in context!!!") 346 return self._context_handle.get_costmodel_allreduce_fusion_times() 347 348 def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent): 349 """ 350 Set costmodel allreduce fusion tail percent. 351 352 Args: 353 tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients 354 AllReduce in the whole backward computing time. 355 356 Raises: 357 ValueError: If context handle is none. 358 """ 359 if self._context_handle is None: 360 raise ValueError("Context handle is none in context!!!") 361 self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent) 362 363 def get_costmodel_allreduce_fusion_tail_percent(self): 364 """ 365 Get costmodel allreduce fusion tail percent. 366 367 Raises: 368 ValueError: If context handle is none. 369 """ 370 if self._context_handle is None: 371 raise ValueError("Context handle is none in context!!!") 372 return self._context_handle.get_costmodel_allreduce_fusion_tail_percent() 373 374 def set_costmodel_allreduce_fusion_tail_time(self, tail_time): 375 """ 376 Set costmodel allreduce fusion tail time. 377 378 Args: 379 tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward 380 computation. 381 382 Raises: 383 ValueError: If context handle is none. 384 """ 385 if self._context_handle is None: 386 raise ValueError("Context handle is none in context!!!") 387 self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time) 388 389 def get_costmodel_allreduce_fusion_tail_time(self): 390 """ 391 Get costmodel allreduce fusion tail time. 392 393 Raises: 394 ValueError: If context handle is none. 395 """ 396 if self._context_handle is None: 397 raise ValueError("Context handle is none in context!!!") 398 return self._context_handle.get_costmodel_allreduce_fusion_tail_time() 399 400 def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time): 401 """ 402 Set costmodel allreduce fusion allreduce inherent time. 403 404 Args: 405 allreduce_inherent_time (int): The inherent cost time of AllReduce. 406 407 Raises: 408 ValueError: If context handle is none. 409 """ 410 if self._context_handle is None: 411 raise ValueError("Context handle is none in context!!!") 412 self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time) 413 414 def get_costmodel_allreduce_fusion_allreduce_inherent_time(self): 415 """ 416 Get costmodel allreduce fusion allreduce inherent time. 417 418 Raises: 419 ValueError: If context handle is none. 420 """ 421 if self._context_handle is None: 422 raise ValueError("Context handle is none in context!!!") 423 return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time() 424 425 def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth): 426 """ 427 Set costmodel allreduce fusion allreduce bandwidth. 428 429 Args: 430 allreduce_bandwidth (int): The bandwidth of AllReduce. 431 432 Raises: 433 ValueError: If context handle is none. 434 """ 435 if self._context_handle is None: 436 raise ValueError("Context handle is none in context!!!") 437 self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth) 438 439 def get_costmodel_allreduce_fusion_allreduce_bandwidth(self): 440 """ 441 Get costmodel allreduce fusion allreduce bandwidth. 442 443 Raises: 444 ValueError: If context handle is none. 445 """ 446 if self._context_handle is None: 447 raise ValueError("Context handle is none in context!!!") 448 return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth() 449 450 def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter): 451 """ 452 Set costmodel allreduce fusion computation time parameter. 453 454 Args: 455 computation_time_parameter (int): The parameter used to compute backward computation time. 456 457 Raises: 458 ValueError: If context handle is none. 459 """ 460 if self._context_handle is None: 461 raise ValueError("Context handle is none in context!!!") 462 self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter) 463 464 def get_costmodel_allreduce_fusion_computation_time_parameter(self): 465 """ 466 Get costmodel allreduce fusion computation time parameter. 467 468 Raises: 469 ValueError: If context handle is none. 470 """ 471 if self._context_handle is None: 472 raise ValueError("Context handle is none in context!!!") 473 return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter() 474 475 def reset_cost_model(self): 476 """ 477 Reset cost model settings. 478 479 Raises: 480 ValueError: If context handle is none. 481 """ 482 if self._context_handle is None: 483 raise ValueError("Context handle is none in context!!!") 484 self._context_handle.reset_cost_model() 485 486 487_cost_model_context = None 488 489 490def cost_model_context(): 491 """ 492 Get the global _cost_model_context. If it is not created, create a new one. 493 494 Returns: 495 The global cost_model context. 496 """ 497 global _cost_model_context 498 if _cost_model_context is None: 499 _cost_model_context = _CostModelContext() 500 return _cost_model_context 501 502 503set_cost_model_context_func_map = { 504 "device_memory_capacity": cost_model_context().set_device_memory_capacity, 505 "costmodel_alpha": cost_model_context().set_costmodel_alpha, 506 "costmodel_beta": cost_model_context().set_costmodel_beta, 507 "costmodel_gamma": cost_model_context().set_costmodel_gamma, 508 "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold, 509 "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, 510 "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, 511 "run_phase": cost_model_context().set_run_phase, 512 "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, 513 "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, 514 "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, 515 "costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time, 516 "costmodel_allreduce_fusion_allreduce_inherent_time": 517 cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time, 518 "costmodel_allreduce_fusion_allreduce_bandwidth": 519 cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth, 520 "costmodel_allreduce_fusion_computation_time_parameter": 521 cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter} 522 523 524get_cost_model_context_func_map = { 525 "device_memory_capacity": cost_model_context().get_device_memory_capacity, 526 "costmodel_alpha": cost_model_context().get_costmodel_alpha, 527 "costmodel_beta": cost_model_context().get_costmodel_beta, 528 "costmodel_gamma": cost_model_context().get_costmodel_gamma, 529 "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, 530 "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, 531 "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, 532 "run_phase": cost_model_context().get_run_phase, 533 "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, 534 "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, 535 "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, 536 "costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time, 537 "costmodel_allreduce_fusion_allreduce_inherent_time": 538 cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time, 539 "costmodel_allreduce_fusion_allreduce_bandwidth": 540 cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth, 541 "costmodel_allreduce_fusion_computation_time_parameter": 542 cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter} 543 544 545@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, 546 costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, 547 multi_subgraphs=bool, run_phase=int, 548 costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, 549 costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, 550 costmodel_allreduce_fusion_allreduce_inherent_time=float, 551 costmodel_allreduce_fusion_allreduce_bandwidth=float, 552 costmodel_allreduce_fusion_computation_time_parameter=float) 553def set_cost_model_context(**kwargs): 554 """ 555 Set cost model context. 556 557 Note: 558 Attribute name is needed. 559 560 Args: 561 device_memory_capacity (float): The memory capacity for each device. 562 costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm. 563 costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm. 564 costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm. 565 costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice. 566 costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. 567 costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. 568 run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0. 569 costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. 570 0: bypass allreduce fusion; 571 1: only use backward computation time to group allreduce; 572 2: use backward computation time and parameter gradient allreduce time to group allreduce. 573 costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients. 574 costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage 575 of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward 576 computing time. 577 costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of 578 the last parameter gradients AllReduce after the end of backward computation. 579 costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The 580 inherent cost time of AllReduce. 581 costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The 582 bandwidth of AllReduce. 583 costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm. 584 The parameter used to compute backward computation time. 585 586 587 588 Raises: 589 ValueError: If context keyword is not recognized. 590 """ 591 for key, value in kwargs.items(): 592 if key not in set_cost_model_context_func_map: 593 raise ValueError("Set context keyword %s is not recognized!" % key) 594 set_func = set_cost_model_context_func_map[key] 595 set_func(value) 596 597 598def get_cost_model_context(attr_key): 599 """ 600 Get cost model context attributes. 601 602 Note: 603 Return value according to the attribute value. 604 605 Args: 606 attr_key (str): The key of the attribute. 607 608 Raises: 609 ValueError: If context keyword is not recognized. 610 """ 611 if attr_key not in get_cost_model_context_func_map: 612 raise ValueError("Get context keyword %s is not recognized!" % attr_key) 613 get_func = get_cost_model_context_func_map[attr_key] 614 return get_func() 615 616 617def reset_cost_model_context(): 618 """Reset cost model context attributes.""" 619 cost_model_context().reset_cost_model() 620 621 622def _set_multi_subgraphs(multi_subgraph=True): 623 """ 624 Set the flag of ANF graph containing multiple subgraphs. 625 626 Args: 627 multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag. 628 """ 629 cost_model_context().set_multi_subgraphs(multi_subgraph) 630 631 632def _get_multi_subgraphs(): 633 """ 634 Get the flag of ANF graph containing multiple subgraphs. 635 """ 636 return cost_model_context().get_multi_subgraphs() 637 638 639def _set_algo_single_loop(single_loop=True): 640 """ 641 Set the flag of generating a single suite of OperatorInfos in for-loop. 642 643 Args: 644 single_loop (bool): The parameter for the single loop flag. 645 """ 646 cost_model_context().set_dp_algo_single_loop(single_loop) 647 648 649def _get_algo_single_loop(): 650 """ 651 Get the flag of whether or not generating a single suite of OperatorInfos in for-loop. 652 """ 653 return cost_model_context().get_dp_algo_single_loop() 654