• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from typing import Any, Dict, List, Mapping, Union
3
4import torch.optim as optim
5from torch import Tensor
6from torch.distributed._shard.sharded_tensor import ShardedTensor
7
8
9class ShardedOptimizer(optim.Optimizer):
10    def __init__(
11        self,
12        named_params: Mapping[str, Union[Tensor, ShardedTensor]],
13        optimizer_class,
14        *optimizer_args,
15        **optimizer_kwargs,
16    ):
17        """
18        ShardedOptimizer collects all tensors and local shard tensors of
19        ShardedTensor, then use these tensors as ``params`` for optimizers
20
21        Args:
22            named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
23                of parameters, where key is the parameter key, value is either
24                Tensor or ShardedTensor parameter.
25            optimizer_class (torch.optim.Optimizer): the Optimizer to use
26                locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
27            *optimizer_args: the arguments to initialize the optimizer.
28            **optimizer_kwargs: the key-word arguments to initialize the optimizer.
29
30        """
31        tensors: List[Tensor] = []
32        for value in named_params.values():
33            if isinstance(value, ShardedTensor):
34                for local_shard in value.local_shards():
35                    tensors.append(local_shard.tensor)
36            else:
37                tensors.append(value)
38
39        self.named_params = named_params
40        self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
41        self.param_groups = self._optim.param_groups
42        self.state = self._optim.state
43
44    def zero_grad(self, set_to_none: bool = True):  # type: ignore[override]
45        r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
46
47        Args:
48            set_to_none (bool): instead of setting to zero, set the grads to None.
49                This will in general have lower memory footprint, and can modestly improve performance.
50                However, it changes certain behaviors. For example:
51                1. When the user tries to access a gradient and perform manual ops on it,
52                a None attribute or a Tensor full of 0s will behave differently.
53                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
54                are guaranteed to be None for params that did not receive a gradient.
55                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
56                (in one case it does the step with a gradient of 0 and in the other it skips
57                the step altogether).
58        """
59        self._optim.zero_grad(set_to_none)
60
61    def step(self, closure=None):
62        r"""Performs a single optimization step (parameter update).
63
64        Args:
65            closure (Callable): A closure that reevaluates the model and
66                returns the loss. Optional for most optimizers.
67
68        .. note::
69            Unless otherwise specified, this function should not modify the
70            ``.grad`` field of the parameters.
71        """
72        self._optim.step(closure)
73
74    def state_dict(self) -> Dict[str, Any]:
75        """
76        Returned state and param_groups will contain parameter keys
77        instead of parameter indices like torch.optim.Optimizer.
78        This allows for advanced functionality like optimizer re-sharding to be implemented.
79        """
80        # TODO: implement state_dict
81        raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!")
82
83    def load_state_dict(self, state_dict: Mapping[str, Any]):
84        r"""Loads the ShardedOptimizer state.
85
86        Args:
87            state_dict (dict): ShardedOptimizer state. Should be an object returned
88                from a call to :meth:`state_dict`.
89        """
90        # TODO: implement load_state_dict
91        raise NotImplementedError(
92            "ShardedOptimizer load_state_dict not implemented yet!"
93        )
94
95    def add_param_group(self, param_group: Any):
96        r"""Add a new param group"""
97        # TODO: implement add_param_group
98        raise NotImplementedError(
99            "ShardedOptimizer add_param_group not implemented yet!"
100        )
101