• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import enum
3from typing import Any, Callable, overload
4
5import torch
6from torch.distributed.algorithms.join import Joinable, JoinHook
7from torch.optim import Optimizer
8
9class _ZeROJoinHook(JoinHook):
10    zero: Any = ...
11    def __init__(self, zero: Any) -> None: ...
12    def main_hook(self) -> None: ...
13
14class _DDPBucketAssignment:
15    bucket_index: int
16    parameters: list[torch.Tensor]
17    offset: int
18    device: torch.device
19    tensor: torch.Tensor | None
20
21class _OverlapStatus(enum.IntEnum):
22    UNINITIALIZED: int = ...
23    DDP_HAS_REBUILT_BUCKETS: int = ...
24    INITIALIZED: int = ...
25
26class _OverlapInfo:
27    status: Any = ...
28    params_per_bucket: Any = ...
29    params_per_rank: Any = ...
30    offsets: Any = ...
31    broadcast_handles: Any = ...
32    bucket_index_to_future: Any = ...
33    bucket_index_to_bucket: Any = ...
34    bucket_indices_seen: Any = ...
35    assigned_ranks_per_bucket: list[set[int]] = ...
36    total_size: int = ...
37    shard_buckets: bool = ...
38    def __init__(self) -> None: ...
39    def wait_for_broadcasts(self) -> None: ...
40    def clear_per_iter_info(self) -> None: ...
41
42class ZeroRedundancyOptimizer(Optimizer, Joinable):
43    functional_optim_map: Any = ...
44    initialized: bool = ...
45    process_group: Any = ...
46    world_size: int = ...
47    rank: int = ...
48    global_rank: int = ...
49    parameters_as_bucket_view: bool = ...
50    optim: Any = ...
51    _device_to_device_index: dict[torch.device, int] = ...
52    _overlap_with_ddp: bool = ...
53    _overlap_info: _OverlapInfo = ...
54    _buckets: list[list[torch.Tensor]] = ...
55    _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ...
56    def __init__(
57        self,
58        params: Any,
59        optimizer_class: type[Optimizer],
60        process_group: Any | None = ...,
61        parameters_as_bucket_view: bool = ...,
62        overlap_with_ddp: bool = ...,
63        **defaults: Any,
64    ) -> None: ...
65    def add_param_group(self, param_group: dict[str, Any]) -> None: ...
66    def consolidate_state_dict(self, to: int = ...) -> None: ...
67    @overload
68    def step(self, closure: None = ..., **kwargs: Any) -> None: ...
69    @overload
70    def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
71    def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
72    def state_dict(self) -> dict[str, Any]: ...
73    def _local_step(
74        self,
75        gradients: list[torch.Tensor | None] | None = None,
76        closure: Callable[[], float] | None = None,
77        **kwargs: Any,
78    ) -> float | None: ...
79    def _get_assigned_rank(self, bucket_index: int) -> int: ...
80    def _init_zero_for_overlap(self) -> None: ...
81    def join_hook(self, **kwargs): ...
82    @property
83    def join_device(self) -> torch.device: ...
84    def join_process_group(self) -> Any: ...
85