In this tutorial, we will train a text classifier with Differential Privacy by taking a model pre-trained on public text data and fine-tuning it for a different task.
When training a model with differential privacy, we almost always face a trade-off between model size and accuracy on the task. The exact details depend on the problem, but a rule of thumb is that the fewer parameters the model has, the easier it is to get good performance with DP.
Most state-of-the-art NLP models are quite deep and large (e.g. BERT-base has over 100M parameters), which makes the task of training text models on private datasets rather challenging.
One way of addressing this problem is to divide the training process into two stages. First, we will pre-train the model on a public dataset, exposing the model to generic text data. Assuming that the generic text data is public, we will not be using differential privacy at this step. Then, we freeze most of the layers, leaving only a few upper layers to be trained on the private dataset using DP-SGD. This way we can get the best of both worlds - we have a deep and powerful text understanding model, while only training a small number of parameters with differentially private algorithm.
In this tutorial, we will take the pre-trained BERT-base model and fine-tune it to recognize textual entailment on the SNLI dataset.
We also fine-tune it with Ghost Clipping DP-SGD, a memory-efficient implementation of DP-SGD, which enables the use of large batch sizes.
First, we need to download the dataset (we'll use Stanford NLP mirror)
STANFORD_SNLI_URL = "https://nlp.stanford.edu/projects/snli/snli_1.0.zip"
DATA_DIR = "data"
import zipfile
import urllib.request
import os
import warnings
warnings.simplefilter("ignore")
def download_and_extract(dataset_url, data_dir):
print("Downloading and extracting ...")
filename = "snli.zip"
urllib.request.urlretrieve(dataset_url, filename)
with zipfile.ZipFile(filename) as zip_ref:
zip_ref.extractall(data_dir)
os.remove(filename)
print("Completed!")
download_and_extract(STANFORD_SNLI_URL, DATA_DIR)
Downloading and extracting ... Completed!
The dataset comes in two formats (tsv
and json
) and has already been split into train/dev/test. Let’s verify that’s the case.
snli_folder = os.path.join(DATA_DIR, "snli_1.0")
os.listdir(snli_folder)
['snli_1.0_dev.txt', 'README.txt', 'snli_1.0_dev.jsonl', 'Icon\r', '.DS_Store', 'snli_1.0_test.txt', 'snli_1.0_train.jsonl', 'snli_1.0_test.jsonl', 'snli_1.0_train.txt']
Let's now take a look inside. SNLI dataset provides ample syntactic metadata, but we'll only use raw input text. Therefore, the only fields we're interested in are sentence1 (premise), sentence2 (hypothesis), and gold_label (label chosen by the majority of annotators).
The label defines the relation between premise and hypothesis: either contradiction, neutral, or entailment.
import pandas as pd
train_path = os.path.join(snli_folder, "snli_1.0_train.txt")
dev_path = os.path.join(snli_folder, "snli_1.0_dev.txt")
df_train = pd.read_csv(train_path, sep='\t')
df_test = pd.read_csv(dev_path, sep='\t')
df_train[['sentence1', 'sentence2', 'gold_label']][:5]
sentence1 | sentence2 | gold_label | |
---|---|---|---|
0 | A person on a horse jumps over a broken down a... | A person is training his horse for a competition. | neutral |
1 | A person on a horse jumps over a broken down a... | A person is at a diner, ordering an omelette. | contradiction |
2 | A person on a horse jumps over a broken down a... | A person is outdoors, on a horse. | entailment |
3 | Children smiling and waving at camera | They are smiling at their parents | neutral |
4 | Children smiling and waving at camera | There are children present | entailment |
BERT (Bidirectional Encoder Representations from Transformers) is a state-of-the-art approach to various NLP tasks. It uses a Transformer architecture and relies heavily on the concept of pre-training.
We'll use a pre-trained BERT-base model, provided in the huggingface transformers repo. It gives us a PyTorch implementation for the classic BERT architecture, as well as a tokenizer and weights, pre-trained on a public English corpus (Wikipedia).
Please follow these installation instructions before proceeding.
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
model_name = "bert-base-cased"
config = BertConfig.from_pretrained(
model_name,
num_labels=3,
)
tokenizer = BertTokenizer.from_pretrained(
"bert-base-cased",
do_lower_case=False,
)
model = BertForSequenceClassification.from_pretrained(
"bert-base-cased",
config=config,
)
To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html 100%|██████████| 433/433 [00:00<00:00, 455171.34B/s] 100%|██████████| 213450/213450 [00:00<00:00, 37577090.82B/s] 100%|██████████| 435779157/435779157 [00:11<00:00, 39433911.33B/s]
The model has the following structure. It uses a combination of word, positional and token embeddings to create a sequence representation, then passes the data through 12 transformer encoders and finally uses a linear classifier to produce the final label.
As the model is already pre-trained and we only plan to fine-tune a few upper layers, we want to freeze all layers, except for the last encoder and above (BertPooler
and Classifier
).
from IPython.display import Image
Image(filename='img/BERT.png')
trainable_layers = [model.bert.encoder.layer[-1], model.bert.pooler, model.classifier]
total_params = 0
trainable_params = 0
for p in model.parameters():
p.requires_grad = False
total_params += p.numel()
for layer in trainable_layers:
for p in layer.parameters():
p.requires_grad = True
trainable_params += p.numel()
print(f"Total parameters count: {total_params}") # ~108M
print(f"Trainable parameters count: {trainable_params}") # ~7M
Total parameters count: 108312579 Trainable parameters count: 7680771
Thus, by using a pre-trained model we reduce the number of trainable params from over 100 million to just above 7.5 million. This will help both performance and convergence with added noise.
Before we begin training, we need to preprocess the data and convert it to the format our model expects.
(Note: it'll take 5-10 minutes to run on a laptop)
LABEL_LIST = ['contradiction', 'entailment', 'neutral']
MAX_SEQ_LENGHT = 128
import torch
import torch.nn as nn
import transformers
from torch.utils.data import TensorDataset
from transformers.data.processors.utils import InputExample
from transformers.data.processors.glue import glue_convert_examples_to_features
def _create_examples(df, set_type):
""" Convert raw dataframe to a list of InputExample. Filter malformed examples
"""
examples = []
for index, row in df.iterrows():
if row['gold_label'] not in LABEL_LIST:
continue
if not isinstance(row['sentence1'], str) or not isinstance(row['sentence2'], str):
continue
guid = f"{index}-{set_type}"
examples.append(
InputExample(guid=guid, text_a=row['sentence1'], text_b=row['sentence2'], label=row['gold_label']))
return examples
def _df_to_features(df, set_type):
""" Pre-process text. This method will:
1) tokenize inputs
2) cut or pad each sequence to MAX_SEQ_LENGHT
3) convert tokens into ids
The output will contain:
`input_ids` - padded token ids sequence
`attention mask` - mask indicating padded tokens
`token_type_ids` - mask indicating the split between premise and hypothesis
`label` - label
"""
examples = _create_examples(df, set_type)
#backward compatibility with older transformers versions
legacy_kwards = {}
from packaging import version
if version.parse(transformers.__version__) < version.parse("2.9.0"):
legacy_kwards = {
"pad_on_left": False,
"pad_token": tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
"pad_token_segment_id": 0,
}
return glue_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
label_list=LABEL_LIST,
max_length=MAX_SEQ_LENGHT,
output_mode="classification",
**legacy_kwards,
)
def _features_to_dataset(features):
""" Convert features from `_df_to_features` into a single dataset
"""
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor(
[f.attention_mask for f in features], dtype=torch.long
)
all_token_type_ids = torch.tensor(
[f.token_type_ids for f in features], dtype=torch.long
)
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
dataset = TensorDataset(
all_input_ids, all_attention_mask, all_token_type_ids, all_labels
)
return dataset
train_features = _df_to_features(df_train, "train")
test_features = _df_to_features(df_test, "test")
train_dataset = _features_to_dataset(train_features)
test_dataset = _features_to_dataset(test_features)
Let's talk about batch sizes for a bit.
In addition to all the considerations you normally take into account when choosing batch size, training models with DP adds another one - privacy cost.
Because of the threat model we assume and the way we add noise to the gradients, larger batch sizes (to a certain extent) generally help convergence. We add the same amount of noise to each gradient update (scaled to the norm of one sample in the batch) regardless of the batch size. What this means is that as the batch size increases, the relative amount of noise added decreases. while preserving the same epsilon guarantee.
You should, however, keep in mind that increasing batch size has its price in terms of epsilon, which grows at O(sqrt(batch_size))
as we train (therefore larger batches make it grow faster). The good strategy here is to experiment with multiple combinations of batch_size
and noise_multiplier
to find the one that provides the best possible quality at acceptable privacy guarantee.
There's another side to this - memory. Opacus computes and stores per sample gradients, so for every normal gradient, Opacus will store n=batch_size
per-sample gradients on each step, thus increasing the memory footprint by at least O(batch_size)
. In reality, however, the peak memory requirement is O(batch_size^2)
compared to a non-private model. This is because some intermediate steps in per sample gradient computation involve operations on two matrices, each with batch_size as one of the dimensions.
The good news is, we can pick the most appropriate batch size, regardless of memory constraints. Opacus has built-in support for virtual batches. Using it we can separate physical steps (gradient computation) and logical steps (noise addition and parameter updates): use larger batches for training, while keeping memory footprint low. Below we will specify two constants:
MAX_PHYSICAL_BATCH_SIZE
defines the maximum batch size we can afford from a memory standpoint, and only affects computation speedBATCH_SIZE
, on the other hand, will affect only convergence and privacy guarantee.BATCH_SIZE = 32
MAX_PHYSICAL_BATCH_SIZE = 8
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from opacus.utils.uniform_sampler import UniformWithReplacementSampler
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=BATCH_SIZE)
# Move the model to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Set the model to train mode (HuggingFace models load in eval mode)
model = model.train()
# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, eps=1e-8)
First, we specify some training parameters ready to run the training loop for three epochs
EPOCHS = 3
LOGGING_INTERVAL = 5000 # once every how many steps we run evaluation cycle and report metrics
EPSILON = 7.5
DELTA = 1 / len(train_dataloader) # Parameter for privacy accounting. Probability of not achieving privacy guarantees
Let’s now define the evaluation cycle.
import numpy as np
from tqdm.notebook import tqdm
def accuracy(preds, labels):
return (preds == labels).mean()
# define evaluation cycle
def evaluate(model):
model.eval()
loss_arr = []
accuracy_arr = []
for batch in test_dataloader:
batch = tuple(t.to(device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2],
'labels': batch[3]}
outputs = model(**inputs)
loss, logits = outputs[:2]
preds = np.argmax(logits.detach().cpu().numpy(), axis=1)
labels = inputs['labels'].detach().cpu().numpy()
loss_arr.append(loss.item())
accuracy_arr.append(accuracy(preds, labels))
model.train()
return np.mean(loss_arr), np.mean(accuracy_arr)
Next, we will define and attach PrivacyEngine. There are two parameters you need to consider here:
noise_multiplier
. It defines the trade-off between privacy and accuracy. Adding more noise will provide stronger privacy guarantees, but will also hurt model quality. In this run, the PrivacyEngine will determine this value based on the target values of EPSILON
, DELTA
, and EPOCHS
. For the default settings, this will set noise_multiplier
to about 0.4.max_grad_norm
. Defines the maximum magnitude of L2 norms to which we clip per sample gradients. There is a bit of tug of war with this threshold: on the one hand, a low threshold means that we will clip many gradients, hurting convergence, so we might be tempted to raise it. However, recall that we add noise with std=noise_multiplier * max_grad_norm
so we will pay for the increased threshold with more noise. In most cases you can rely on the model being quite resilient to clipping (after the first few iterations your model will tend to adjust so that its gradients stay below the clipping threshold), so you can often just keep the default value (=1.0
) and focus on tuning batch_size
and noise_multiplier
instead. That being said, sometimes clipping hurts the model so it may be worth experimenting with different clipping thresholds, like we are doing in this tutorial.These two parameters define the scale of the noise we add to gradients: the noise will be sampled from a Gaussian distribution with std=noise_multiplier * max_grad_norm
.
from opacus import PrivacyEngine
MAX_GRAD_NORM = 0.1
privacy_engine = PrivacyEngine()
model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dataloader,
target_delta=DELTA,
target_epsilon=EPSILON,
epochs=EPOCHS,
max_grad_norm=MAX_GRAD_NORM,
)
Now we can train the model.
from opacus.utils.batch_memory_manager import BatchMemoryManager
for epoch in range(1, EPOCHS+1):
losses = []
with BatchMemoryManager(
data_loader=train_dataloader,
max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
optimizer=optimizer
) as memory_safe_data_loader:
for step, batch in enumerate(tqdm(memory_safe_data_loader)):
optimizer.zero_grad()
batch = tuple(t.to(device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2],
'labels': batch[3]}
outputs = model(**inputs) # output = loss, logits, hidden_states, attentions
loss = outputs[0]
loss.backward()
losses.append(loss.item())
optimizer.step()
if step > 0 and step % LOGGING_INTERVAL == 0:
train_loss = np.mean(losses)
eps = privacy_engine.get_epsilon(DELTA)
eval_loss, eval_accuracy = evaluate(model)
print(
f"Epoch: {epoch} | "
f"Step: {step} | "
f"Train loss: {train_loss:.3f} | "
f"Eval loss: {eval_loss:.3f} | "
f"Eval accuracy: {eval_accuracy:.3f} | "
f"ɛ: {eps:.2f}"
)
For the test accuracy, after training for three epochs you should expect something close to the results below.
You can see that we can achieve quite strong privacy guarantee at epsilon=7.5 with a moderate accuracy cost of 11 percentage points compared to non-private model trained in a similar setting (upper layers only) and 16 points compared to best results we were able to achieve using the same architecture.
NB: When not specified, DP-SGD is trained with upper layers only
Model | Noise multiplier | Batch size | Accuracy | Epsilon |
---|---|---|---|---|
no DP, train full model | N/A | 32 | 90.1% | N/A |
no DP, train upper layers only | N/A | 32 | 85.4% | N/A |
DP-SGD | 1.0 | 32 | 70.5% | 0.7 |
DP-SGD (this tutorial) | 0.4 | 32 | 74.3% | 7.5 |
DP-SGD | 0.3 | 32 | 75.8% | 20.7 |
DP-SGD | 0.1 | 32 | 78.3% | 2865 |
DP-SGD | 0.4 | 8 | 67.3% | 5.9 |
In this section, we show how to use Fast Gradient Clipping and Ghost Clipping DP-SGD. The training loop is nearly identical to the existing one in Opacus, which was based on the (non-private) PyTorch training loop. To use Fast Gradient Clipping, we need to pass grad_sample_mode = 'ghost' in the make_private function.
The other change is that privacy engine's make_private function takes the loss criterion as input too and sanitizes it. This allows us to repurpose loss.backward to do two backward passes, and a loss rescaling in between. The first backward computes per-sample gradient norms, where as the second backward on the rescaled loss computes the aggregard clipped gradient
device = torch.device("cuda:0")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, eps=1e-8)
model = model.train()
privacy_engine = PrivacyEngine()
criterion = nn.CrossEntropyLoss(reduction="mean")
model_gc, optimizer_gc, criterion_gc, train_dataloader = (
privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dataloader,
criterion=criterion,
target_delta=DELTA,
target_epsilon=EPSILON,
epochs=EPOCHS,
max_grad_norm=MAX_GRAD_NORM,
grad_sample_mode="ghost",
)
)
model_gc = model_gc.to(device)
model_gc = model_gc.train()
for epoch in range(1, EPOCHS + 1):
losses = []
for step, batch in enumerate(tqdm(train_dataloader)):
optimizer_gc.zero_grad()
batch = tuple(t.to(device) for t in batch)
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
"labels": batch[3],
}
outputs = model_gc(**inputs) # output = loss, logits, hidden_states, attentions
loss = criterion_gc(outputs[1], batch[3])
loss.backward()
optimizer_gc.step()
losses.append(loss.item())
if step > 0 and step % LOGGING_INTERVAL == 0:
train_loss = np.mean(losses)
eval_loss, eval_accuracy = evaluate(model_gc)
eps = privacy_engine.get_epsilon(DELTA)
print(
f"Epoch: {epoch} | "
f"Step: {step} | "
f"Train loss: {train_loss:.3f} | "
f"Eval loss: {eval_loss:.3f} | "
f"Eval accuracy: {eval_accuracy:.3f} | "
f"ɛ: {eps:.2f}"
)
Epoch: 1 | Step: 500 | Train loss: 1.209 | Eval loss: 1.409 | Eval accuracy: 0.443 | ɛ: 5.25 Epoch: 1 | Step: 1000 | Train loss: 1.273 | Eval loss: 1.496 | Eval accuracy: 0.481 | ɛ: 6.12 Epoch: 1 | Step: 1500 | Train loss: 1.316 | Eval loss: 1.514 | Eval accuracy: 0.537 | ɛ: 6.72