Privacy Engine¶
-
class
opacus.privacy_engine.
PrivacyEngine
(module, batch_size, sample_size, alphas, noise_multiplier, max_grad_norm, secure_rng=False, batch_first=True, target_delta=1e-06, loss_reduction='mean', **misc_settings)[source]¶ The main component of Opacus is the
PrivacyEngine
.To train a model with differential privacy, all you need to do is to define a
PrivacyEngine
and later attach it to your optimizer before running.Example
This example shows how to define a
PrivacyEngine
and to attach it to your optimizer.>>> import torch >>> model = torch.nn.Linear(16, 32) # An example model >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05) >>> privacy_engine = PrivacyEngine(model, batch_size, sample_size, alphas=range(2,32), noise_multiplier=1.3, max_grad_norm=1.0) >>> privacy_engine.attach(optimizer) # That's it! Now it's business as usual.
- Parameters
module (
Module
) – The Pytorch module to which we are attaching the privacy enginebatch_size (
int
) – Training batch size. Used in the privacy accountant.sample_size (
int
) – The size of the sample (dataset). Used in the privacy accountant.noise_multiplier (
float
) – The ratio of the standard deviation of the Gaussian noise to the L2-sensitivity of the function to which the noise is addedmax_grad_norm (
Union
[float
,List
[float
]]) – The maximum norm of the per-sample gradients. Any gradient with norm higher than this will be clipped to this value.secure_rng (
bool
) – If on, it will usetorchcsprng
for secure random number generation. Comes with a significant performance cost, therefore it’s recommended that you turn it off when just experimenting.batch_first (
bool
) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor will be[batch_size, ..., ...]
.target_delta (
float
) – The target deltaloss_reduction (
str
) – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”**misc_settings – Other arguments to the init
-
attach
(optimizer)[source]¶ Attaches the privacy engine to the optimizer.
Attaches to the
PrivacyEngine
an optimizer object,and injects itself into the optimizer’s step. To do that it,Validates that the model does not have unsupported layers.
Adds a pointer to this object (the
PrivacyEngine
) inside the optimizer.Moves optimizer’s original
step()
function tooriginal_step()
.
4. Monkeypatches the optimizer’s
step()
function to callstep()
on the query engine automatically whenever it would callstep()
for itself.- Parameters
optimizer (
Optimizer
) – The optimizer to which the privacy engine will attach
-
detach
()[source]¶ Detaches the privacy engine from optimizer.
To detach the
PrivacyEngine
from optimizer, this method returns the model and the optimizer to their original states (i.e. all added attributes/methods will be removed).
-
get_privacy_spent
(target_delta=None)[source]¶ Computes the (epsilon, delta) privacy budget spent so far.
This method converts from an (alpha, epsilon)-DP guarantee for all alphas that the
PrivacyEngine
was initialized with. It returns the optimal alpha together with the best epsilon.
-
step
()[source]¶ Takes a step for the privacy engine.
Notes
You should not call this method directly. Rather, by attaching your
PrivacyEngine
to the optimizer, thePrivacyEngine
would have the optimizer call this method for you.- Raises
ValueError – If the last batch of training epoch is greater than others. This ensures the clipper consumed the right amount of gradients. In the last batch of a training epoch, we might get a batch that is smaller than others but we should never get a batch that is too large
-
to
(device)[source]¶ Moves the privacy engine to the target device.
- Parameters
device (
Union
[str
,device
]) – The device on which Pytorch Tensors are allocated. See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
Example
This example shows the usage of this method, on how to move the model after instantiating the
PrivacyEngine
.>>> model = torch.nn.Linear(16, 32) # An example model. Default device is CPU >>> privacy_engine = PrivacyEngine(model, batch_size, sample_size, alphas=range(5,64), noise_multiplier=0.8, max_grad_norm=0.5) >>> device = "cuda:3" # GPU >>> model.to(device) # If we move the model to GPU, we should call the to() method of the privacy engine (next line) >>> privacy_engine.to(device)
- Returns
The current
PrivacyEngine
-
virtual_step
()[source]¶ Takes a virtual step.
Virtual batches enable training with arbitrary large batch sizes, while keeping the memory consumption constant. This is beneficial, when training models with larger batch sizes than standard models.
Example
Imagine you want to train a model with batch size of 2048, but you can only fit batch size of 128 in your GPU. Then, you can do the following:
>>> for i, (X, y) in enumerate(dataloader): >>> logits = model(X) >>> loss = criterion(logits, y) >>> loss.backward() >>> if i % 16 == 15: >>> optimizer.step() # this will call privacy engine's step() >>> optimizer.zero_grad() >>> else: >>> optimizer.virtual_step() # this will call privacy engine's virtual_step()
The rough idea of virtual step is as follows:
1. Calling
loss.backward()
repeatedly stores the per-sample gradients for all mini-batches. If we callloss.backward()
N
times on mini-batches of sizeB
, then each weight’s.grad_sample
field will containNxB
gradients. Then, when callingstep()
, the privacy engine clips allNxB
gradients and computes the average gradient for an effective batch of sizeNxB
. A call tooptimizer.zero_grad()
erases the per-sample gradients.2. By calling
virtual_step()
afterloss.backward()
,theB
per-sample gradients for this mini-batch are clipped and summed up into a gradient accumulator. The per-sample gradients can then be discarded. AfterN
iterations (alternating calls toloss.backward()
andvirtual_step()
), a call tostep()
will compute the average gradient for an effective batch of sizeNxB
.The advantage here is that this is memory-efficient: it discards the per-sample gradients after every mini-batch. We can thus handle batches of arbitrary size.
-
zero_grad
()[source]¶ Resets clippers status.
Clipper keeps internal gradient per sample in the batch in each
forward
call of the module, they need to be cleaned before the next round.If these variables are not cleaned the per sample gradients keep being concatenated accross batches. If accumulating gradients is intented behavious, e.g. simulating a large batch, prefer using
virtual_step()
function.