Per Sample Gradients

Autograd Grad Sample

Based on

This module provides functions to capture per-sample gradients by using hooks.


The register_backward_hook() function has a known issue being tracked at However, it is the only known way of implementing this as of now (your suggestions and contributions are very welcome). The behaviour has been verified to be correct for the layers currently supported by opacus.

opacus.autograd_grad_sample.add_hooks(model, loss_reduction='mean', batch_first=True)[source]

Adds hooks to model to save activations and backprop values. The hooks will

  1. save activations into param.activations during forward pass.

  2. compute per-sample gradients and save them in param.grad_sample during backward pass.

  • model (Module) – Model to which hooks are added.

  • loss_reduction (str) – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values sum or mean.

  • batch_first (bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension represent the batch, for example of shape [batch_size, ..., ...]. Set to True if batch appears in first dimension else set to False (batch_first=False implies that the batch is always in the second dimension).


Globally disables all hooks installed by this library.


Globally enables all hooks installed by this library.


Checks if the layer is supported by this library.


layer (Module) – Layer for which we need to determine if the support for capturing per-sample gradients is available.

Return type



Whether the layer is supported by this library.


Removes hooks added by add_hooks().


model (Module) – Model from which hooks are to be removed.

Layers Grad Samplers

This module is a collection of grad samplers - methods to calculate per sample gradients for a layer given two tensors: activations (module inputs) and backpropagations (gradient values propagated from downstream layers).


Mapping from layer name to corresponding grad sampler


Dict[str, Callable]