Opacus
  • Introduction
  • FAQ
  • Tutorials
  • API Reference
  • GitHub

›

Tutorials

  • Overview

Using Opacus

  • Building text classifier with Fast Gradient Clipping DP-SGD
  • Building image classifier with Differential Privacy
  • Training a differentially private LSTM model for name classification
  • Deep dive into advanced features of Opacus
  • Guide to Module Validator and Fixer
  • Guide to grad samplers
  • Training on multiple GPUs with DistributedDataParallel

Guide to grad samplers¶

DP-SGD guarantees privacy of every sample used in the training. In order to realize this, we have to bound the sensitivity of every sample, and in order to do that, we have to clip the gradient of every sample. Unfortunately, pytorch doesn't maintain the gradients of individual samples in a batch and only exposes the aggregated gradients of all the samples in a batch via the .grad attribute.

The easiest way to get what we want is to train with batch size of 1 as follows:

In [ ]:
optimizer = torch.optim.SGD(lr=0.01)
for x, y i DataLoader(train_dataset, batch_size=128):
  # Run samples one-by-one to get per-sample gradients
  for x_i, y_i in zip(x, y):
    y_hat_i = model(x_i)
    loss = criterion(y_hat_i, y_i)
    loss.backward()
  
    # Clip each parameter's per-sample gradient
    for p in model.parameters():
      per_sample_grad = p.grad.detach().clone()
      torch.nn.utils.clip_grad_norm(per_sample_grad, max_norm=1.0)
      p.accumulated_grads.append(per_sample_grad)
    model.zero_grad(). # p.grad is accumulative, so we need to manually reset
  
  # Aggregate clipped gradients of all samples in a batch, and add DP noise
  for p in model.parameters():
    p.grad = accumulate_and_noise(p.accumulated_grads, dp_paramters)
  
  optimizer.step()
  optimizer.zero_grad()

This, however, would be a criminal waste of time and resources, and we will be leaving all the vectorized optimizations on the sidelines.

GradSampleModule is an nn.Module replacement offered by Opacus to solve the above problem. In addition to the .grad attribute, the parameters of this module will also have a .grad_sample attribute.

GradSampleModule internals¶

For most modules, Opacus provides a function (aka grad_sampler) that essentially computes the per-sample-gradients of a batch by -- more or less -- doing the backpropagation "by hand".

GradSampleModule is a wrapper around the existing nn.Modules. It attaches the above function to the modules it wraps using backward hooks. It also provides other auxiliary methods such as validation, utilities to add/remove/set/reset grad_sample, utilities to attach/remove hooks, etc.

TL;DR: grad_samplers contain the logic to compute the gradients given the activations and backpropagated gradients, and the GradSampleModule takes care of everything else by attaching the grad_samplers to the right modules and exposes a simple/minimal interface to the user.

Let's see an example. Say you want to get a GradSampleModule version of nn.Linear. This is what you would have to do:

In [ ]:
import torch.nn as nn
from opacus.grad_sample import GradSampleModule

lin_mod = nn.Linear(42,2)
print(f"Before wrapping: {lin_mod}")

gs_lin_mod = GradSampleModule(lin_mod)
print(f"After wrapping : {gs_lin_mod}")
Before wrapping: Linear(in_features=42, out_features=2, bias=True)
After wrapping : GradSample(Linear(in_features=42, out_features=2, bias=True))

That's it! GradSampleModule wraps your linear module with all the goodies and you can use this module as a drop-in replacement.

grad_sampler internals¶

Now, what does the grad_sampler for the above nn.Linear layer look like? It looks as follows:

In [ ]:
def compute_linear_grad_sample(
    layer: nn.Linear, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    """
    Computes per sample gradients for ``nn.Linear`` layer
    Args:
        layer: Layer
        activations: Activations
        backprops: Backpropagations
    """
    gs = torch.einsum("n...i,n...j->nij", backprops, activations)
    ret = {layer.weight: gs}
    if layer.bias is not None:
        ret[layer.bias] = torch.einsum("n...k->nk", backprops)

    return ret

The above grad_sampler takes in the activations and backpropagated gradients, computes the per-sample-gradients with respect to the module parameters, and maps them to the corresponding parameters. This blog discusses the implementation and the math behind it in detail.

Registering a grad_sampler¶

But how do you tell Opacus this is the grad_sampler? That's simple, you simply decorate it with register_grad_sampler

In [ ]:
from opacus.grad_sample import register_grad_sampler


@register_grad_sampler(nn.Linear)
def compute_linear_grad_sample(
    layer: nn.Linear, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    """
    Computes per sample gradients for ``nn.Linear`` layer
    Args:
        layer: Layer
        activations: Activations
        backprops: Backpropagations
    """
    gs = torch.einsum("n...i,n...j->nij", backprops, activations)
    ret = {layer.weight: gs}
    if layer.bias is not None:
        ret[layer.bias] = torch.einsum("n...k->nk", backprops)

    return ret

Once again, that's it! No really, check out the code at is literally just this.

The register_grad_sampler defined in grad_sample/utils registers the function as a grad_sampler for nn.Linear (which is passed as an arg to the decorator). The GradSampleModule maintains a register of all the grad_samplers and their corresponding modules.

If you want to register a custom grad_sampler, all you have to do is decorate your function as shown above. Note that the order of registration matters; if you register more than one grad_sampler for a certain module, the last one wins.

Supported modules¶

Opacus offers grad_samplers for most common modules; you can see the full list here. As you can see, this list is not at all exhaustive; we wholeheartedly welcome your contributions.

By design, the GradSampleModule just does that - computes grad samples. While it is built for use with Opacus, it certainly isn't restricted to DP use cases and can be used for any task that needs per-sample-gradients.

If you have any questions or comments, please don't hesitate to post them on our forum.

Per-sample-gradients correctness utility¶

Here you can find a simple utility function check_per_sample_gradients_are_correct that checks if the gradient sampler works correctly with a particular module.

In [ ]:
x_shape = [N, Z, W]
x = torch.randn(x_shape)
model = nn.Linear(W, W + 2)
assert check_per_sample_gradients_are_correct(
        x,
        model
    ) # This will fail only if the opacus per sample gradients do not match the micro-batch gradients.
Download Tutorial Jupyter Notebook
Opacus
Docs
IntroductionFAQTutorialsAPI Reference
Github
opacus
Legal
PrivacyTerms
Meta Open Source
Copyright © 2025 Meta Platforms, Inc.