• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import warnings
3from abc import ABC, abstractmethod
4from typing import Dict, Iterable, Union
5
6import torch
7import torch.distributed as dist
8import torch.distributed.algorithms.model_averaging.utils as utils
9
10
11__all__ = ["ModelAverager", "PeriodicModelAverager"]
12
13
14class ModelAverager(ABC):
15    r"""Base class for all model averagers.
16
17    Args:
18        process_group: The process group to be used for all-reduce.
19                       If ``None``, the default process group, which
20                       is created by :func:`torch.distributed.init_process_group`,
21                       will be used. (default: ``None``)
22    """
23
24    def __init__(self, process_group=None):
25        self.process_group = (
26            process_group if process_group is not None else dist.group.WORLD
27        )
28        self.step = 0
29
30    @abstractmethod
31    def average_parameters(self, params):
32        raise NotImplementedError
33
34
35class PeriodicModelAverager(ModelAverager):
36    r"""
37    Averages parameters periodically after the warm-up stage.
38
39    This can be used for running `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
40    by running :class:`~torch.nn.DistributedDataParallel` (DDP)
41    using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
42
43    Args:
44        period (int): The number of steps per model averaging.
45                      Usually the period should be greater than ``1`` to reduce the communication cost.
46                      Otherwise, only DDP needs to be used.
47        warmup_steps (int): The number of warm-up steps. During this stage,
48                            model averaging is skipped.
49        process_group: The process group to be used for all-reduce.
50                       If ``None``, the default process group, which
51                       is created by :func:`torch.distributed.init_process_group`,
52                       will be used. (default: ``None``)
53
54    Example::
55
56        >>> # xdoctest: +SKIP("undefined variables")
57        >>> import torch
58        >>> import torch.distributed as dist
59        >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
60        >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
61        >>> import torch.nn as nn
62        >>>
63        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
64        >>> torch.cuda.set_device(rank)
65        >>> module = nn.Linear(1, 1, bias=False).cuda()
66        >>> model = nn.parallel.DistributedDataParallel(
67        >>>    module, device_ids=[rank], output_device=rank
68        >>> )
69        >>> # Register a post-localSGD communication hook.
70        >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
71        >>> model.register_comm_hook(state, post_localSGD_hook)
72        >>>
73        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
74        >>> # After 100 steps, run model averaging every 4 steps.
75        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
76        >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
77        >>> for step in range(0, 200):
78        >>>    optimizer.zero_grad()
79        >>>    loss = loss_fn(output, labels)
80        >>>    loss.backward()
81        >>>    optimizer.step()
82        >>>    # Will average model parameters globally every 4 steps. Thus,
83        >>>    # inter-node communication only occurs every 4 iterations after
84        >>>    # the initial ``warmup_steps`` period.
85        >>>    averager.average_parameters(model.parameters())
86    """
87
88    def __init__(self, period, warmup_steps=0, process_group=None):
89        super().__init__(process_group)
90        if warmup_steps < 0:
91            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
92        self.warmup_steps = warmup_steps
93        if period < 1:
94            raise ValueError("Arg ``period`` must be a positive value.")
95        elif period == 1:
96            warnings.warn(
97                "When period is 1, no need to use model averaging because the communication cost "
98                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
99                "by DistributedDataParallel in the backward pass. Therefore, only "
100                "DistributedDataParallel should be used for this case."
101            )
102        self.period = period
103
104    def average_parameters(
105        self,
106        params: Union[
107            Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
108        ],
109    ):
110        """
111        Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``.
112
113        Can be divided by ``period``, where ``step`` is increased by 1
114        at each iteration in the training loop.
115        Args:
116            params: The parameters of a model or parameter groups of an optimizer.
117
118        """
119        if (
120            self.step >= self.warmup_steps
121            and (self.step - self.warmup_steps) % self.period == 0
122        ):
123            utils.average_parameters_or_parameter_groups(params, self.process_group)
124        self.step += 1
125