• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import argparse
3import os
4from enum import Enum
5from typing import cast, Dict, List, Optional, Union
6
7import torch
8import torch.distributed as dist
9from torch.distributed._shard._utils import narrow_tensor_by_index
10from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
11from torch.distributed.checkpoint._nested_dict import flatten_state_dict
12from torch.distributed.checkpoint.default_planner import (
13    _EmptyStateDictLoadPlanner,
14    DefaultLoadPlanner,
15)
16from torch.distributed.checkpoint.metadata import (
17    Metadata,
18    STATE_DICT_TYPE,
19    STORAGE_TYPES,
20    TensorProperties,
21    TensorStorageMetadata,
22)
23from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
24from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
25from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
26from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
27from torch.distributed.checkpoint.storage import StorageReader
28from torch.futures import Future
29
30
31__all__ = [
32    "dcp_to_torch_save",
33    "torch_save_to_dcp",
34    "BroadcastingTorchSaveReader",
35    "DynamicMetaLoadPlanner",
36]
37
38
39class BroadcastingTorchSaveReader(StorageReader):
40    """
41    StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
42    on the coordinator rank, and then broadcast and shard each tensor to all ranks.
43
44    . N.B. Intended to be used with DynamicMetaLoadPlanner
45
46    .. warning::
47        Current implementation only supports loading Tensors.
48
49    >>> # xdoctest: +SKIP("undefined vars")
50    >>> sd = {"mode": model}
51    >>> dcp.load(
52    >>>    sd,
53    >>>    storage_reader=BroadcastingTorchSaveReader(),
54    >>>    planner=DynamicMetaLoadPlanner(),
55    >>>    checkpoint_id="path_to_model.pt"
56    >>> )
57    """
58
59    def __init__(
60        self,
61        checkpoint_id: Optional[Union[str, os.PathLike]] = None,
62        coordinator_rank: int = 0,
63    ) -> None:
64        self.checkpoint_id = checkpoint_id
65        self.coordinator_rank = coordinator_rank
66
67    def read_metadata(self) -> Metadata:
68        """Extends the default StorageReader to support building the metadata file"""
69        # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
70        # the disk
71        return Metadata(state_dict_metadata={})
72
73    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
74        """
75        Reads torch save data on the coordinator rank, and broadcast afterwards
76        this incurrs a communication cost, but avoids having to load
77        the entire checkpoint on each rank, hopefully preventing OOM issues
78        """
79        planner = cast(DefaultLoadPlanner, planner)
80
81        # data is read in on the coordinator rank, and broadcast afterwards
82        # this incurrs a communication cost, but it avoids having to load
83        # the entire checkpoint on each rank, hopefully preventing OOM issues
84        # TODO: read on each host, instead of only the coordinator
85        if self.is_coordinator:
86            assert self.checkpoint_id is not None
87            torch_state_dict = torch.load(
88                self.checkpoint_id, map_location="cpu", weights_only=False
89            )
90            if planner.flatten_state_dict:
91                torch_state_dict, _ = flatten_state_dict(torch_state_dict)
92        else:
93            torch_state_dict = None
94
95        for req in plan.items:
96            if req.type == LoadItemType.BYTE_IO:
97                raise RuntimeError(
98                    f"Non-tensor value identified at {req.storage_index.fqn}. "
99                    f"At this time {type(self).__name__} only supports loading Tensors."
100                )
101
102            #  Broadcast the tensor from the coordinator rank
103            if self.is_coordinator:
104                pg_device = dist.distributed_c10d._get_pg_default_device()
105                tensor = torch_state_dict[req.storage_index.fqn].to(pg_device)
106            else:
107                tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
108
109            dist.broadcast(tensor, src=self.coordinator_rank, async_op=False)
110
111            tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
112            target_tensor = planner.resolve_tensor(req).detach()
113            assert target_tensor.size() == tensor.size(), (
114                f"req {req.storage_index} mismatch sizes, "
115                f"{target_tensor.size()} vs {tensor.size()}"
116            )
117            target_tensor.copy_(tensor)
118            planner.commit_tensor(req, target_tensor)
119
120        fut: Future = Future()
121        fut.set_result(None)
122        return fut
123
124    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
125        """Implementation of the StorageReader method"""
126        self.is_coordinator = is_coordinator
127        if self.is_coordinator:
128            assert dist.get_rank() == self.coordinator_rank
129
130        assert self.checkpoint_id is not None
131
132    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
133        """Implementation of the StorageReader method"""
134        return plan
135
136    def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
137        """Implementation of the StorageReader method"""
138        return global_plan
139
140    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
141        """Implementation of the StorageReader method"""
142        self.checkpoint_id = checkpoint_id
143
144    @classmethod
145    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
146        """Implementation of the StorageReader method"""
147        return os.path.isfile(checkpoint_id)
148
149
150class DynamicMetaLoadPlanner(DefaultLoadPlanner):
151    """
152    Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,
153    avoiding the need to read metadata from disk. This is useful when reading formats which don't have a
154    metadata file, like Torch Save files.
155
156    . N.B. Intended to be used with BroadcastingTorchSaveReader
157
158    .. warning::
159        Current implementation only supports loading Tensors.
160
161    >>> # xdoctest: +SKIP("undefined vars")
162    >>> sd = {"mode": model}
163    >>> dcp.load(
164    >>>    sd,
165    >>>    storage_reader=BroadcastingTorchSaveReader(),
166    >>>    planner=DynamicMetaLoadPlanner(),
167    >>>    checkpoint_id="path_to_model.pt"
168    >>> )
169    """
170
171    def set_up_planner(
172        self,
173        state_dict: STATE_DICT_TYPE,
174        metadata: Optional[Metadata] = None,
175        is_coordinator: bool = False,
176    ) -> None:
177        """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
178        super().set_up_planner(state_dict, metadata, is_coordinator)
179
180        state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
181        for key, tensor in self.state_dict.items():
182            if not torch.is_tensor(tensor):
183                raise RuntimeError(
184                    f"Non-tensor value identified at {key}. "
185                    f"At this time {type(self).__name__} only supports loading Tensors."
186                )
187
188            state_dict_metadata[key] = TensorStorageMetadata(
189                TensorProperties(dtype=tensor.dtype),
190                tensor.size(),
191                _create_chunk_list(tensor),
192            )
193        self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
194
195
196def dcp_to_torch_save(
197    dcp_checkpoint_dir: Union[str, os.PathLike],
198    torch_save_path: Union[str, os.PathLike],
199):
200    """
201    Given a directory containing a DCP checkpoint, this function will convert it into a
202    Torch save file.
203
204    Args:
205        dcp_checkpoint_dir: Directory containing the DCP checkpoint.
206        torch_save_path: Filename to store the converted Torch save file.
207
208    .. warning::
209        To avoid OOM, it's recommended to only run this function on a single rank.
210    """
211    sd: STATE_DICT_TYPE = {}
212    _load_state_dict(
213        sd,
214        storage_reader=FileSystemReader(dcp_checkpoint_dir),
215        planner=_EmptyStateDictLoadPlanner(),
216        no_dist=True,
217    )
218    torch.save(sd, torch_save_path)
219
220
221def torch_save_to_dcp(
222    torch_save_path: Union[str, os.PathLike],
223    dcp_checkpoint_dir: Union[str, os.PathLike],
224):
225    """
226    Given the location of a torch save file, converts it into a DCP checkpoint.
227
228    Args:
229        torch_save_path: Filename of the Torch save file.
230        dcp_checkpoint_dir: Directory to store the DCP checkpoint.
231
232    .. warning::
233        To avoid OOM, it's recommended to only run this function on a single rank.
234    """
235
236    state_dict = torch.load(torch_save_path, weights_only=False)
237    # we don't need stateful behavior here because the expectation is anything loaded by
238    # torch.load would not contain stateful objects.
239    _save_state_dict(
240        state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True
241    )
242
243
244if __name__ == "__main__":
245
246    class FormatMode(Enum):
247        TORCH_TO_DCP = "torch_to_dcp"
248        DCP_TO_TORCH = "dcp_to_torch"
249
250    # Parse command-line arguments
251    parser = argparse.ArgumentParser()
252    parser.add_argument(
253        "mode",
254        type=str,
255        help="Conversion mode",
256        choices=[m.value for m in FormatMode],
257        default=FormatMode.TORCH_TO_DCP,
258    )
259    parser.add_argument("src", type=str, help="Path to the source model")
260    parser.add_argument("dst", type=str, help="Path to the destination model")
261    args = parser.parse_args()
262
263    print(
264        f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'"
265    )
266    checkpoint_missing_warning = (
267        f"No checkpoint found at {args.src}. Skipping conversion."
268    )
269    if args.mode == FormatMode.TORCH_TO_DCP.value:
270        if os.path.isfile(args.src):
271            torch_save_to_dcp(args.src, args.dst)
272        else:
273            print(checkpoint_missing_warning)
274    elif args.mode == FormatMode.DCP_TO_TORCH.value:
275        if os.path.isdir(args.src):
276            dcp_to_torch_save(args.src, args.dst)
277        else:
278            print(checkpoint_missing_warning)
279    else:
280        raise ValueError(f"Unknown conversion mode: {args.mode}")
281