# Source code for opacus.per_sample_gradient_clip

#!/usr/bin/env python3
r"""
The process of adding differential privacy to a model involves bounds its sensitivity prior to
applying the Gaussian mechanism. This is achieved by clipping the per-sample gradients.
Normally for a parameterized layer if you have a tensor of parameters of size [m, n],
the size of the gradients will match it. This means that they get aggregated over the batch.
Here, we will keep them per-sample i.e., we will have a tensor of size [b_sz, m, n], where
the slice [i, :, :] corresponds to the per-example gradients for the i-th example in the batch.

Per-sample gradient clipping has to be achieved under the following constraints:

1. The norm of the grad_sample of the loss with respect to all model parameters has
to be clipped so that if they were to be put in a single vector together. If C is the clipping
threshold, this ensures the total norm will be at most C.

Example:
>>> T = torch.cat([p.grad_sample.flatten() for p in model.parameters()])

T will have shape [B, N_TOTAL_PARAMS]. The total L2 norm of each row of T
cannot be greater than C.

2. This clipping should not backpropagate. This means that clipping in the layer i+1
should not affect computing the gradient of layer i. To make sure this is followed
we will first compute the grad_sample of all layers **without clipping**. In a second pass, we will
go back to the per-sample gradients, clip them, and accumulate them in .grad
(thus replacing the "real" gradients).

Notes:
There is only a single .backward() call as the second pass just works on top of
"""

from typing import Callable, Iterator, Optional, Tuple

import torch
from torch import nn

from .utils.clipping import NormClipper
from .utils.tensor_utils import calc_sample_norms

r"""
Class to define a per-sample gradient clipper for a module. Per-sample gradient clipping
bounds the sensitivity of the computation before applying the Gaussian mechanism.
"""

def __init__(
self,
module: nn.Module,
norm_clipper: NormClipper,
batch_first: bool = True,
loss_reduction: str = "mean",
):
r"""
Attaches to a module, and clips all grad_sample in the backward
pass. It then puts them in each parameter's .grad.

Args:
module: Module to which backward hooks are added and for which per-sample

norm_clipper: A norm clipper object of class
:class:~opacus.utils.clipping.NormClipper which encapsulated different
clipping strategies (such as flat clipping for the entire model, or
per-layer clipping)

batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension represent the batch, for example of shape
[batch_size, ..., ...]. Set to True if batch appears in first
dimension else set to False (batch_first=False implies that the batch
is always in the second dimension).

loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values sum or mean
"""
self.module = module
self.module, batch_first=batch_first, loss_reduction=loss_reduction
)
self.norm_clipper = norm_clipper
self.batch_first = batch_first
self.loss_reduction = loss_reduction

self._reset_aggregated_state()

self.hooks_attached = True
self.on_batch_clip_func = None

[docs]    def set_on_batch_clip_func(self, on_batch_clip_func: Callable[..., None]) -> None:
r"""
Sets the function to be called after clipping to the input callable parameter
(for example clipping stats collection)

Args:
on_batch_clip_func: Function to be called after clipping
"""
self.on_batch_clip_func = on_batch_clip_func

def __del__(self):
r"""
Destructor to remove all attached hooks from the module when the clipper
object is deleted
"""
self.close()

[docs]    def close(self) -> None:
r"""
Removes backward hooks from the module
"""
if self.hooks_attached:  # do not close twice
self.hooks_attached = False

def __repr__(self):
return f"PerSampleGradientClipModuleHook on {self.module}"

def _reset_aggregated_state(self) -> None:
r"""
Resets the aggregated state of the clipper to be zero for
the batch size and zero tensors for the per-layer thresholds
"""
self._aggr_batch_size = 0
self._aggr_thresh = torch.zeros(1)

def _get_aggregated_state(self) -> Tuple[torch.Tensor, int]:
r"""
Returns an aggregated state of the clipper consisting of the
list of layer thresholds (for those providing gradient norms)
as well as the aggregate batch size

Returns:
Aggregated state (layer thresholds and batch size)
"""
return self._aggr_thresh, self._aggr_batch_size

[docs]    def pre_step(self) -> Tuple[torch.Tensor, int]:
r"""
Prepares the .grad field of the parameters and provides statistics on the
maximum gradient norm which should be used to scale noise in the privacy engine
(:class:~opacus.privacy_engine.PrivacyEngine). This function is called before
the optimizer step().

Returns:
The maximum gradient norm per batch (repeated in batch dimension
as a tensor) and the batch size
"""

# check if we've already accumulated clipped gradients for this batch
if self._aggr_batch_size == 0:
raise ValueError("You need to call clip_and_accumulate first")

threshs, batch_size = self._get_aggregated_state()
# now that we know the full batch size, we can average the gradients
n = 0
for _, p in self._named_params():
p.summed_grad, batch_size  # pyre-ignore[16]
)
n += 1

# NOTE: For Renyi-based epsilon calculation, we will calculate a flat
# max norm equal to the norm of all clip values per layer.
max_norm = threshs.new_full((n,), threshs.norm(2))  # pyre-ignore[16]
self._reset_aggregated_state()
return max_norm, batch_size

