• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2
3# pyre-unsafe
4import argparse
5import io
6import os
7import random
8import shlex
9import subprocess
10import time
11
12import numpy as np
13
14import torch
15import torch.distributed as dist
16import torch.distributed.autograd as dist_autograd
17import torch.distributed.rpc as rpc
18import torch.multiprocessing as mp
19import torch.nn as nn
20import torch.optim as optim
21from torch.distributed.optim import DistributedOptimizer
22from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions
23from torch.distributed.rpc.backend_registry import BackendType
24from torch.nn.parallel import DistributedDataParallel as DDP
25
26
27# Config
28NUM_TRAINERS = 8
29NUM_PS = 8
30
31NUM_EMBEDDINGS = 300
32EMBEDDING_DIM = 64
33
34WARMUP_CYCLES = 5
35
36
37class HybridModel(torch.nn.Module):
38    r"""
39    The model consists of a sparse part and a dense part.
40
41    The dense part is an nn.Linear module that is replicated across all trainers using
42    DistributedDataParallel. The sparse part has nn.EmbeddingBags stored on multiple
43    parameter servers.
44
45    The model holds a Remote Reference to the embedding tables on the parameter
46    servers.
47    """
48
49    def __init__(self, emb_rref_list, device):
50        super().__init__()
51        self.emb_rref_list = emb_rref_list
52        fc1 = torch.nn.Linear(512, 256)
53        fc2 = torch.nn.Linear(256, 128)
54        relu = torch.nn.ReLU()
55        fc3 = torch.nn.Linear(128, 64)
56        fc4 = torch.nn.Linear(64, 32)
57        fc5 = torch.nn.Linear(32, 8)
58        sec = nn.Sequential(fc1, fc2, relu, fc3, fc4, fc5)
59        self.ddp = DDP(sec.to(device), device_ids=[device])
60        self.device = device
61
62    def forward(self, indices, offsets):
63        emb_lookups = []
64
65        for emb_rref in self.emb_rref_list:
66            emb_lookups.append(
67                emb_rref.rpc_sync().forward(
68                    indices, offsets
69                )  # embedding_sum(input, offsets)
70            )
71            emb_lookups_cat = torch.cat(emb_lookups, dim=1)
72
73        # Make sure combined PS dimension is always bigger or equal than the FC input
74        assert NUM_PS * EMBEDDING_DIM >= 512
75        dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
76        emb_lookups_reshaped = emb_lookups_cat.reshape(  # type: ignore[possibly-undefined]
77            [emb_lookups_cat.shape[0] * dim_normalizer, 512]
78        )
79
80        return self.ddp(emb_lookups_reshaped)
81
82
83def _retrieve_embedding_parameters(emb_rref):
84    return [RRef(p) for p in emb_rref.local_value().parameters()]
85
86
87def _print_header():
88    _print_cont("\n")
89    _print_cont("%10s" % "")
90    for p in [50, 75, 90, 95]:
91        _print_cont("%14s%10s" % ("sec/epoch", "epoch/sec"))
92    _print_cont("\n")
93
94
95def _print_benchmark(prefix, nelem, measurements):
96    measurements = sorted(measurements)
97    _print_cont("%8s:" % prefix)
98    for p in [50, 75, 90, 95]:
99        v = np.percentile(measurements, p)
100        _print_cont("  p%02d:  %1.3fs  %6d/s" % (p, v, nelem / v))
101    _print_cont("\n")
102
103
104def _print_cont(msg):
105    print(msg, end="", flush=True)
106
107
108def _run_printable(cmd):
109    proc = subprocess.run(shlex.split(cmd), capture_output=True, check=False)  # type: ignore[call-overload]
110    assert proc.returncode == 0
111
112    buffer = io.BytesIO()
113    torch.save(proc.stdout.decode("utf-8"), buffer)
114    input_tensor = torch.ByteTensor(list(buffer.getvalue()))
115    input_length = torch.IntTensor([input_tensor.size(0)])
116
117    output = []
118    buffer = io.BytesIO(np.asarray(input_tensor).tobytes())
119    output.append(torch.load(buffer))
120    return output
121
122
123def _run_trainer(emb_rref_list, rank):
124    r"""
125    Each trainer runs a forward pass which involves an embedding lookup on the 8 parameter servers,
126    and running nn.Linear locally.
127
128    During the backward pass, DDP is responsible for aggregating the gradients for the dense part
129    (nn.Linear) and distributed autograd ensures gradients updates are
130    propagated to the parameter servers.
131    """
132    # Setup the model.
133    model = HybridModel(emb_rref_list, rank)
134
135    # Retrieve all model parameters as rrefs for DistributedOptimizer.
136
137    # Retrieve parameters from all embedding tables for the current trainer.
138    model_parameter_rrefs = []
139    for ind, emb_rref in enumerate(emb_rref_list):
140        ps_name = f"ps{ind}"
141        model_parameter_rrefs.extend(
142            rpc.rpc_sync(ps_name, _retrieve_embedding_parameters, args=(emb_rref,))
143        )
144
145    # model.parameters() only includes local parameters.
146    for param in model.parameters():
147        model_parameter_rrefs.append(RRef(param))
148
149    # Setup distributed optimizer
150    opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05)
151
152    criterion = torch.nn.CrossEntropyLoss()
153
154    def get_next_batch(rank):
155        for _ in range(10):
156            num_indices = random.randint(20, 50)
157            indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
158
159            # Generate offsets.
160            offsets = []
161            start = 0
162            batch_size = 0
163
164            while start < num_indices:
165                offsets.append(start)
166                start += random.randint(1, 10)
167                batch_size += 1
168
169            offsets_tensor = torch.LongTensor(offsets)
170            target = torch.LongTensor(batch_size).random_(8).cuda(rank)
171
172            yield indices, offsets_tensor, target
173
174    measurements = []
175    # Include warm-up cycles during training
176    for epoch in range(100 + WARMUP_CYCLES):
177        start = time.time()
178        batch_size = 0
179
180        # create distributed autograd context
181        for indices, offsets, target in get_next_batch(rank):
182            batch_size += len(target)
183
184            with dist_autograd.context() as context_id:
185                output = model(indices, offsets)
186                loss = criterion(output, target)
187
188                # Run distributed backward pass
189                dist_autograd.backward(context_id, [loss])
190
191                # Run distributed optimizer. Gradients propagated all the way to the parameter servers
192                opt.step(context_id)
193
194                # Not necessary to zero grads as each iteration creates a different
195                # distributed autograd context which hosts different grads
196
197        measurements.append(time.time() - start)
198        # print("Training done for epoch {}".format(epoch))
199
200    # Throw away warm-up measurements
201    measurements = measurements[WARMUP_CYCLES:]
202    return rank, measurements, batch_size  # type: ignore[possibly-undefined]
203
204
205def run_worker(rank, world_size):
206    r"""
207    Initialize RPC, calls the function, and shuts down RPC.
208    """
209    # Using different port numbers in TCP init_method for init_rpc and
210    # init_process_group to avoid port conflicts.
211    rpc_backend_options = TensorPipeRpcBackendOptions()
212    rpc_backend_options.init_method = "tcp://localhost:29500"
213
214    # Rank 16. Master
215    if rank == (NUM_TRAINERS + NUM_PS):
216        rpc.init_rpc(
217            "master",
218            rank=rank,
219            backend=BackendType.TENSORPIPE,  # type: ignore[attr-defined]
220            world_size=world_size,
221        )
222
223        # Build the Embedding tables on the Parameter Servers.
224        emb_rref_list = []
225        index = 0
226        while index < NUM_PS:
227            ps_name = f"ps{index}"
228            emb_rref = rpc.remote(
229                ps_name,
230                torch.nn.EmbeddingBag,
231                args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
232                kwargs={"mode": "sum"},
233            )
234            emb_rref_list.append(emb_rref)
235            index += 1
236
237        # Run training loop on the trainers.
238        futs = []
239        for trainer_rank in range(NUM_TRAINERS):
240            trainer_name = f"trainer{trainer_rank}"
241            fut = rpc.rpc_async(
242                trainer_name, _run_trainer, args=(emb_rref_list, trainer_rank)
243            )
244            futs.append(fut)
245
246        _print_header()
247
248        measurements_all_trainers = []
249        batch_size_all_trainers = 0
250        # Wait for all training to finish.
251        for fut in futs:
252            rank, measurements, batch_size = fut.wait()
253            _print_benchmark(f"Trainer{rank}", batch_size, measurements)
254            batch_size_all_trainers += batch_size
255            measurements_all_trainers.append(measurements)
256
257        _print_benchmark("All", batch_size_all_trainers, measurements_all_trainers)
258
259    # Rank 0-7. Trainers
260    elif rank >= 0 and rank < NUM_PS:
261        # Initialize process group for Distributed DataParallel on trainers.
262        dist.init_process_group(
263            backend=dist.Backend.GLOO,
264            rank=rank,
265            world_size=NUM_TRAINERS,
266            init_method="tcp://localhost:29501",
267        )
268
269        # Initialize RPC. Trainer just waits for RPCs from master.
270        trainer_name = f"trainer{rank}"
271        rpc.init_rpc(
272            trainer_name,
273            rank=rank,
274            world_size=world_size,
275            rpc_backend_options=rpc_backend_options,
276        )
277
278    # Rank 8-15. Parameter Servers
279    elif rank >= NUM_TRAINERS and rank < NUM_TRAINERS + NUM_PS:
280        ps_name = f"ps{rank - NUM_TRAINERS}"
281        rpc.init_rpc(
282            ps_name,
283            rank=rank,
284            world_size=world_size,
285            backend=BackendType.TENSORPIPE,  # type: ignore[attr-defined]
286            rpc_backend_options=rpc_backend_options,
287        )
288        # parameter server do nothing
289
290    # block until all rpcs finish
291    rpc.shutdown()
292
293
294if __name__ == "__main__":
295    """Initializing the distributed environment."""
296
297    output = _run_printable("nvidia-smi topo -m")
298    print("-------------------------------------------")
299    print("                  Info                     ")
300    print("-------------------------------------------")
301    print()
302    print(f"* PyTorch version: {torch.__version__}")
303    print(f"* CUDA version: {torch.version.cuda}")
304    print()
305    print("------------ nvidia-smi topo -m -----------")
306    print()
307    print(output[0])
308    print("-------------------------------------------")
309    print("PyTorch Distributed Benchmark (DDP and RPC)")
310    print("-------------------------------------------")
311
312    # Cmd arguments to enable automated runs (e.g. Chronos, SSH, etc).
313    parser = argparse.ArgumentParser(description="PyTorch DDP and RPC Benchmark")
314    parser.add_argument(
315        "--master-addr", type=str, default="localhost", help="Address of master node."
316    )
317    parser.add_argument("--master-port", type=str, default="29500", help="Master port.")
318
319    parser.add_argument(
320        "--number-trainers",
321        type=int,
322        default=NUM_TRAINERS,
323        help="Number of Trainer Nodes.",
324    )
325    parser.add_argument(
326        "--number-ps", type=int, default=NUM_PS, help="Number of Parameter Servers."
327    )
328    parser.add_argument(
329        "--number-embeddings",
330        type=int,
331        default=NUM_EMBEDDINGS,
332        help="Number of test embeddings to be generated.",
333    )
334    parser.add_argument(
335        "--embedding-dim",
336        type=int,
337        default=EMBEDDING_DIM,
338        help="Number of embedding dimensions.",
339    )
340    parser.add_argument(
341        "--warmup-cycles",
342        type=int,
343        default=WARMUP_CYCLES,
344        help="Number of cycles to warm-up each process before running the benchmark.",
345    )
346
347    args = parser.parse_args()
348
349    os.environ["MASTER_ADDR"] = args.master_addr
350    os.environ["MASTER_PORT"] = args.master_port
351
352    NUM_TRAINERS = args.number_trainers
353    NUM_PS = args.number_ps
354
355    NUM_EMBEDDINGS = args.number_embeddings
356    EMBEDDING_DIM = args.embedding_dim
357
358    WARMUP_CYCLES = args.warmup_cycles
359
360    # Defaults:
361    #  8 trainers (rank 0-7),
362    #  8 parameter servers (rank 8-15),
363    #  1 master (rank 16).
364    world_size = NUM_TRAINERS + NUM_PS + 1  # Trainers + PS + Master
365    mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
366