Opacus
  • Introduction
  • FAQ
  • Tutorials
  • API Reference
  • GitHub

›

Tutorials

  • Overview

Using Opacus

  • Building text classifier with Fast Gradient Clipping DP-SGD
  • Building image classifier with Differential Privacy
  • Training a differentially private LSTM model for name classification
  • Deep dive into advanced features of Opacus
  • Guide to Module Validator and Fixer
  • Guide to grad samplers
  • Training on multiple GPUs with DistributedDataParallel

Table of Contents

  • Building an Image Classifier with Differential Privacy
    • Overview
    • Hyper-parameters
    • Data
    • Model
    • Prepare for Training
    • Train the network
    • Test the network on test data
    • Tips and Tricks
    • Private Model vs Non-Private Model Performance

Building an Image Classifier with Differential Privacy¶

Overview¶

In this tutorial we will learn to do the following:

  1. Learn about privacy-specific hyper-parameters related to DP-SGD
  2. Learn about ModelInspector, incompatible layers, and use model rewriting utility.
  3. Train a differentially private ResNet18 for image classification.

Hyper-parameters¶

To train a model with Opacus there are three privacy-specific hyper-parameters that must be tuned for better performance:

  • Max Grad Norm: The maximum L2 norm of per-sample gradients before they are aggregated by the averaging step.
  • Noise Multiplier: The amount of noise sampled and added to the average of the gradients in a batch.
  • Delta: The target δ of the (ϵ,δ)-differential privacy guarantee. Generally, it should be set to be less than the inverse of the size of the training dataset. In this tutorial, it is set to $10^{−5}$ as the CIFAR10 dataset has 50,000 training points.

We use the hyper-parameter values below to obtain results in the last section:

In [1]:
import warnings
warnings.simplefilter("ignore")

MAX_GRAD_NORM = 1.2
EPSILON = 50.0
DELTA = 1e-5
EPOCHS = 20

LR = 1e-3

There's another constraint we should be mindful of—memory. To balance peak memory requirement, which is proportional to batch_size^2, and training performance, we will be using BatchMemoryManager. It separates the logical batch size (which defines how often the model is updated and how much DP noise is added), and a physical batch size (which defines how many samples we process at a time).

With BatchMemoryManager you will create your DataLoader with a logical batch size, and then provide the maximum physical batch size to the memory manager.

In [2]:
BATCH_SIZE = 512
MAX_PHYSICAL_BATCH_SIZE = 128

Data¶

Now, let's load the CIFAR10 dataset. We don't use data augmentation here because, in our experiments, we found that data augmentation lowers utility when training with DP.

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms

# These values, specific to the CIFAR10 dataset, are assumed to be known.
# If necessary, they can be computed with modest privacy budgets.
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV),
])

Using torchvision datasets, we can load CIFAR10 and transform the PILImage images to Tensors of normalized range [-1, 1]

In [4]:
from torchvision.datasets import CIFAR10

DATA_ROOT = '../cifar10'

train_dataset = CIFAR10(
    root=DATA_ROOT, train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
)

test_dataset = CIFAR10(
    root=DATA_ROOT, train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)
Files already downloaded and verified
Files already downloaded and verified

Model¶

In [5]:
from torchvision import models

model = models.resnet18(num_classes=10)

Now, let’s check if the model is compatible with Opacus. Opacus does not support all types of Pytorch layers. To check if your model is compatible with the privacy engine, we have provided a util class to validate your model.

When you run the code below, you're presented with a list of errors, indicating which modules are incompatible.

In [6]:
from opacus.validators import ModuleValidator

errors = ModuleValidator.validate(model, strict=False)
errors[-5:]
Out[6]:
[opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"),
 opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"),
 opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"),
 opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet"),
 opacus.validators.errors.ShouldReplaceModuleError("BatchNorm cannot support training with differential privacy. The reason for it is that BatchNorm makes each sample's normalized value depend on its peers in a batch, ie the same sample x will get normalized to a different value depending on who else is on its batch. Privacy-wise, this means that we would have to put a privacy mechanism there too. While it can in principle be done, there are now multiple normalization layers that do not have this issue: LayerNorm, InstanceNorm and their generalization GroupNorm are all privacy-safe since they don't have this property.We offer utilities to automatically replace BatchNorms to GroupNorms and we will release pretrained models to help transition, such as GN-ResNet ie a ResNet using GroupNorm, pretrained on ImageNet")]

Let us modify the model to work with Opacus. From the output above, you can see that the BatchNorm layers are not supported because they compute the mean and variance across the batch, creating a dependency between samples in a batch, a privacy violation.

Recommended approach to deal with it is calling ModuleValidator.fix(model) - it tries to find the best replacement for incompatible modules. For example, for BatchNorm modules, it replaces them with GroupNorm. You can see, that after this, no exception is raised

In [7]:
model = ModuleValidator.fix(model)
ModuleValidator.validate(model, strict=False)
Out[7]:
[]

For maximal speed, we can check if CUDA is available and supported by the PyTorch installation. If GPU is available, set the device variable to your CUDA-compatible device. We can then transfer the neural network onto that device.

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

We then define our optimizer and loss function. Opacus’ privacy engine can attach to any (first-order) optimizer. You can use your favorite—Adam, Adagrad, RMSprop—as long as it has an implementation derived from torch.optim.Optimizer. In this tutorial, we're going to use RMSprop.

In [9]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=LR)

Prepare for Training¶

We will define a util function to calculate accuracy

In [10]:
def accuracy(preds, labels):
    return (preds == labels).mean()

We now attach the privacy engine initialized with the privacy hyperparameters defined earlier.

In [11]:
from opacus import PrivacyEngine

privacy_engine = PrivacyEngine()

