In this tutorial, we will build a differentially-private LSTM model to classify names to their source languages, which is the same task as in the tutorial NLP From Scratch (https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html). Since the objective of this tutorial is to demonstrate the effective use of an LSTM with privacy guarantees, we will be utilizing it in place of the bare-bones RNN model defined in the original tutorial. Specifically, we use the DPLSTM
module from opacus.layers.dp_lstm
to facilitate the calculation of the per-example gradients, which are utilized in the addition of noise during the application of differential privacy. DPLSTM
has the same API and functionality as the nn.LSTM
, with some restrictions (ex. we currently support single layers, the full list is given below).
First, let us download the dataset of names and their associated language labels as given in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html. We train our differentially-private LSTM on the same dataset as in that tutorial.
import warnings
warnings.simplefilter("ignore")
import os
import requests
NAMES_DATASET_URL = "https://download.pytorch.org/tutorial/data.zip"
DATA_DIR = "names"
import zipfile
import urllib
def download_and_extract(dataset_url, data_dir):
print("Downloading and extracting ...")
filename = "data.zip"
urllib.request.urlretrieve(dataset_url, filename)
with zipfile.ZipFile(filename) as zip_ref:
zip_ref.extractall(data_dir)
os.remove(filename)
print("Completed!")
download_and_extract(NAMES_DATASET_URL, DATA_DIR)
Downloading and extracting ... Completed!
names_folder = os.path.join(DATA_DIR, 'data', 'names')
all_filenames = []
for language_file in os.listdir(names_folder):
all_filenames.append(os.path.join(names_folder, language_file))
print(os.listdir(names_folder))
['Italian.txt', 'Arabic.txt', 'English.txt', 'German.txt', 'French.txt', 'Spanish.txt', 'Greek.txt', 'Dutch.txt', 'Korean.txt', 'Portuguese.txt', 'Japanese.txt', 'Polish.txt', 'Irish.txt', 'Chinese.txt', 'Russian.txt', 'Czech.txt', 'Vietnamese.txt', 'Scottish.txt']
import torch
import torch.nn as nn
class CharByteEncoder(nn.Module):
"""
This encoder takes a UTF-8 string and encodes its bytes into a Tensor. It can also
perform the opposite operation to check a result.
Examples:
>>> encoder = CharByteEncoder()
>>> t = encoder('Ślusàrski') # returns tensor([256, 197, 154, 108, 117, 115, 195, 160, 114, 115, 107, 105, 257])
>>> encoder.decode(t) # returns "<s>Ślusàrski</s>"
"""
def __init__(self):
super().__init__()
self.start_token = "<s>"
self.end_token = "</s>"
self.pad_token = "<pad>"
self.start_idx = 256
self.end_idx = 257
self.pad_idx = 258
def forward(self, s: str, pad_to=0) -> torch.LongTensor:
"""
Encodes a string. It will append a start token <s> (id=self.start_idx) and an end token </s>
(id=self.end_idx).
Args:
s: The string to encode.
pad_to: If not zero, pad by appending self.pad_idx until string is of length `pad_to`.
Defaults to 0.
Returns:
The encoded LongTensor of indices.
"""
encoded = s.encode()
n_pad = pad_to - len(encoded) if pad_to > len(encoded) else 0
return torch.LongTensor(
[self.start_idx]
+ [c for c in encoded] # noqa
+ [self.end_idx]
+ [self.pad_idx for _ in range(n_pad)]
)
def decode(self, char_ids_tensor: torch.LongTensor) -> str:
"""
The inverse of `forward`. Keeps the start, end, and pad indices.
"""
char_ids = char_ids_tensor.cpu().detach().tolist()
out = []
buf = []
for c in char_ids:
if c < 256:
buf.append(c)
else:
if buf:
out.append(bytes(buf).decode())
buf = []
if c == self.start_idx:
out.append(self.start_token)
elif c == self.end_idx:
out.append(self.end_token)
elif c == self.pad_idx:
out.append(self.pad_token)
if buf: # in case some are left
out.append(bytes(buf).decode())
return "".join(out)
def __len__(self):
"""
The length of our encoder space. This is fixed to 256 (one byte) + 3 special chars
(start, end, pad).
Returns:
259
"""
return 259
from torch.nn.utils.rnn import pad_sequence
def padded_collate(batch, padding_idx=0):
x = pad_sequence(
[elem[0] for elem in batch], batch_first=True, padding_value=padding_idx
)
y = torch.stack([elem[1] for elem in batch]).long()
return x, y
from torch.utils.data import Dataset
from pathlib import Path
class NamesDataset(Dataset):
def __init__(self, root):
self.root = Path(root)
self.labels = list({langfile.stem for langfile in self.root.iterdir()})
self.labels_dict = {label: i for i, label in enumerate(self.labels)}
self.encoder = CharByteEncoder()
self.samples = self.construct_samples()
def __getitem__(self, i):
return self.samples[i]
def __len__(self):
return len(self.samples)
def construct_samples(self):
samples = []
for langfile in self.root.iterdir():
label_name = langfile.stem
label_id = self.labels_dict[label_name]
with open(langfile, "r") as fin:
for row in fin:
samples.append(
(self.encoder(row.strip()), torch.tensor(label_id).long())
)
return samples
def label_count(self):
cnt = Counter()
for _x, y in self.samples:
label = self.labels[int(y)]
cnt[label] += 1
return cnt
VOCAB_SIZE = 256 + 3 # 256 alternatives in one byte, plus 3 special characters.
We split the dataset into a 80-20 split for training and validation.
secure_mode = False
train_split = 0.8
test_every = 5
batch_size = 800
ds = NamesDataset(names_folder)
train_len = int(train_split * len(ds))
test_len = len(ds) - train_len
print(f"{train_len} samples for training, {test_len} for testing")
train_ds, test_ds = torch.utils.data.random_split(ds, [train_len, test_len])
16059 samples for training, 4015 for testing
from torch.utils.data import DataLoader
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
pin_memory=True,
collate_fn=padded_collate,
)
test_loader = DataLoader(
test_ds,
batch_size=2 * batch_size,
shuffle=False,
pin_memory=True,
collate_fn=padded_collate,
)
After splitting the dataset into a training and a validation set, we now have to convert the data into a numeric form suitable for training the LSTM model. For each name, we set a maximum sequence length of 15, and if a name is longer than the threshold, we truncate it (this rarely happens in this dataset!). If a name is smaller than the threshold, we add a dummy #
character to pad it to the desired length. We also batch the names in the dataset and set a batch size of 256 for all the experiments in this tutorial. The function line_to_tensor()
returns a tensor of shape [15, 256] where each element is the index (in all_letters
) of the corresponding character.
The training and the evaluation functions train()
and test()
are defined below. During the training loop, the per-example gradients are computed and the parameters are updated subsequent to gradient clipping (to bound their sensitivity) and addition of noise.
from statistics import mean
def train(model, criterion, optimizer, train_loader, epoch, privacy_engine, device="cuda:0"):
accs = []
losses = []
for x, y in train_loader:
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
preds = logits.argmax(-1)
n_correct = float(preds.eq(y).sum())
batch_accuracy = n_correct / len(y)
accs.append(batch_accuracy)
losses.append(float(loss))
printstr = (
f"\t Epoch {epoch}. Accuracy: {mean(accs):.6f} | Loss: {mean(losses):.6f}"
)
if privacy_engine:
epsilon = privacy_engine.get_epsilon(delta)
printstr += f" | (ε = {epsilon:.2f}, δ = {delta})"
print(printstr)
return
def test(model, test_loader, privacy_engine, device="cuda:0"):
accs = []
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
preds = model(x).argmax(-1)
n_correct = float(preds.eq(y).sum())
batch_accuracy = n_correct / len(y)
accs.append(batch_accuracy)
printstr = "\n----------------------------\n" f"Test Accuracy: {mean(accs):.6f}"
if privacy_engine:
epsilon = privacy_engine.get_epsilon(delta)
printstr += f" (ε = {epsilon:.2f}, δ = {delta})"
print(printstr + "\n----------------------------\n")
return
There are two sets of hyper-parameters associated with this model. The first are hyper-parameters which we would expect in any machine learning training, such as the learning rate and batch size. The second set are related to the privacy engine, where for example we define the amount of noise added to the gradients (noise_multiplier
), and the maximum L2 norm to which the per-sample gradients are clipped (max_grad_norm
).
# Training hyper-parameters
epochs = 50
learning_rate = 2.0
# Privacy engine hyper-parameters
max_per_sample_grad_norm = 1.5
delta = 8e-5
epsilon = 12.0
We define the name classification model in the cell below. Note that it is a simple char-LSTM classifier, where the input characters are passed through an nn.Embedding
layer, and are subsequently input to the DPLSTM.
import torch
from torch import nn
from opacus.layers import DPLSTM
class CharNNClassifier(nn.Module):
def __init__(
self,
embedding_size,
hidden_size,
output_size,
num_lstm_layers=1,
bidirectional=False,
vocab_size=VOCAB_SIZE,
):
super().__init__()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.output_size = output_size
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.lstm = DPLSTM(
embedding_size,
hidden_size,
num_layers=num_lstm_layers,
bidirectional=bidirectional,
batch_first=True,
)
self.out_layer = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden=None):
x = self.embedding(x) # -> [B, T, D]
x, _ = self.lstm(x, hidden) # -> [B, T, H]
x = x[:, -1, :] # -> [B, H]
x = self.out_layer(x) # -> [B, C]
return x
We now proceed to instantiate the objects (privacy engine, model and optimizer) for our differentially-private LSTM training. However, the nn.LSTM
is replaced with a DPLSTM
module which enables us to calculate per-example gradients.
# Set the device to run on a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define classifier parameters
embedding_size = 64
hidden_size = 128 # Number of neurons in hidden layer after LSTM
n_lstm_layers = 1
bidirectional_lstm = False
model = CharNNClassifier(
embedding_size,
hidden_size,
len(ds.labels),
n_lstm_layers,
bidirectional_lstm,
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine(secure_mode=secure_mode)
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
max_grad_norm=max_per_sample_grad_norm,
target_delta=delta,
target_epsilon=epsilon,
epochs=epochs,
)
Finally, we can start training! We will be training for 50 epochs iterations (where each epoch corresponds to a pass over the whole dataset). We will be reporting the privacy epsilon every test_every
epoch. We will also benchmark this differentially-private model against a model without privacy and obtain almost identical performance. Further, the private model trained with Opacus incurs only minimal overhead in training time, with the differentially-private classifier only slightly slower (by a couple of minutes) than the non-private model.
print("Train stats: \n")
for epoch in range(epochs):
train(model, criterion, optimizer, train_loader, epoch, privacy_engine, device=device)
if test_every:
if epoch % test_every == 0:
test(model, test_loader, privacy_engine, device=device)
test(model, test_loader, privacy_engine, device=device)
Train stats: Epoch 0. Accuracy: 0.428835 | Loss: 2.220773 ---------------------------- Test Accuracy: 0.469154 (ε = 2.30, δ = 8e-05) ---------------------------- Epoch 1. Accuracy: 0.472534 | Loss: 1.895850 Epoch 2. Accuracy: 0.471778 | Loss: 1.893783 Epoch 3. Accuracy: 0.459604 | Loss: 1.958717 Epoch 4. Accuracy: 0.491896 | Loss: 1.782331 Epoch 5. Accuracy: 0.540205 | Loss: 1.577036 ---------------------------- Test Accuracy: 0.559490 (ε = 4.16, δ = 8e-05) ---------------------------- Epoch 6. Accuracy: 0.593796 | Loss: 1.456133 Epoch 7. Accuracy: 0.616827 | Loss: 1.388250 Epoch 8. Accuracy: 0.632560 | Loss: 1.345773 Epoch 9. Accuracy: 0.639074 | Loss: 1.327238 Epoch 10. Accuracy: 0.650502 | Loss: 1.316831 ---------------------------- Test Accuracy: 0.650821 (ε = 5.43, δ = 8e-05) ---------------------------- Epoch 11. Accuracy: 0.649294 | Loss: 1.315323 Epoch 12. Accuracy: 0.656350 | Loss: 1.288794 Epoch 13. Accuracy: 0.656104 | Loss: 1.285352 Epoch 14. Accuracy: 0.656424 | Loss: 1.283710 Epoch 15. Accuracy: 0.666633 | Loss: 1.273102 ---------------------------- Test Accuracy: 0.667164 (ε = 6.51, δ = 8e-05) ---------------------------- Epoch 16. Accuracy: 0.672707 | Loss: 1.247125 Epoch 17. Accuracy: 0.680121 | Loss: 1.223817 Epoch 18. Accuracy: 0.686456 | Loss: 1.214923 Epoch 19. Accuracy: 0.694982 | Loss: 1.193048 Epoch 20. Accuracy: 0.694282 | Loss: 1.184953 ---------------------------- Test Accuracy: 0.682519 (ε = 7.46, δ = 8e-05) ---------------------------- Epoch 21. Accuracy: 0.701802 | Loss: 1.161172 Epoch 22. Accuracy: 0.706358 | Loss: 1.166274 Epoch 23. Accuracy: 0.722667 | Loss: 1.097268 Epoch 24. Accuracy: 0.703950 | Loss: 1.185700 Epoch 25. Accuracy: 0.720196 | Loss: 1.112226 ---------------------------- Test Accuracy: 0.707127 (ε = 8.33, δ = 8e-05) ---------------------------- Epoch 26. Accuracy: 0.720644 | Loss: 1.115221 Epoch 27. Accuracy: 0.708652 | Loss: 1.158104 Epoch 28. Accuracy: 0.724744 | Loss: 1.119688 Epoch 29. Accuracy: 0.733490 | Loss: 1.088846 Epoch 30. Accuracy: 0.729441 | Loss: 1.089938 ---------------------------- Test Accuracy: 0.701941 (ε = 9.15, δ = 8e-05) ---------------------------- Epoch 31. Accuracy: 0.731014 | Loss: 1.096586 Epoch 32. Accuracy: 0.736907 | Loss: 1.065786 Epoch 33. Accuracy: 0.733743 | Loss: 1.098627 Epoch 34. Accuracy: 0.741741 | Loss: 1.064197 Epoch 35. Accuracy: 0.742394 | Loss: 1.053995 ---------------------------- Test Accuracy: 0.720777 (ε = 9.93, δ = 8e-05) ---------------------------- Epoch 36. Accuracy: 0.749420 | Loss: 1.034596 Epoch 37. Accuracy: 0.748662 | Loss: 1.037211 Epoch 38. Accuracy: 0.745869 | Loss: 1.061525 Epoch 39. Accuracy: 0.751734 | Loss: 1.022538 Epoch 40. Accuracy: 0.751194 | Loss: 1.028292 ---------------------------- Test Accuracy: 0.744636 (ε = 10.67, δ = 8e-05) ---------------------------- Epoch 41. Accuracy: 0.754300 | Loss: 1.032082 Epoch 42. Accuracy: 0.753252 | Loss: 1.017024 Epoch 43. Accuracy: 0.755629 | Loss: 1.035767 Epoch 44. Accuracy: 0.758195 | Loss: 1.029165 Epoch 45. Accuracy: 0.751091 | Loss: 1.028669 ---------------------------- Test Accuracy: 0.739427 (ε = 11.38, δ = 8e-05) ---------------------------- Epoch 46. Accuracy: 0.760692 | Loss: 0.995788 Epoch 47. Accuracy: 0.763821 | Loss: 0.990309 Epoch 48. Accuracy: 0.763423 | Loss: 0.997126 Epoch 49. Accuracy: 0.767976 | Loss: 0.982944 ---------------------------- Test Accuracy: 0.752090 (ε = 11.93, δ = 8e-05) ----------------------------
The differentially-private name classification model obtains a test accuracy of 0.75 with an epsilon of just under 12. This shows that we can achieve good accuracy on this task, with minimal loss of privacy.
We also run a comparison with a non-private model to see if the performance obtained with privacy is comparable to it. To do this, we keep the parameters such as learning rate and batch size the same, and only define a different instance of the model along with a separate optimizer.
model_nodp = CharNNClassifier(
embedding_size,
hidden_size,
len(ds.labels),
n_lstm_layers,
bidirectional_lstm,
).to(device)
optimizer_nodp = torch.optim.SGD(model_nodp.parameters(), lr=0.5)
for epoch in range(epochs):
train(model_nodp, criterion, optimizer_nodp, train_loader, epoch, device=device)
if test_every:
if epoch % test_every == 0:
test(model_nodp, test_loader, None, device=device)
test(model_nodp, test_loader, None, device=device)
Epoch 0. Accuracy: 0.423231 | Loss: 1.957621 ---------------------------- Test Accuracy: 0.469154 ---------------------------- Epoch 1. Accuracy: 0.470835 | Loss: 1.850998 Epoch 2. Accuracy: 0.461741 | Loss: 1.845881 Epoch 3. Accuracy: 0.466039 | Loss: 1.848411 Epoch 4. Accuracy: 0.470612 | Loss: 1.857506 Epoch 5. Accuracy: 0.460152 | Loss: 1.845789 ---------------------------- Test Accuracy: 0.469154 ---------------------------- Epoch 6. Accuracy: 0.477714 | Loss: 1.775618 Epoch 7. Accuracy: 0.518488 | Loss: 1.622382 Epoch 8. Accuracy: 0.535421 | Loss: 1.565642 Epoch 9. Accuracy: 0.545521 | Loss: 1.511846 Epoch 10. Accuracy: 0.543908 | Loss: 1.514014 ---------------------------- Test Accuracy: 0.575170 ---------------------------- Epoch 11. Accuracy: 0.561950 | Loss: 1.454853 Epoch 12. Accuracy: 0.605502 | Loss: 1.388555 Epoch 13. Accuracy: 0.607155 | Loss: 1.367188 Epoch 14. Accuracy: 0.615066 | Loss: 1.346803 Epoch 15. Accuracy: 0.621913 | Loss: 1.332553 ---------------------------- Test Accuracy: 0.635465 ---------------------------- Epoch 16. Accuracy: 0.619772 | Loss: 1.314691 Epoch 17. Accuracy: 0.629337 | Loss: 1.302999 Epoch 18. Accuracy: 0.634173 | Loss: 1.277790 Epoch 19. Accuracy: 0.647275 | Loss: 1.226866 Epoch 20. Accuracy: 0.652142 | Loss: 1.226686 ---------------------------- Test Accuracy: 0.651832 ---------------------------- Epoch 21. Accuracy: 0.646773 | Loss: 1.219855 Epoch 22. Accuracy: 0.663006 | Loss: 1.195204 Epoch 23. Accuracy: 0.670526 | Loss: 1.165726 Epoch 24. Accuracy: 0.676121 | Loss: 1.148621 Epoch 25. Accuracy: 0.687536 | Loss: 1.109896 ---------------------------- Test Accuracy: 0.690590 ---------------------------- Epoch 26. Accuracy: 0.690961 | Loss: 1.110705 Epoch 27. Accuracy: 0.674958 | Loss: 1.158181 Epoch 28. Accuracy: 0.696233 | Loss: 1.091395 Epoch 29. Accuracy: 0.699146 | Loss: 1.077446 Epoch 30. Accuracy: 0.710076 | Loss: 1.061827 ---------------------------- Test Accuracy: 0.716664 ---------------------------- Epoch 31. Accuracy: 0.714624 | Loss: 1.040824 Epoch 32. Accuracy: 0.709445 | Loss: 1.044048 Epoch 33. Accuracy: 0.719751 | Loss: 1.021937 Epoch 34. Accuracy: 0.722247 | Loss: 1.002287 Epoch 35. Accuracy: 0.725602 | Loss: 0.985023 ---------------------------- Test Accuracy: 0.717073 ---------------------------- Epoch 36. Accuracy: 0.721840 | Loss: 0.990956 Epoch 37. Accuracy: 0.726419 | Loss: 0.978770 Epoch 38. Accuracy: 0.730414 | Loss: 0.945205 Epoch 39. Accuracy: 0.733045 | Loss: 0.931660 Epoch 40. Accuracy: 0.743858 | Loss: 0.914782 ---------------------------- Test Accuracy: 0.724982 ---------------------------- Epoch 41. Accuracy: 0.751916 | Loss: 0.876523 Epoch 42. Accuracy: 0.737594 | Loss: 0.914662 Epoch 43. Accuracy: 0.735986 | Loss: 0.923208 Epoch 44. Accuracy: 0.752869 | Loss: 0.868417 Epoch 45. Accuracy: 0.753095 | Loss: 0.867506 ---------------------------- Test Accuracy: 0.740716 ---------------------------- Epoch 46. Accuracy: 0.755373 | Loss: 0.851085 Epoch 47. Accuracy: 0.755981 | Loss: 0.842593 Epoch 48. Accuracy: 0.768917 | Loss: 0.813079 Epoch 49. Accuracy: 0.761222 | Loss: 0.829013 ---------------------------- Test Accuracy: 0.754173 ----------------------------
We run the training loop again, this time without privacy and for the same number of iterations.
The non-private classifier obtains a test accuracy of around 0.75 with the same parameters and number of epochs. We are effectively trading off performance on the name classification task for a lower loss of privacy.