import time import torchvision.models as models from opacus import PrivacyEngine from opacus.utils.module_modification import convert_batchnorm_modules import torch import torch.nn as nn from functorch import grad, make_functional, vmap device = "cuda" batch_size = 128 torch.manual_seed(0) model_functorch = convert_batchnorm_modules(models.resnet18(num_classes=10)) model_functorch = model_functorch.to(device) criterion = nn.CrossEntropyLoss() images = torch.randn(batch_size, 3, 32, 32, device=device) targets = torch.randint(0, 10, (batch_size,), device=device) func_model, weights = make_functional(model_functorch) def compute_loss(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) output = func_model(weights, images) loss = criterion(output, targets) return loss def functorch_per_sample_grad(): compute_grad = grad(compute_loss) compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) start = time.time() result = compute_per_sample_grad(weights, images, targets) torch.cuda.synchronize() end = time.time() return result, end - start # end - start in seconds torch.manual_seed(0) model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10)) model_opacus = model_opacus.to(device) criterion = nn.CrossEntropyLoss() for p_f, p_o in zip(model_functorch.parameters(), model_opacus.parameters()): assert torch.allclose(p_f, p_o) # Sanity check privacy_engine = PrivacyEngine( model_opacus, sample_rate=0.01, alphas=[10, 100], noise_multiplier=1, max_grad_norm=10000.0, ) def opacus_per_sample_grad(): start = time.time() output = model_opacus(images) loss = criterion(output, targets) loss.backward() torch.cuda.synchronize() end = time.time() expected = [p.grad_sample for p in model_opacus.parameters()] for p in model_opacus.parameters(): delattr(p, "grad_sample") p.grad = None return expected, end - start for _ in range(5): _, seconds = functorch_per_sample_grad() print(seconds) result, seconds = functorch_per_sample_grad() print(seconds) for _ in range(5): _, seconds = opacus_per_sample_grad() print(seconds) expected, seconds = opacus_per_sample_grad() print(seconds) result = [r.detach() for r in result] print(len(result)) # TODO: The following shows that the per-sample-grads computed are different. # This concerns me a little; we should compare to a source of truth. # for i, (r, e) in enumerate(list(zip(result, expected))[::-1]): # if torch.allclose(r, e, rtol=1e-5): # continue # print(-(i+1), ((r - e)/(e + 0.000001)).abs().max())