GradSampleModule¶
- class opacus.grad_sample.grad_sample_module.GradSampleModule(m, *, batch_first=True, loss_reduction='mean', strict=True, force_functorch=False)[source]¶
Hooks-based implementation of AbstractGradSampleModule
Computes per-sample gradients using custom-written methods for each layer. See README.md for more details
- Parameters:
m (
Module
) – nn.Module to be wrappedbatch_first – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be
[batch_size, ...]
, otherwise[K, batch_size, ...]
loss_reduction – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”
strict (
bool
) – If set toTrue
, the input module will be validated to check thatGradSampleModule
has grad sampler functions for all submodules of the input module (i.e. if it knows how to calculate per sample gradients) for all model parameters. If set toFalse
, per sample gradients will be computed on “best effort” basis - they will be available where possible and set to None otherwise. This is not recommended, because some unsupported modules (e.g. BatchNorm) affect other parameters and invalidate the concept of per sample gradients for the entire model.force_functorch – If set to
True
, will use functorch to compute all per sample gradients. Otherwise, functorch will be used only for layers without registered grad sampler methods.
- Raises:
NotImplementedError – If
strict
is set toTrue
and modulem
(or any of its submodules) doesn’t have a registered grad sampler function.
- add_hooks(*, loss_reduction='mean', batch_first=True, force_functorch=False)[source]¶
Adds hooks to model to save activations and backprop values. The hooks will 1. save activations into param.activations during forward pass 2. compute per-sample gradients in params.grad_sample during backward pass. Call
remove_hooks(model)
to disable this.- Parameters:
model – the model to which hooks are added
batch_first (
bool
) – Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be[batch_size, ...]
, otherwise[K, batch_size, ...]
loss_reduction (
str
) – Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values “sum” or “mean”force_functorch (
bool
) – If set toTrue
, will use functorch to compute all per sample gradients. Otherwise, functorch will be used only for layers without registered grad sampler methods.
- Return type:
- capture_backprops_hook(module, _forward_input, forward_output, loss_reduction, batch_first)[source]¶
Computes per sample gradients given the current backprops and activations stored by the associated forward hook. Computed per sample gradients are stored in
grad_sample
field in each parameter.For non-recurrent layers the process is straightforward: for each
loss.backward()
call this hook will be called exactly one. For recurrent layers, however, this is more complicated and the hook will be called multiple times, while still processing the same batch of data.For this reason we first accumulate the gradients from the same batch in
p._current_grad_sample
and then, when we detect the end of a full backward pass - we store accumulated result onp.grad_sample
.From there,
p.grad_sample
could be either a Tensor or a list of Tensors, if accumulated over multiple batches
- disable_hooks()[source]¶
Globally disable all hooks installed by this library. Why is this needed? As per https://github.com/pytorch/pytorch/issues/25723, there is a bug in Autograd that makes removing hooks do nothing if the graph was already constructed. For this reason, we have this method to at least turn them off.
- Return type:
- enable_hooks()[source]¶
The opposite of
disable_hooks()
. Hooks are always enabled unless you explicitly disable them so you don’t need to call this unless you want to re-enable them.- Return type:
- forward(*args, **kwargs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- classmethod is_supported(module)[source]¶
Checks if this individual model is supported (i.e. has a registered grad sampler function)
Notes
Note that this method does not check submodules
- Parameters:
module (
Module
) – nn.Module to be checked- Return type:
- Returns:
True
if grad sampler is found,False
otherwise
- rearrange_grad_samples(*, module, backprops, loss_reduction, batch_first)[source]¶
Rearrange activations and grad_samples based on loss reduction and batch dim
- Parameters:
- Return type:
- classmethod validate(module, *, strict=False)[source]¶
Check if per sample gradients can be fully computed for a given model
- Parameters:
module (
Module
) – nn.Module to be checkedraise_if_error – Behaviour in case of a negative check result. Will
False (return the list of exceptions if set to) –
otherwise (and throw) –
- Return type:
- Returns:
Empty list of validation is successful. List of validation errors if
raise_if_error=False
and unsupported modules are found- Raises:
NotImplementedError – If
raise_if_error=True
and unsupported modules are found