[docs]    def clip_and_accumulate(self) -> None:
r"""
Clips and sums up per-sample gradients into an accumulator. When this function is called
N >= 1 times on mini-batches of size B (could be smaller on final batch), a call to
:meth:~opacus.per_sample_gradient_clip.PerSampleGradientClipper.pre_step
will populate the .grad field with the average gradient over the entire batch of size
(N-1)* B + b with b <= B.
"""
# step 0 : calculate the layer norms
all_norms = calc_sample_norms(
flat=not self.norm_clipper.is_per_layer,
)

# step 1: calculate the clipping factors based on the noise
clipping_factor = self.norm_clipper.calc_clipping_factors(all_norms)

# step 2: update the aggreagated thresholds and batch size
self._aggr_thresh = torch.max(
self._aggr_thresh, self.norm_clipper.thresholds
)  # retain the largest clipping thresholds accross the entire batch
batch_size = next(p.shape[0] for (_, p) in self._named_grad_samples())
# The size for every param.grad_sample is the batch size
self._aggr_batch_size += batch_size

for i, (clip_factor, named_param) in enumerate(
zip(clipping_factor, self._named_params())
):
# Do the clipping
name, p = named_param
clip_factor, p.grad_sample  # pyre-ignore[16]
)
clipping_thresh = self.norm_clipper.thresholds[
i if len(self.norm_clipper.thresholds) > 1 else 0
]
per_sample_norm = all_norms[i if len(all_norms) > 1 else 0]
# accumulate the summed gradient for this mini-batch
else:

self._on_batch_clip(
name,
clip_factor,
clipping_thresh,
per_sample_norm,
)

# remove the per-sample gradients
self._on_batch_clip()  # inform analysis of the whole module

"""
Deletes the added attributes, grad_sample and summed_grad.

The two mentioned attributes are
automatically deleted when pre_step or
clip_and_accumulate are properly called. This is a safety measure
to avoid further issues if regular use has not been followed.
"""
for _, param in self._named_params():

def _named_params(self) -> Iterator[Tuple[str, nn.Parameter]]:
r"""
Helper function to get parameter with their names that require grad

Returns:
Iterator over parameters with their names
"""
return ((n, p) for n, p in self.module.named_parameters() if p.requires_grad)

def _named_grad_samples(self) -> Iterator[Tuple[str, torch.Tensor]]:
r"""
Helper function to get names and per-sample gradients for parameters

Returns:
Iterator of parameter names and per-sample gradients
"""
return (
(n, p.grad_sample)  # pyre-ignore[16]
for n, p in self.module.named_parameters()
)

self, summed_grad: torch.Tensor, batch_size: int
) -> torch.Tensor:
r"""
Depending on the loss type, this function averages the summed gradient over batch
if attribute loss_reduction is set to "mean", else it returns the input summed

Args:
summed_grad: Summed gradient tensor which might be averaged depending on loss_reduction

batch_size: Batch size of gradient tensor

Returns:
Summed gradient tensor if loss_reduction is set to sum else averaged over batch.

Raises:
ValueError
If the loss reduction is not defined to be either 'sum' or 'mean'
"""
if self.loss_reduction == "mean":
return summed_grad / batch_size
elif self.loss_reduction == "sum":
else:
raise ValueError(
f"Loss reduction must be either sum or mean. Got {self.loss_reduction}"
)

def _weighted_sum(
self, batch_weight: torch.Tensor, param: torch.Tensor
) -> torch.Tensor:
r"""
Helper function to calculate a weighted sum of tensor param
along the batch dimension weighted by tensor batch_weight.

Args:
batch_weight: Tensor of shape B (where B is the batch size) corresponding
to weights along the batch dimension. Each sample in the batch has its own weight.
param: Tensor to be weighted, is of shape [B,...] where B represents the
batch size.

Returns:
Weighted sum tensor for param along the batch dimension weighted by batch_weight.
"""

def _on_batch_clip(
self,
param_name: Optional[str] = None,
clipping_factor: Optional[torch.Tensor] = None,
clipping_threshold: Optional[torch.Tensor] = None,
per_sample_norm: Optional[torch.Tensor] = None,
per_sample_grad: Optional[torch.Tensor] = None,
grad_before_clip: Optional[torch.Tensor] = None,
grad_after_clip: Optional[torch.Tensor] = None,
):
r"""
Calls a pre-specified function (for example, for clipping stats computation) and
grants access to that function about current parameter state during the back propagation
of each batch.

Args:
param_name: Name of the parameter, the parameter could be accessed by
self.module.state_dict()[param_name]. A value of None
indicates that all parameters have been processed.
clipping_factor: Scaling factor used in gradient clipping.
clipping_threshold: Threshold used in gradient clipping.
per_sample_norm: Per-sample gradient norms for clipping
per_sample_grad: Raw per sample gradients for parameter
grad_before_clip: Aggregated gradient before clipping (= per_sample_grad.mean())