Extends nn.Module so that its parameter tensors have an extra field called .grad_sample.

Parameters:
• m (Module) – nn.Module to be wrapped

• batch_first – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be [batch_size, ...], otherwise [K, batch_size, ...]

• loss_reduction – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”

• strict (bool) – If set to True, the input module will be validated to check that GradSampleModule has grad sampler functions for all submodules of the input module (i.e. if it knows how to calculate per sample gradients) for all model parameters. If set to False, per sample gradients will be computed on “best effort” basis - they will be available where possible and set to None otherwise. This is not recommended, because some unsupported modules (e.g. BatchNorm) affect other parameters and invalidate the concept of per sample gradients for the entire model.

Raises:

NotImplementedError – If strict is set to True and module m (or any of its submodules) doesn’t have a registered grad sampler function.

Adds hooks to model to save activations and backprop values. The hooks will 1. save activations into param.activations during forward pass 2. compute per-sample gradients in params.grad_sample during backward pass. Call remove_hooks(model) to disable this.

Parameters:
• model – the model to which hooks are added

• batch_first (bool) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be [batch_size, ...], otherwise [K, batch_size, ...]

• loss_reduction (str) – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”

Return type:

None

capture_backprops_hook(module, _forward_input, forward_output, loss_reduction, batch_first)[source]

Computes per sample gradients given the current backprops and activations stored by the associated forward hook. Computed per sample gradients are stored in grad_sample field in each parameter.

For non-recurrent layers the process is straightforward: for each loss.backward() call this hook will be called exactly one. For recurrent layers, however, this is more complicated and the hook will be called multiple times, while still processing the same batch of data.

For this reason we first accumulate the gradients from the same batch in p._current_grad_sample and then, when we detect the end of a full backward pass - we store accumulated result on p.grad_sample.

From there, p.grad_sample could be either a Tensor or a list of Tensors, if accumulated over multiple batches

Parameters:

Deleted .grad_sample attribute from all model parameters

disable_hooks()[source]

Globally disable all hooks installed by this library. Why is this needed? As per https://github.com/pytorch/pytorch/issues/25723, there is a bug in Autograd that makes removing hooks do nothing if the graph was already constructed. For this reason, we have this method to at least turn them off.

Return type:

None

enable_hooks()[source]

The opposite of disable_hooks(). Hooks are always enabled unless you explicitly disable them so you don’t need to call this unless you want to re-enable them.

Return type:

None

forward(*args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod is_supported(module)[source]

Checks if this individual model is supported (i.e. has a registered grad sampler function)

Notes

Note that this method does not check submodules

Parameters:

module (Module) – nn.Module to be checked

Return type:

bool

Returns:

True if grad sampler is found, False otherwise

Rearrange activations and grad_samples based on loss reduction and batch dim

Parameters:
• module (Module) – the module for which per-sample gradients are computed

• backprops (Tensor) – the captured backprops

• loss_reduction (str) – either “mean” or “sum” depending on whether backpropped loss was averaged or summed over batch

• batch_first (bool) – True is batch dimension is first

Return type:
remove_hooks()[source]

Removes hooks added by add_hooks()

Return type:

None

Sets .grad_sample to None

to_standard_module()[source]

Returns the standard nn.Module wrapped by this, eliminating all traces of grad samples and hooks

Return type:

Module

Returns:

The wrapped module

classmethod validate(module, *, strict=False)[source]

Check if per sample gradients can be fully computed for a given model

Parameters:
• module (Module) – nn.Module to be checked

• raise_if_error – Behaviour in case of a negative check result. Will

• False (return the list of exceptions if set to) –

• otherwise (and throw) –

Return type:
Returns:

Empty list of validation is successful. List of validation errors if raise_if_error=False and unsupported modules are found

Raises:

NotImplementedError – If raise_if_error=True and unsupported modules are found

Clears p.grad and p.grad_sample for all of it’s parameters

Notes

set_to_none argument only affects p.grad. p.grad_sample is never zeroed out and always set to None. Normal grads can do this, because their shape is always the same. Grad samples do not behave like this, as we accumulate gradients from different batches in a list

Parameters:
• set_to_none (bool) – instead of setting to zero, set the grads to None. (only

• None) (affects regular gradients. Per sample gradients are always set to) –

Creates a _current_grad_sample attribute in the given parameter, or adds to it if the _current_grad_sample attribute already exists.
• param (Tensor) – Parameter to which grad_sample will be added
• grad_sample (Tensor) – Per-sample gradients tensor. Must be of the same shape as param with extra batch dimension
None