1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6import torch.nn as nn 7import torch.optim as optim 8from torch import distributed as dist 9from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 11from torch.testing._internal.common_fsdp import FSDPTest 12from torch.testing._internal.common_utils import ( 13 instantiate_parametrized_tests, 14 parametrize, 15 run_tests, 16 TEST_WITH_DEV_DBG_ASAN, 17) 18from torch.utils.checkpoint import checkpoint 19 20 21if not dist.is_available(): 22 print("Distributed not available, skipping tests", file=sys.stderr) 23 sys.exit(0) 24 25if TEST_WITH_DEV_DBG_ASAN: 26 print( 27 "Skip dev-asan as torch + multiprocessing spawn have known issues", 28 file=sys.stderr, 29 ) 30 sys.exit(0) 31 32 33def get_cur_mem(rank, result, prefix): 34 """Collect memory allocated values in a result dict in MB""" 35 torch._C._cuda_clearCublasWorkspaces() 36 result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024) 37 38 39class Model(nn.Module): 40 def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False): 41 super().__init__() 42 if with_fsdp: 43 self.stem = nn.Sequential( 44 nn.Conv2d(3, 64, kernel_size=3), 45 FSDP(nn.BatchNorm2d(64)), 46 nn.ReLU(inplace=True), 47 ) 48 else: 49 self.stem = nn.Sequential( 50 nn.Conv2d(3, 64, kernel_size=3), 51 nn.BatchNorm2d(64), 52 nn.ReLU(inplace=True), 53 ) 54 if with_fsdp: 55 self.blocks = nn.Sequential( 56 nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2), 57 FSDP(nn.BatchNorm2d(hidden_dim)), 58 nn.ReLU(inplace=True), 59 nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2), 60 FSDP(nn.BatchNorm2d(hidden_dim)), 61 nn.ReLU(inplace=True), 62 nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2), 63 FSDP(nn.BatchNorm2d(hidden_dim)), 64 nn.ReLU(inplace=True), 65 nn.AdaptiveAvgPool2d(output_size=(1, 1)), 66 nn.Flatten(), 67 ) 68 else: 69 self.blocks = nn.Sequential( 70 nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2), 71 nn.BatchNorm2d(hidden_dim), 72 nn.ReLU(inplace=True), 73 nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2), 74 nn.BatchNorm2d(hidden_dim), 75 nn.ReLU(inplace=True), 76 nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2), 77 nn.BatchNorm2d(hidden_dim), 78 nn.ReLU(inplace=True), 79 nn.AdaptiveAvgPool2d(output_size=(1, 1)), 80 nn.Flatten(), 81 ) 82 83 self.head = nn.Linear(hidden_dim, 10) 84 self.with_checkpoint = with_checkpoint 85 86 def forward(self, x): 87 if self.with_checkpoint: 88 return self.head(checkpoint(self.blocks, self.stem(x), use_reentrant=True)) 89 else: 90 return self.head(self.blocks(self.stem(x))) 91 92 93def create_model(with_fsdp, with_checkpoint, model_hidden_dim): 94 torch.manual_seed(0) 95 model = Model(model_hidden_dim, with_fsdp, with_checkpoint) 96 if with_fsdp: 97 model.stem = FSDP(model.stem) 98 model.blocks = FSDP(model.blocks) 99 model.head = FSDP(model.head) 100 101 return model 102 103 104class TestFSDPMemory(FSDPTest): 105 @property 106 def world_size(self): 107 return 2 108 109 def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations): 110 gpu_id = self.rank 111 world_size = self.world_size 112 113 batch = torch.randn(size=(2, 3, 224, 224)).cuda() 114 115 model = create_model( 116 with_fsdp=True, 117 with_checkpoint=with_checkpoint, 118 model_hidden_dim=model_hidden_dim, 119 ) 120 model = model.cuda() 121 model = FSDP(model) 122 123 # We enable momentum so that after the first iteration, the optimizer state is added 124 # to the total memory used. 125 criterion = nn.MSELoss() 126 optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) 127 128 results = {} # results of memory stats 129 for iteration in range(iterations): 130 get_cur_mem(gpu_id, results, f"iter {iteration}: start") 131 132 out = model(batch) 133 get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") 134 135 out = sum(o.sum() for o in out[0]) 136 fake_loss = criterion(out, torch.tensor(0.0).cuda()) 137 get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") 138 139 fake_loss.backward() 140 get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") 141 142 optimizer.step() 143 get_cur_mem(gpu_id, results, f"iter {iteration}: after step") 144 145 # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. 146 model.zero_grad(set_to_none=True) 147 get_cur_mem(gpu_id, results, f"iter {iteration}: done") 148 149 def cmp(results, expected): 150 ret = "" 151 self.assertEqual(results.keys(), expected.keys()) 152 for k, v in results.items(): 153 exp = expected[k] 154 if abs(exp - v) > 1: # allow 1MB rounding differences 155 ret += f"{k}: got {v}, expected {exp}\n" 156 return ret 157 158 output = cmp(results, expected) 159 self.assertEqual(output, "") 160 161 @skip_if_lt_x_gpu(2) 162 @parametrize("ckpt", ["no_ckpt", "ckpt"]) 163 def test_fsdp_memory(self, ckpt): 164 # hidden_dim 128: model size ~4MB 165 model_hidden_dim = 128 166 167 model = create_model( 168 with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim 169 ).cuda() 170 model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024) 171 del model 172 173 sharded_model_size_mb = int(model_size_mb / self.world_size) 174 175 # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this 176 # test but on much bigger scale tests). We run 4 iterations here just in case it happens. 177 iterations = 4 178 179 expected = {} 180 181 for iteration in range(iterations): 182 if iteration == 0: 183 # sharded model size + 1MB temp memory 184 expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1 185 # it is hard to calculate this memory size, get it from printed memory usage 186 if ckpt == "ckpt": 187 expected[f"iter {iteration}: after fwd"] = 51 188 expected[f"iter {iteration}: after loss"] = 51 189 else: 190 expected[f"iter {iteration}: after fwd"] = 340 191 expected[f"iter {iteration}: after loss"] = 340 192 # sharded model size + sharded grad size + 1M temp memory 193 expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1 194 else: 195 # after optimizer step in the first iteration, memory usage increased by 196 # sharded_model_size_mb because of increased optimizer states memory usage 197 expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1 198 if ckpt == "ckpt": 199 expected[f"iter {iteration}: after fwd"] = ( 200 51 + sharded_model_size_mb 201 ) 202 expected[f"iter {iteration}: after loss"] = ( 203 51 + sharded_model_size_mb 204 ) 205 else: 206 expected[f"iter {iteration}: after fwd"] = ( 207 340 + sharded_model_size_mb 208 ) 209 expected[f"iter {iteration}: after loss"] = ( 210 340 + sharded_model_size_mb 211 ) 212 expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1 213 214 # sharded model size + sharded grad size + optimizer states + 1M temp memory 215 expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1 216 # grad memory is claimed after setting grad = None 217 # sharded model size + optimizer states + 1M temp memory 218 expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1 219 220 # Get the fsdp and checkpoint flags. 221 with_ckpt = ckpt == "ckpt" 222 223 self._dist_train( 224 with_ckpt, 225 expected, 226 model_hidden_dim, 227 iterations, 228 ) 229 230 231instantiate_parametrized_tests(TestFSDPMemory) 232 233 234if __name__ == "__main__": 235 run_tests() 236