• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
4"""
5Runs CIFAR10 training with differential privacy.
6"""
7
8import argparse
9import logging
10import shutil
11import sys
12from datetime import datetime, timedelta
13
14import numpy as np
15import torchvision.transforms as transforms
16from torchvision import models
17from torchvision.datasets import CIFAR10
18from tqdm import tqdm
19
20import torch
21import torch.nn as nn
22import torch.optim as optim
23import torch.utils.data
24from torch.func import functional_call, grad_and_value, vmap
25
26
27logging.basicConfig(
28    format="%(asctime)s:%(levelname)s:%(message)s",
29    datefmt="%m/%d/%Y %H:%M:%S",
30    stream=sys.stdout,
31)
32logger = logging.getLogger("ddp")
33logger.setLevel(level=logging.INFO)
34
35
36def save_checkpoint(state, is_best, filename="checkpoint.tar"):
37    torch.save(state, filename)
38    if is_best:
39        shutil.copyfile(filename, "model_best.pth.tar")
40
41
42def accuracy(preds, labels):
43    return (preds == labels).mean()
44
45
46def compute_norms(sample_grads):
47    batch_size = sample_grads[0].shape[0]
48    norms = [
49        sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads
50    ]
51    norms = torch.stack(norms, dim=0).norm(2, dim=0)
52    return norms, batch_size
53
54
55def clip_and_accumulate_and_add_noise(
56    model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0
57):
58    sample_grads = tuple(param.grad_sample for param in model.parameters())
59
60    # step 0: compute the norms
61    sample_norms, batch_size = compute_norms(sample_grads)
62
63    # step 1: compute clipping factors
64    clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
65    clip_factor = clip_factor.clamp(max=1.0)
66
67    # step 2: clip
68    grads = tuple(
69        torch.einsum("i,i...", clip_factor, sample_grad) for sample_grad in sample_grads
70    )
71
72    # step 3: add gaussian noise
73    stddev = max_per_sample_grad_norm * noise_multiplier
74    noises = tuple(
75        torch.normal(0, stddev, grad_param.shape, device=grad_param.device)
76        for grad_param in grads
77    )
78    grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads))
79
80    # step 4: assign the new grads, delete the sample grads
81    for param, param_grad in zip(model.parameters(), grads):
82        param.grad = param_grad / batch_size
83        del param.grad_sample
84
85
86def train(args, model, train_loader, optimizer, epoch, device):
87    start_time = datetime.now()
88
89    criterion = nn.CrossEntropyLoss()
90
91    losses = []
92    top1_acc = []
93
94    for i, (images, target) in enumerate(tqdm(train_loader)):
95        images = images.to(device)
96        target = target.to(device)
97
98        # Step 1: compute per-sample-grads
99
100        # To use vmap+grad to compute per-sample-grads, the forward pass
101        # must be re-formulated on a single example.
102        # We use the `grad` operator to compute forward+backward on a single example,
103        # and finally `vmap` to do forward+backward on multiple examples.
104        def compute_loss_and_output(weights, image, target):
105            images = image.unsqueeze(0)
106            targets = target.unsqueeze(0)
107            output = functional_call(model, weights, images)
108            loss = criterion(output, targets)
109            return loss, output.squeeze(0)
110
111        # `grad(f)` is a functional API that returns a function `f'` that
112        # computes gradients by running both the forward and backward pass.
113        # We want to extract some intermediate
114        # values from the computation (i.e. the loss and output).
115        #
116        # To extract the loss, we use the `grad_and_value` API, that returns the
117        # gradient of the weights w.r.t. the loss and the loss.
118        #
119        # To extract the output, we use the `has_aux=True` flag.
120        # `has_aux=True` assumes that `f` returns a tuple of two values,
121        # where the first is to be differentiated and the second "auxiliary value"
122        # is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
123        # the loss, and the auxiliary value.
124        grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True)
125        weights = dict(model.named_parameters())
126
127        # detaching weights since we don't need to track gradients outside of transforms
128        # and this is more performant
129        detached_weights = {k: v.detach() for k, v in weights.items()}
130        sample_grads, (sample_loss, output) = vmap(grads_loss_output, (None, 0, 0))(
131            detached_weights, images, target
132        )
133        loss = sample_loss.mean()
134
135        for name, grad_sample in sample_grads.items():
136            weights[name].grad_sample = grad_sample.detach()
137
138        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
139        clip_and_accumulate_and_add_noise(
140            model, args.max_per_sample_grad_norm, args.sigma
141        )
142
143        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
144        labels = target.detach().cpu().numpy()
145        losses.append(loss.item())
146
147        # measure accuracy and record loss
148        acc1 = accuracy(preds, labels)
149
150        top1_acc.append(acc1)
151
152        # make sure we take a step after processing the last mini-batch in the
153        # epoch to ensure we start the next epoch with a clean state
154        optimizer.step()
155        optimizer.zero_grad()
156
157        if i % args.print_freq == 0:
158            print(
159                f"\tTrain Epoch: {epoch} \t"
160                f"Loss: {np.mean(losses):.6f} "
161                f"Acc@1: {np.mean(top1_acc):.6f} "
162            )
163    train_duration = datetime.now() - start_time
164    return train_duration
165
166
167def test(args, model, test_loader, device):
168    model.eval()
169    criterion = nn.CrossEntropyLoss()
170    losses = []
171    top1_acc = []
172
173    with torch.no_grad():
174        for images, target in tqdm(test_loader):
175            images = images.to(device)
176            target = target.to(device)
177
178            output = model(images)
179            loss = criterion(output, target)
180            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
181            labels = target.detach().cpu().numpy()
182            acc1 = accuracy(preds, labels)
183
184            losses.append(loss.item())
185            top1_acc.append(acc1)
186
187    top1_avg = np.mean(top1_acc)
188
189    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
190    return np.mean(top1_acc)
191
192
193# flake8: noqa: C901
194def main():
195    args = parse_args()
196
197    if args.debug >= 1:
198        logger.setLevel(level=logging.DEBUG)
199
200    device = args.device
201
202    if args.secure_rng:
203        try:
204            import torchcsprng as prng
205        except ImportError as e:
206            msg = (
207                "To use secure RNG, you must install the torchcsprng package! "
208                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
209            )
210            raise ImportError(msg) from e
211
212        generator = prng.create_random_device_generator("/dev/urandom")
213
214    else:
215        generator = None
216
217    augmentations = [
218        transforms.RandomCrop(32, padding=4),
219        transforms.RandomHorizontalFlip(),
220    ]
221    normalize = [
222        transforms.ToTensor(),
223        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
224    ]
225    train_transform = transforms.Compose(normalize)
226
227    test_transform = transforms.Compose(normalize)
228
229    train_dataset = CIFAR10(
230        root=args.data_root, train=True, download=True, transform=train_transform
231    )
232
233    train_loader = torch.utils.data.DataLoader(
234        train_dataset,
235        batch_size=int(args.sample_rate * len(train_dataset)),
236        generator=generator,
237        num_workers=args.workers,
238        pin_memory=True,
239    )
240
241    test_dataset = CIFAR10(
242        root=args.data_root, train=False, download=True, transform=test_transform
243    )
244    test_loader = torch.utils.data.DataLoader(
245        test_dataset,
246        batch_size=args.batch_size_test,
247        shuffle=False,
248        num_workers=args.workers,
249    )
250
251    best_acc1 = 0
252
253    model = models.__dict__[args.architecture](
254        pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c))
255    )
256    model = model.to(device)
257
258    if args.optim == "SGD":
259        optimizer = optim.SGD(
260            model.parameters(),
261            lr=args.lr,
262            momentum=args.momentum,
263            weight_decay=args.weight_decay,
264        )
265    elif args.optim == "RMSprop":
266        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
267    elif args.optim == "Adam":
268        optimizer = optim.Adam(model.parameters(), lr=args.lr)
269    else:
270        raise NotImplementedError("Optimizer not recognized. Please check spelling")
271
272    # Store some logs
273    accuracy_per_epoch = []
274    time_per_epoch = []
275
276    for epoch in range(args.start_epoch, args.epochs + 1):
277        if args.lr_schedule == "cos":
278            lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1)))
279            for param_group in optimizer.param_groups:
280                param_group["lr"] = lr
281
282        train_duration = train(args, model, train_loader, optimizer, epoch, device)
283        top1_acc = test(args, model, test_loader, device)
284
285        # remember best acc@1 and save checkpoint
286        is_best = top1_acc > best_acc1
287        best_acc1 = max(top1_acc, best_acc1)
288
289        time_per_epoch.append(train_duration)
290        accuracy_per_epoch.append(float(top1_acc))
291
292        save_checkpoint(
293            {
294                "epoch": epoch + 1,
295                "arch": "Convnet",
296                "state_dict": model.state_dict(),
297                "best_acc1": best_acc1,
298                "optimizer": optimizer.state_dict(),
299            },
300            is_best,
301            filename=args.checkpoint_file + ".tar",
302        )
303
304    time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch]
305    avg_time_per_epoch = sum(time_per_epoch_seconds) / len(time_per_epoch_seconds)
306    metrics = {
307        "accuracy": best_acc1,
308        "accuracy_per_epoch": accuracy_per_epoch,
309        "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))),
310        "time_per_epoch": time_per_epoch_seconds,
311    }
312
313    logger.info(
314        "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n"
315    )
316    logger.info(metrics)
317
318
319def parse_args():
320    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
321    parser.add_argument(
322        "-j",
323        "--workers",
324        default=2,
325        type=int,
326        metavar="N",
327        help="number of data loading workers (default: 2)",
328    )
329    parser.add_argument(
330        "--epochs",
331        default=90,
332        type=int,
333        metavar="N",
334        help="number of total epochs to run",
335    )
336    parser.add_argument(
337        "--start-epoch",
338        default=1,
339        type=int,
340        metavar="N",
341        help="manual epoch number (useful on restarts)",
342    )
343    parser.add_argument(
344        "-b",
345        "--batch-size-test",
346        default=256,
347        type=int,
348        metavar="N",
349        help="mini-batch size for test dataset (default: 256)",
350    )
351    parser.add_argument(
352        "--sample-rate",
353        default=0.005,
354        type=float,
355        metavar="SR",
356        help="sample rate used for batch construction (default: 0.005)",
357    )
358    parser.add_argument(
359        "--lr",
360        "--learning-rate",
361        default=0.1,
362        type=float,
363        metavar="LR",
364        help="initial learning rate",
365        dest="lr",
366    )
367    parser.add_argument(
368        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
369    )
370    parser.add_argument(
371        "--wd",
372        "--weight-decay",
373        default=0,
374        type=float,
375        metavar="W",
376        help="SGD weight decay",
377        dest="weight_decay",
378    )
379    parser.add_argument(
380        "-p",
381        "--print-freq",
382        default=10,
383        type=int,
384        metavar="N",
385        help="print frequency (default: 10)",
386    )
387    parser.add_argument(
388        "--resume",
389        default="",
390        type=str,
391        metavar="PATH",
392        help="path to latest checkpoint (default: none)",
393    )
394    parser.add_argument(
395        "-e",
396        "--evaluate",
397        dest="evaluate",
398        action="store_true",
399        help="evaluate model on validation set",
400    )
401    parser.add_argument(
402        "--seed", default=None, type=int, help="seed for initializing training. "
403    )
404
405    parser.add_argument(
406        "--sigma",
407        type=float,
408        default=1.5,
409        metavar="S",
410        help="Noise multiplier (default 1.0)",
411    )
412    parser.add_argument(
413        "-c",
414        "--max-per-sample-grad_norm",
415        type=float,
416        default=10.0,
417        metavar="C",
418        help="Clip per-sample gradients to this norm (default 1.0)",
419    )
420    parser.add_argument(
421        "--secure-rng",
422        action="store_true",
423        default=False,
424        help="Enable Secure RNG to have trustworthy privacy guarantees."
425        "Comes at a performance cost. Opacus will emit a warning if secure rng is off,"
426        "indicating that for production use it's recommender to turn it on.",
427    )
428    parser.add_argument(
429        "--delta",
430        type=float,
431        default=1e-5,
432        metavar="D",
433        help="Target delta (default: 1e-5)",
434    )
435
436    parser.add_argument(
437        "--checkpoint-file",
438        type=str,
439        default="checkpoint",
440        help="path to save check points",
441    )
442    parser.add_argument(
443        "--data-root",
444        type=str,
445        default="../cifar10",
446        help="Where CIFAR10 is/will be stored",
447    )
448    parser.add_argument(
449        "--log-dir",
450        type=str,
451        default="/tmp/stat/tensorboard",
452        help="Where Tensorboard log will be stored",
453    )
454    parser.add_argument(
455        "--optim",
456        type=str,
457        default="SGD",
458        help="Optimizer to use (Adam, RMSprop, SGD)",
459    )
460    parser.add_argument(
461        "--lr-schedule", type=str, choices=["constant", "cos"], default="cos"
462    )
463
464    parser.add_argument(
465        "--device", type=str, default="cpu", help="Device on which to run the code."
466    )
467
468    parser.add_argument(
469        "--architecture",
470        type=str,
471        default="resnet18",
472        help="model from torchvision to run",
473    )
474
475    parser.add_argument(
476        "--gn-groups",
477        type=int,
478        default=8,
479        help="Number of groups in GroupNorm",
480    )
481
482    parser.add_argument(
483        "--clip-per-layer",
484        "--clip_per_layer",
485        action="store_true",
486        default=False,
487        help="Use static per-layer clipping with the same clipping threshold for each layer. Necessary for DDP. If `False` (default), uses flat clipping.",
488    )
489    parser.add_argument(
490        "--debug",
491        type=int,
492        default=0,
493        help="debug level (default: 0)",
494    )
495
496    return parser.parse_args()
497
498
499if __name__ == "__main__":
500    main()
501