model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    epochs=EPOCHS,
    target_epsilon=EPSILON,
    target_delta=DELTA,
    max_grad_norm=MAX_GRAD_NORM,
)

print(f"Using sigma={optimizer.noise_multiplier} and C={MAX_GRAD_NORM}")
Using sigma=0.39066894531249996 and C=1.2

We will then define our train function. This function will train the model for one epoch.

In [12]:
import numpy as np
from opacus.utils.batch_memory_manager import BatchMemoryManager


def train(model, train_loader, optimizer, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()

    losses = []
    top1_acc = []
    
    with BatchMemoryManager(
        data_loader=train_loader, 
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, 
        optimizer=optimizer
    ) as memory_safe_data_loader:

        for i, (images, target) in enumerate(memory_safe_data_loader):   
            optimizer.zero_grad()
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()

            # measure accuracy and record loss
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

            loss.backward()
            optimizer.step()

            if (i+1) % 200 == 0:
                epsilon = privacy_engine.get_epsilon(DELTA)
                print(
                    f"\tTrain Epoch: {epoch} \t"
                    f"Loss: {np.mean(losses):.6f} "
                    f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
                    f"(ε = {epsilon:.2f}, δ = {DELTA})"
                )

Next, we will define our test function to validate our model on our test dataset.

In [13]:
def test(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)

    print(
        f"\tTest set:"
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)

Train the network¶

In [14]:
from tqdm.notebook import tqdm

for epoch in tqdm(range(EPOCHS), desc="Epoch", unit="epoch"):
    train(model, train_loader, optimizer, epoch + 1, device)
Epoch:   0%|          | 0/20 [00:00<?, ?epoch/s]
	Train Epoch: 1 	Loss: 2.771490 Acc@1: 15.429688 (ε = 13.64, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.755194 Acc@1: 38.804688 (ε = 17.68, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.724797 Acc@1: 45.769531 (ε = 20.62, δ = 1e-05)
	Train Epoch: 4 	Loss: 1.706076 Acc@1: 48.921875 (ε = 22.94, δ = 1e-05)
	Train Epoch: 5 	Loss: 1.682414 Acc@1: 51.664062 (ε = 25.25, δ = 1e-05)
	Train Epoch: 6 	Loss: 1.671187 Acc@1: 53.234375 (ε = 26.99, δ = 1e-05)
	Train Epoch: 7 	Loss: 1.657112 Acc@1: 55.324219 (ε = 28.73, δ = 1e-05)
	Train Epoch: 8 	Loss: 1.633768 Acc@1: 56.277344 (ε = 30.46, δ = 1e-05)
	Train Epoch: 9 	Loss: 1.647288 Acc@1: 57.203125 (ε = 32.20, δ = 1e-05)
	Train Epoch: 10 	Loss: 1.639933 Acc@1: 58.191406 (ε = 33.83, δ = 1e-05)
	Train Epoch: 11 	Loss: 1.639214 Acc@1: 59.140625 (ε = 35.17, δ = 1e-05)
	Train Epoch: 12 	Loss: 1.629374 Acc@1: 59.613281 (ε = 36.51, δ = 1e-05)
	Train Epoch: 13 	Loss: 1.636143 Acc@1: 60.246094 (ε = 37.85, δ = 1e-05)
	Train Epoch: 14 	Loss: 1.634575 Acc@1: 60.550781 (ε = 39.18, δ = 1e-05)
	Train Epoch: 15 	Loss: 1.611133 Acc@1: 61.378906 (ε = 40.52, δ = 1e-05)
	Train Epoch: 16 	Loss: 1.604075 Acc@1: 62.015625 (ε = 41.86, δ = 1e-05)
	Train Epoch: 17 	Loss: 1.601270 Acc@1: 62.140625 (ε = 43.20, δ = 1e-05)
	Train Epoch: 18 	Loss: 1.599596 Acc@1: 62.437500 (ε = 44.53, δ = 1e-05)
	Train Epoch: 19 	Loss: 1.587946 Acc@1: 63.097656 (ε = 45.87, δ = 1e-05)
	Train Epoch: 20 	Loss: 1.583897 Acc@1: 63.250000 (ε = 47.21, δ = 1e-05)

Test the network on test data¶

In [15]:
top1_acc = test(model, test_loader, device)
	Test set:Loss: 1.711833 Acc: 60.753676 

Tips and Tricks¶

  1. Generally speaking, differentially private training is enough of a regularizer by itself. Adding any more regularization (such as dropouts or data augmentation) is unnecessary and typically hurts performance.
  2. Tuning MAX_GRAD_NORM is very important. Start with a low noise multiplier like .1, this should give comparable performance to a non-private model. Then do a grid search for the optimal MAX_GRAD_NORM value. The grid can be in the range [.1, 10].
  3. You can play around with the level of privacy, EPSILON. Smaller EPSILON means more privacy, more noise -- and hence lower accuracy. Reducing EPSILON to 5.0 reduces the Top 1 Accuracy to around 53%. One useful technique is to pre-train a model on public (non-private) data, before completing the training on the private training data. See the workbook at bit.ly/opacus-dev-day for an example.

Private Model vs Non-Private Model Performance¶

Now let us compare how our private model compares with the non-private ResNet18.

We trained a non-private ResNet18 model for 20 epochs using the same hyper-parameters as above and with BatchNorm replaced with GroupNorm. The results of that training and the training that is discussed in this tutorial are summarized in the table below:

Model Top 1 Accuracy (%) ϵ
ResNet 76 ∞
Private ResNet 61 47.21
Download Tutorial Jupyter Notebook
Opacus
Docs
IntroductionFAQTutorialsAPI Reference
Github
opacus
Legal
PrivacyTerms
Meta Open Source
Copyright © 2025 Meta Platforms, Inc.