GradSampleModule

class opacus.grad_sample.grad_sample_module.GradSampleModule(m, *, batch_first=True, loss_reduction='mean', strict=True, force_functorch=False)[source]

Hooks-based implementation of AbstractGradSampleModule

Computes per-sample gradients using custom-written methods for each layer. See README.md for more details

Parameters:
  • m (Module) – nn.Module to be wrapped

  • batch_first – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be [batch_size, ...], otherwise [K, batch_size, ...]

  • loss_reduction – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”

  • strict (bool) – If set to True, the input module will be validated to check that GradSampleModule has grad sampler functions for all submodules of the input module (i.e. if it knows how to calculate per sample gradients) for all model parameters. If set to False, per sample gradients will be computed on “best effort” basis - they will be available where possible and set to None otherwise. This is not recommended, because some unsupported modules (e.g. BatchNorm) affect other parameters and invalidate the concept of per sample gradients for the entire model.

  • force_functorch – If set to True, will use functorch to compute all per sample gradients. Otherwise, functorch will be used only for layers without registered grad sampler methods.

Raises:

NotImplementedError – If strict is set to True and module m (or any of its submodules) doesn’t have a registered grad sampler function.

add_hooks(*, loss_reduction='mean', batch_first=True, force_functorch=False)[source]

Adds hooks to model to save activations and backprop values. The hooks will 1. save activations into param.activations during forward pass 2. compute per-sample gradients in params.grad_sample during backward pass. Call remove_hooks(model) to disable this.

Parameters:
  • model – the model to which hooks are added

  • batch_first (bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be [batch_size, ...], otherwise [K, batch_size, ...]

  • loss_reduction (str) – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”

  • force_functorch (bool) – If set to True, will use functorch to compute all per sample gradients. Otherwise, functorch will be used only for layers without registered grad sampler methods.

Return type:

None

allow_grad_accumulation()[source]

Unsets a flag to detect gradient accumulation (multiple forward/backward passes without an optimizer step or clearing out gradients).

When set, GradSampleModule will throw a ValueError on the second backward pass. :return:

capture_backprops_hook(module, _forward_input, forward_output, loss_reduction, batch_first)[source]

Computes per sample gradients given the current backprops and activations stored by the associated forward hook. Computed per sample gradients are stored in grad_sample field in each parameter.

For non-recurrent layers the process is straightforward: for each loss.backward() call this hook will be called exactly one. For recurrent layers, however, this is more complicated and the hook will be called multiple times, while still processing the same batch of data.

For this reason we first accumulate the gradients from the same batch in p._current_grad_sample and then, when we detect the end of a full backward pass - we store accumulated result on p.grad_sample.

From there, p.grad_sample could be either a Tensor or a list of Tensors, if accumulated over multiple batches

Parameters:
  • module (Module) – nn.Module,

  • _forward_input (Tensor) – torch.Tensor,

  • forward_output (Tensor) – torch.Tensor,

  • loss_reduction (str) – str,

  • batch_first (bool) – bool,

disable_hooks()[source]

Globally disable all hooks installed by this library. Why is this needed? As per https://github.com/pytorch/pytorch/issues/25723, there is a bug in Autograd that makes removing hooks do nothing if the graph was already constructed. For this reason, we have this method to at least turn them off.

Return type:

None

enable_hooks()[source]

The opposite of disable_hooks(). Hooks are always enabled unless you explicitly disable them so you don’t need to call this unless you want to re-enable them.

Return type:

None

forbid_grad_accumulation()[source]

Sets a flag to detect gradient accumulation (multiple forward/backward passes without an optimizer step or clearing out gradients).

When set, GradSampleModule will throw a ValueError on the second backward pass. :return:

forward(*args, **kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod is_supported(module)[source]

Checks if this individual model is supported (i.e. has a registered grad sampler function)

Notes

Note that this method does not check submodules

Parameters:

module (Module) – nn.Module to be checked

Return type:

bool

Returns:

True if grad sampler is found, False otherwise

rearrange_grad_samples(*, module, backprops, loss_reduction, batch_first)[source]

Rearrange activations and grad_samples based on loss reduction and batch dim

Parameters:
  • module (Module) – the module for which per-sample gradients are computed

  • backprops (Tensor) – the captured backprops

  • loss_reduction (str) – either “mean” or “sum” depending on whether backpropped loss was averaged or summed over batch

  • batch_first (bool) – True is batch dimension is first

Return type:

Tuple[Tensor, Tensor]

remove_hooks()[source]

Removes hooks added by add_hooks()

Return type:

None

classmethod validate(module, *, strict=False)[source]

Check if per sample gradients can be fully computed for a given model

Parameters:
  • module (Module) – nn.Module to be checked

  • raise_if_error – Behaviour in case of a negative check result. Will

  • False (return the list of exceptions if set to) –

  • otherwise (and throw) –

Return type:

List[NotImplementedError]

Returns:

Empty list of validation is successful. List of validation errors if raise_if_error=False and unsupported modules are found

Raises:

NotImplementedError – If raise_if_error=True and unsupported modules are found

opacus.grad_sample.grad_sample_module.create_or_accumulate_grad_sample(*, param, grad_sample, max_batch_len)[source]

Creates a _current_grad_sample attribute in the given parameter, or adds to it if the _current_grad_sample attribute already exists.

Parameters:
  • param (Tensor) – Parameter to which grad_sample will be added

  • grad_sample (Tensor) – Per-sample gradients tensor. Must be of the same shape as param with extra batch dimension

  • layer – nn.Module parameter belongs to

Return type:

None