In this tutorial we'll go over the basics you need to know to start using Opacus in your distributed model training pipeline. As the state-of-the-art models and datasets get bigger, multi-GPU training became the norm and Opacus comes with seamless, out-of-the-box support for Distributed Data Parallel (DDP).
This tutorial requires basic knowledge of Opacus and DDP. If you're new to either of these tools, we suggest starting with the following tutorials: Building an Image Classifier with Differential Privacy and Getting Started with Distributed Data Parallel
In Chapter 1 we'll start with a minimal working example to demonstrate what exactly you need to do in order to make Opacus work in a distributed setting. This should be enough to get started for most common scenarios.
In Chapters 2 and 3 we'll take a closer look at the implementation and talk about technical details. We'll see what are the differences between private DDP and regular DDP and why we need to introduce them.
Before we begin, there are a few things we need to mention.
First, this tutorial is written to be executed on a single Linux machine with at least 2 GPUs. The general principles remain the same for Windows environment and/or multi-node training, but you'll need to slightly modify the DDP code to make it work.
Second, Jupyter notebooks are known not to support DDP training. Throughout the tutorial, we'll use %%writefile
magic command to write code to a separate file and later execute it via the terminal. These files will be cleaned up in the last cell of this notebook.
First, let's initialise the distributed environment
%%writefile opacus_ddp_demo.py
import os
import torch.distributed as dist
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
Overwriting opacus_ddp_demo.py
We'll be using MNIST for a toy example, so let's also initialize simple convolutional network and download the dataset
%%writefile -a opacus_ddp_demo.py
import torch.nn as nn
import torch.nn.functional as F
class SampleConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
self.conv2 = nn.Conv2d(16, 32, 4, 2)
self.fc1 = nn.Linear(32 * 4 * 4, 32)
self.fc2 = nn.Linear(32, 10)
def forward(self, x):
# x of shape [B, 1, 28, 28]
x = F.relu(self.conv1(x)) # -> [B, 16, 14, 14]
x = F.max_pool2d(x, 2, 1) # -> [B, 16, 13, 13]
x = F.relu(self.conv2(x)) # -> [B, 32, 5, 5]
x = F.max_pool2d(x, 2, 1) # -> [B, 32, 4, 4]
x = x.view(-1, 32 * 4 * 4) # -> [B, 512]
x = F.relu(self.fc1(x)) # -> [B, 32]
x = self.fc2(x) # -> [B, 10]
return x
Appending to opacus_ddp_demo.py
%%writefile -a opacus_ddp_demo.py
from torchvision import datasets, transforms
# Precomputed characteristics of the MNIST dataset
MNIST_MEAN = 0.1307
MNIST_STD = 0.3081
DATA_ROOT = "./mnist"
mnist_train_ds = datasets.MNIST(
DATA_ROOT,
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),
]
),
)
mnist_test_ds = datasets.MNIST(
DATA_ROOT,
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),
]
),
)
Appending to opacus_ddp_demo.py
Coming next is the key bit - and the only one that's different from non-private DDP.
First, instead of wrapping the model with DistributedDataParallel
we'll wrap it with DifferentiallyPrivateDistributedDataParallel
from opacus.distributed
package. Simple as that.
Second difference comes when initializing the DataLoader
. Normally, for distributed training you would initialize data loader specific to your distributed setup. It affects two parameters:
local_batch_size*num_gpus
.sampler=DistributedSampler(dataset)
to distribute the training dataset across GPUs.With Opacus you don't need to do either of those things. make_private
method expects user-provided DataLoader
to be non-distributed, initialized as if you're training on a single GPU.
The code below highlights changes you need to make to a normal DDP training pipeline by commenting out lines you need to replace or remove.
%%writefile -a opacus_ddp_demo.py
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from torch.nn.parallel import DistributedDataParallel as DDP
from opacus import PrivacyEngine
LR = 0.1
BATCH_SIZE = 200
N_GPUS = torch.cuda.device_count()
def init_training(rank):
model = SampleConvNet()
#model = DDP(model) -- non-private
model = DPDDP(model)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0)
data_loader = DataLoader(
mnist_train_ds,
#batch_size=BATCH_SIZE // N_GPUS, -- non-private
batch_size=BATCH_SIZE,
#sampler=DistributedSampler(mnist_train_ds) -- non-private
)
if rank == 0:
logger.info(
f"(rank {rank}) Initialized model ({type(model).__name__}), "
f"optimizer ({type(optimizer).__name__}), "
f"data loader ({type(data_loader).__name__}, len={len(data_loader)})"
)
privacy_engine = PrivacyEngine()
# PrivacyEngine looks at the model's class and enables
# distributed processing if it's wrapped with DPDDP
model, optimizer, data_loader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=data_loader,
noise_multiplier=1.,
max_grad_norm=1.,
)
if rank == 0:
logger.info(
f"(rank {rank}) After privatization: model ({type(model).__name__}), "
f"optimizer ({type(optimizer).__name__}), "
f"data loader ({type(data_loader).__name__}, len={len(data_loader)})"
)
logger.info(f"(rank {rank}) Average batch size per GPU: {int(optimizer.expected_batch_size)}")
return model, optimizer, data_loader, privacy_engine
Appending to opacus_ddp_demo.py
Now we just need to define the training loop and launch it.
%%writefile -a opacus_ddp_demo.py
import numpy as np
def test(model, device):
test_loader = DataLoader(
mnist_test_ds,
batch_size=BATCH_SIZE,
)
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(
dim=1, keepdim=True
)
correct += pred.eq(target.view_as(pred)).sum().item()
model.train()
return correct / len(mnist_test_ds)
def launch(rank, world_size, epochs):
setup(rank, world_size)
criterion = nn.CrossEntropyLoss()
model, optimizer, data_loader, privacy_engine = init_training(rank)
model.to(rank)
model.train()
for e in range(epochs):
losses = []
correct = 0
total = 0
for data, target in data_loader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += len(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
losses.append(loss.item())
test_accuracy = test(model, rank)
train_accuracy = correct / total
epsilon = privacy_engine.get_epsilon(delta=1e-5)
if rank == 0:
print(
f"Epoch: {e} \t"
f"Train Loss: {np.mean(losses):.4f} | "
f"Train Accuracy: {train_accuracy:.2f} | "
f"Test Accuracy: {test_accuracy:.2f} |"
f"(ε = {epsilon:.2f})"
)
cleanup()
Appending to opacus_ddp_demo.py
%%writefile -a opacus_ddp_demo.py
import torch.multiprocessing as mp
EPOCHS = 10
world_size = torch.cuda.device_count()
if __name__ == '__main__':
mp.spawn(
launch,
args=(world_size,EPOCHS,),
nprocs=world_size,
join=True
)
Appending to opacus_ddp_demo.py
And, finally, running the script. Notice, that we've initialized our DataLoader
with batch_size=200
, which is equivalent to 300 batches on the full dataset (60000 images).
After passing it to make_private
on each worker we have a data loader with batch_size=100
each, but each data loader still goes over 300 batches.
!python -W ignore opacus_ddp_demo.py
05/13/2022 11:13:16:INFO:(rank 0) Initialized model (DifferentiallyPrivateDistributedDataParallel), optimizer (SGD), data loader (DataLoader, len=300) 05/13/2022 11:13:16:INFO:(rank 1) Average batch size per GPU: 100 05/13/2022 11:13:16:INFO:(rank 0) After privatization: model (GradSampleModule), optimizer (DistributedDPOptimizer), data loader (DPDataLoader, len=300) 05/13/2022 11:13:16:INFO:(rank 0) Average batch size per GPU: 100 Epoch: 0 Train Loss: 1.5412 | Train Accuracy: 0.57 | Test Accuracy: 0.73 |(ε = 0.87) Epoch: 1 Train Loss: 0.6717 | Train Accuracy: 0.79 | Test Accuracy: 0.83 |(ε = 0.91) Epoch: 2 Train Loss: 0.5659 | Train Accuracy: 0.85 | Test Accuracy: 0.86 |(ε = 0.96) Epoch: 3 Train Loss: 0.5347 | Train Accuracy: 0.87 | Test Accuracy: 0.88 |(ε = 1.00) Epoch: 4 Train Loss: 0.5178 | Train Accuracy: 0.88 | Test Accuracy: 0.90 |(ε = 1.03) Epoch: 5 Train Loss: 0.4750 | Train Accuracy: 0.90 | Test Accuracy: 0.91 |(ε = 1.07) Epoch: 6 Train Loss: 0.4502 | Train Accuracy: 0.90 | Test Accuracy: 0.91 |(ε = 1.11) Epoch: 7 Train Loss: 0.4358 | Train Accuracy: 0.91 | Test Accuracy: 0.92 |(ε = 1.14) Epoch: 8 Train Loss: 0.4186 | Train Accuracy: 0.92 | Test Accuracy: 0.92 |(ε = 1.18) Epoch: 9 Train Loss: 0.4129 | Train Accuracy: 0.92 | Test Accuracy: 0.93 |(ε = 1.21)
Note: The following two chapters discuss the advanced usage of Opacus and its implementation details. We strongly recommend to read the tutorial on Advanced Features of Opacus before proceeding.
Now let's look inside make_private
method and see what it does to enable DDP processing. And we'll start with the modifications made to the DataLoader
.
As a reminder, DPDataLoader
is different from a regular DataLoader
in only one aspect - it samples data with uniform with replacement random sampler (a.k.a. "Poisson sampling"). It means, that instead of a fixed batch size we have a sampling rate: a probability with which every sample is included in the next batch.
Let's now initialize the regular data loader and then transform it to the DPDataLoader
. This is exactly how we do it in the make_private()
method.
Below we'll initialize three data loaders:
All three are initialized so that the logical batch size is 64.
%%writefile opacus_distributed_data_loader_demo.py
from opacus_ddp_demo import setup, cleanup, mnist_train_ds
import logging
from torch.utils.data import DataLoader, DistributedSampler
from opacus.data_loader import DPDataLoader
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
BATCH_SIZE = 64
def init_data(rank, world_size):
setup(rank, world_size)
non_distributed_dl = DataLoader(
mnist_train_ds,
batch_size=BATCH_SIZE
)
distributed_non_private_dl = DataLoader(
mnist_train_ds,
batch_size=BATCH_SIZE // world_size,
sampler=DistributedSampler(mnist_train_ds),
)
private_dl = DPDataLoader.from_data_loader(non_distributed_dl, distributed=True)
if rank == 0:
logger.info(
f"(rank {rank}) Non-distributed non-private data loader. "
f"#batches: {len(non_distributed_dl)}, "
f"#data points: {len(non_distributed_dl.sampler)}, "
f"batch_size: {non_distributed_dl.batch_size}"
)
logger.info(
f"(rank {rank}) Distributed, non-private data loader. "
f"#batches: {len(distributed_non_private_dl)}, "
f"#data points: {len(distributed_non_private_dl.sampler)}, "
f"batch_size: {distributed_non_private_dl.batch_size}"
)
logger.info(
f"(rank {rank}) Distributed, private data loader. "
f"#batches: {len(private_dl)}, "
f"#data points: {private_dl.batch_sampler.num_samples}, "
f"sample_rate: {private_dl.sample_rate:4f}, "
f"avg batch_size (=sample_rate*num_data_points): {int(private_dl.sample_rate*private_dl.batch_sampler.num_samples)}"
)
Writing opacus_distributed_data_loader_demo.py
%%writefile -a opacus_distributed_data_loader_demo.py
import torch
import torch.multiprocessing as mp
world_size = torch.cuda.device_count()
if __name__ == '__main__':
mp.spawn(
init_data,
args=(world_size,),
nprocs=world_size,
join=True
)
Appending to opacus_distributed_data_loader_demo.py
Let's see what happens when we run it - and what exactly does from_data_loader
factory did.
Notice, that our private DataLoader was initialized with a non-distributed, non-private data loader. And all the basic parameters (per GPU batch size and number of examples per GPU) match with distributed, non-private data loader.
!python -W ignore opacus_distributed_data_loader_demo.py
05/13/2022 11:14:53:INFO:(rank 0) Non-distributed non-private data loader. #batches: 938, #data points: 60000, batch_size: 64 05/13/2022 11:14:53:INFO:(rank 0) Distributed, non-private data loader. #batches: 938, #data points: 30000, batch_size: 32 05/13/2022 11:14:53:INFO:(rank 0) Distributed, private data loader. #batches: 938, #data points: 30000, sample_rate: 0.001066, avg batch_size (=sample_rate*num_data_points): 31
One significant difference between DDP
and DPDDP
is how it approaches synchronisation.
Normally with Distributed Data Parallel forward and backward passes are synchronisation points, and DDP
wrapper ensures that the gradients are synchronised across workers at the end of the backward pass.
Opacus, however, need a later synchronisation point. Before we can use the gradients, we need to clip them and add noise. This is done in the optimizer, which moves the synchronisation point from the backward pass to the optimization step.
Additionally, to simplify the calculations, we only add noise on worker with rank=0
, and use the noise scale calibrated to the combined batch across all workers.
%%writefile opacus_sync_demo.py
import sys
sys.path.append('/data/home/shilov/opacus')
from opacus_ddp_demo import setup, cleanup, mnist_train_ds, SampleConvNet
import logging
from torch.utils.data import DataLoader
import torch.optim as optim
from opacus.data_loader import DPDataLoader
from opacus import GradSampleModule
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from opacus.optimizers import DistributedDPOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
BATCH_SIZE = 64
LR = 64
def init_training(rank, world_size):
model = SampleConvNet()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0)
model = GradSampleModule(model)
model = DPDDP(model)
optimizer = DistributedDPOptimizer(
optimizer=optimizer,
noise_multiplier=0.,
max_grad_norm=100.,
expected_batch_size=BATCH_SIZE//world_size,
)
data_loader = DPDataLoader.from_data_loader(
data_loader=DataLoader(
mnist_train_ds,
batch_size=BATCH_SIZE,
),
distributed=True,
)
return model, optimizer, data_loader
Writing opacus_sync_demo.py
Now we've initialized DifferentiallyPrivateDistributedDataParallel
model and DistributedDPOptimizer
let's see how they work together.
DifferentiallyPrivateDistributedDataParallel
is a no-op: we only perform model synchronisation on initialization and do nothing on forward and backward passes.
DistributedDPOptimizer
, on the other hand does all the heavy lifting:
rank=0
onlytorch.distributed.all_reduce
and gradients on step()
, right before applying the gradients%%writefile -a opacus_sync_demo.py
import torch.nn as nn
import numpy as np
def launch(rank, world_size):
setup(rank, world_size)
criterion = nn.CrossEntropyLoss()
model, optimizer, data_loader = init_training(rank, world_size)
model.to(rank)
model.train()
for data, target in data_loader:
data = data
target = torch.tensor(target)
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
flat_grad = torch.cat([p.grad_sample.sum(dim=0).view(-1) for p in model.parameters()]).cpu().numpy() / optimizer.expected_batch_size
logger.info(
f"(rank={rank}) Gradient norm before optimizer.step(): {np.linalg.norm(flat_grad):.4f}"
)
logger.info(
f"(rank={rank}) Gradient sample before optimizer.step(): {flat_grad[:3]}"
)
optimizer.step()
flat_grad = torch.cat([p.grad.view(-1) for p in model.parameters()]).cpu().numpy()
logger.info(
f"(rank={rank}) Gradient norm after optimizer.step(): {np.linalg.norm(flat_grad):.4f}"
)
logger.info(
f"(rank={rank}) Gradient sample after optimizer.step(): {flat_grad[:3]}"
)
break
cleanup()
Appending to opacus_sync_demo.py
%%writefile -a opacus_sync_demo.py
import torch.multiprocessing as mp
import torch
world_size = torch.cuda.device_count()
if __name__ == '__main__':
mp.spawn(
launch,
args=(world_size,),
nprocs=world_size,
join=True
)
Appending to opacus_sync_demo.py
When we run the code, notice that the gradients are not synchronised after loss.backward()
, but only after optimizer.step()
. For this example, we've set privacy parameters to effectively disable noise and clipping, so the synchronised gradient is indeed the average between individual worker's gradients.
!python -W ignore opacus_sync_demo.py
05/13/2022 11:15:22:INFO:(rank=1) Gradient norm before optimizer.step(): 0.9924 05/13/2022 11:15:22:INFO:(rank=1) Gradient sample before optimizer.step(): [-0.00525815 -0.01079952 -0.01051272] 05/13/2022 11:15:22:INFO:(rank=0) Gradient norm before optimizer.step(): 1.7812 05/13/2022 11:15:22:INFO:(rank=0) Gradient sample before optimizer.step(): [-0.0181896 -0.02559735 -0.02745825] 05/13/2022 11:15:22:INFO:(rank=0) Gradient norm after optimizer.step(): 1.2387 05/13/2022 11:15:22:INFO:(rank=1) Gradient norm after optimizer.step(): 1.2387 05/13/2022 11:15:22:INFO:(rank=0) Gradient sample after optimizer.step(): [-0.01172432 -0.01819846 -0.01898623] 05/13/2022 11:15:22:INFO:(rank=1) Gradient sample after optimizer.step(): [-0.01172432 -0.01819846 -0.01898623]
%%bash
rm opacus_ddp_demo.py
rm opacus_distributed_data_loader_demo.py
rm opacus_sync_demo.py