In this tutorial we will learn to do the following:
To train a model with Opacus there are three privacy-specific hyper-parameters that must be tuned for better performance:
We use the hyper-parameter values below to obtain results in the last section:
MAX_GRAD_NORM = 1.2
NOISE_MULTIPLIER = .38
DELTA = 1e-5
LR = 1e-3
NUM_WORKERS = 2
There's another constraint we should be mindful of—memory. To balance peak memory requirement, which is proportional to batch_size^2
, and training performance, we use virtual batches. With virtual batches we can separate physical steps (gradient computation) and logical steps (noise addition and parameter updates): use larger batches for training, while keeping memory footprint low. Below we will specify two constants:
BATCH_SIZE = 128
VIRTUAL_BATCH_SIZE = 512
Now, let's load the CIFAR10 dataset. We don't use data augmentation here because, in our experiments, we found that data augmentation lowers utility when training with DP.
import torch
import torchvision
import torchvision.transforms as transforms
# These values, specific to the CIFAR10 dataset, are assumed to be known.
# If necessary, they can be computed with modest privacy budget.
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV),
])
Using torchvision datasets, we can load CIFAR10 and transform the PILImage images to Tensors of normalized range [-1, 1]
from torchvision.datasets import CIFAR10
DATA_ROOT = '../cifar10'
train_dataset = CIFAR10(
root=DATA_ROOT, train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
)
test_dataset = CIFAR10(
root=DATA_ROOT, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
)
from torchvision import models
model = models.resnet18(num_classes=10)
Now, let’s check if the model is compatible with Opacus. Opacus does not support all type of Pytorch layers. To check if your model is compatible with the privacy engine, we have provided a util class to validate your model.
If you run these commands, you will get the following error:
from opacus.dp_model_inspector import DPModelInspector
inspector = DPModelInspector()
inspector.validate(model)
--------------------------------------------------------------------------- IncompatibleModuleException Traceback (most recent call last) <ipython-input-11-c3648acc319a> in <module> 2 3 inspector = DPModelInspector() ----> 4 inspector.validate(model) /mnt/xarfuse/uid-179429/a7a74ae2-seed-nspid4026531836-ns-4026531840/opacus/dp_model_inspector.py in validate(self, model) 115 if inspector.violators: 116 message += f"\n{inspector.message}: {inspector.violators}" --> 117 raise IncompatibleModuleException(message) 118 return valid 119 IncompatibleModuleException: Model contains incompatible modules. Some modules are not valid.: ['Main.bn1', 'Main.layer1.0.bn1', 'Main.layer1.0.bn2', 'Main.layer1.1.bn1', 'Main.layer1.1.bn2', 'Main.layer2.0.bn1', 'Main.layer2.0.bn2', 'Main.layer2.0.downsample.1', 'Main.layer2.1.bn1', 'Main.layer2.1.bn2', 'Main.layer3.0.bn1', 'Main.layer3.0.bn2', 'Main.layer3.0.downsample.1', 'Main.layer3.1.bn1', 'Main.layer3.1.bn2', 'Main.layer4.0.bn1', 'Main.layer4.0.bn2', 'Main.layer4.0.downsample.1', 'Main.layer4.1.bn1', 'Main.layer4.1.bn2']
Let us modify the model to work with Opacus. From the output above, you can see that the BatchNorm layers are not supported because they compute the mean and variance across the batch, creating a dependency between samples in a batch, a privacy violation. One way to modify our model is to replace all the BatchNorm layers with GroupNorm using the convert_batchnorm_modules
util function.
from opacus.utils import module_modification
model = module_modification.convert_batchnorm_modules(model)
inspector = DPModelInspector()
print(f"Is the model valid? {inspector.validate(model)}")
Is the model valid? True
For maximal speed, we can check if CUDA is available and supported by the PyTorch installation. If GPU is available, set the device
variable to your CUDA-compatible device. We can then transfer the neural network onto that device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
We then define our optimizer and loss function. Opacus’ privacy engine can attach to any (first-order) optimizer. You can use your favorite—Adam, Adagrad, RMSprop—as long as it has an implementation derived from torch.optim.Optimizer. In this tutorial, we're going to use RMSprop.
import torch.nn as nn
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=LR)
We will define a util function to calculate accuracy
def accuracy(preds, labels):
return (preds == labels).mean()
We now attach the privacy engine initialized with the privacy hyperparameters defined earlier. There’s also the enigmatic-looking parameter alphas
, which we won’t touch for the time being.
from opacus import PrivacyEngine
print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")
privacy_engine = PrivacyEngine(
model,
batch_size=VIRTUAL_BATCH_SIZE,
sample_size=len(train_dataset),
alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
noise_multiplier=NOISE_MULTIPLIER,
max_grad_norm=MAX_GRAD_NORM,
)
privacy_engine.attach(optimizer)
Using sigma=0.38 and C=1.2
We will then define our train function. This function will train the model for one epoch.
assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
virtual_batch_rate = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)
import numpy as np
def train(model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
for i, (images, target) in enumerate(train_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
# measure accuracy and record loss
acc = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc)
loss.backward()
# take a real optimizer step after N_VIRTUAL_STEP steps t
if ((i + 1) % virtual_batch_rate == 0) or ((i + 1) == len(train_loader)):
optimizer.step()
else:
optimizer.virtual_step() # take a virtual step
if i % 200 == 0:
epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(DELTA)
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"[email protected]: {np.mean(top1_acc) * 100:.6f} "
f"(ε = {epsilon:.2f}, δ = {DELTA})"
)
Next, we will define our test function to validate our model on our test dataset.
def test(model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in test_loader:
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc)
top1_avg = np.mean(top1_acc)
print(
f"\tTest set:"
f"Loss: {np.mean(losses):.6f} "
f"Acc: {top1_avg * 100:.6f} "
)
return np.mean(top1_acc)
from tqdm import tqdm_notebook
for epoch in tqdm_notebook(range(20), desc="Epoch", unit="epoch"):
train(model, train_loader, optimizer, epoch + 1, device)
Train Epoch: 1 Loss: 2.732535 [email protected]: 7.812500 (ε = 0.19, δ = 1e-05) Train Epoch: 1 Loss: 2.782132 [email protected]: 14.128576 (ε = 16.04, δ = 1e-05) Train Epoch: 2 Loss: 1.730190 [email protected]: 36.718750 (ε = 18.66, δ = 1e-05) Train Epoch: 2 Loss: 1.761741 [email protected]: 38.355100 (ε = 20.83, δ = 1e-05) Train Epoch: 3 Loss: 1.542238 [email protected]: 50.781250 (ε = 22.43, δ = 1e-05) Train Epoch: 3 Loss: 1.723531 [email protected]: 45.565143 (ε = 23.96, δ = 1e-05) Train Epoch: 4 Loss: 1.949715 A[email protected]: 42.187500 (ε = 25.43, δ = 1e-05) Train Epoch: 4 Loss: 1.720563 [email protected]: 49.109919 (ε = 26.85, δ = 1e-05) Train Epoch: 5 Loss: 1.774229 [email protected]: 51.562500 (ε = 27.92, δ = 1e-05) Train Epoch: 5 Loss: 1.714577 [email protected]: 51.698539 (ε = 29.03, δ = 1e-05) Train Epoch: 6 Loss: 1.405191 [email protected]: 60.156250 (ε = 30.10, δ = 1e-05) Train Epoch: 6 Loss: 1.675316 [email protected]: 53.976213 (ε = 31.21, δ = 1e-05) Train Epoch: 7 Loss: 1.554456 [email protected]: 52.343750 (ε = 32.28, δ = 1e-05) Train Epoch: 7 Loss: 1.685776 [email protected]: 54.909049 (ε = 33.39, δ = 1e-05) Train Epoch: 8 Loss: 1.778377 [email protected]: 53.125000 (ε = 34.46, δ = 1e-05) Train Epoch: 8 Loss: 1.696506 [email protected]: 56.012904 (ε = 35.36, δ = 1e-05) Train Epoch: 9 Loss: 1.904890 [email protected]: 54.687500 (ε = 36.17, δ = 1e-05) Train Epoch: 9 Loss: 1.639168 [email protected]: 57.668688 (ε = 37.00, δ = 1e-05) Train Epoch: 10 Loss: 1.614231 [email protected]: 53.906250 (ε = 37.81, δ = 1e-05) Train Epoch: 10 Loss: 1.641604 [email protected]: 58.652052 (ε = 38.65, δ = 1e-05) Train Epoch: 11 Loss: 1.678383 [email protected]: 55.468750 (ε = 39.45, δ = 1e-05) Train Epoch: 11 Loss: 1.627174 [email protected]: 59.343905 (ε = 40.29, δ = 1e-05) Train Epoch: 12 Loss: 1.534639 [email protected]: 62.500000 (ε = 41.09, δ = 1e-05) Train Epoch: 12 Loss: 1.627389 [email protected]: 59.732587 (ε = 41.93, δ = 1e-05) Train Epoch: 13 Loss: 1.565778 [email protected]: 63.281250 (ε = 42.74, δ = 1e-05) Train Epoch: 13 Loss: 1.611232 [email protected]: 60.743159 (ε = 43.58, δ = 1e-05) Train Epoch: 14 Loss: 1.529223 [email protected]: 62.500000 (ε = 44.38, δ = 1e-05) Train Epoch: 14 Loss: 1.614406 [email protected]: 60.886971 (ε = 45.22, δ = 1e-05) Train Epoch: 15 Loss: 1.445196 [email protected]: 66.406250 (ε = 46.02, δ = 1e-05) Train Epoch: 15 Loss: 1.589295 [email protected]: 61.427239 (ε = 46.86, δ = 1e-05) Train Epoch: 16 Loss: 1.733013 [email protected]: 59.375000 (ε = 47.67, δ = 1e-05) Train Epoch: 16 Loss: 1.577568 [email protected]: 61.866449 (ε = 48.46, δ = 1e-05) Train Epoch: 17 Loss: 1.845230 [email protected]: 55.468750 (ε = 49.09, δ = 1e-05) Train Epoch: 17 Loss: 1.578402 [email protected]: 62.010261 (ε = 49.73, δ = 1e-05) Train Epoch: 18 Loss: 1.565229 [email protected]: 64.843750 (ε = 50.35, δ = 1e-05) Train Epoch: 18 Loss: 1.542843 [email protected]: 63.432836 (ε = 51.00, δ = 1e-05) Train Epoch: 19 Loss: 1.447579 [email protected]: 65.625000 (ε = 51.62, δ = 1e-05) Train Epoch: 19 Loss: 1.541139 [email protected]: 63.219061 (ε = 52.27, δ = 1e-05) Train Epoch: 20 Loss: 1.561364 [email protected]: 65.625000 (ε = 52.89, δ = 1e-05) Train Epoch: 20 Loss: 1.576318 [email protected]: 62.624378 (ε = 53.54, δ = 1e-05)
top1_acc = test(model, test_loader, device)
Test set:Loss: 2.015386 Acc: 56.615902
Now let us compare how our private model compares with the non-private ResNet18.
We trained a non-private ResNet18 model for 20 epochs using the same hyper-parameters as above and with BatchNorm replaced with GroupNorm. The results of that training and the training that is discussed in this tutorial are summarized in the table below:
Model | Top 1 Accuracy (%) | ϵ |
---|---|---|
ResNet | 76 | ∞ |
Private ResNet | 56.61 | 53.54 |