Opacus
  • Introduction
  • FAQ
  • Tutorials
  • API Reference
  • GitHub

›

Tutorials

  • Overview

Using Opacus

  • Building image classifier with Differential Privacy
  • Building text classifier with Differential Privacy
  • Training a differentially private LSTM model for name classification

Table of Contents

  • Building an Image Classifier with Differential Privacy
    • Overview
    • Hyper-parameters
    • Data
    • Model
    • Prepare for Training
    • Train the network
    • Test the network on test data
    • Tips and Tricks
    • Private Model vs Non-Private Model Performance

Building an Image Classifier with Differential Privacy¶

Overview¶

In this tutorial we will learn to do the following:

  1. Learn about privacy specific hyper-parameters related to DP-SGD
  2. Learn about ModelInspector, incompatible layers, and use model rewriting utility.
  3. Train a differentially private ResNet18 for image classification.

Hyper-parameters¶

To train a model with Opacus there are three privacy-specific hyper-parameters that must be tuned for better performance:

  • Max Grad Norm: The maximum L2 norm of per-sample gradients before they are aggregated by the averaging step.
  • Noise Multiplier: The amount of noise sampled and added to the average of the gradients in a batch.
  • Delta: The target δ of the (ϵ,δ)-differential privacy guarantee. Generally, it should be set to be less than the inverse of the size of the training dataset. In this tutorial, it is set to $10^{−5}$ as the CIFAR10 dataset has 50,000 training points.

We use the hyper-parameter values below to obtain results in the last section:

In [1]:
MAX_GRAD_NORM = 1.2
NOISE_MULTIPLIER = .38
DELTA = 1e-5

LR = 1e-3
NUM_WORKERS = 2

There's another constraint we should be mindful of—memory. To balance peak memory requirement, which is proportional to batch_size^2, and training performance, we use virtual batches. With virtual batches we can separate physical steps (gradient computation) and logical steps (noise addition and parameter updates): use larger batches for training, while keeping memory footprint low. Below we will specify two constants:

In [2]:
BATCH_SIZE = 128
VIRTUAL_BATCH_SIZE = 512

Data¶

Now, let's load the CIFAR10 dataset. We don't use data augmentation here because, in our experiments, we found that data augmentation lowers utility when training with DP.

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms

# These values, specific to the CIFAR10 dataset, are assumed to be known.
# If necessary, they can be computed with modest privacy budget.
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV),
])

Using torchvision datasets, we can load CIFAR10 and transform the PILImage images to Tensors of normalized range [-1, 1]

In [ ]:
from torchvision.datasets import CIFAR10

DATA_ROOT = '../cifar10'

