1#!/usr/bin/env python3 2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 4""" 5Runs CIFAR10 training with differential privacy. 6""" 7 8import argparse 9import logging 10import shutil 11import sys 12from datetime import datetime, timedelta 13 14import numpy as np 15import torchvision.transforms as transforms 16from torchvision import models 17from torchvision.datasets import CIFAR10 18from tqdm import tqdm 19 20import torch 21import torch.nn as nn 22import torch.optim as optim 23import torch.utils.data 24from torch.func import functional_call, grad_and_value, vmap 25 26 27logging.basicConfig( 28 format="%(asctime)s:%(levelname)s:%(message)s", 29 datefmt="%m/%d/%Y %H:%M:%S", 30 stream=sys.stdout, 31) 32logger = logging.getLogger("ddp") 33logger.setLevel(level=logging.INFO) 34 35 36def save_checkpoint(state, is_best, filename="checkpoint.tar"): 37 torch.save(state, filename) 38 if is_best: 39 shutil.copyfile(filename, "model_best.pth.tar") 40 41 42def accuracy(preds, labels): 43 return (preds == labels).mean() 44 45 46def compute_norms(sample_grads): 47 batch_size = sample_grads[0].shape[0] 48 norms = [ 49 sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads 50 ] 51 norms = torch.stack(norms, dim=0).norm(2, dim=0) 52 return norms, batch_size 53 54 55def clip_and_accumulate_and_add_noise( 56 model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0 57): 58 sample_grads = tuple(param.grad_sample for param in model.parameters()) 59 60 # step 0: compute the norms 61 sample_norms, batch_size = compute_norms(sample_grads) 62 63 # step 1: compute clipping factors 64 clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6) 65 clip_factor = clip_factor.clamp(max=1.0) 66 67 # step 2: clip 68 grads = tuple( 69 torch.einsum("i,i...", clip_factor, sample_grad) for sample_grad in sample_grads 70 ) 71 72 # step 3: add gaussian noise 73 stddev = max_per_sample_grad_norm * noise_multiplier 74 noises = tuple( 75 torch.normal(0, stddev, grad_param.shape, device=grad_param.device) 76 for grad_param in grads 77 ) 78 grads = tuple(noise + grad_param for noise, grad_param in zip(noises, grads)) 79 80 # step 4: assign the new grads, delete the sample grads 81 for param, param_grad in zip(model.parameters(), grads): 82 param.grad = param_grad / batch_size 83 del param.grad_sample 84 85 86def train(args, model, train_loader, optimizer, epoch, device): 87 start_time = datetime.now() 88 89 criterion = nn.CrossEntropyLoss() 90 91 losses = [] 92 top1_acc = [] 93 94 for i, (images, target) in enumerate(tqdm(train_loader)): 95 images = images.to(device) 96 target = target.to(device) 97 98 # Step 1: compute per-sample-grads 99 100 # To use vmap+grad to compute per-sample-grads, the forward pass 101 # must be re-formulated on a single example. 102 # We use the `grad` operator to compute forward+backward on a single example, 103 # and finally `vmap` to do forward+backward on multiple examples. 104 def compute_loss_and_output(weights, image, target): 105 images = image.unsqueeze(0) 106 targets = target.unsqueeze(0) 107 output = functional_call(model, weights, images) 108 loss = criterion(output, targets) 109 return loss, output.squeeze(0) 110 111 # `grad(f)` is a functional API that returns a function `f'` that 112 # computes gradients by running both the forward and backward pass. 113 # We want to extract some intermediate 114 # values from the computation (i.e. the loss and output). 115 # 116 # To extract the loss, we use the `grad_and_value` API, that returns the 117 # gradient of the weights w.r.t. the loss and the loss. 118 # 119 # To extract the output, we use the `has_aux=True` flag. 120 # `has_aux=True` assumes that `f` returns a tuple of two values, 121 # where the first is to be differentiated and the second "auxiliary value" 122 # is not to be differentiated. `f'` returns the gradient w.r.t. the loss, 123 # the loss, and the auxiliary value. 124 grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True) 125 weights = dict(model.named_parameters()) 126 127 # detaching weights since we don't need to track gradients outside of transforms 128 # and this is more performant 129 detached_weights = {k: v.detach() for k, v in weights.items()} 130 sample_grads, (sample_loss, output) = vmap(grads_loss_output, (None, 0, 0))( 131 detached_weights, images, target 132 ) 133 loss = sample_loss.mean() 134 135 for name, grad_sample in sample_grads.items(): 136 weights[name].grad_sample = grad_sample.detach() 137 138 # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise 139 clip_and_accumulate_and_add_noise( 140 model, args.max_per_sample_grad_norm, args.sigma 141 ) 142 143 preds = np.argmax(output.detach().cpu().numpy(), axis=1) 144 labels = target.detach().cpu().numpy() 145 losses.append(loss.item()) 146 147 # measure accuracy and record loss 148 acc1 = accuracy(preds, labels) 149 150 top1_acc.append(acc1) 151 152 # make sure we take a step after processing the last mini-batch in the 153 # epoch to ensure we start the next epoch with a clean state 154 optimizer.step() 155 optimizer.zero_grad() 156 157 if i % args.print_freq == 0: 158 print( 159 f"\tTrain Epoch: {epoch} \t" 160 f"Loss: {np.mean(losses):.6f} " 161 f"Acc@1: {np.mean(top1_acc):.6f} " 162 ) 163 train_duration = datetime.now() - start_time 164 return train_duration 165 166 167def test(args, model, test_loader, device): 168 model.eval() 169 criterion = nn.CrossEntropyLoss() 170 losses = [] 171 top1_acc = [] 172 173 with torch.no_grad(): 174 for images, target in tqdm(test_loader): 175 images = images.to(device) 176 target = target.to(device) 177 178 output = model(images) 179 loss = criterion(output, target) 180 preds = np.argmax(output.detach().cpu().numpy(), axis=1) 181 labels = target.detach().cpu().numpy() 182 acc1 = accuracy(preds, labels) 183 184 losses.append(loss.item()) 185 top1_acc.append(acc1) 186 187 top1_avg = np.mean(top1_acc) 188 189 print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ") 190 return np.mean(top1_acc) 191 192 193# flake8: noqa: C901 194def main(): 195 args = parse_args() 196 197 if args.debug >= 1: 198 logger.setLevel(level=logging.DEBUG) 199 200 device = args.device 201 202 if args.secure_rng: 203 try: 204 import torchcsprng as prng 205 except ImportError as e: 206 msg = ( 207 "To use secure RNG, you must install the torchcsprng package! " 208 "Check out the instructions here: https://github.com/pytorch/csprng#installation" 209 ) 210 raise ImportError(msg) from e 211 212 generator = prng.create_random_device_generator("/dev/urandom") 213 214 else: 215 generator = None 216 217 augmentations = [ 218 transforms.RandomCrop(32, padding=4), 219 transforms.RandomHorizontalFlip(), 220 ] 221 normalize = [ 222 transforms.ToTensor(), 223 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 224 ] 225 train_transform = transforms.Compose(normalize) 226 227 test_transform = transforms.Compose(normalize) 228 229 train_dataset = CIFAR10( 230 root=args.data_root, train=True, download=True, transform=train_transform 231 ) 232 233 train_loader = torch.utils.data.DataLoader( 234 train_dataset, 235 batch_size=int(args.sample_rate * len(train_dataset)), 236 generator=generator, 237 num_workers=args.workers, 238 pin_memory=True, 239 ) 240 241 test_dataset = CIFAR10( 242 root=args.data_root, train=False, download=True, transform=test_transform 243 ) 244 test_loader = torch.utils.data.DataLoader( 245 test_dataset, 246 batch_size=args.batch_size_test, 247 shuffle=False, 248 num_workers=args.workers, 249 ) 250 251 best_acc1 = 0 252 253 model = models.__dict__[args.architecture]( 254 pretrained=False, norm_layer=(lambda c: nn.GroupNorm(args.gn_groups, c)) 255 ) 256 model = model.to(device) 257 258 if args.optim == "SGD": 259 optimizer = optim.SGD( 260 model.parameters(), 261 lr=args.lr, 262 momentum=args.momentum, 263 weight_decay=args.weight_decay, 264 ) 265 elif args.optim == "RMSprop": 266 optimizer = optim.RMSprop(model.parameters(), lr=args.lr) 267 elif args.optim == "Adam": 268 optimizer = optim.Adam(model.parameters(), lr=args.lr) 269 else: 270 raise NotImplementedError("Optimizer not recognized. Please check spelling") 271 272 # Store some logs 273 accuracy_per_epoch = [] 274 time_per_epoch = [] 275 276 for epoch in range(args.start_epoch, args.epochs + 1): 277 if args.lr_schedule == "cos": 278 lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1))) 279 for param_group in optimizer.param_groups: 280 param_group["lr"] = lr 281 282 train_duration = train(args, model, train_loader, optimizer, epoch, device) 283 top1_acc = test(args, model, test_loader, device) 284 285 # remember best acc@1 and save checkpoint 286 is_best = top1_acc > best_acc1 287 best_acc1 = max(top1_acc, best_acc1) 288 289 time_per_epoch.append(train_duration) 290 accuracy_per_epoch.append(float(top1_acc)) 291 292 save_checkpoint( 293 { 294 "epoch": epoch + 1, 295 "arch": "Convnet", 296 "state_dict": model.state_dict(), 297 "best_acc1": best_acc1, 298 "optimizer": optimizer.state_dict(), 299 }, 300 is_best, 301 filename=args.checkpoint_file + ".tar", 302 ) 303 304 time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch] 305 avg_time_per_epoch = sum(time_per_epoch_seconds) / len(time_per_epoch_seconds) 306 metrics = { 307 "accuracy": best_acc1, 308 "accuracy_per_epoch": accuracy_per_epoch, 309 "avg_time_per_epoch_str": str(timedelta(seconds=int(avg_time_per_epoch))), 310 "time_per_epoch": time_per_epoch_seconds, 311 } 312 313 logger.info( 314 "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n" 315 ) 316 logger.info(metrics) 317 318 319def parse_args(): 320 parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") 321 parser.add_argument( 322 "-j", 323 "--workers", 324 default=2, 325 type=int, 326 metavar="N", 327 help="number of data loading workers (default: 2)", 328 ) 329 parser.add_argument( 330 "--epochs", 331 default=90, 332 type=int, 333 metavar="N", 334 help="number of total epochs to run", 335 ) 336 parser.add_argument( 337 "--start-epoch", 338 default=1, 339 type=int, 340 metavar="N", 341 help="manual epoch number (useful on restarts)", 342 ) 343 parser.add_argument( 344 "-b", 345 "--batch-size-test", 346 default=256, 347 type=int, 348 metavar="N", 349 help="mini-batch size for test dataset (default: 256)", 350 ) 351 parser.add_argument( 352 "--sample-rate", 353 default=0.005, 354 type=float, 355 metavar="SR", 356 help="sample rate used for batch construction (default: 0.005)", 357 ) 358 parser.add_argument( 359 "--lr", 360 "--learning-rate", 361 default=0.1, 362 type=float, 363 metavar="LR", 364 help="initial learning rate", 365 dest="lr", 366 ) 367 parser.add_argument( 368 "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum" 369 ) 370 parser.add_argument( 371 "--wd", 372 "--weight-decay", 373 default=0, 374 type=float, 375 metavar="W", 376 help="SGD weight decay", 377 dest="weight_decay", 378 ) 379 parser.add_argument( 380 "-p", 381 "--print-freq", 382 default=10, 383 type=int, 384 metavar="N", 385 help="print frequency (default: 10)", 386 ) 387 parser.add_argument( 388 "--resume", 389 default="", 390 type=str, 391 metavar="PATH", 392 help="path to latest checkpoint (default: none)", 393 ) 394 parser.add_argument( 395 "-e", 396 "--evaluate", 397 dest="evaluate", 398 action="store_true", 399 help="evaluate model on validation set", 400 ) 401 parser.add_argument( 402 "--seed", default=None, type=int, help="seed for initializing training. " 403 ) 404 405 parser.add_argument( 406 "--sigma", 407 type=float, 408 default=1.5, 409 metavar="S", 410 help="Noise multiplier (default 1.0)", 411 ) 412 parser.add_argument( 413 "-c", 414 "--max-per-sample-grad_norm", 415 type=float, 416 default=10.0, 417 metavar="C", 418 help="Clip per-sample gradients to this norm (default 1.0)", 419 ) 420 parser.add_argument( 421 "--secure-rng", 422 action="store_true", 423 default=False, 424 help="Enable Secure RNG to have trustworthy privacy guarantees." 425 "Comes at a performance cost. Opacus will emit a warning if secure rng is off," 426 "indicating that for production use it's recommender to turn it on.", 427 ) 428 parser.add_argument( 429 "--delta", 430 type=float, 431 default=1e-5, 432 metavar="D", 433 help="Target delta (default: 1e-5)", 434 ) 435 436 parser.add_argument( 437 "--checkpoint-file", 438 type=str, 439 default="checkpoint", 440 help="path to save check points", 441 ) 442 parser.add_argument( 443 "--data-root", 444 type=str, 445 default="../cifar10", 446 help="Where CIFAR10 is/will be stored", 447 ) 448 parser.add_argument( 449 "--log-dir", 450 type=str, 451 default="/tmp/stat/tensorboard", 452 help="Where Tensorboard log will be stored", 453 ) 454 parser.add_argument( 455 "--optim", 456 type=str, 457 default="SGD", 458 help="Optimizer to use (Adam, RMSprop, SGD)", 459 ) 460 parser.add_argument( 461 "--lr-schedule", type=str, choices=["constant", "cos"], default="cos" 462 ) 463 464 parser.add_argument( 465 "--device", type=str, default="cpu", help="Device on which to run the code." 466 ) 467 468 parser.add_argument( 469 "--architecture", 470 type=str, 471 default="resnet18", 472 help="model from torchvision to run", 473 ) 474 475 parser.add_argument( 476 "--gn-groups", 477 type=int, 478 default=8, 479 help="Number of groups in GroupNorm", 480 ) 481 482 parser.add_argument( 483 "--clip-per-layer", 484 "--clip_per_layer", 485 action="store_true", 486 default=False, 487 help="Use static per-layer clipping with the same clipping threshold for each layer. Necessary for DDP. If `False` (default), uses flat clipping.", 488 ) 489 parser.add_argument( 490 "--debug", 491 type=int, 492 default=0, 493 help="debug level (default: 0)", 494 ) 495 496 return parser.parse_args() 497 498 499if __name__ == "__main__": 500 main() 501