1# mypy: allow-untyped-defs 2from typing import Any, Dict, List, Mapping, Union 3 4import torch.optim as optim 5from torch import Tensor 6from torch.distributed._shard.sharded_tensor import ShardedTensor 7 8 9class ShardedOptimizer(optim.Optimizer): 10 def __init__( 11 self, 12 named_params: Mapping[str, Union[Tensor, ShardedTensor]], 13 optimizer_class, 14 *optimizer_args, 15 **optimizer_kwargs, 16 ): 17 """ 18 ShardedOptimizer collects all tensors and local shard tensors of 19 ShardedTensor, then use these tensors as ``params`` for optimizers 20 21 Args: 22 named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict 23 of parameters, where key is the parameter key, value is either 24 Tensor or ShardedTensor parameter. 25 optimizer_class (torch.optim.Optimizer): the Optimizer to use 26 locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc. 27 *optimizer_args: the arguments to initialize the optimizer. 28 **optimizer_kwargs: the key-word arguments to initialize the optimizer. 29 30 """ 31 tensors: List[Tensor] = [] 32 for value in named_params.values(): 33 if isinstance(value, ShardedTensor): 34 for local_shard in value.local_shards(): 35 tensors.append(local_shard.tensor) 36 else: 37 tensors.append(value) 38 39 self.named_params = named_params 40 self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) 41 self.param_groups = self._optim.param_groups 42 self.state = self._optim.state 43 44 def zero_grad(self, set_to_none: bool = True): # type: ignore[override] 45 r"""Resets the gradients of all optimized :class:`torch.Tensor` s. 46 47 Args: 48 set_to_none (bool): instead of setting to zero, set the grads to None. 49 This will in general have lower memory footprint, and can modestly improve performance. 50 However, it changes certain behaviors. For example: 51 1. When the user tries to access a gradient and perform manual ops on it, 52 a None attribute or a Tensor full of 0s will behave differently. 53 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s 54 are guaranteed to be None for params that did not receive a gradient. 55 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None 56 (in one case it does the step with a gradient of 0 and in the other it skips 57 the step altogether). 58 """ 59 self._optim.zero_grad(set_to_none) 60 61 def step(self, closure=None): 62 r"""Performs a single optimization step (parameter update). 63 64 Args: 65 closure (Callable): A closure that reevaluates the model and 66 returns the loss. Optional for most optimizers. 67 68 .. note:: 69 Unless otherwise specified, this function should not modify the 70 ``.grad`` field of the parameters. 71 """ 72 self._optim.step(closure) 73 74 def state_dict(self) -> Dict[str, Any]: 75 """ 76 Returned state and param_groups will contain parameter keys 77 instead of parameter indices like torch.optim.Optimizer. 78 This allows for advanced functionality like optimizer re-sharding to be implemented. 79 """ 80 # TODO: implement state_dict 81 raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") 82 83 def load_state_dict(self, state_dict: Mapping[str, Any]): 84 r"""Loads the ShardedOptimizer state. 85 86 Args: 87 state_dict (dict): ShardedOptimizer state. Should be an object returned 88 from a call to :meth:`state_dict`. 89 """ 90 # TODO: implement load_state_dict 91 raise NotImplementedError( 92 "ShardedOptimizer load_state_dict not implemented yet!" 93 ) 94 95 def add_param_group(self, param_group: Any): 96 r"""Add a new param group""" 97 # TODO: implement add_param_group 98 raise NotImplementedError( 99 "ShardedOptimizer add_param_group not implemented yet!" 100 ) 101