1#!/usr/bin/env python3 2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 4""" 5Runs CIFAR10 training with differential privacy. 6""" 7 8import argparse 9import logging 10import shutil 11import sys 12from datetime import datetime, timedelta 13 14import numpy as np 15import torchvision.transforms as transforms 16from opacus import PrivacyEngine 17from torchvision import models 18from torchvision.datasets import CIFAR10 19from tqdm import tqdm 20 21import torch 22import torch.nn as nn 23import torch.optim as optim 24import torch.utils.data 25 26 27logging.basicConfig( 28 format="%(asctime)s:%(levelname)s:%(message)s", 29 datefmt="%m/%d/%Y %H:%M:%S", 30 stream=sys.stdout, 31) 32logger = logging.getLogger("ddp") 33logger.setLevel(level=logging.INFO) 34 35 36def save_checkpoint(state, is_best, filename="checkpoint.tar"): 37 torch.save(state, filename) 38 if is_best: 39 shutil.copyfile(filename, "model_best.pth.tar") 40 41 42def accuracy(preds, labels): 43 return (preds == labels).mean() 44 45 46def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): 47 start_time = datetime.now() 48 49 model.train() 50 criterion = nn.CrossEntropyLoss() 51 52 losses = [] 53 top1_acc = [] 54 55 for i, (images, target) in enumerate(tqdm(train_loader)): 56 images = images.to(device) 57 target = target.to(device) 58 59 # compute output 60 output = model(images) 61 loss = criterion(output, target) 62 preds = np.argmax(output.detach().cpu().numpy(), axis=1) 63 labels = target.detach().cpu().numpy() 64 65 # measure accuracy and record loss 66 acc1 = accuracy(preds, labels) 67 68 losses.append(loss.item()) 69 top1_acc.append(acc1) 70 71 # compute gradient and do SGD step 72 loss.backward() 73 74 # make sure we take a step after processing the last mini-batch in the 75 # epoch to ensure we start the next epoch with a clean state 76 optimizer.step() 77 optimizer.zero_grad() 78 79 if i % args.print_freq == 0: 80 if not args.disable_dp: 81 epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent( 82 delta=args.delta, 83 alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)), 84 ) 85 print( 86 f"\tTrain Epoch: {epoch} \t" 87 f"Loss: {np.mean(losses):.6f} " 88 f"Acc@1: {np.mean(top1_acc):.6f} " 89 f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}" 90 ) 91 else: 92 print( 93 f"\tTrain Epoch: {epoch} \t" 94 f"Loss: {np.mean(losses):.6f} " 95 f"Acc@1: {np.mean(top1_acc):.6f} " 96 ) 97 train_duration = datetime.now() - start_time 98 return train_duration 99 100 101def test(args, model, test_loader, device): 102 model.eval() 103 criterion = nn.CrossEntropyLoss() 104 losses = [] 105 top1_acc = [] 106 107 with torch.no_grad(): 108 for images, target in tqdm(test_loader): 109 images = images.to(device) 110 target = target.to(device) 111 112 output = model(images) 113 loss = criterion(output, target) 114 preds = np.argmax(output.detach().cpu().numpy(), axis=1) 115 labels = target.detach().cpu().numpy() 116 acc1 = accuracy(preds, labels) 117 118 losses.append(loss.item()) 119 top1_acc.append(acc1) 120 121 top1_avg = np.mean(top1_acc) 122 123 print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ") 124 return np.mean(top1_acc) 125 126 127# flake8: noqa: C901 128def main(): 129 args = parse_args() 130 131 if args.debug >= 1: 132 logger.setLevel(level=logging.DEBUG) 133 134 device = args.device 135 136 if args.secure_rng: 137 try: 138 import torchcsprng as prng 139 except ImportError as e: 140 msg = ( 141 "To use secure RNG, you must install the torchcsprng package! " 142 "Check out the instructions here: https://github.com/pytorch/csprng#installation" 143 ) 144 raise ImportError(msg) from e 145 146 generator = prng.create_random_device_generator("/dev/urandom") 147 148 else: 149 generator = None 150 151 augmentations = [ 152 transforms.RandomCrop(32, padding=4), 153 transforms.RandomHorizontalFlip(), 154 ] 155 normalize = [ 156 transforms.ToTensor(), 157 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 158 ] 159 train_transform = transforms.Compose( 160 augmentations + normalize if args.disable_dp else normalize 161 ) 162 163 test_transform = transforms.Compose(normalize) 164 165 train_dataset = CIFAR10( 166 root=args.data_root, train=True, download=True, transform=train_transform 167 ) 168 169 train_loader = torch.utils.data.DataLoader( 170 train_dataset, 171 batch_size=int(args.sample_rate * len(train_dataset)), 172 generator=generator, 173 num_workers=args.workers, 174 pin_memory=True, 175 ) 176 177 test_dataset = CIFAR10( 178 root=args.data_root, train=False, download=True, transform=test_transform 179 ) 180 test_loader = torch.utils.data.DataLoader( 181 test_dataset, 182 batch_size=args.batch_size_test, 183 shuffle=False, 184 num_workers=args.workers, 185 ) 186 187 best_acc1 = 0 188 189 model = models.__dict__[args.architecture]( 190 pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c)) 191 ) 192 model = model.to(device) 193 194 if args.optim == "SGD": 195 optimizer = optim.SGD( 196 model.parameters(), 197 lr=args.lr, 198 momentum=args.momentum, 199 weight_decay=args.weight_decay, 200 ) 201 elif args.optim == "RMSprop": 202 optimizer = optim.RMSprop(model.parameters(), lr=args.lr) 203 elif args.optim == "Adam": 204 optimizer = optim.Adam(model.parameters(), lr=args.lr) 205 else: 206 raise NotImplementedError("Optimizer not recognized. Please check spelling") 207 208 privacy_engine = None 209 if not args.disable_dp: 210 if args.clip_per_layer: 211 # Each layer has the same clipping threshold. The total grad norm is still bounded by `args.max_per_sample_grad_norm`. 212 n_layers = len( 213 [(n, p) for n, p in model.named_parameters() if p.requires_grad] 214 ) 215 max_grad_norm = [ 216 args.max_per_sample_grad_norm / np.sqrt(n_layers) 217 ] * n_layers 218 else: 219 max_grad_norm = args.max_per_sample_grad_norm 220 221 privacy_engine = PrivacyEngine( 222 secure_mode=args.secure_rng, 223 ) 224 clipping = "per_layer" if args.clip_per_layer else "flat" 225 model, optimizer, train_loader = privacy_engine.make_private( 226 module=model, 227 optimizer=optimizer, 228 data_loader=train_loader, 229 noise_multiplier=args.sigma, 230 max_grad_norm=max_grad_norm, 231 clipping=clipping, 232 ) 233 234 # Store some logs 235 accuracy_per_epoch = [] 236 time_per_epoch = [] 237 238 for epoch in range(args.start_epoch, args.epochs + 1): 239 if args.lr_schedule == "cos": 240 lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1))) 241 for param_group in optimizer.param_groups: 242 param_group["lr"] = lr 243 244 train_duration = train( 245 args, model, train_loader, optimizer, privacy_engine, epoch, device 246 ) 247 top1_acc = test(args, model, test_loader, device) 248 249 # remember best acc@1 and save checkpoint 250 is_best = top1_acc > best_acc1 251 best_acc1 = max(top1_acc, best_acc1) 252 253 time_per_epoch.append(train_duration) 254 accuracy_per_epoch.append(float(top1_acc)) 255 256 save_checkpoint( 257 { 258 "epoch": epoch + 1, 259 "arch": "Convnet", 260 "state_dict": model.state_dict(), 261 "best_acc1": best_acc1, 262 "optimizer": optimizer.state_dict(), 263 }, 264 is_best, 265 filename=args.checkpoint_file + ".tar", 266 ) 267 268 time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch] 269 avg_time_per_epoch = sum(time_per_epoch_seconds) / len(time_per_epoch_seconds) 270 metrics = { 271 "accuracy": best_acc1, 272 "accuracy_per_epoch": accuracy_per_epoch, 273 "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))), 274 "time_per_epoch": time_per_epoch_seconds, 275 } 276 277 logger.info( 278 "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n" 279 ) 280 logger.info(metrics) 281 282 283def parse_args(): 284 parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") 285 parser.add_argument( 286 "-j", 287 "--workers", 288 default=2, 289 type=int, 290 metavar="N", 291 help="number of data loading workers (default: 2)", 292 ) 293 parser.add_argument( 294 "--epochs", 295 default=90, 296 type=int, 297 metavar="N", 298 help="number of total epochs to run", 299 ) 300 parser.add_argument( 301 "--start-epoch", 302 default=1, 303 type=int, 304 metavar="N", 305 help="manual epoch number (useful on restarts)", 306 ) 307 parser.add_argument( 308 "-b", 309 "--batch-size-test", 310 default=256, 311 type=int, 312 metavar="N", 313 help="mini-batch size for test dataset (default: 256)", 314 ) 315 parser.add_argument( 316 "--sample-rate", 317 default=0.005, 318 type=float, 319 metavar="SR", 320 help="sample rate used for batch construction (default: 0.005)", 321 ) 322 parser.add_argument( 323 "--lr", 324 "--learning-rate", 325 default=0.1, 326 type=float, 327 metavar="LR", 328 help="initial learning rate", 329 dest="lr", 330 ) 331 parser.add_argument( 332 "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum" 333 ) 334 parser.add_argument( 335 "--wd", 336 "--weight-decay", 337 default=0, 338 type=float, 339 metavar="W", 340 help="SGD weight decay", 341 dest="weight_decay", 342 ) 343 parser.add_argument( 344 "-p", 345 "--print-freq", 346 default=10, 347 type=int, 348 metavar="N", 349 help="print frequency (default: 10)", 350 ) 351 parser.add_argument( 352 "--resume", 353 default="", 354 type=str, 355 metavar="PATH", 356 help="path to latest checkpoint (default: none)", 357 ) 358 parser.add_argument( 359 "-e", 360 "--evaluate", 361 dest="evaluate", 362 action="store_true", 363 help="evaluate model on validation set", 364 ) 365 parser.add_argument( 366 "--seed", default=None, type=int, help="seed for initializing training. " 367 ) 368 369 parser.add_argument( 370 "--sigma", 371 type=float, 372 default=1.5, 373 metavar="S", 374 help="Noise multiplier (default 1.0)", 375 ) 376 parser.add_argument( 377 "-c", 378 "--max-per-sample-grad_norm", 379 type=float, 380 default=10.0, 381 metavar="C", 382 help="Clip per-sample gradients to this norm (default 1.0)", 383 ) 384 parser.add_argument( 385 "--disable-dp", 386 action="store_true", 387 default=False, 388 help="Disable privacy training and just train with vanilla SGD", 389 ) 390 parser.add_argument( 391 "--secure-rng", 392 action="store_true", 393 default=False, 394 help="Enable Secure RNG to have trustworthy privacy guarantees." 395 "Comes at a performance cost. Opacus will emit a warning if secure rng is off," 396 "indicating that for production use it's recommender to turn it on.", 397 ) 398 parser.add_argument( 399 "--delta", 400 type=float, 401 default=1e-5, 402 metavar="D", 403 help="Target delta (default: 1e-5)", 404 ) 405 406 parser.add_argument( 407 "--checkpoint-file", 408 type=str, 409 default="checkpoint", 410 help="path to save check points", 411 ) 412 parser.add_argument( 413 "--data-root", 414 type=str, 415 default="../cifar10", 416 help="Where CIFAR10 is/will be stored", 417 ) 418 parser.add_argument( 419 "--log-dir", 420 type=str, 421 default="/tmp/stat/tensorboard", 422 help="Where Tensorboard log will be stored", 423 ) 424 parser.add_argument( 425 "--optim", 426 type=str, 427 default="SGD", 428 help="Optimizer to use (Adam, RMSprop, SGD)", 429 ) 430 parser.add_argument( 431 "--lr-schedule", type=str, choices=["constant", "cos"], default="cos" 432 ) 433 434 parser.add_argument( 435 "--device", type=str, default="cuda", help="Device on which to run the code." 436 ) 437 438 parser.add_argument( 439 "--architecture", 440 type=str, 441 default="resnet18", 442 help="model from torchvision to run", 443 ) 444 445 parser.add_argument( 446 "--gn-groups", 447 type=int, 448 default=8, 449 help="Number of groups in GroupNorm", 450 ) 451 452 parser.add_argument( 453 "--clip-per-layer", 454 "--clip_per_layer", 455 action="store_true", 456 default=False, 457 help="Use static per-layer clipping with the same clipping threshold for each layer. Necessary for DDP. If `False` (default), uses flat clipping.", 458 ) 459 parser.add_argument( 460 "--debug", 461 type=int, 462 default=0, 463 help="debug level (default: 0)", 464 ) 465 466 return parser.parse_args() 467 468 469if __name__ == "__main__": 470 main() 471