1# mypy: allow-untyped-defs 2import logging 3 4import torch 5import torch.distributed as dist 6 7from . import default_hooks as default 8 9 10logger = logging.getLogger(__name__) 11 12 13class PostLocalSGDState: 14 r""" 15 Store state for all-reducing gradients globally until given step, then locally after. 16 17 Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``, 18 and all-reducing gradients locally using ``subgroup`` afterwards. 19 20 If ``process_group`` is ``None``, the global process group will be used. 21 If ``subgroup`` is ``None``, the intra-node process group on each machine will be used. 22 23 Additionally, ``post_local_gradient_allreduce`` may be worth tuning, 24 because both true and false may give a faster convergence. 25 """ 26 27 __slots__ = [ 28 "process_group", 29 "subgroup", 30 "start_localSGD_iter", 31 "post_local_gradient_allreduce", 32 "iter", 33 ] 34 35 def __init__( 36 self, 37 process_group, 38 subgroup, 39 start_localSGD_iter, 40 post_local_gradient_allreduce=True, 41 ): 42 """Initialize state object with given parameters and log when localSGD start.""" 43 logger.info( 44 "Local SGD will be started after %s iterations", start_localSGD_iter 45 ) 46 47 # The group used for all-reducing gradients globally. 48 self.process_group = process_group 49 # The group used for all-reducing gradients locally. 50 self.subgroup = subgroup 51 self.start_localSGD_iter = start_localSGD_iter 52 # Allreduce gradients locally since iteration `start_localSGD_iter`. 53 # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication. 54 self.post_local_gradient_allreduce = post_local_gradient_allreduce 55 # Iteration/step in the training loop. 56 self.iter = 0 57 58 def maybe_increase_iter(self, bucket): 59 """Track iterations and trigger log message at start of local SGD.""" 60 # Since bucket 0 is the last bucket to allreduce in an iteration. 61 # Only increase `iter` when bucket 0 is processed. 62 if bucket.is_last(): 63 self.iter += 1 64 65 if self.iter == self.start_localSGD_iter: 66 logger.info("Start to apply local SGD after %s iterations.", self.iter) 67 68 69def post_localSGD_hook( 70 state: PostLocalSGDState, bucket: dist.GradBucket 71) -> torch.futures.Future[torch.Tensor]: 72 """ 73 Run post-localSGD algorithm. 74 75 This DDP communication hook is used for running post-localSGD algorithm, 76 by combining with a model averaging component (e.g., 77 :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`) 78 that runs after the optimizer step. 79 80 Args: 81 state (PostLocalSGDState): State information to run post-localSGD. 82 Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD. 83 bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. 84 Note that since DDP comm hook only supports single process single device mode, 85 only exactly one tensor is stored in this bucket. 86 87 Returns: 88 Future handler of the communication, which updates the gradients in place. 89 90 Example:: 91 >>> # xdoctest: +SKIP 92 >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup, 93 start_localSGD_iter=10) 94 >>> ddp_model.register_comm_hook(state, post_localSGD_hook) 95 >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``. 96 >>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module. 97 """ 98 global_group_to_use = ( 99 state.process_group if state.process_group is not None else dist.group.WORLD 100 ) 101 102 # The input tensor is a flattened 1D tensor. 103 input_tensor = bucket.buffer() 104 105 # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations. 106 if state.iter < state.start_localSGD_iter: 107 state.maybe_increase_iter(bucket) 108 return default._allreduce_fut(global_group_to_use, input_tensor) 109 110 # If `post_local_gradient_allreduce` is not set, 111 # then no gradient synchronization after the first `start_localSGD_iter` iterations. 112 if not state.post_local_gradient_allreduce: 113 fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() 114 fut.set_result(input_tensor) 115 return fut 116 117 # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations. 118 # Note that by default, a separate subgroup for each node is created which 119 # causes an intra-node allreduce to be done at each training step. 120 # From this moment, model averaging should run after the optimizer step, 121 # to globally allreduce all the parameters. 122 if state.subgroup is None: 123 state.subgroup, _ = dist.new_subgroups() 124 return default._allreduce_fut(state.subgroup, input_tensor) 125