• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import warnings
3
4import torch
5import torch.distributed.algorithms.model_averaging.averagers as averagers
6
7
8class PostLocalSGDOptimizer(torch.optim.Optimizer):
9    r"""
10    Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
11    This optimizer runs local optimizer at every step.
12    After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
13
14    Args:
15        optim: The local optimizer.
16        averager: A model averager instance to run post-localSGD algorithm.
17
18    Example::
19
20        >>> # xdoctest: +SKIP("undefined variables")
21        >>> import torch
22        >>> import torch.distributed as dist
23        >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
24        >>> import torch.nn as nn
25        >>> from torch.distributed.optim import PostLocalSGDOptimizer
26        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
27        >>>   PostLocalSGDState,
28        >>>   post_localSGD_hook,
29        >>> )
30        >>>
31        >>> model = nn.parallel.DistributedDataParallel(
32        >>>    module, device_ids=[rank], output_device=rank
33        >>> )
34        >>>
35        >>> # Register a post-localSGD communication hook.
36        >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
37        >>> model.register_comm_hook(state, post_localSGD_hook)
38        >>>
39        >>> # Create a post-localSGD optimizer that wraps a local optimizer.
40        >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
41        >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
42        >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
43        >>> opt = PostLocalSGDOptimizer(
44        >>>     optim=local_optim,
45        >>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
46        >>> )
47        >>>
48        >>> # In the first 100 steps, DDP runs global gradient averaging at every step.
49        >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
50        >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
51        >>> for step in range(0, 200):
52        >>>    opt.zero_grad()
53        >>>    loss = loss_fn(output, labels)
54        >>>    loss.backward()
55        >>>    opt.step()
56    """
57
58    def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
59        self.optim = optim
60        self.param_groups = self.optim.param_groups
61        self.averager = averager
62
63    @property
64    def state(self):
65        return self.optim.state
66
67    def __repr__(self):
68        return self.optim.__repr__()
69
70    def state_dict(self):
71        r"""
72        This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,
73        but adds an extra entry to record model averager's step to the checkpoint
74        to ensure reload does not cause unnecessary warm up again.
75        """
76        optim_state_dict = self.optim.state_dict()
77        optim_state_dict["step"] = self.averager.step
78        return optim_state_dict
79
80    def load_state_dict(self, state_dict):
81        r"""
82        This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,
83        but also restores model averager's step value to the one
84        saved in the provided ``state_dict``.
85
86        If there is no ``"step"`` entry in ``state_dict``,
87        it will raise a warning and initialize the model averager's step to 0.
88        """
89        self.optim.load_state_dict(state_dict)
90        if "step" in state_dict:
91            self.averager.step = state_dict["step"]
92        else:
93            warnings.warn(
94                "Loaded state dict does not contain a step counter for an averager. "
95                "Setting step counter to 0."
96            )
97            self.averager.step = 0
98
99    def step(self):
100        r"""
101        Performs a single optimization step (parameter update).
102        """
103        self.optim.step()
104        self.averager.average_parameters(params=self.param_groups)
105
106    def zero_grad(self, set_to_none: bool = True):  # type: ignore[override]
107        self.optim.zero_grad(set_to_none=set_to_none)
108
109    def add_param_group(self, param_group):
110        self.optim.add_param_group(param_group)
111