# Source code for opacus.per_sample_gradient_clip

```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
r"""
The process of adding differential privacy to a model involves bounds its sensitivity prior to
applying the Gaussian mechanism. This is achieved by clipping the per-sample gradients.
Normally for a parameterized layer if you have a tensor of parameters of size ``[m, n]``,
the size of the gradients will match it. This means that they get aggregated over the batch.
Here, we will keep them per-sample i.e., we will have a tensor of size ``[b_sz, m, n]``, where
the slice ``[i, :, :]`` corresponds to the per-example gradients for the i-th example in the batch.
Per-sample gradient clipping has to be achieved under the following constraints:
1. The norm of the grad_sample of the loss with respect to all model parameters has
to be clipped so that if they were to be put in a single vector together. If ``C`` is the clipping
threshold, this ensures the total norm will be at most ``C``.
Example:
>>> T = torch.cat([p.grad_sample.flatten() for p in model.parameters()])
``T`` will have shape ``[B, N_TOTAL_PARAMS]``. The total L2 norm of each row of ``T``
cannot be greater than ``C``.
2. This clipping should not backpropagate. This means that clipping in the layer ``i+1``
should not affect computing the gradient of layer ``i``. To make sure this is followed
we will first compute the grad_sample of all layers **without clipping**. In a second pass, we will
go back to the per-sample gradients, clip them, and accumulate them in ``.grad``
(thus replacing the "real" gradients).
Notes:
There is only a single .backward() call as the second pass just works on top of
the stored grad_sample.
"""
from typing import Callable, Iterator, Optional, Tuple
import torch
from torch import nn
from . import autograd_grad_sample
from .utils.clipping import NormClipper
from .utils.tensor_utils import calc_sample_norms
[docs]class PerSampleGradientClipper:
r"""
Class to define a per-sample gradient clipper for a module. Per-sample gradient clipping
bounds the sensitivity of the computation before applying the Gaussian mechanism.
"""
def __init__(
self,
module: nn.Module,
norm_clipper: NormClipper,
batch_first: bool = True,
loss_reduction: str = "mean",
):
r"""
Attaches to a module, and clips all grad_sample in the backward
pass. It then puts them in each parameter's ``.grad``.
Args:
module: Module to which backward hooks are added and for which per-sample
gradients are clipped
norm_clipper: A norm clipper object of class
:class:`~opacus.utils.clipping.NormClipper` which encapsulated different
clipping strategies (such as flat clipping for the entire model, or
per-layer clipping)
batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension represent the batch, for example of shape
[batch_size, ..., ...]. Set to True if batch appears in first
dimension else set to False (batch_first=False implies that the batch
is always in the second dimension).
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values ``sum`` or ``mean``
"""
self.module = module
autograd_grad_sample.add_hooks(
self.module, batch_first=batch_first, loss_reduction=loss_reduction
)
self.norm_clipper = norm_clipper
self.batch_first = batch_first
self.loss_reduction = loss_reduction
self._reset_aggregated_state()
self.hooks_attached = True
self.on_batch_clip_func = None
[docs] def set_on_batch_clip_func(self, on_batch_clip_func: Callable[..., None]) -> None:
r"""
Sets the function to be called after clipping to the input callable parameter
(for example clipping stats collection)
Args:
on_batch_clip_func: Function to be called after clipping
"""
self.on_batch_clip_func = on_batch_clip_func
def __del__(self):
r"""
Destructor to remove all attached hooks from the module when the clipper
object is deleted
"""
self.close()
[docs] def close(self) -> None:
r"""
Removes backward hooks from the module
"""
if self.hooks_attached: # do not close twice
autograd_grad_sample.remove_hooks(self.module)
self.hooks_attached = False
def __repr__(self):
return f"PerSampleGradientClipModuleHook on {self.module}"
def _reset_aggregated_state(self) -> None:
r"""
Resets the aggregated state of the clipper to be zero for
the batch size and zero tensors for the per-layer thresholds
"""
self._aggr_batch_size = 0
self._aggr_thresh = torch.zeros(1)
def _get_aggregated_state(self) -> Tuple[torch.Tensor, int]:
r"""
Returns an aggregated state of the clipper consisting of the
list of layer thresholds (for those providing gradient norms)
as well as the aggregate batch size
Returns:
Aggregated state (layer thresholds and batch size)
"""
return self._aggr_thresh, self._aggr_batch_size
[docs] def pre_step(self) -> Tuple[torch.Tensor, int]:
r"""
Prepares the ``.grad`` field of the parameters and provides statistics on the
maximum gradient norm which should be used to scale noise in the privacy engine
(:class:``~opacus.privacy_engine.PrivacyEngine``). This function is called before
the optimizer ``step()``.
Returns:
The maximum gradient norm per batch (repeated in batch dimension
as a tensor) and the batch size
"""
# check if we've already accumulated clipped gradients for this batch
if self._aggr_batch_size == 0:
raise ValueError("You need to call clip_and_accumulate first")
threshs, batch_size = self._get_aggregated_state()
# now that we know the full batch size, we can average the gradients
n = 0
for _, p in self._named_params():
p.grad = self._scale_summed_grad( # pyre-ignore[16]
p.summed_grad, batch_size # pyre-ignore[16]
)
n += 1
del p.summed_grad
# NOTE: For Renyi-based epsilon calculation, we will calculate a flat
# max norm equal to the norm of all clip values per layer.
max_norm = threshs.new_full((n,), threshs.norm(2)) # pyre-ignore[16]
self._reset_aggregated_state()
return max_norm, batch_size
[docs] def clip_and_accumulate(self) -> None:
r"""
Clips and sums up per-sample gradients into an accumulator. When this function is called
``N >= 1`` times on mini-batches of size ``B`` (could be smaller on final batch), a call to
:meth:`~opacus.per_sample_gradient_clip.PerSampleGradientClipper.pre_step`
will populate the ``.grad`` field with the average gradient over the entire batch of size
``(N-1)* B + b`` with ``b <= B``.
"""
# step 0 : calculate the layer norms
all_norms = calc_sample_norms(
named_params=self._named_grad_samples(),
flat=not self.norm_clipper.is_per_layer,
)
# step 1: calculate the clipping factors based on the noise
clipping_factor = self.norm_clipper.calc_clipping_factors(all_norms)
# step 2: update the aggreagated thresholds and batch size
self._aggr_thresh = torch.max(
self._aggr_thresh, self.norm_clipper.thresholds
) # retain the largest clipping thresholds accross the entire batch
batch_size = next(p.shape[0] for (_, p) in self._named_grad_samples())
# The size for every param.grad_sample is the batch size
self._aggr_batch_size += batch_size
for i, (clip_factor, named_param) in enumerate(
zip(clipping_factor, self._named_params())
):
# Do the clipping
name, p = named_param
summed_grad = self._weighted_sum(
clip_factor, p.grad_sample # pyre-ignore[16]
)
clipping_thresh = self.norm_clipper.thresholds[
i if len(self.norm_clipper.thresholds) > 1 else 0
]
per_sample_norm = all_norms[i if len(all_norms) > 1 else 0]
# accumulate the summed gradient for this mini-batch
if hasattr(p, "summed_grad"):
p.summed_grad += summed_grad # pyre-ignore[16]
else:
p.summed_grad = summed_grad
self._on_batch_clip(
name,
clip_factor,
clipping_thresh,
per_sample_norm,
p.grad_sample,
grad_before_clip=p.grad, # pyre-ignore[16]
grad_after_clip=self._scale_summed_grad(summed_grad, batch_size),
)
# remove the per-sample gradients
del p.grad_sample
self._on_batch_clip() # inform analysis of the whole module
[docs] def zero_grad(self):
"""
Deletes the added attributes, ``grad_sample`` and ``summed_grad``.
The two mentioned attributes are
automatically deleted when ``pre_step`` or
``clip_and_accumulate`` are properly called. This is a safety measure
to avoid further issues if regular use has not been followed.
"""
for _, param in self._named_params():
if hasattr(param, "grad_sample"):
del param.grad_sample
if hasattr(param, "summed_grad"):
del param.summed_grad
def _named_params(self) -> Iterator[Tuple[str, nn.Parameter]]:
r"""
Helper function to get parameter with their names that require grad
Returns:
Iterator over parameters with their names
"""
return ((n, p) for n, p in self.module.named_parameters() if p.requires_grad)
def _named_grad_samples(self) -> Iterator[Tuple[str, torch.Tensor]]:
r"""
Helper function to get names and per-sample gradients for parameters
that required grad.
Returns:
Iterator of parameter names and per-sample gradients
"""
return (
(n, p.grad_sample) # pyre-ignore[16]
for n, p in self.module.named_parameters()
if p.requires_grad
)
def _scale_summed_grad(
self, summed_grad: torch.Tensor, batch_size: int
) -> torch.Tensor:
r"""
Depending on the loss type, this function averages the summed gradient over batch
if attribute ``loss_reduction`` is set to "mean", else it returns the input summed
gradient tensor.
Args:
summed_grad: Summed gradient tensor which might be averaged depending on loss_reduction
batch_size: Batch size of gradient tensor
Returns:
Summed gradient tensor if loss_reduction is set to sum else averaged over batch.
Raises:
ValueError
If the loss reduction is not defined to be either 'sum' or 'mean'
"""
if self.loss_reduction == "mean":
return summed_grad / batch_size
elif self.loss_reduction == "sum":
return summed_grad.detach()
else:
raise ValueError(
f"Loss reduction must be either sum or mean. Got {self.loss_reduction}"
)
def _weighted_sum(
self, batch_weight: torch.Tensor, param: torch.Tensor
) -> torch.Tensor:
r"""
Helper function to calculate a weighted sum of tensor ``param``
along the batch dimension weighted by tensor ``batch_weight``.
Args:
batch_weight: Tensor of shape ``B`` (where ``B`` is the batch size) corresponding
to weights along the batch dimension. Each sample in the batch has its own weight.
param: Tensor to be weighted, is of shape ``[B,...]`` where ``B`` represents the
batch size.
Returns:
Weighted sum tensor for ``param`` along the batch dimension weighted by batch_weight.
"""
return torch.einsum("i,i...", batch_weight, param)
def _on_batch_clip(
self,
param_name: Optional[str] = None,
clipping_factor: Optional[torch.Tensor] = None,
clipping_threshold: Optional[torch.Tensor] = None,
per_sample_norm: Optional[torch.Tensor] = None,
per_sample_grad: Optional[torch.Tensor] = None,
grad_before_clip: Optional[torch.Tensor] = None,
grad_after_clip: Optional[torch.Tensor] = None,
):
r"""
Calls a pre-specified function (for example, for clipping stats computation) and
grants access to that function about current parameter state during the back propagation
of each batch.
Args:
param_name: Name of the parameter, the parameter could be accessed by
``self.module.state_dict()[param_name]``. A value of ``None``
indicates that all parameters have been processed.
clipping_factor: Scaling factor used in gradient clipping.
clipping_threshold: Threshold used in gradient clipping.
per_sample_norm: Per-sample gradient norms for clipping
per_sample_grad: Raw per sample gradients for parameter
grad_before_clip: Aggregated gradient before clipping (``= per_sample_grad.mean()``)
grad_after_clip: Aggregated gradients after clipping
"""
if self.on_batch_clip_func:
self.on_batch_clip_func(
param_name=param_name,
clipping_factor=clipping_factor,
clipping_threshold=clipping_threshold,
per_sample_norm=per_sample_norm,
per_sample_grad=per_sample_grad,
grad_before_clip=grad_before_clip,
grad_after_clip=grad_after_clip,
)
```