• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import logging
3
4import torch
5import torch.distributed as dist
6
7from . import default_hooks as default
8
9
10logger = logging.getLogger(__name__)
11
12
13class PostLocalSGDState:
14    r"""
15    Store state for all-reducing gradients globally until given step, then locally after.
16
17    Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
18    and all-reducing gradients locally using ``subgroup`` afterwards.
19
20    If ``process_group`` is ``None``, the global process group will be used.
21    If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
22
23    Additionally, ``post_local_gradient_allreduce`` may be worth tuning,
24    because both true and false may give a faster convergence.
25    """
26
27    __slots__ = [
28        "process_group",
29        "subgroup",
30        "start_localSGD_iter",
31        "post_local_gradient_allreduce",
32        "iter",
33    ]
34
35    def __init__(
36        self,
37        process_group,
38        subgroup,
39        start_localSGD_iter,
40        post_local_gradient_allreduce=True,
41    ):
42        """Initialize state object with given parameters and log when localSGD start."""
43        logger.info(
44            "Local SGD will be started after %s iterations", start_localSGD_iter
45        )
46
47        # The group used for all-reducing gradients globally.
48        self.process_group = process_group
49        # The group used for all-reducing gradients locally.
50        self.subgroup = subgroup
51        self.start_localSGD_iter = start_localSGD_iter
52        # Allreduce gradients locally since iteration `start_localSGD_iter`.
53        # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication.
54        self.post_local_gradient_allreduce = post_local_gradient_allreduce
55        # Iteration/step in the training loop.
56        self.iter = 0
57
58    def maybe_increase_iter(self, bucket):
59        """Track iterations and trigger log message at start of local SGD."""
60        # Since bucket 0 is the last bucket to allreduce in an iteration.
61        # Only increase `iter` when bucket 0 is processed.
62        if bucket.is_last():
63            self.iter += 1
64
65        if self.iter == self.start_localSGD_iter:
66            logger.info("Start to apply local SGD after %s iterations.", self.iter)
67
68
69def post_localSGD_hook(
70    state: PostLocalSGDState, bucket: dist.GradBucket
71) -> torch.futures.Future[torch.Tensor]:
72    """
73    Run post-localSGD algorithm.
74
75    This DDP communication hook is used for running post-localSGD algorithm,
76    by combining with a model averaging component (e.g.,
77    :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`)
78    that runs after the optimizer step.
79
80    Args:
81        state (PostLocalSGDState): State information to run post-localSGD.
82            Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD.
83        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
84            Note that since DDP comm hook only supports single process single device mode,
85            only exactly one tensor is stored in this bucket.
86
87    Returns:
88        Future handler of the communication, which updates the gradients in place.
89
90    Example::
91        >>> # xdoctest: +SKIP
92        >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup,
93                                  start_localSGD_iter=10)
94        >>> ddp_model.register_comm_hook(state, post_localSGD_hook)
95        >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``.
96        >>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module.
97    """
98    global_group_to_use = (
99        state.process_group if state.process_group is not None else dist.group.WORLD
100    )
101
102    # The input tensor is a flattened 1D tensor.
103    input_tensor = bucket.buffer()
104
105    # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
106    if state.iter < state.start_localSGD_iter:
107        state.maybe_increase_iter(bucket)
108        return default._allreduce_fut(global_group_to_use, input_tensor)
109
110    # If `post_local_gradient_allreduce` is not set,
111    # then no gradient synchronization after the first `start_localSGD_iter` iterations.
112    if not state.post_local_gradient_allreduce:
113        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
114        fut.set_result(input_tensor)
115        return fut
116
117    # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
118    # Note that by default, a separate subgroup for each node is created which
119    # causes an intra-node allreduce to be done at each training step.
120    # From this moment, model averaging should run after the optimizer step,
121    # to globally allreduce all the parameters.
122    if state.subgroup is None:
123        state.subgroup, _ = dist.new_subgroups()
124    return default._allreduce_fut(state.subgroup, input_tensor)
125