1# mypy: allow-untyped-defs 2import warnings 3from abc import ABC, abstractmethod 4from typing import Dict, Iterable, Union 5 6import torch 7import torch.distributed as dist 8import torch.distributed.algorithms.model_averaging.utils as utils 9 10 11__all__ = ["ModelAverager", "PeriodicModelAverager"] 12 13 14class ModelAverager(ABC): 15 r"""Base class for all model averagers. 16 17 Args: 18 process_group: The process group to be used for all-reduce. 19 If ``None``, the default process group, which 20 is created by :func:`torch.distributed.init_process_group`, 21 will be used. (default: ``None``) 22 """ 23 24 def __init__(self, process_group=None): 25 self.process_group = ( 26 process_group if process_group is not None else dist.group.WORLD 27 ) 28 self.step = 0 29 30 @abstractmethod 31 def average_parameters(self, params): 32 raise NotImplementedError 33 34 35class PeriodicModelAverager(ModelAverager): 36 r""" 37 Averages parameters periodically after the warm-up stage. 38 39 This can be used for running `post-local SGD <https://arxiv.org/abs/1808.07217>`_, 40 by running :class:`~torch.nn.DistributedDataParallel` (DDP) 41 using the subgroups created by :meth:`~torch.distributed.new_subgroups`. 42 43 Args: 44 period (int): The number of steps per model averaging. 45 Usually the period should be greater than ``1`` to reduce the communication cost. 46 Otherwise, only DDP needs to be used. 47 warmup_steps (int): The number of warm-up steps. During this stage, 48 model averaging is skipped. 49 process_group: The process group to be used for all-reduce. 50 If ``None``, the default process group, which 51 is created by :func:`torch.distributed.init_process_group`, 52 will be used. (default: ``None``) 53 54 Example:: 55 56 >>> # xdoctest: +SKIP("undefined variables") 57 >>> import torch 58 >>> import torch.distributed as dist 59 >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD 60 >>> import torch.distributed.algorithms.model_averaging.averagers as averagers 61 >>> import torch.nn as nn 62 >>> 63 >>> dist.init_process_group("nccl", rank=rank, world_size=16) 64 >>> torch.cuda.set_device(rank) 65 >>> module = nn.Linear(1, 1, bias=False).cuda() 66 >>> model = nn.parallel.DistributedDataParallel( 67 >>> module, device_ids=[rank], output_device=rank 68 >>> ) 69 >>> # Register a post-localSGD communication hook. 70 >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) 71 >>> model.register_comm_hook(state, post_localSGD_hook) 72 >>> 73 >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step. 74 >>> # After 100 steps, run model averaging every 4 steps. 75 >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``. 76 >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100) 77 >>> for step in range(0, 200): 78 >>> optimizer.zero_grad() 79 >>> loss = loss_fn(output, labels) 80 >>> loss.backward() 81 >>> optimizer.step() 82 >>> # Will average model parameters globally every 4 steps. Thus, 83 >>> # inter-node communication only occurs every 4 iterations after 84 >>> # the initial ``warmup_steps`` period. 85 >>> averager.average_parameters(model.parameters()) 86 """ 87 88 def __init__(self, period, warmup_steps=0, process_group=None): 89 super().__init__(process_group) 90 if warmup_steps < 0: 91 raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") 92 self.warmup_steps = warmup_steps 93 if period < 1: 94 raise ValueError("Arg ``period`` must be a positive value.") 95 elif period == 1: 96 warnings.warn( 97 "When period is 1, no need to use model averaging because the communication cost " 98 "of all-reducing parameters will be no less than the cost of all-reducing gradients " 99 "by DistributedDataParallel in the backward pass. Therefore, only " 100 "DistributedDataParallel should be used for this case." 101 ) 102 self.period = period 103 104 def average_parameters( 105 self, 106 params: Union[ 107 Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] 108 ], 109 ): 110 """ 111 Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. 112 113 Can be divided by ``period``, where ``step`` is increased by 1 114 at each iteration in the training loop. 115 Args: 116 params: The parameters of a model or parameter groups of an optimizer. 117 118 """ 119 if ( 120 self.step >= self.warmup_steps 121 and (self.step - self.warmup_steps) % self.period == 0 122 ): 123 utils.average_parameters_or_parameter_groups(params, self.process_group) 124 self.step += 1 125