• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import shutil
4import tempfile
5from functools import wraps
6from typing import Any, Callable, Dict, Optional, Tuple
7
8import torch
9import torch.distributed as dist
10import torch.distributed.checkpoint as dcp
11import torch.nn as nn
12from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
13from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
14from torch.distributed.checkpoint.utils import CheckpointException
15from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
17from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
18from torch.testing._internal.common_utils import run_tests
19from torch.testing._internal.distributed._shard.sharded_tensor import (
20    ShardedTensorTestBase,
21    with_comms,
22)
23
24
25def with_temp_dir(
26    func: Optional[Callable] = None,
27) -> Optional[Callable]:
28    """
29    Wrapper to initialize temp directory for distributed checkpoint.
30    """
31    assert func is not None
32
33    @wraps(func)
34    def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
35        # Only create temp_dir when rank is 0
36        if dist.get_rank() == 0:
37            temp_dir = tempfile.mkdtemp()
38            print(f"Using temp directory: {temp_dir}")
39        else:
40            temp_dir = ""
41        object_list = [temp_dir]
42
43        # Broadcast temp_dir to all the other ranks
44        dist.broadcast_object_list(object_list)
45        self.temp_dir = object_list[0]
46
47        try:
48            func(self, *args, **kwargs)
49        finally:
50            if dist.get_rank() == 0:
51                shutil.rmtree(self.temp_dir, ignore_errors=True)
52
53    return wrapper
54
55
56class MyTestModule(torch.nn.Module):
57    def __init__(self) -> None:
58        super().__init__()
59        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
60        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
61        self.net3 = nn.Linear(32, 64)
62        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
63
64    def forward(self, x):
65        return self.net4(self.net3(self.net2(self.net1(x))))
66
67
68class TestFSSpec(ShardedTensorTestBase):
69    @property
70    def world_size(self) -> int:
71        return 2
72
73    @with_comms(init_rpc=False)
74    @skip_if_lt_x_gpu(2)
75    @requires_nccl()
76    @with_temp_dir
77    def test_fsspec(self):
78        CHECKPOINT_DIR = self.temp_dir
79
80        model = FSDP(MyTestModule().cuda())
81        optim = torch.optim.Adam(model.parameters(), lr=0.1)
82        model(torch.rand(8, 8, device=dist.get_rank())).sum().backward()
83        optim.step()
84
85        with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
86            state_dict = {
87                "model": model.state_dict(),
88                "optim": FSDP.optim_state_dict(model, optim),
89            }
90
91            dcp.save(
92                state_dict=state_dict,
93                storage_writer=FsspecWriter(CHECKPOINT_DIR),
94                planner=dcp.DefaultSavePlanner(),
95            )
96
97        model_2 = FSDP(MyTestModule().cuda())
98        optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
99
100        with FSDP.summon_full_params(model):
101            with FSDP.summon_full_params(model_2):
102                for n_p1, n_p2 in zip(
103                    model.named_parameters(), model_2.named_parameters()
104                ):
105                    self.assertNotEqual(n_p1[1], n_p2[1])
106
107        # now load the model and ensure the values are the same
108        with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT):
109            state_dict = {
110                "model": model_2.state_dict(),
111            }
112
113            dcp.load(
114                state_dict=state_dict,
115                storage_reader=FsspecReader(CHECKPOINT_DIR),
116                planner=dcp.DefaultLoadPlanner(),
117            )
118            model_2.load_state_dict(state_dict["model"])
119
120            optim_state = load_sharded_optimizer_state_dict(
121                model_state_dict=state_dict["model"],
122                optimizer_key="optim",
123                storage_reader=FsspecReader(CHECKPOINT_DIR),
124            )
125
126            flattened_osd = FSDP.optim_state_dict_to_load(
127                model_2, optim_2, optim_state["optim"]
128            )
129            optim_2.load_state_dict(flattened_osd)
130
131        with FSDP.summon_full_params(model):
132            with FSDP.summon_full_params(model_2):
133                for n_p1, n_p2 in zip(
134                    model.named_parameters(), model_2.named_parameters()
135                ):
136                    self.assertEqual(n_p1[1], n_p2[1])
137
138        def opt_at(opt, idx):
139            return list(iter(opt.state.values()))[idx]
140
141        # Adam lazily creates its state
142        self.assertEqual(opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"])
143        self.assertEqual(
144            opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
145        )
146
147    @with_comms(init_rpc=False)
148    @skip_if_lt_x_gpu(2)
149    @requires_nccl()
150    @with_temp_dir
151    def test_overwrite(self):
152        t1, t2 = torch.randn(10), torch.randn(10)
153
154        dcp.save(
155            {"random": t1}, storage_writer=FsspecWriter(self.temp_dir, overwrite=False)
156        )
157        dcp.save(
158            {"random": t2}, storage_writer=FsspecWriter(self.temp_dir, overwrite=True)
159        )
160
161        sd = {"random": torch.zeros(10)}
162        dcp.load(sd, checkpoint_id=self.temp_dir)
163        self.assertTrue(torch.allclose(sd["random"], t2))
164
165        with self.assertRaisesRegex(
166            CheckpointException, ".*Checkpoint already exists.*"
167        ):
168            dcp.save(
169                {"random": t2},
170                storage_writer=FsspecWriter(self.temp_dir, overwrite=False),
171            )
172
173
174if __name__ == "__main__":
175    run_tests()
176