Source code for opacus.utils.fast_gradient_clipping_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
    GradSampleModuleFastGradientClipping,
)
from opacus.optimizers import DPOptimizerFastGradientClipping


[docs] class DPTensorFastGradientClipping: """ Packages the training loop for Fast Gradient and Ghost Clipping into loss.backward(). """ def __init__( self, module: GradSampleModuleFastGradientClipping, optimizer: DPOptimizerFastGradientClipping, loss_per_sample: torch.Tensor, loss_reduction: str = "mean", ): """ Args: module: the module to train optimizer: the optimizer used to train the module loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] """ self.module = module self.optimizer = optimizer self.loss_per_sample = loss_per_sample self.loss_reduction = loss_reduction def item(self): return self.detach().item() def detach(self): if self.loss_reduction == "mean": return torch.mean(self.loss_per_sample).detach() elif self.loss_reduction == "sum": return torch.sum(self.loss_per_sample).detach() def __truediv__(self, other): """ Division operation for DPTensorFastGradientClipping. Enables: loss / scalar """ return DPTensorFastGradientClipping( self.module, self.optimizer, self.loss_per_sample / other, self.loss_reduction, ) def __mul__(self, other): """ Multiplication operation for DPTensorFastGradientClipping. Enables: loss * scalar or scalar * loss """ return DPTensorFastGradientClipping( self.module, self.optimizer, self.loss_per_sample * other, self.loss_reduction, ) def __rmul__(self, other): """ Left multiplication by a scalar. Required to support loss weighting: weight * loss """ return self.__mul__(other) def __add__(self, other): """ Addition operation for DPTensorFastGradientClipping. Enables: loss + scalar or loss + loss. Required to support combining multiple losses in a single training step. """ if isinstance(other, DPTensorFastGradientClipping): if self.loss_reduction != other.loss_reduction: raise ValueError( f"Cannot add losses with different reductions: {self.loss_reduction} vs {other.loss_reduction}" ) return DPTensorFastGradientClipping( self.module, self.optimizer, self.loss_per_sample + other.loss_per_sample, self.loss_reduction, ) else: return DPTensorFastGradientClipping( self.module, self.optimizer, self.loss_per_sample + other, self.loss_reduction, ) def __radd__(self, other): """ Right addition operation for DPTensorFastGradientClipping. Enables: scalar + loss """ return self.__add__(other) def __sub__(self, other): """ Subtraction operation for DPTensorFastGradientClipping. Enables: loss - scalar or loss - loss """ if isinstance(other, DPTensorFastGradientClipping): if self.loss_reduction != other.loss_reduction: raise ValueError( f"Cannot subtract losses with different reductions: {self.loss_reduction} vs {other.loss_reduction}" ) return DPTensorFastGradientClipping( self.module, self.optimizer, self.loss_per_sample - other.loss_per_sample, self.loss_reduction, ) else: return DPTensorFastGradientClipping( self.module, self.optimizer, self.loss_per_sample - other, self.loss_reduction, ) def __rsub__(self, other): """ Right subtraction operation for DPTensorFastGradientClipping. Enables: scalar - loss """ return DPTensorFastGradientClipping( self.module, self.optimizer, other - self.loss_per_sample, self.loss_reduction, ) def __neg__(self): """ Negation operation for DPTensorFastGradientClipping. Enables: -loss """ return DPTensorFastGradientClipping( self.module, self.optimizer, -self.loss_per_sample, self.loss_reduction, ) def __repr__(self): """String representation""" return f"DPTensorFastGradientClipping(loss_reduction={self.loss_reduction}, shape={self.loss_per_sample.shape})" def __str__(self): """String representation""" return f"DPTensorFastGradientClipping({self.item():.4f})"
[docs] def backward(self): """ Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between """ if self.loss_reduction == "mean": reduced_loss = torch.mean(self.loss_per_sample, dim=0) elif self.loss_reduction == "sum": reduced_loss = torch.sum(self.loss_per_sample, dim=0) else: raise ValueError( f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" ) reduced_loss.backward(retain_graph=True) self.optimizer.zero_grad() coeff = self.module.get_clipping_coef() second_loss_per_sample = ( coeff.to(self.loss_per_sample.device) * self.loss_per_sample ) second_loss = torch.sum(second_loss_per_sample) self.module.disable_hooks() second_loss.backward() self.module.enable_hooks()
[docs] class DPLossFastGradientClipping: """ Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping. """ def __init__( self, module: GradSampleModuleFastGradientClipping, optimizer: DPOptimizerFastGradientClipping, criterion, loss_reduction: str = "mean", ): assert loss_reduction in [ "mean", "sum", ], "loss_reduction should be either 'mean' or 'sum'" # if the criterion is missing reduction attribute, use module's reduction attribute' if not hasattr(criterion, "reduction"): setattr(criterion, "reduction", module.loss_reduction) assert ( loss_reduction == criterion.reduction == module.loss_reduction == optimizer.loss_reduction ), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction" self.optimizer = optimizer self.module = module self.criterion = criterion self.loss_reduction = loss_reduction def __call__(self, *args, shape=None, **kwargs) -> DPTensorFastGradientClipping: """ Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping """ old_reduction = self.criterion.reduction self.criterion.reduction = "none" loss_per_sample = self.criterion(*args, **kwargs) self.criterion.reduction = old_reduction if shape is not None and loss_per_sample.shape[0] == shape[0] * shape[1]: # Note that the privacy unit for generative NLP tasks is per sequence. # The shape variable is the shape of the logits before flattening i.e., [batch_size, sequence_lenght, vocab_size]. # This variable is necessary for ghost clipping to work with generative NLP tasks. loss_per_sample = loss_per_sample.view(shape[0], shape[1]) # BxT if self.loss_reduction == "mean": # When the criterion has ignore_index, positions matching it # produce zero loss but should also be excluded from the # denominator (matching PyTorch's CrossEntropyLoss behavior). ignore_index = getattr(self.criterion, "ignore_index", None) if ignore_index is not None and len(args) >= 2: targets = args[1] if "target" in kwargs: targets = kwargs["target"] mask = targets.view(shape[0], shape[1]) != ignore_index num_valid = mask.sum(dim=1).clamp(min=1) loss_per_sample = loss_per_sample.sum(dim=1) / num_valid # B else: loss_per_sample = loss_per_sample.mean(dim=1) # B elif self.loss_reduction == "sum": loss_per_sample = loss_per_sample.sum(dim=1) # B else: raise ValueError( f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" ) return DPTensorFastGradientClipping( self.module, self.optimizer, loss_per_sample, self.loss_reduction )