#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import List
import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
from opacus.grad_sample.grad_sample_module import (
GradSampleModule,
create_or_accumulate_grad_sample,
promote_current_grad_sample,
)
from opacus.utils.module_utils import requires_grad, trainable_parameters
logger = logging.getLogger(__name__)
logger.disabled = True
[docs]
def create_norm_sample(
*, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
) -> None:
"""
Creates a ``_norm_sample`` attribute in the given parameter
Args:
param: Parameter to which ``_norm_sample`` will be added
grad_sample: Per-sample gradients tensor. Must be of the same
shape as ``param`` with extra batch dimension
"""
if param.requires_grad:
if (
max_batch_len == 0
): # To handle the case of empty batch that may arise from Poisson sampling
param._norm_sample = torch.tensor(
[], device=grad_sample.device, dtype=grad_sample.dtype
)
else:
param._norm_sample = torch.zeros(
torch.Size([max_batch_len, 1]),
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(
2, dim=-1
)
[docs]
class GradSampleModuleFastGradientClipping(GradSampleModule):
"""
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping
Computes norms of gradients without gradient instantiation
"""
NORM_SAMPLERS = {}
def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
strict: bool = True,
force_functorch=False,
max_grad_norm=1,
use_ghost_clipping=True,
):
"""
Args:
m: nn.Module to be wrapped
batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension representing the batch. If set to True, dimensions on
input tensor are expected be ``[batch_size, ...]``, otherwise
``[K, batch_size, ...]``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
max_grad_norm: The value at which gradients are to be clipped.
strict: If set to True, the input module will be validated to make sure that
it does not have buffers in all its submodules.
force_functorch: If set to ``True``, will use functorch to compute
all per sample gradients. Otherwise, functorch will be used only
for layers without registered grad sampler methods.
use_ghost_clipping: If set to ``True``, Ghost Clipping
will be used for clipping gradients of supported layers. If ``False``, Fast
Gradient Clipping will be used for all layers.
Raises:
NotImplementedError
If ``strict`` is set to ``True`` and module ``m`` (or any of its
submodules) includes a buffer.
"""
super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
strict=strict,
force_functorch=force_functorch,
)
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping
[docs]
def get_clipping_coef(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
norm_sample = self.get_norm_sample()
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
[docs]
def get_norm_sample(self) -> torch.Tensor:
"""Get per-example gradient norms."""
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
return norm_sample
def capture_activations_hook(
self,
module: nn.Module,
forward_input: List[torch.Tensor],
_forward_output: torch.Tensor,
):
if (
not requires_grad(module)
or not module.training
or not torch.is_grad_enabled()
or not self.hooks_enabled
):
return
if not hasattr(module, "activations"):
module.activations = []
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore
for _, p in trainable_parameters(module):
p._forward_counter += 1
if (
self.use_ghost_clipping
and p._forward_counter > 1
and type(module) in self.NORM_SAMPLERS
):
raise NotImplementedError(
"Parameter tying is not supported with Ghost Clipping"
)
[docs]
def capture_backprops_hook(
self,
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
):
"""
Computes norms of per sample gradient given the current backprops and activations
stored by the associated forward hook. Computed per sample gradient norms are
stored in ``norm_sample`` field in each parameter.
Args:
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
"""
if not self.hooks_enabled:
return
backprops = forward_output[0].detach()
activations, backprops = self.rearrange_grad_samples(
module=module,
backprops=backprops,
loss_reduction=loss_reduction,
batch_first=batch_first,
)
if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS:
norm_sampler_fn = self.NORM_SAMPLERS[type(module)]
norm_samples = norm_sampler_fn(module, activations, backprops)
for param, ns in norm_samples.items():
if param.requires_grad:
param._norm_sample = ns
param._forward_counter -= 1
else:
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
else:
grad_sampler_fn = ft_compute_per_sample_gradient
grad_samples = grad_sampler_fn(module, activations, backprops)
for param, gs in grad_samples.items():
create_or_accumulate_grad_sample(
param=param, grad_sample=gs, max_batch_len=module.max_batch_len
)
del grad_samples
# Detect end of current batch processing and switch accumulation
# mode from sum to stacking. Used for RNNs and tied parameters
# (See #417 for details)
for _, p in trainable_parameters(module):
p._forward_counter -= 1
if p._forward_counter == 0:
promote_current_grad_sample(p)
create_norm_sample(
param=p,
grad_sample=p.grad_sample,
max_batch_len=module.max_batch_len,
)
del p.grad_sample
if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len