Source code for opacus.autograd_grad_sample

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
r"""
*Based on* https://github.com/cybertronai/autograd-hacks

This module provides functions to capture per-sample gradients by using hooks.

Notes:
    The ``register_backward_hook()`` function has a known issue being tracked at
    https://github.com/pytorch/pytorch/issues/598. However, it is the only known
    way of implementing this as of now (your suggestions and contributions are
    very welcome). The behaviour has been verified to be correct for the layers
    currently supported by opacus.
"""

from functools import partial
from typing import Tuple

import torch
import torch.nn as nn
from opacus.layers.dp_lstm import LSTMLinear

from .supported_layers_grad_samplers import _supported_layers_grad_samplers
from .utils.module_inspection import get_layer_type, requires_grad


# work-around for https://github.com/pytorch/pytorch/issues/25723
_hooks_disabled: bool = False


[docs]def add_hooks(model: nn.Module, loss_reduction: str = "mean", batch_first: bool = True): r""" Adds hooks to model to save activations and backprop values. The hooks will 1. save activations into ``param.activations`` during forward pass. 2. compute per-sample gradients and save them in ``param.grad_sample`` during backward pass. Args: model: Model to which hooks are added. loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values ``sum`` or ``mean``. batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension represent the batch, for example of shape ``[batch_size, ..., ...]``. Set to True if batch appears in first dimension else set to False (``batch_first=False`` implies that the batch is always in the second dimension). """ if hasattr(model, "autograd_grad_sample_hooks"): raise ValueError("Trying to add hooks twice to the same model") global _hooks_disabled _hooks_disabled = False handles = [] for layer in model.modules(): if get_layer_type(layer) in _supported_layers_grad_samplers.keys(): # pyre-fixme[16]: `Module` has no attribute `register_forward_hook`. handles.append(layer.register_forward_hook(_capture_activations)) handles.append( # pyre-fixme[16]: `Module` has no attribute `register_backward_hook`. layer.register_backward_hook( partial( _capture_backprops, loss_reduction=loss_reduction, batch_first=batch_first, ) ) ) model.__dict__.setdefault("autograd_grad_sample_hooks", []).extend(handles)
[docs]def remove_hooks(model: nn.Module): r""" Removes hooks added by ``add_hooks()``. Args: model: Model from which hooks are to be removed. """ if not hasattr(model, "autograd_grad_sample_hooks"): raise ValueError("Asked to remove hooks, but no hooks found") else: # pyre-fixme[16]: `Module` has no attribute `autograd_grad_sample_hooks`. for handle in model.autograd_grad_sample_hooks: handle.remove() del model.autograd_grad_sample_hooks
[docs]def disable_hooks(): r""" Globally disables all hooks installed by this library. """ global _hooks_disabled _hooks_disabled = True
[docs]def enable_hooks(): r""" Globally enables all hooks installed by this library. """ global _hooks_disabled _hooks_disabled = False
[docs]def is_supported(layer: nn.Module) -> bool: r"""Checks if the ``layer`` is supported by this library. Args: layer: Layer for which we need to determine if the support for capturing per-sample gradients is available. Returns: Whether the ``layer`` is supported by this library. """ return get_layer_type(layer) in _supported_layers_grad_samplers.keys()
def _capture_activations( layer: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor] ): r"""Forward hook handler captures and saves activations flowing into the ``layer`` in ``layer.activations`` during forward pass. Args: layer: Layer to capture the activations in. inputs: Inputs to the ``layer``. outputs: Outputs of the ``layer``. """ layer_type = get_layer_type(layer) if ( not requires_grad(layer) or layer_type not in _supported_layers_grad_samplers.keys() or not layer.training ): return if _hooks_disabled: return if get_layer_type(layer) not in _supported_layers_grad_samplers.keys(): raise ValueError("Hook installed on unsupported layer") if not hasattr(layer, "activations"): layer.activations = [] layer.activations.append(inputs[0].detach()) def _capture_backprops( layer: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor], loss_reduction: str, batch_first: bool, ): r"""Backward hook handler captures backpropagated gradients during backward pass, and computes and stores per-sample gradients. Args: layer: Layer to capture gradients in. inputs: Gradients of the tensor on which ``.backward()`` is called with respect to the inputs to the ``layer``. outputs: Gradients of the tensor on which ``.backward()`` is called with respect to the outputs of the ``layer``. loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values ``sum`` or ``mean``. batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension represent the batch, for example of shape ``[batch_size, ..., ...]``. Set to True if batch appears in first dimension else set to False (``batch_first=False`` implies that the batch is always in the second dimension). """ if _hooks_disabled: return backprops = outputs[0].detach() _compute_grad_sample(layer, backprops, loss_reduction, batch_first) def _compute_grad_sample( layer: nn.Module, backprops: torch.Tensor, loss_reduction: str, batch_first: bool ): r"""Computes per-sample gradients with respect to the parameters of the ``layer`` (if supported), and saves them in ``param.grad_sample``. Args: layer: Layer to capture per-sample gradients in. backprops: Back propagated gradients captured by the backward hook. loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values ``sum`` or ``mean``. batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension represent the batch, for example of shape ``[batch_size, ..., ...]``. Set to True if batch appears in first dimension else set to False (``batch_first=False`` implies that the batch is always in the second dimension). """ layer_type = get_layer_type(layer) if ( not requires_grad(layer) or layer_type not in _supported_layers_grad_samplers.keys() or not layer.training ): return if not hasattr(layer, "activations"): raise ValueError( f"No activations detected for {type(layer)}," " run forward after add_hooks(model)" ) # Outside of the LSTM there is "batch_first" but not for the Linear inside the LSTM batch_dim = 0 if batch_first or type(layer) is LSTMLinear else 1 if isinstance(layer.activations, list): A = layer.activations.pop() else: A = layer.activations n = A.shape[batch_dim] if loss_reduction == "mean": B = backprops * n elif loss_reduction == "sum": B = backprops else: raise ValueError( f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported" ) # rearrange the blob dimensions if batch_dim != 0: A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim]) # pyre-fixme[6]: Expected `int` for 1st param but got `List[int]`. B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim]) # compute grad sample for individual layers compute_layer_grad_sample = _supported_layers_grad_samplers.get( get_layer_type(layer) ) compute_layer_grad_sample(layer, A, B)