train_dataset = CIFAR10(
    root=DATA_ROOT, train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

test_dataset = CIFAR10(
    root=DATA_ROOT, train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

Model¶

In [5]:
from torchvision import models

model = models.resnet18(num_classes=10)

Now, let’s check if the model is compatible with Opacus. Opacus does not support all type of Pytorch layers. To check if your model is compatible with the privacy engine, we have provided a util class to validate your model.

If you run these commands, you will get the following error:

In [6]:
from opacus.dp_model_inspector import DPModelInspector

inspector = DPModelInspector()
inspector.validate(model)
---------------------------------------------------------------------------
IncompatibleModuleException               Traceback (most recent call last)
<ipython-input-11-c3648acc319a> in <module>
      2 
      3 inspector = DPModelInspector()
----> 4 inspector.validate(model)

/mnt/xarfuse/uid-179429/a7a74ae2-seed-nspid4026531836-ns-4026531840/opacus/dp_model_inspector.py in validate(self, model)
    115                 if inspector.violators:
    116                     message += f"\n{inspector.message}: {inspector.violators}"
--> 117             raise IncompatibleModuleException(message)
    118         return valid
    119 

IncompatibleModuleException: Model contains incompatible modules.
Some modules are not valid.: ['Main.bn1', 'Main.layer1.0.bn1', 'Main.layer1.0.bn2', 'Main.layer1.1.bn1', 'Main.layer1.1.bn2', 'Main.layer2.0.bn1', 'Main.layer2.0.bn2', 'Main.layer2.0.downsample.1', 'Main.layer2.1.bn1', 'Main.layer2.1.bn2', 'Main.layer3.0.bn1', 'Main.layer3.0.bn2', 'Main.layer3.0.downsample.1', 'Main.layer3.1.bn1', 'Main.layer3.1.bn2', 'Main.layer4.0.bn1', 'Main.layer4.0.bn2', 'Main.layer4.0.downsample.1', 'Main.layer4.1.bn1', 'Main.layer4.1.bn2']

Let us modify the model to work with Opacus. From the output above, you can see that the BatchNorm layers are not supported because they compute the mean and variance across the batch, creating a dependency between samples in a batch, a privacy violation. One way to modify our model is to replace all the BatchNorm layers with GroupNorm using the convert_batchnorm_modules util function.

In [7]:
from opacus.utils import module_modification

model = module_modification.convert_batchnorm_modules(model)
inspector = DPModelInspector()
print(f"Is the model valid? {inspector.validate(model)}")
Is the model valid? True

For maximal speed, we can check if CUDA is available and supported by the PyTorch installation. If GPU is available, set the device variable to your CUDA-compatible device. We can then transfer the neural network onto that device.

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

We then define our optimizer and loss function. Opacus’ privacy engine can attach to any (first-order) optimizer. You can use your favorite—Adam, Adagrad, RMSprop—as long as it has an implementation derived from torch.optim.Optimizer. In this tutorial, we're going to use RMSprop.

In [36]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=LR)

Prepare for Training¶

We will define a util function to calculate accuracy

In [10]:
def accuracy(preds, labels):
    return (preds == labels).mean()

We now attach the privacy engine initialized with the privacy hyperparameters defined earlier. There’s also the enigmatic-looking parameter alphas, which we won’t touch for the time being.

In [11]:
from opacus import PrivacyEngine

print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")

privacy_engine = PrivacyEngine(
    model,
    batch_size=VIRTUAL_BATCH_SIZE,
    sample_size=len(train_dataset),
    alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=MAX_GRAD_NORM,
)
privacy_engine.attach(optimizer)
Using sigma=0.38 and C=1.2

We will then define our train function. This function will train the model for one epoch.

In [13]:
assert VIRTUAL_BATCH_SIZE % BATCH_SIZE == 0 # VIRTUAL_BATCH_SIZE should be divisible by BATCH_SIZE
virtual_batch_rate = int(VIRTUAL_BATCH_SIZE / BATCH_SIZE)
In [12]:
import numpy as np

def train(model, train_loader, optimizer, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()

    losses = []
    top1_acc = []

    for i, (images, target) in enumerate(train_loader):        
        images = images.to(device)
        target = target.to(device)

        # compute output
        output = model(images)
        loss = criterion(output, target)
        
        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
        labels = target.detach().cpu().numpy()
        
        # measure accuracy and record loss
        acc = accuracy(preds, labels)

        losses.append(loss.item())
        top1_acc.append(acc)
        
        loss.backward()
        	
        # take a real optimizer step after N_VIRTUAL_STEP steps t
        if ((i + 1) % virtual_batch_rate == 0) or ((i + 1) == len(train_loader)):
            optimizer.step()
        else:
            optimizer.virtual_step() # take a virtual step

        if i % 200 == 0:
            epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(DELTA)
            print(
                f"\tTrain Epoch: {epoch} \t"
                f"Loss: {np.mean(losses):.6f} "
                f"[email protected]: {np.mean(top1_acc) * 100:.6f} "
                f"(ε = {epsilon:.2f}, δ = {DELTA})"
            )

Next, we will define our test function to validate our model on our test dataset.

In [14]:
def test(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)
            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()
            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)

    print(
        f"\tTest set:"
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)

Train the network¶

In [ ]:
from tqdm import tqdm_notebook

for epoch in tqdm_notebook(range(20), desc="Epoch", unit="epoch"):
    train(model, train_loader, optimizer, epoch + 1, device)
	Train Epoch: 1 	Loss: 2.732535 [email protected]: 7.812500 (ε = 0.19, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.782132 [email protected]: 14.128576 (ε = 16.04, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.730190 [email protected]: 36.718750 (ε = 18.66, δ = 1e-05)
	Train Epoch: 2 	Loss: 1.761741 [email protected]: 38.355100 (ε = 20.83, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.542238 [email protected]: 50.781250 (ε = 22.43, δ = 1e-05)
	Train Epoch: 3 	Loss: 1.723531 [email protected]: 45.565143 (ε = 23.96, δ = 1e-05)
	Train Epoch: 4 	Loss: 1.949715 A[email protected]: 42.187500 (ε = 25.43, δ = 1e-05)
	Train Epoch: 4 	Loss: 1.720563 [email protected]: 49.109919 (ε = 26.85, δ = 1e-05)
	Train Epoch: 5 	Loss: 1.774229 [email protected]: 51.562500 (ε = 27.92, δ = 1e-05)
	Train Epoch: 5 	Loss: 1.714577 [email protected]: 51.698539 (ε = 29.03, δ = 1e-05)
	Train Epoch: 6 	Loss: 1.405191 [email protected]: 60.156250 (ε = 30.10, δ = 1e-05)
	Train Epoch: 6 	Loss: 1.675316 [email protected]: 53.976213 (ε = 31.21, δ = 1e-05)
	Train Epoch: 7 	Loss: 1.554456 [email protected]: 52.343750 (ε = 32.28, δ = 1e-05)
	Train Epoch: 7 	Loss: 1.685776 [email protected]: 54.909049 (ε = 33.39, δ = 1e-05)
	Train Epoch: 8 	Loss: 1.778377 [email protected]: 53.125000 (ε = 34.46, δ = 1e-05)
	Train Epoch: 8 	Loss: 1.696506 [email protected]: 56.012904 (ε = 35.36, δ = 1e-05)
	Train Epoch: 9 	Loss: 1.904890 [email protected]: 54.687500 (ε = 36.17, δ = 1e-05)
	Train Epoch: 9 	Loss: 1.639168 [email protected]: 57.668688 (ε = 37.00, δ = 1e-05)
	Train Epoch: 10 	Loss: 1.614231 [email protected]: 53.906250 (ε = 37.81, δ = 1e-05)
	Train Epoch: 10 	Loss: 1.641604 [email protected]: 58.652052 (ε = 38.65, δ = 1e-05)
	Train Epoch: 11 	Loss: 1.678383 [email protected]: 55.468750 (ε = 39.45, δ = 1e-05)
	Train Epoch: 11 	Loss: 1.627174 [email protected]: 59.343905 (ε = 40.29, δ = 1e-05)
	Train Epoch: 12 	Loss: 1.534639 [email protected]: 62.500000 (ε = 41.09, δ = 1e-05)
	Train Epoch: 12 	Loss: 1.627389 [email protected]: 59.732587 (ε = 41.93, δ = 1e-05)
	Train Epoch: 13 	Loss: 1.565778 [email protected]: 63.281250 (ε = 42.74, δ = 1e-05)
	Train Epoch: 13 	Loss: 1.611232 [email protected]: 60.743159 (ε = 43.58, δ = 1e-05)
	Train Epoch: 14 	Loss: 1.529223 [email protected]: 62.500000 (ε = 44.38, δ = 1e-05)
	Train Epoch: 14 	Loss: 1.614406 [email protected]: 60.886971 (ε = 45.22, δ = 1e-05)
	Train Epoch: 15 	Loss: 1.445196 [email protected]: 66.406250 (ε = 46.02, δ = 1e-05)
	Train Epoch: 15 	Loss: 1.589295 [email protected]: 61.427239 (ε = 46.86, δ = 1e-05)
	Train Epoch: 16 	Loss: 1.733013 [email protected]: 59.375000 (ε = 47.67, δ = 1e-05)
	Train Epoch: 16 	Loss: 1.577568 [email protected]: 61.866449 (ε = 48.46, δ = 1e-05)
	Train Epoch: 17 	Loss: 1.845230 [email protected]: 55.468750 (ε = 49.09, δ = 1e-05)
	Train Epoch: 17 	Loss: 1.578402 [email protected]: 62.010261 (ε = 49.73, δ = 1e-05)
	Train Epoch: 18 	Loss: 1.565229 [email protected]: 64.843750 (ε = 50.35, δ = 1e-05)
	Train Epoch: 18 	Loss: 1.542843 [email protected]: 63.432836 (ε = 51.00, δ = 1e-05)
	Train Epoch: 19 	Loss: 1.447579 [email protected]: 65.625000 (ε = 51.62, δ = 1e-05)
	Train Epoch: 19 	Loss: 1.541139 [email protected]: 63.219061 (ε = 52.27, δ = 1e-05)
	Train Epoch: 20 	Loss: 1.561364 [email protected]: 65.625000 (ε = 52.89, δ = 1e-05)
	Train Epoch: 20 	Loss: 1.576318 [email protected]: 62.624378 (ε = 53.54, δ = 1e-05)

Test the network on test data¶

In [41]:
top1_acc = test(model, test_loader, device)
	Test set:Loss: 2.015386 Acc: 56.615902 

Tips and Tricks¶

  1. Generally speaking, differentially private training is enough of a regularizer by itself. Adding any more regularization (such as dropouts or data augmentation) is unnecessary and typically hurts performance.
  2. Tuning MAX_GRAD_NORM is very important. Start with a low noise multiplier like .1, this should give comparable performance to a non-private model. Then do a grid search for the optimal MAX_GRAD_NORM value. The grid can be in the range [.1, 10]

Private Model vs Non-Private Model Performance¶

Now let us compare how our private model compares with the non-private ResNet18.

We trained a non-private ResNet18 model for 20 epochs using the same hyper-parameters as above and with BatchNorm replaced with GroupNorm. The results of that training and the training that is discussed in this tutorial are summarized in the table below:

Model Top 1 Accuracy (%) ϵ
ResNet 76 ∞
Private ResNet 56.61 53.54
Download Tutorial Jupyter Notebook
Opacus
Docs
IntroductionFAQTutorialsAPI Reference
Social
opacus
Legal
PrivacyTerms
Facebook Open Source
Copyright © 2021 Facebook Inc.