• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
4"""
5Runs CIFAR10 training with differential privacy.
6"""
7
8import argparse
9import logging
10import shutil
11import sys
12from datetime import datetime, timedelta
13
14import numpy as np
15import torchvision.transforms as transforms
16from opacus import PrivacyEngine
17from torchvision import models
18from torchvision.datasets import CIFAR10
19from tqdm import tqdm
20
21import torch
22import torch.nn as nn
23import torch.optim as optim
24import torch.utils.data
25
26
27logging.basicConfig(
28    format="%(asctime)s:%(levelname)s:%(message)s",
29    datefmt="%m/%d/%Y %H:%M:%S",
30    stream=sys.stdout,
31)
32logger = logging.getLogger("ddp")
33logger.setLevel(level=logging.INFO)
34
35
36def save_checkpoint(state, is_best, filename="checkpoint.tar"):
37    torch.save(state, filename)
38    if is_best:
39        shutil.copyfile(filename, "model_best.pth.tar")
40
41
42def accuracy(preds, labels):
43    return (preds == labels).mean()
44
45
46def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
47    start_time = datetime.now()
48
49    model.train()
50    criterion = nn.CrossEntropyLoss()
51
52    losses = []
53    top1_acc = []
54
55    for i, (images, target) in enumerate(tqdm(train_loader)):
56        images = images.to(device)
57        target = target.to(device)
58
59        # compute output
60        output = model(images)
61        loss = criterion(output, target)
62        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
63        labels = target.detach().cpu().numpy()
64
65        # measure accuracy and record loss
66        acc1 = accuracy(preds, labels)
67
68        losses.append(loss.item())
69        top1_acc.append(acc1)
70
71        # compute gradient and do SGD step
72        loss.backward()
73
74        # make sure we take a step after processing the last mini-batch in the
75        # epoch to ensure we start the next epoch with a clean state
76        optimizer.step()
77        optimizer.zero_grad()
78
79        if i % args.print_freq == 0:
80            if not args.disable_dp:
81                epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent(
82                    delta=args.delta,
83                    alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
84                )
85                print(
86                    f"\tTrain Epoch: {epoch} \t"
87                    f"Loss: {np.mean(losses):.6f} "
88                    f"Acc@1: {np.mean(top1_acc):.6f} "
89                    f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}"
90                )
91            else:
92                print(
93                    f"\tTrain Epoch: {epoch} \t"
94                    f"Loss: {np.mean(losses):.6f} "
95                    f"Acc@1: {np.mean(top1_acc):.6f} "
96                )
97    train_duration = datetime.now() - start_time
98    return train_duration
99
100
101def test(args, model, test_loader, device):
102    model.eval()
103    criterion = nn.CrossEntropyLoss()
104    losses = []
105    top1_acc = []
106
107    with torch.no_grad():
108        for images, target in tqdm(test_loader):
109            images = images.to(device)
110            target = target.to(device)
111
112            output = model(images)
113            loss = criterion(output, target)
114            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
115            labels = target.detach().cpu().numpy()
116            acc1 = accuracy(preds, labels)
117
118            losses.append(loss.item())
119            top1_acc.append(acc1)
120
121    top1_avg = np.mean(top1_acc)
122
123    print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
124    return np.mean(top1_acc)
125
126
127# flake8: noqa: C901
128def main():
129    args = parse_args()
130
131    if args.debug >= 1:
132        logger.setLevel(level=logging.DEBUG)
133
134    device = args.device
135
136    if args.secure_rng:
137        try:
138            import torchcsprng as prng
139        except ImportError as e:
140            msg = (
141                "To use secure RNG, you must install the torchcsprng package! "
142                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
143            )
144            raise ImportError(msg) from e
145
146        generator = prng.create_random_device_generator("/dev/urandom")
147
148    else:
149        generator = None
150
151    augmentations = [
152        transforms.RandomCrop(32, padding=4),
153        transforms.RandomHorizontalFlip(),
154    ]
155    normalize = [
156        transforms.ToTensor(),
157        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
158    ]
159    train_transform = transforms.Compose(
160        augmentations + normalize if args.disable_dp else normalize
161    )
162
163    test_transform = transforms.Compose(normalize)
164
165    train_dataset = CIFAR10(
166        root=args.data_root, train=True, download=True, transform=train_transform
167    )
168
169    train_loader = torch.utils.data.DataLoader(
170        train_dataset,
171        batch_size=int(args.sample_rate * len(train_dataset)),
172        generator=generator,
173        num_workers=args.workers,
174        pin_memory=True,
175    )
176
177    test_dataset = CIFAR10(
178        root=args.data_root, train=False, download=True, transform=test_transform
179    )
180    test_loader = torch.utils.data.DataLoader(
181        test_dataset,
182        batch_size=args.batch_size_test,
183        shuffle=False,
184        num_workers=args.workers,
185    )
186
187    best_acc1 = 0
188
189    model = models.__dict__[args.architecture](
190        pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c))
191    )
192    model = model.to(device)
193
194    if args.optim == "SGD":
195        optimizer = optim.SGD(
196            model.parameters(),
197            lr=args.lr,
198            momentum=args.momentum,
199            weight_decay=args.weight_decay,
200        )
201    elif args.optim == "RMSprop":
202        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
203    elif args.optim == "Adam":
204        optimizer = optim.Adam(model.parameters(), lr=args.lr)
205    else:
206        raise NotImplementedError("Optimizer not recognized. Please check spelling")
207
208    privacy_engine = None
209    if not args.disable_dp:
210        if args.clip_per_layer:
211            # Each layer has the same clipping threshold. The total grad norm is still bounded by `args.max_per_sample_grad_norm`.
212            n_layers = len(
213                [(n, p) for n, p in model.named_parameters() if p.requires_grad]
214            )
215            max_grad_norm = [
216                args.max_per_sample_grad_norm / np.sqrt(n_layers)
217            ] * n_layers
218        else:
219            max_grad_norm = args.max_per_sample_grad_norm
220
221        privacy_engine = PrivacyEngine(
222            secure_mode=args.secure_rng,
223        )
224        clipping = "per_layer" if args.clip_per_layer else "flat"
225        model, optimizer, train_loader = privacy_engine.make_private(
226            module=model,
227            optimizer=optimizer,
228            data_loader=train_loader,
229            noise_multiplier=args.sigma,
230            max_grad_norm=max_grad_norm,
231            clipping=clipping,
232        )
233
234    # Store some logs
235    accuracy_per_epoch = []
236    time_per_epoch = []
237
238    for epoch in range(args.start_epoch, args.epochs + 1):
239        if args.lr_schedule == "cos":
240            lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1)))
241            for param_group in optimizer.param_groups:
242                param_group["lr"] = lr
243
244        train_duration = train(
245            args, model, train_loader, optimizer, privacy_engine, epoch, device
246        )
247        top1_acc = test(args, model, test_loader, device)
248
249        # remember best acc@1 and save checkpoint
250        is_best = top1_acc > best_acc1
251        best_acc1 = max(top1_acc, best_acc1)
252
253        time_per_epoch.append(train_duration)
254        accuracy_per_epoch.append(float(top1_acc))
255
256        save_checkpoint(
257            {
258                "epoch": epoch + 1,
259                "arch": "Convnet",
260                "state_dict": model.state_dict(),
261                "best_acc1": best_acc1,
262                "optimizer": optimizer.state_dict(),
263            },
264            is_best,
265            filename=args.checkpoint_file + ".tar",
266        )
267
268    time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch]
269    avg_time_per_epoch = sum(time_per_epoch_seconds) / len(time_per_epoch_seconds)
270    metrics = {
271        "accuracy": best_acc1,
272        "accuracy_per_epoch": accuracy_per_epoch,
273        "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))),
274        "time_per_epoch": time_per_epoch_seconds,
275    }
276
277    logger.info(
278        "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n"
279    )
280    logger.info(metrics)
281
282
283def parse_args():
284    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
285    parser.add_argument(
286        "-j",
287        "--workers",
288        default=2,
289        type=int,
290        metavar="N",
291        help="number of data loading workers (default: 2)",
292    )
293    parser.add_argument(
294        "--epochs",
295        default=90,
296        type=int,
297        metavar="N",
298        help="number of total epochs to run",
299    )
300    parser.add_argument(
301        "--start-epoch",
302        default=1,
303        type=int,
304        metavar="N",
305        help="manual epoch number (useful on restarts)",
306    )
307    parser.add_argument(
308        "-b",
309        "--batch-size-test",
310        default=256,
311        type=int,
312        metavar="N",
313        help="mini-batch size for test dataset (default: 256)",
314    )
315    parser.add_argument(
316        "--sample-rate",
317        default=0.005,
318        type=float,
319        metavar="SR",
320        help="sample rate used for batch construction (default: 0.005)",
321    )
322    parser.add_argument(
323        "--lr",
324        "--learning-rate",
325        default=0.1,
326        type=float,
327        metavar="LR",
328        help="initial learning rate",
329        dest="lr",
330    )
331    parser.add_argument(
332        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
333    )
334    parser.add_argument(
335        "--wd",
336        "--weight-decay",
337        default=0,
338        type=float,
339        metavar="W",
340        help="SGD weight decay",
341        dest="weight_decay",
342    )
343    parser.add_argument(
344        "-p",
345        "--print-freq",
346        default=10,
347        type=int,
348        metavar="N",
349        help="print frequency (default: 10)",
350    )
351    parser.add_argument(
352        "--resume",
353        default="",
354        type=str,
355        metavar="PATH",
356        help="path to latest checkpoint (default: none)",
357    )
358    parser.add_argument(
359        "-e",
360        "--evaluate",
361        dest="evaluate",
362        action="store_true",
363        help="evaluate model on validation set",
364    )
365    parser.add_argument(
366        "--seed", default=None, type=int, help="seed for initializing training. "
367    )
368
369    parser.add_argument(
370        "--sigma",
371        type=float,
372        default=1.5,
373        metavar="S",
374        help="Noise multiplier (default 1.0)",
375    )
376    parser.add_argument(
377        "-c",
378        "--max-per-sample-grad_norm",
379        type=float,
380        default=10.0,
381        metavar="C",
382        help="Clip per-sample gradients to this norm (default 1.0)",
383    )
384    parser.add_argument(
385        "--disable-dp",
386        action="store_true",
387        default=False,
388        help="Disable privacy training and just train with vanilla SGD",
389    )
390    parser.add_argument(
391        "--secure-rng",
392        action="store_true",
393        default=False,
394        help="Enable Secure RNG to have trustworthy privacy guarantees."
395        "Comes at a performance cost. Opacus will emit a warning if secure rng is off,"
396        "indicating that for production use it's recommender to turn it on.",
397    )
398    parser.add_argument(
399        "--delta",
400        type=float,
401        default=1e-5,
402        metavar="D",
403        help="Target delta (default: 1e-5)",
404    )
405
406    parser.add_argument(
407        "--checkpoint-file",
408        type=str,
409        default="checkpoint",
410        help="path to save check points",
411    )
412    parser.add_argument(
413        "--data-root",
414        type=str,
415        default="../cifar10",
416        help="Where CIFAR10 is/will be stored",
417    )
418    parser.add_argument(
419        "--log-dir",
420        type=str,
421        default="/tmp/stat/tensorboard",
422        help="Where Tensorboard log will be stored",
423    )
424    parser.add_argument(
425        "--optim",
426        type=str,
427        default="SGD",
428        help="Optimizer to use (Adam, RMSprop, SGD)",
429    )
430    parser.add_argument(
431        "--lr-schedule", type=str, choices=["constant", "cos"], default="cos"
432    )
433
434    parser.add_argument(
435        "--device", type=str, default="cuda", help="Device on which to run the code."
436    )
437
438    parser.add_argument(
439        "--architecture",
440        type=str,
441        default="resnet18",
442        help="model from torchvision to run",
443    )
444
445    parser.add_argument(
446        "--gn-groups",
447        type=int,
448        default=8,
449        help="Number of groups in GroupNorm",
450    )
451
452    parser.add_argument(
453        "--clip-per-layer",
454        "--clip_per_layer",
455        action="store_true",
456        default=False,
457        help="Use static per-layer clipping with the same clipping threshold for each layer. Necessary for DDP. If `False` (default), uses flat clipping.",
458    )
459    parser.add_argument(
460        "--debug",
461        type=int,
462        default=0,
463        help="debug level (default: 0)",
464    )
465
466    return parser.parse_args()
467
468
469if __name__ == "__main__":
470    main()
471