Source code for opacus.grad_sample.grad_sample_module_fast_gradient_clipping

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from typing import List

import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
from opacus.grad_sample.grad_sample_module import (
from opacus.utils.module_utils import requires_grad, trainable_parameters

logger = logging.getLogger(__name__)
logger.disabled = True

[docs] def create_norm_sample( *, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int ) -> None: """ Creates a ``_norm_sample`` attribute in the given parameter Args: param: Parameter to which ``_norm_sample`` will be added grad_sample: Per-sample gradients tensor. Must be of the same shape as ``param`` with extra batch dimension """ if param.requires_grad: if ( max_batch_len == 0 ): # To handle the case of empty batch that may arise from Poisson sampling param._norm_sample = torch.tensor( [], device=grad_sample.device, dtype=grad_sample.dtype ) else: param._norm_sample = torch.zeros( torch.Size([max_batch_len, 1]), device=grad_sample.device, dtype=grad_sample.dtype, ) param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( 2, dim=-1 )
[docs] class GradSampleModuleFastGradientClipping(GradSampleModule): """ Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping Computes norms of gradients without gradient instantiation """ NORM_SAMPLERS = {} def __init__( self, m: nn.Module, *, batch_first=True, loss_reduction="mean", strict: bool = True, force_functorch=False, max_grad_norm=1, use_ghost_clipping=True, ): """ Args: m: nn.Module to be wrapped batch_first: Flag to indicate if the input tensor to the corresponding module has the first dimension representing the batch. If set to True, dimensions on input tensor are expected be ``[batch_size, ...]``, otherwise ``[K, batch_size, ...]`` loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values "sum" or "mean" max_grad_norm: The value at which gradients are to be clipped. strict: If set to True, the input module will be validated to make sure that it does not have buffers in all its submodules. force_functorch: If set to ``True``, will use functorch to compute all per sample gradients. Otherwise, functorch will be used only for layers without registered grad sampler methods. use_ghost_clipping: If set to ``True``, Ghost Clipping will be used for clipping gradients of supported layers. If ``False``, Fast Gradient Clipping will be used for all layers. Raises: NotImplementedError If ``strict`` is set to ``True`` and module ``m`` (or any of its submodules) includes a buffer. """ super().__init__( m, batch_first=batch_first, loss_reduction=loss_reduction, strict=strict, force_functorch=force_functorch, ) self.trainable_parameters = [p for _, p in trainable_parameters(self._module)] self.max_grad_norm = max_grad_norm self.use_ghost_clipping = use_ghost_clipping
[docs] def get_clipping_coef(self) -> torch.Tensor: """Get per-example gradient scaling factor for clipping.""" norm_sample = self.get_norm_sample() return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
[docs] def get_norm_sample(self) -> torch.Tensor: """Get per-example gradient norms.""" norm_sample = torch.stack( [param._norm_sample for param in self.trainable_parameters], dim=0 ).norm(2, dim=0) return norm_sample
def capture_activations_hook( self, module: nn.Module, forward_input: List[torch.Tensor], _forward_output: torch.Tensor, ): if ( not requires_grad(module) or not or not torch.is_grad_enabled() or not self.hooks_enabled ): return if not hasattr(module, "activations"): module.activations = [] module.activations.append([t.detach() for t in forward_input]) # pyre-ignore for _, p in trainable_parameters(module): p._forward_counter += 1 if ( self.use_ghost_clipping and p._forward_counter > 1 and type(module) in self.NORM_SAMPLERS ): raise NotImplementedError( "Parameter tying is not supported with Ghost Clipping" )
[docs] def capture_backprops_hook( self, module: nn.Module, _forward_input: torch.Tensor, forward_output: torch.Tensor, loss_reduction: str, batch_first: bool, ): """ Computes norms of per sample gradient given the current backprops and activations stored by the associated forward hook. Computed per sample gradient norms are stored in ``norm_sample`` field in each parameter. Args: module: nn.Module, _forward_input: torch.Tensor, forward_output: torch.Tensor, loss_reduction: str, batch_first: bool, """ if not self.hooks_enabled: return backprops = forward_output[0].detach() activations, backprops = self.rearrange_grad_samples( module=module, backprops=backprops, loss_reduction=loss_reduction, batch_first=batch_first, ) if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS: norm_sampler_fn = self.NORM_SAMPLERS[type(module)] norm_samples = norm_sampler_fn(module, activations, backprops) for param, ns in norm_samples.items(): if param.requires_grad: param._norm_sample = ns param._forward_counter -= 1 else: if not self.force_functorch and type(module) in self.GRAD_SAMPLERS: grad_sampler_fn = self.GRAD_SAMPLERS[type(module)] else: grad_sampler_fn = ft_compute_per_sample_gradient grad_samples = grad_sampler_fn(module, activations, backprops) for param, gs in grad_samples.items(): create_or_accumulate_grad_sample( param=param, grad_sample=gs, max_batch_len=module.max_batch_len ) del grad_samples # Detect end of current batch processing and switch accumulation # mode from sum to stacking. Used for RNNs and tied parameters # (See #417 for details) for _, p in trainable_parameters(module): p._forward_counter -= 1 if p._forward_counter == 0: promote_current_grad_sample(p) create_norm_sample( param=p, grad_sample=p.grad_sample, max_batch_len=module.max_batch_len, ) del p.grad_sample if len(module.activations) == 0: if hasattr(module, "max_batch_len"): del module.max_batch_len