In this tutorial we will learn to do the following:
To train a model with Opacus there are three privacy-specific hyper-parameters that must be tuned for better performance:
We use the hyper-parameter values below to obtain results in the last section:
import warnings
warnings.simplefilter("ignore")
MAX_GRAD_NORM = 1.2
EPSILON = 50.0
DELTA = 1e-5
EPOCHS = 20
LR = 1e-3
There's another constraint we should be mindful of—memory. To balance peak memory requirement, which is proportional to batch_size^2
, and training performance, we will be using BatchMemoryManager. It separates the logical batch size (which defines how often the model is updated and how much DP noise is added), and a physical batch size (which defines how many samples we process at a time).
With BatchMemoryManager you will create your DataLoader with a logical batch size, and then provide the maximum physical batch size to the memory manager.
BATCH_SIZE = 512
MAX_PHYSICAL_BATCH_SIZE = 128
Now, let's load the CIFAR10 dataset. We don't use data augmentation here because, in our experiments, we found that data augmentation lowers utility when training with DP.
import torch
import torchvision
import torchvision.transforms as transforms
# These values, specific to the CIFAR10 dataset, are assumed to be known.
# If necessary, they can be computed with modest privacy budgets.
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV),
])
Using torchvision datasets, we can load CIFAR10 and transform the PILImage images to Tensors of normalized range [-1, 1]
from torchvision.datasets import CIFAR10
DATA_ROOT = '../cifar10'
train_dataset = CIFAR10(
root=DATA_ROOT, train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
)
test_dataset = CIFAR10(
root=DATA_ROOT, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
)
Files already downloaded and verified Files already downloaded and verified
from torchvision import models
model = models.resnet18(num_classes=10)
Now, let’s check if the model is compatible with Opacus. Opacus does not support all types of Pytorch layers. To check if your model is compatible with the privacy engine, we have provided a util class to validate your model.
When you run the code below, you're presented with a list of errors, indicating which modules are incompatible.
from opacus.validators import ModuleValidator
errors = ModuleValidator.validate(model, strict=False)
errors[-5:]
[opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"), opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"), opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"), opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"), opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet")]
Let us modify the model to work with Opacus. From the output above, you can see that the BatchNorm layers are not supported because they compute the mean and variance across the batch, creating a dependency between samples in a batch, a privacy violation.
Recommended approach to deal with it is calling ModuleValidator.fix(model)
- it tries to find the best replacement for incompatible modules. For example, for BatchNorm modules, it replaces them with GroupNorm.
You can see, that after this, no exception is raised
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=False)
[]
For maximal speed, we can check if CUDA is available and supported by the PyTorch installation. If GPU is available, set the device
variable to your CUDA-compatible device. We can then transfer the neural network onto that device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
We then define our optimizer and loss function. Opacus’ privacy engine can attach to any (first-order) optimizer. You can use your favorite—Adam, Adagrad, RMSprop—as long as it has an implementation derived from torch.optim.Optimizer. In this tutorial, we're going to use RMSprop.
import torch.nn as nn
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=LR)
We will define a util function to calculate accuracy
def accuracy(preds, labels):
return (preds == labels).mean()
We now attach the privacy engine initialized with the privacy hyperparameters defined earlier.
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=EPOCHS,
target_epsilon=EPSILON,
target_delta=DELTA,
max_grad_norm=MAX_GRAD_NORM,
)
print(f"Using sigma={optimizer.noise_multiplier} and C={MAX_GRAD_NORM}")
Using sigma=0.39066894531249996 and C=1.2
We will then define our train function. This function will train the model for one epoch.
import numpy as np
from opacus.utils.batch_memory_manager import BatchMemoryManager
def train(model, train_loader, optimizer, epoch, device):
model.train()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
optimizer=optimizer
) as memory_safe_data_loader:
for i, (images, target) in enumerate(memory_safe_data_loader):
optimizer.zero_grad()
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
# measure accuracy and record loss
acc = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc)
loss.backward()
optimizer.step()
if (i+1) % 200 == 0:
epsilon = privacy_engine.get_epsilon(DELTA)
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
f"(ε = {epsilon:.2f}, δ = {DELTA})"
)
Next, we will define our test function to validate our model on our test dataset.
def test(model, test_loader, device):
model.eval()
criterion = nn.CrossEntropyLoss()
losses = []
top1_acc = []
with torch.no_grad():
for images, target in test_loader:
images = images.to(device)
target = target.to(device)
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()
acc = accuracy(preds, labels)
losses.append(loss.item())
top1_acc.append(acc)
top1_avg = np.mean(top1_acc)
print(
f"\tTest set:"
f"Loss: {np.mean(losses):.6f} "
f"Acc: {top1_avg * 100:.6f} "
)
return np.mean(top1_acc)
from tqdm.notebook import tqdm
for epoch in tqdm(range(EPOCHS), desc="Epoch", unit="epoch"):
train(model, train_loader, optimizer, epoch + 1, device)
Epoch: 0%| | 0/20 [00:00<?, ?epoch/s]
Train Epoch: 1 Loss: 2.771490 Acc@1: 15.429688 (ε = 13.64, δ = 1e-05) Train Epoch: 2 Loss: 1.755194 Acc@1: 38.804688 (ε = 17.68, δ = 1e-05) Train Epoch: 3 Loss: 1.724797 Acc@1: 45.769531 (ε = 20.62, δ = 1e-05) Train Epoch: 4 Loss: 1.706076 Acc@1: 48.921875 (ε = 22.94, δ = 1e-05) Train Epoch: 5 Loss: 1.682414 Acc@1: 51.664062 (ε = 25.25, δ = 1e-05) Train Epoch: 6 Loss: 1.671187 Acc@1: 53.234375 (ε = 26.99, δ = 1e-05) Train Epoch: 7 Loss: 1.657112 Acc@1: 55.324219 (ε = 28.73, δ = 1e-05) Train Epoch: 8 Loss: 1.633768 Acc@1: 56.277344 (ε = 30.46, δ = 1e-05) Train Epoch: 9 Loss: 1.647288 Acc@1: 57.203125 (ε = 32.20, δ = 1e-05) Train Epoch: 10 Loss: 1.639933 Acc@1: 58.191406 (ε = 33.83, δ = 1e-05) Train Epoch: 11 Loss: 1.639214 Acc@1: 59.140625 (ε = 35.17, δ = 1e-05) Train Epoch: 12 Loss: 1.629374 Acc@1: 59.613281 (ε = 36.51, δ = 1e-05) Train Epoch: 13 Loss: 1.636143 Acc@1: 60.246094 (ε = 37.85, δ = 1e-05) Train Epoch: 14 Loss: 1.634575 Acc@1: 60.550781 (ε = 39.18, δ = 1e-05) Train Epoch: 15 Loss: 1.611133 Acc@1: 61.378906 (ε = 40.52, δ = 1e-05) Train Epoch: 16 Loss: 1.604075 Acc@1: 62.015625 (ε = 41.86, δ = 1e-05) Train Epoch: 17 Loss: 1.601270 Acc@1: 62.140625 (ε = 43.20, δ = 1e-05) Train Epoch: 18 Loss: 1.599596 Acc@1: 62.437500 (ε = 44.53, δ = 1e-05) Train Epoch: 19 Loss: 1.587946 Acc@1: 63.097656 (ε = 45.87, δ = 1e-05) Train Epoch: 20 Loss: 1.583897 Acc@1: 63.250000 (ε = 47.21, δ = 1e-05)
top1_acc = test(model, test_loader, device)
Test set:Loss: 1.711833 Acc: 60.753676
Now let us compare how our private model compares with the non-private ResNet18.
We trained a non-private ResNet18 model for 20 epochs using the same hyper-parameters as above and with BatchNorm replaced with GroupNorm. The results of that training and the training that is discussed in this tutorial are summarized in the table below:
Model | Top 1 Accuracy (%) | ϵ |
---|---|---|
ResNet | 76 | ∞ |
Private ResNet | 61 | 47.21 |