1# mypy: allow-untyped-defs 2 3# pyre-unsafe 4import argparse 5import io 6import os 7import random 8import shlex 9import subprocess 10import time 11 12import numpy as np 13 14import torch 15import torch.distributed as dist 16import torch.distributed.autograd as dist_autograd 17import torch.distributed.rpc as rpc 18import torch.multiprocessing as mp 19import torch.nn as nn 20import torch.optim as optim 21from torch.distributed.optim import DistributedOptimizer 22from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions 23from torch.distributed.rpc.backend_registry import BackendType 24from torch.nn.parallel import DistributedDataParallel as DDP 25 26 27# Config 28NUM_TRAINERS = 8 29NUM_PS = 8 30 31NUM_EMBEDDINGS = 300 32EMBEDDING_DIM = 64 33 34WARMUP_CYCLES = 5 35 36 37class HybridModel(torch.nn.Module): 38 r""" 39 The model consists of a sparse part and a dense part. 40 41 The dense part is an nn.Linear module that is replicated across all trainers using 42 DistributedDataParallel. The sparse part has nn.EmbeddingBags stored on multiple 43 parameter servers. 44 45 The model holds a Remote Reference to the embedding tables on the parameter 46 servers. 47 """ 48 49 def __init__(self, emb_rref_list, device): 50 super().__init__() 51 self.emb_rref_list = emb_rref_list 52 fc1 = torch.nn.Linear(512, 256) 53 fc2 = torch.nn.Linear(256, 128) 54 relu = torch.nn.ReLU() 55 fc3 = torch.nn.Linear(128, 64) 56 fc4 = torch.nn.Linear(64, 32) 57 fc5 = torch.nn.Linear(32, 8) 58 sec = nn.Sequential(fc1, fc2, relu, fc3, fc4, fc5) 59 self.ddp = DDP(sec.to(device), device_ids=[device]) 60 self.device = device 61 62 def forward(self, indices, offsets): 63 emb_lookups = [] 64 65 for emb_rref in self.emb_rref_list: 66 emb_lookups.append( 67 emb_rref.rpc_sync().forward( 68 indices, offsets 69 ) # embedding_sum(input, offsets) 70 ) 71 emb_lookups_cat = torch.cat(emb_lookups, dim=1) 72 73 # Make sure combined PS dimension is always bigger or equal than the FC input 74 assert NUM_PS * EMBEDDING_DIM >= 512 75 dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512) 76 emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined] 77 [emb_lookups_cat.shape[0] * dim_normalizer, 512] 78 ) 79 80 return self.ddp(emb_lookups_reshaped) 81 82 83def _retrieve_embedding_parameters(emb_rref): 84 return [RRef(p) for p in emb_rref.local_value().parameters()] 85 86 87def _print_header(): 88 _print_cont("\n") 89 _print_cont("%10s" % "") 90 for p in [50, 75, 90, 95]: 91 _print_cont("%14s%10s" % ("sec/epoch", "epoch/sec")) 92 _print_cont("\n") 93 94 95def _print_benchmark(prefix, nelem, measurements): 96 measurements = sorted(measurements) 97 _print_cont("%8s:" % prefix) 98 for p in [50, 75, 90, 95]: 99 v = np.percentile(measurements, p) 100 _print_cont(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v)) 101 _print_cont("\n") 102 103 104def _print_cont(msg): 105 print(msg, end="", flush=True) 106 107 108def _run_printable(cmd): 109 proc = subprocess.run(shlex.split(cmd), capture_output=True, check=False) # type: ignore[call-overload] 110 assert proc.returncode == 0 111 112 buffer = io.BytesIO() 113 torch.save(proc.stdout.decode("utf-8"), buffer) 114 input_tensor = torch.ByteTensor(list(buffer.getvalue())) 115 input_length = torch.IntTensor([input_tensor.size(0)]) 116 117 output = [] 118 buffer = io.BytesIO(np.asarray(input_tensor).tobytes()) 119 output.append(torch.load(buffer)) 120 return output 121 122 123def _run_trainer(emb_rref_list, rank): 124 r""" 125 Each trainer runs a forward pass which involves an embedding lookup on the 8 parameter servers, 126 and running nn.Linear locally. 127 128 During the backward pass, DDP is responsible for aggregating the gradients for the dense part 129 (nn.Linear) and distributed autograd ensures gradients updates are 130 propagated to the parameter servers. 131 """ 132 # Setup the model. 133 model = HybridModel(emb_rref_list, rank) 134 135 # Retrieve all model parameters as rrefs for DistributedOptimizer. 136 137 # Retrieve parameters from all embedding tables for the current trainer. 138 model_parameter_rrefs = [] 139 for ind, emb_rref in enumerate(emb_rref_list): 140 ps_name = f"ps{ind}" 141 model_parameter_rrefs.extend( 142 rpc.rpc_sync(ps_name, _retrieve_embedding_parameters, args=(emb_rref,)) 143 ) 144 145 # model.parameters() only includes local parameters. 146 for param in model.parameters(): 147 model_parameter_rrefs.append(RRef(param)) 148 149 # Setup distributed optimizer 150 opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05) 151 152 criterion = torch.nn.CrossEntropyLoss() 153 154 def get_next_batch(rank): 155 for _ in range(10): 156 num_indices = random.randint(20, 50) 157 indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS) 158 159 # Generate offsets. 160 offsets = [] 161 start = 0 162 batch_size = 0 163 164 while start < num_indices: 165 offsets.append(start) 166 start += random.randint(1, 10) 167 batch_size += 1 168 169 offsets_tensor = torch.LongTensor(offsets) 170 target = torch.LongTensor(batch_size).random_(8).cuda(rank) 171 172 yield indices, offsets_tensor, target 173 174 measurements = [] 175 # Include warm-up cycles during training 176 for epoch in range(100 + WARMUP_CYCLES): 177 start = time.time() 178 batch_size = 0 179 180 # create distributed autograd context 181 for indices, offsets, target in get_next_batch(rank): 182 batch_size += len(target) 183 184 with dist_autograd.context() as context_id: 185 output = model(indices, offsets) 186 loss = criterion(output, target) 187 188 # Run distributed backward pass 189 dist_autograd.backward(context_id, [loss]) 190 191 # Run distributed optimizer. Gradients propagated all the way to the parameter servers 192 opt.step(context_id) 193 194 # Not necessary to zero grads as each iteration creates a different 195 # distributed autograd context which hosts different grads 196 197 measurements.append(time.time() - start) 198 # print("Training done for epoch {}".format(epoch)) 199 200 # Throw away warm-up measurements 201 measurements = measurements[WARMUP_CYCLES:] 202 return rank, measurements, batch_size # type: ignore[possibly-undefined] 203 204 205def run_worker(rank, world_size): 206 r""" 207 Initialize RPC, calls the function, and shuts down RPC. 208 """ 209 # Using different port numbers in TCP init_method for init_rpc and 210 # init_process_group to avoid port conflicts. 211 rpc_backend_options = TensorPipeRpcBackendOptions() 212 rpc_backend_options.init_method = "tcp://localhost:29500" 213 214 # Rank 16. Master 215 if rank == (NUM_TRAINERS + NUM_PS): 216 rpc.init_rpc( 217 "master", 218 rank=rank, 219 backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] 220 world_size=world_size, 221 ) 222 223 # Build the Embedding tables on the Parameter Servers. 224 emb_rref_list = [] 225 index = 0 226 while index < NUM_PS: 227 ps_name = f"ps{index}" 228 emb_rref = rpc.remote( 229 ps_name, 230 torch.nn.EmbeddingBag, 231 args=(NUM_EMBEDDINGS, EMBEDDING_DIM), 232 kwargs={"mode": "sum"}, 233 ) 234 emb_rref_list.append(emb_rref) 235 index += 1 236 237 # Run training loop on the trainers. 238 futs = [] 239 for trainer_rank in range(NUM_TRAINERS): 240 trainer_name = f"trainer{trainer_rank}" 241 fut = rpc.rpc_async( 242 trainer_name, _run_trainer, args=(emb_rref_list, trainer_rank) 243 ) 244 futs.append(fut) 245 246 _print_header() 247 248 measurements_all_trainers = [] 249 batch_size_all_trainers = 0 250 # Wait for all training to finish. 251 for fut in futs: 252 rank, measurements, batch_size = fut.wait() 253 _print_benchmark(f"Trainer{rank}", batch_size, measurements) 254 batch_size_all_trainers += batch_size 255 measurements_all_trainers.append(measurements) 256 257 _print_benchmark("All", batch_size_all_trainers, measurements_all_trainers) 258 259 # Rank 0-7. Trainers 260 elif rank >= 0 and rank < NUM_PS: 261 # Initialize process group for Distributed DataParallel on trainers. 262 dist.init_process_group( 263 backend=dist.Backend.GLOO, 264 rank=rank, 265 world_size=NUM_TRAINERS, 266 init_method="tcp://localhost:29501", 267 ) 268 269 # Initialize RPC. Trainer just waits for RPCs from master. 270 trainer_name = f"trainer{rank}" 271 rpc.init_rpc( 272 trainer_name, 273 rank=rank, 274 world_size=world_size, 275 rpc_backend_options=rpc_backend_options, 276 ) 277 278 # Rank 8-15. Parameter Servers 279 elif rank >= NUM_TRAINERS and rank < NUM_TRAINERS + NUM_PS: 280 ps_name = f"ps{rank - NUM_TRAINERS}" 281 rpc.init_rpc( 282 ps_name, 283 rank=rank, 284 world_size=world_size, 285 backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] 286 rpc_backend_options=rpc_backend_options, 287 ) 288 # parameter server do nothing 289 290 # block until all rpcs finish 291 rpc.shutdown() 292 293 294if __name__ == "__main__": 295 """Initializing the distributed environment.""" 296 297 output = _run_printable("nvidia-smi topo -m") 298 print("-------------------------------------------") 299 print(" Info ") 300 print("-------------------------------------------") 301 print() 302 print(f"* PyTorch version: {torch.__version__}") 303 print(f"* CUDA version: {torch.version.cuda}") 304 print() 305 print("------------ nvidia-smi topo -m -----------") 306 print() 307 print(output[0]) 308 print("-------------------------------------------") 309 print("PyTorch Distributed Benchmark (DDP and RPC)") 310 print("-------------------------------------------") 311 312 # Cmd arguments to enable automated runs (e.g. Chronos, SSH, etc). 313 parser = argparse.ArgumentParser(description="PyTorch DDP and RPC Benchmark") 314 parser.add_argument( 315 "--master-addr", type=str, default="localhost", help="Address of master node." 316 ) 317 parser.add_argument("--master-port", type=str, default="29500", help="Master port.") 318 319 parser.add_argument( 320 "--number-trainers", 321 type=int, 322 default=NUM_TRAINERS, 323 help="Number of Trainer Nodes.", 324 ) 325 parser.add_argument( 326 "--number-ps", type=int, default=NUM_PS, help="Number of Parameter Servers." 327 ) 328 parser.add_argument( 329 "--number-embeddings", 330 type=int, 331 default=NUM_EMBEDDINGS, 332 help="Number of test embeddings to be generated.", 333 ) 334 parser.add_argument( 335 "--embedding-dim", 336 type=int, 337 default=EMBEDDING_DIM, 338 help="Number of embedding dimensions.", 339 ) 340 parser.add_argument( 341 "--warmup-cycles", 342 type=int, 343 default=WARMUP_CYCLES, 344 help="Number of cycles to warm-up each process before running the benchmark.", 345 ) 346 347 args = parser.parse_args() 348 349 os.environ["MASTER_ADDR"] = args.master_addr 350 os.environ["MASTER_PORT"] = args.master_port 351 352 NUM_TRAINERS = args.number_trainers 353 NUM_PS = args.number_ps 354 355 NUM_EMBEDDINGS = args.number_embeddings 356 EMBEDDING_DIM = args.embedding_dim 357 358 WARMUP_CYCLES = args.warmup_cycles 359 360 # Defaults: 361 # 8 trainers (rank 0-7), 362 # 8 parameter servers (rank 8-15), 363 # 1 master (rank 16). 364 world_size = NUM_TRAINERS + NUM_PS + 1 # Trainers + PS + Master 365 mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True) 366