1# mypy: allow-untyped-defs 2import enum 3from typing import Any, Callable, overload 4 5import torch 6from torch.distributed.algorithms.join import Joinable, JoinHook 7from torch.optim import Optimizer 8 9class _ZeROJoinHook(JoinHook): 10 zero: Any = ... 11 def __init__(self, zero: Any) -> None: ... 12 def main_hook(self) -> None: ... 13 14class _DDPBucketAssignment: 15 bucket_index: int 16 parameters: list[torch.Tensor] 17 offset: int 18 device: torch.device 19 tensor: torch.Tensor | None 20 21class _OverlapStatus(enum.IntEnum): 22 UNINITIALIZED: int = ... 23 DDP_HAS_REBUILT_BUCKETS: int = ... 24 INITIALIZED: int = ... 25 26class _OverlapInfo: 27 status: Any = ... 28 params_per_bucket: Any = ... 29 params_per_rank: Any = ... 30 offsets: Any = ... 31 broadcast_handles: Any = ... 32 bucket_index_to_future: Any = ... 33 bucket_index_to_bucket: Any = ... 34 bucket_indices_seen: Any = ... 35 assigned_ranks_per_bucket: list[set[int]] = ... 36 total_size: int = ... 37 shard_buckets: bool = ... 38 def __init__(self) -> None: ... 39 def wait_for_broadcasts(self) -> None: ... 40 def clear_per_iter_info(self) -> None: ... 41 42class ZeroRedundancyOptimizer(Optimizer, Joinable): 43 functional_optim_map: Any = ... 44 initialized: bool = ... 45 process_group: Any = ... 46 world_size: int = ... 47 rank: int = ... 48 global_rank: int = ... 49 parameters_as_bucket_view: bool = ... 50 optim: Any = ... 51 _device_to_device_index: dict[torch.device, int] = ... 52 _overlap_with_ddp: bool = ... 53 _overlap_info: _OverlapInfo = ... 54 _buckets: list[list[torch.Tensor]] = ... 55 _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ... 56 def __init__( 57 self, 58 params: Any, 59 optimizer_class: type[Optimizer], 60 process_group: Any | None = ..., 61 parameters_as_bucket_view: bool = ..., 62 overlap_with_ddp: bool = ..., 63 **defaults: Any, 64 ) -> None: ... 65 def add_param_group(self, param_group: dict[str, Any]) -> None: ... 66 def consolidate_state_dict(self, to: int = ...) -> None: ... 67 @overload 68 def step(self, closure: None = ..., **kwargs: Any) -> None: ... 69 @overload 70 def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ... 71 def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... 72 def state_dict(self) -> dict[str, Any]: ... 73 def _local_step( 74 self, 75 gradients: list[torch.Tensor | None] | None = None, 76 closure: Callable[[], float] | None = None, 77 **kwargs: Any, 78 ) -> float | None: ... 79 def _get_assigned_rank(self, bucket_index: int) -> int: ... 80 def _init_zero_for_overlap(self) -> None: ... 81 def join_hook(self, **kwargs): ... 82 @property 83 def join_device(self) -> torch.device: ... 84 def join_process_group(self) -> Any: ... 85