1# mypy: allow-untyped-defs 2import warnings 3 4import torch 5import torch.distributed.algorithms.model_averaging.averagers as averagers 6 7 8class PostLocalSGDOptimizer(torch.optim.Optimizer): 9 r""" 10 Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_, 11 This optimizer runs local optimizer at every step. 12 After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. 13 14 Args: 15 optim: The local optimizer. 16 averager: A model averager instance to run post-localSGD algorithm. 17 18 Example:: 19 20 >>> # xdoctest: +SKIP("undefined variables") 21 >>> import torch 22 >>> import torch.distributed as dist 23 >>> import torch.distributed.algorithms.model_averaging.averagers as averagers 24 >>> import torch.nn as nn 25 >>> from torch.distributed.optim import PostLocalSGDOptimizer 26 >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( 27 >>> PostLocalSGDState, 28 >>> post_localSGD_hook, 29 >>> ) 30 >>> 31 >>> model = nn.parallel.DistributedDataParallel( 32 >>> module, device_ids=[rank], output_device=rank 33 >>> ) 34 >>> 35 >>> # Register a post-localSGD communication hook. 36 >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) 37 >>> model.register_comm_hook(state, post_localSGD_hook) 38 >>> 39 >>> # Create a post-localSGD optimizer that wraps a local optimizer. 40 >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as 41 >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. 42 >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) 43 >>> opt = PostLocalSGDOptimizer( 44 >>> optim=local_optim, 45 >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) 46 >>> ) 47 >>> 48 >>> # In the first 100 steps, DDP runs global gradient averaging at every step. 49 >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), 50 >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. 51 >>> for step in range(0, 200): 52 >>> opt.zero_grad() 53 >>> loss = loss_fn(output, labels) 54 >>> loss.backward() 55 >>> opt.step() 56 """ 57 58 def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager): 59 self.optim = optim 60 self.param_groups = self.optim.param_groups 61 self.averager = averager 62 63 @property 64 def state(self): 65 return self.optim.state 66 67 def __repr__(self): 68 return self.optim.__repr__() 69 70 def state_dict(self): 71 r""" 72 This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`, 73 but adds an extra entry to record model averager's step to the checkpoint 74 to ensure reload does not cause unnecessary warm up again. 75 """ 76 optim_state_dict = self.optim.state_dict() 77 optim_state_dict["step"] = self.averager.step 78 return optim_state_dict 79 80 def load_state_dict(self, state_dict): 81 r""" 82 This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`, 83 but also restores model averager's step value to the one 84 saved in the provided ``state_dict``. 85 86 If there is no ``"step"`` entry in ``state_dict``, 87 it will raise a warning and initialize the model averager's step to 0. 88 """ 89 self.optim.load_state_dict(state_dict) 90 if "step" in state_dict: 91 self.averager.step = state_dict["step"] 92 else: 93 warnings.warn( 94 "Loaded state dict does not contain a step counter for an averager. " 95 "Setting step counter to 0." 96 ) 97 self.averager.step = 0 98 99 def step(self): 100 r""" 101 Performs a single optimization step (parameter update). 102 """ 103 self.optim.step() 104 self.averager.average_parameters(params=self.param_groups) 105 106 def zero_grad(self, set_to_none: bool = True): # type: ignore[override] 107 self.optim.zero_grad(set_to_none=set_to_none) 108 109 def add_param_group(self, param_group): 110 self.optim.add_param_group(param_group) 111