• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Context of cost_model in auto_parallel"""
16import threading
17from mindspore._c_expression import CostModelContext
18from mindspore._checkparam import args_type_check
19
20
21class _CostModelContext:
22    """
23    _CostModelContext is the environment in which operations are executed
24
25    Note:
26        Creating a context through instantiating Context object is not recommended.
27        Use cost_model_context() to get the context since Context is singleton.
28    """
29    _instance = None
30    _instance_lock = threading.Lock()
31
32    def __init__(self):
33        self._context_handle = CostModelContext.get_instance()
34
35    def __new__(cls):
36        if cls._instance is None:
37            cls._instance_lock.acquire()
38            cls._instance = object.__new__(cls)
39            cls._instance_lock.release()
40        return cls._instance
41
42    def set_device_memory_capacity(self, dev_mem_cap):
43        """
44        Set device memory capacity.
45
46        Args:
47            dev_mem_cap (float): The memory capacity for each device.
48
49        Raises:
50            ValueError: If context handle is none.
51        """
52        if self._context_handle is None:
53            raise ValueError("Context handle is none in context!!!")
54        self._context_handle.set_device_memory_capacity(dev_mem_cap)
55
56    def get_device_memory_capacity(self):
57        """
58        Get device memory capacity.
59
60        Raises:
61            ValueError: If context handle is none.
62        """
63        if self._context_handle is None:
64            raise ValueError("Context handle is none in context!!!")
65        return self._context_handle.get_device_memory_capacity()
66
67    def set_costmodel_alpha(self, alpha):
68        """
69        Set costmodel alpha.
70
71        Args:
72            alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
73
74        Raises:
75            ValueError: If context handle is none.
76        """
77        if self._context_handle is None:
78            raise ValueError("Context handle is none in context!!!")
79        self._context_handle.set_costmodel_alpha(alpha)
80
81    def get_costmodel_alpha(self):
82        """
83        Get costmodel alpha.
84
85        Raises:
86            ValueError: If context handle is none.
87        """
88        if self._context_handle is None:
89            raise ValueError("Context handle is none in context!!!")
90        return self._context_handle.get_costmodel_alpha()
91
92    def set_costmodel_beta(self, beta):
93        """
94        Set costmodel beta.
95
96        Args:
97            beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
98
99        Raises:
100            ValueError: If context handle is none.
101        """
102        if self._context_handle is None:
103            raise ValueError("Context handle is none in context!!!")
104        self._context_handle.set_costmodel_beta(beta)
105
106    def get_costmodel_beta(self):
107        """
108        Get costmodel beta.
109
110        Raises:
111            ValueError: If context handle is none.
112        """
113        if self._context_handle is None:
114            raise ValueError("Context handle is none in context!!!")
115        return self._context_handle.get_costmodel_beta()
116
117    def set_costmodel_gamma(self, gamma):
118        """
119        Set costmodel gamma.
120
121        Args:
122            gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
123
124        Raises:
125            ValueError: If context handle is none.
126        """
127        if self._context_handle is None:
128            raise ValueError("Context handle is none in context!!!")
129        self._context_handle.set_costmodel_gamma(gamma)
130
131    def get_costmodel_gamma(self):
132        """
133        Get costmodel gamma.
134
135        Raises:
136            ValueError: If context handle is none.
137        """
138        if self._context_handle is None:
139            raise ValueError("Context handle is none in context!!!")
140        return self._context_handle.get_costmodel_gamma()
141
142    def set_costmodel_communi_threshold(self, threshold):
143        """
144        Set costmodel communication threshold.
145
146        Args:
147            threshold (float): A parameter used in adjusting communication calculation for practice.
148
149        Raises:
150            ValueError: If context handle is none.
151        """
152        if self._context_handle is None:
153            raise ValueError("Context handle is none in context!!!")
154        self._context_handle.set_costmodel_communi_threshold(threshold)
155
156    def get_costmodel_communi_threshold(self):
157        """
158        Get costmodel communication threshold.
159
160        Raises:
161            ValueError: If context handle is none.
162        """
163        if self._context_handle is None:
164            raise ValueError("Context handle is none in context!!!")
165        return self._context_handle.get_costmodel_communi_threshold()
166
167    def set_costmodel_communi_const(self, communi_const):
168        """
169        Set costmodel communication const.
170
171        Args:
172            const (float): A parameter used in adjusting communication calculation for practice.
173
174        Raises:
175            ValueError: If context handle is none.
176        """
177        if self._context_handle is None:
178            raise ValueError("Context handle is none in context!!!")
179        self._context_handle.set_costmodel_communi_const(communi_const)
180
181    def get_costmodel_communi_const(self):
182        """
183        Get costmodel communication const.
184
185        Raises:
186            ValueError: If context handle is none.
187        """
188        if self._context_handle is None:
189            raise ValueError("Context handle is none in context!!!")
190        return self._context_handle.get_costmodel_communi_const()
191
192    def set_costmodel_communi_bias(self, communi_bias):
193        """
194        Set costmodel communication bias.
195
196        Args:
197            communi_bias (float): A parameter used in adjusting communication calculation for practice.
198
199        Raises:
200            ValueError: If context handle is none.
201        """
202        if self._context_handle is None:
203            raise ValueError("Context handle is none in context!!!")
204        self._context_handle.set_costmodel_communi_bias(communi_bias)
205
206    def get_costmodel_communi_bias(self):
207        """
208        Get costmodel communication bias.
209
210        Raises:
211            ValueError: If context handle is none.
212        """
213        if self._context_handle is None:
214            raise ValueError("Context handle is none in context!!!")
215        return self._context_handle.get_costmodel_communi_bias()
216
217    def set_multi_subgraphs(self, multi_subgraph):
218        """
219        Set the flag of ANF graph containing multiple subgraphs.
220
221        Args:
222            multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
223
224        Raises:
225            ValueError: If context handle is none.
226        """
227        if self._context_handle is None:
228            raise ValueError("Context handle is none in context!!!")
229        self._context_handle.set_multi_subgraphs(multi_subgraph)
230
231    def get_multi_subgraphs(self):
232        """
233        Get the flag of ANF graph containing multiple subgraphs.
234
235        Raises:
236            ValueError: If context handle is none.
237        """
238        if self._context_handle is None:
239            raise ValueError("Context handle is none in context!!!")
240        return self._context_handle.get_multi_subgraphs()
241
242    def set_run_phase(self, phase):
243        """
244        Set the flag of running phase: training (0) or inference (1)
245
246        Args:
247            phase (int): A parameter indicating which phase is running.
248
249        Raises:
250            ValueError: If context handle is none, or phase is not in {0, 1}.
251        """
252        if not isinstance(phase, int) or isinstance(phase, bool):
253            raise TypeError(f"The type of communi_const must be int, but got {type(phase)}.")
254        if self._context_handle is None:
255            raise ValueError("Context handle is none in context!!!")
256        if phase not in (0, 1):
257            raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
258        self._context_handle.set_run_phase(phase)
259
260    def get_run_phase(self):
261        """
262        Get the flag of running phase.
263
264        Raises:
265            ValueError: If context handle is none.
266        """
267        if self._context_handle is None:
268            raise ValueError("Context handle is none in context!!!")
269        return self._context_handle.get_run_phase()
270
271    def set_dp_algo_single_loop(self, single_loop):
272        """
273        Set the flag of generating a single suite of OperatorInfos in for-loop.
274
275        Args:
276            single_loop (bool): The parameter for the single loop flag.
277
278        Raises:
279            ValueError: If context handle is none.
280        """
281        if not isinstance(single_loop, bool):
282            raise TypeError(f"The type of single_loop must be bool, but got {type(single_loop)}.")
283        if self._context_handle is None:
284            raise ValueError("Context handle is none in context!!!")
285        self._context_handle.set_dp_algo_single_loop(single_loop)
286
287    def get_dp_algo_single_loop(self):
288        """
289        Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
290
291        Raises:
292            ValueError: If context handle is none.
293        """
294        if self._context_handle is None:
295            raise ValueError("Context handle is none in context!!!")
296        return self._context_handle.get_dp_algo_single_loop()
297
298    def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
299        """
300        Set costmodel allreduce fusion algorithm.
301
302        Args:
303            algorithm (int): The AllReduce fusion algorithm of parameter gradients.
304
305        Raises:
306            ValueError: If context handle is none.
307        """
308        if self._context_handle is None:
309            raise ValueError("Context handle is none in context!!!")
310        self._context_handle.set_costmodel_allreduce_fusion_algorithm(algorithm)
311
312    def get_costmodel_allreduce_fusion_algorithm(self):
313        """
314        Get costmodel allreduce fusion algorithm.
315
316        Raises:
317            ValueError: If context handle is none.
318        """
319        if self._context_handle is None:
320            raise ValueError("Context handle is none in context!!!")
321        return self._context_handle.get_costmodel_allreduce_fusion_algorithm()
322
323    def set_costmodel_allreduce_fusion_times(self, allreduce_fusion_times):
324        """
325        Set costmodel allreduce fusion times.
326
327        Args:
328            allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
329
330        Raises:
331            ValueError: If context handle is none.
332        """
333        if self._context_handle is None:
334            raise ValueError("Context handle is none in context!!!")
335        self._context_handle.set_costmodel_allreduce_fusion_times(allreduce_fusion_times)
336
337    def get_costmodel_allreduce_fusion_times(self):
338        """
339        Get costmodel allreduce fusion times.
340
341        Raises:
342            ValueError: If context handle is none.
343        """
344        if self._context_handle is None:
345            raise ValueError("Context handle is none in context!!!")
346        return self._context_handle.get_costmodel_allreduce_fusion_times()
347
348    def set_costmodel_allreduce_fusion_tail_percent(self, tail_percent):
349        """
350        Set costmodel allreduce fusion tail percent.
351
352        Args:
353            tail_percent (int): The percentage of backward computing time corresponding to the last parameter gradients
354                AllReduce in the whole backward computing time.
355
356        Raises:
357            ValueError: If context handle is none.
358        """
359        if self._context_handle is None:
360            raise ValueError("Context handle is none in context!!!")
361        self._context_handle.set_costmodel_allreduce_fusion_tail_percent(tail_percent)
362
363    def get_costmodel_allreduce_fusion_tail_percent(self):
364        """
365        Get costmodel allreduce fusion tail percent.
366
367        Raises:
368            ValueError: If context handle is none.
369        """
370        if self._context_handle is None:
371            raise ValueError("Context handle is none in context!!!")
372        return self._context_handle.get_costmodel_allreduce_fusion_tail_percent()
373
374    def set_costmodel_allreduce_fusion_tail_time(self, tail_time):
375        """
376        Set costmodel allreduce fusion tail time.
377
378        Args:
379            tail_time (int): The tail time of the last parameter gradients AllReduce after the end of backward
380                computation.
381
382        Raises:
383            ValueError: If context handle is none.
384        """
385        if self._context_handle is None:
386            raise ValueError("Context handle is none in context!!!")
387        self._context_handle.set_costmodel_allreduce_fusion_tail_time(tail_time)
388
389    def get_costmodel_allreduce_fusion_tail_time(self):
390        """
391        Get costmodel allreduce fusion tail time.
392
393        Raises:
394            ValueError: If context handle is none.
395        """
396        if self._context_handle is None:
397            raise ValueError("Context handle is none in context!!!")
398        return self._context_handle.get_costmodel_allreduce_fusion_tail_time()
399
400    def set_costmodel_allreduce_fusion_allreduce_inherent_time(self, allreduce_inherent_time):
401        """
402        Set costmodel allreduce fusion allreduce inherent time.
403
404        Args:
405            allreduce_inherent_time (int): The inherent cost time of AllReduce.
406
407        Raises:
408            ValueError: If context handle is none.
409        """
410        if self._context_handle is None:
411            raise ValueError("Context handle is none in context!!!")
412        self._context_handle.set_costmodel_allreduce_fusion_allreduce_inherent_time(allreduce_inherent_time)
413
414    def get_costmodel_allreduce_fusion_allreduce_inherent_time(self):
415        """
416        Get costmodel allreduce fusion allreduce inherent time.
417
418        Raises:
419            ValueError: If context handle is none.
420        """
421        if self._context_handle is None:
422            raise ValueError("Context handle is none in context!!!")
423        return self._context_handle.get_costmodel_allreduce_fusion_allreduce_inherent_time()
424
425    def set_costmodel_allreduce_fusion_allreduce_bandwidth(self, allreduce_bandwidth):
426        """
427        Set costmodel allreduce fusion allreduce bandwidth.
428
429        Args:
430            allreduce_bandwidth (int): The bandwidth of AllReduce.
431
432        Raises:
433            ValueError: If context handle is none.
434        """
435        if self._context_handle is None:
436            raise ValueError("Context handle is none in context!!!")
437        self._context_handle.set_costmodel_allreduce_fusion_allreduce_bandwidth(allreduce_bandwidth)
438
439    def get_costmodel_allreduce_fusion_allreduce_bandwidth(self):
440        """
441        Get costmodel allreduce fusion allreduce bandwidth.
442
443        Raises:
444            ValueError: If context handle is none.
445        """
446        if self._context_handle is None:
447            raise ValueError("Context handle is none in context!!!")
448        return self._context_handle.get_costmodel_allreduce_fusion_allreduce_bandwidth()
449
450    def set_costmodel_allreduce_fusion_computation_time_parameter(self, computation_time_parameter):
451        """
452        Set costmodel allreduce fusion computation time parameter.
453
454        Args:
455            computation_time_parameter (int): The parameter used to compute backward computation time.
456
457        Raises:
458            ValueError: If context handle is none.
459        """
460        if self._context_handle is None:
461            raise ValueError("Context handle is none in context!!!")
462        self._context_handle.set_costmodel_allreduce_fusion_computation_time_parameter(computation_time_parameter)
463
464    def get_costmodel_allreduce_fusion_computation_time_parameter(self):
465        """
466        Get costmodel allreduce fusion computation time parameter.
467
468        Raises:
469            ValueError: If context handle is none.
470        """
471        if self._context_handle is None:
472            raise ValueError("Context handle is none in context!!!")
473        return self._context_handle.get_costmodel_allreduce_fusion_computation_time_parameter()
474
475    def reset_cost_model(self):
476        """
477        Reset cost model settings.
478
479        Raises:
480            ValueError: If context handle is none.
481        """
482        if self._context_handle is None:
483            raise ValueError("Context handle is none in context!!!")
484        self._context_handle.reset_cost_model()
485
486
487_cost_model_context = None
488
489
490def cost_model_context():
491    """
492    Get the global _cost_model_context. If it is not created, create a new one.
493
494    Returns:
495        The global cost_model context.
496    """
497    global _cost_model_context
498    if _cost_model_context is None:
499        _cost_model_context = _CostModelContext()
500    return _cost_model_context
501
502
503set_cost_model_context_func_map = {
504    "device_memory_capacity": cost_model_context().set_device_memory_capacity,
505    "costmodel_alpha": cost_model_context().set_costmodel_alpha,
506    "costmodel_beta": cost_model_context().set_costmodel_beta,
507    "costmodel_gamma": cost_model_context().set_costmodel_gamma,
508    "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
509    "costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
510    "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
511    "run_phase": cost_model_context().set_run_phase,
512    "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
513    "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
514    "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
515    "costmodel_allreduce_fusion_tail_time": cost_model_context().set_costmodel_allreduce_fusion_tail_time,
516    "costmodel_allreduce_fusion_allreduce_inherent_time":
517        cost_model_context().set_costmodel_allreduce_fusion_allreduce_inherent_time,
518    "costmodel_allreduce_fusion_allreduce_bandwidth":
519        cost_model_context().set_costmodel_allreduce_fusion_allreduce_bandwidth,
520    "costmodel_allreduce_fusion_computation_time_parameter":
521        cost_model_context().set_costmodel_allreduce_fusion_computation_time_parameter}
522
523
524get_cost_model_context_func_map = {
525    "device_memory_capacity": cost_model_context().get_device_memory_capacity,
526    "costmodel_alpha": cost_model_context().get_costmodel_alpha,
527    "costmodel_beta": cost_model_context().get_costmodel_beta,
528    "costmodel_gamma": cost_model_context().get_costmodel_gamma,
529    "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
530    "costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
531    "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
532    "run_phase": cost_model_context().get_run_phase,
533    "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
534    "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
535    "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
536    "costmodel_allreduce_fusion_tail_time": cost_model_context().get_costmodel_allreduce_fusion_tail_time,
537    "costmodel_allreduce_fusion_allreduce_inherent_time":
538        cost_model_context().get_costmodel_allreduce_fusion_allreduce_inherent_time,
539    "costmodel_allreduce_fusion_allreduce_bandwidth":
540        cost_model_context().get_costmodel_allreduce_fusion_allreduce_bandwidth,
541    "costmodel_allreduce_fusion_computation_time_parameter":
542        cost_model_context().get_costmodel_allreduce_fusion_computation_time_parameter}
543
544
545@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
546                 costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
547                 multi_subgraphs=bool, run_phase=int,
548                 costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
549                 costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
550                 costmodel_allreduce_fusion_allreduce_inherent_time=float,
551                 costmodel_allreduce_fusion_allreduce_bandwidth=float,
552                 costmodel_allreduce_fusion_computation_time_parameter=float)
553def set_cost_model_context(**kwargs):
554    """
555    Set cost model context.
556
557    Note:
558        Attribute name is needed.
559
560    Args:
561        device_memory_capacity (float): The memory capacity for each device.
562        costmodel_alpha (float): The parameter costmodel_alpha used in strategy-searching algorithm.
563        costmodel_beta (float): The parameter costmodel_beta used in strategy-searching algorithm.
564        costmodel_gamma (float): The parameter costmodel_gamma used in strategy-searching algorithm.
565        costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
566        costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
567        costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
568        run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
569        costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
570            0: bypass allreduce fusion;
571            1: only use backward computation time to group allreduce;
572            2: use backward computation time and parameter gradient allreduce time to group allreduce.
573        costmodel_allreduce_fusion_times (int): The AllReduce fusion times of parameter gradients.
574        costmodel_allreduce_fusion_tail_percent (float): A parameter used in allreduce fusion algorithm. The percentage
575            of backward computing time corresponding to the last parameter gradients AllReduce in the whole backward
576            computing time.
577        costmodel_allreduce_fusion_tail_time (float): A parameter used in allreduce fusion algorithm. The tail time of
578            the last parameter gradients AllReduce after the end of backward computation.
579        costmodel_allreduce_fusion_allreduce_inherent_time (float): A parameter used in allreduce fusion algorithm. The
580            inherent cost time of AllReduce.
581        costmodel_allreduce_fusion_allreduce_bandwidth (float): A parameter used in allreduce fusion algorithm. The
582            bandwidth of AllReduce.
583        costmodel_allreduce_fusion_computation_time_parameter (float): A parameter used in allreduce fusion algorithm.
584            The parameter used to compute backward computation time.
585
586
587
588    Raises:
589        ValueError: If context keyword is not recognized.
590    """
591    for key, value in kwargs.items():
592        if key not in set_cost_model_context_func_map:
593            raise ValueError("Set context keyword %s is not recognized!" % key)
594        set_func = set_cost_model_context_func_map[key]
595        set_func(value)
596
597
598def get_cost_model_context(attr_key):
599    """
600    Get cost model context attributes.
601
602    Note:
603        Return value according to the attribute value.
604
605    Args:
606        attr_key (str): The key of the attribute.
607
608    Raises:
609        ValueError: If context keyword is not recognized.
610    """
611    if attr_key not in get_cost_model_context_func_map:
612        raise ValueError("Get context keyword %s is not recognized!" % attr_key)
613    get_func = get_cost_model_context_func_map[attr_key]
614    return get_func()
615
616
617def reset_cost_model_context():
618    """Reset cost model context attributes."""
619    cost_model_context().reset_cost_model()
620
621
622def _set_multi_subgraphs(multi_subgraph=True):
623    """
624    Set the flag of ANF graph containing multiple subgraphs.
625
626    Args:
627        multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
628    """
629    cost_model_context().set_multi_subgraphs(multi_subgraph)
630
631
632def _get_multi_subgraphs():
633    """
634        Get the flag of ANF graph containing multiple subgraphs.
635    """
636    return cost_model_context().get_multi_subgraphs()
637
638
639def _set_algo_single_loop(single_loop=True):
640    """
641    Set the flag of generating a single suite of OperatorInfos in for-loop.
642
643    Args:
644        single_loop (bool): The parameter for the single loop flag.
645    """
646    cost_model_context().set_dp_algo_single_loop(single_loop)
647
648
649def _get_algo_single_loop():
650    """
651    Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.
652    """
653    return cost_model_context().get_dp_algo_single_loop()
654