# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import defaultdict
from typing import Callable, List, Optional, Union
import torch
from opacus.optimizers.utils import params
from torch import nn
from torch.optim import Optimizer
logger = logging.getLogger(__name__)
logger.disabled = True
def _mark_as_processed(obj: Union[torch.Tensor, List[torch.Tensor]]):
"""
Marks parameters that have already been used in the optimizer step.
DP-SGD puts certain restrictions on how gradients can be accumulated. In particular,
no gradient can be used twice - client must call .zero_grad() between
optimizer steps, otherwise privacy guarantees are compromised.
This method marks tensors that have already been used in optimizer steps to then
check if zero_grad has been duly called.
Notes:
This is used to only mark ``p.grad_sample`` and ``p.summed_grad``
Args:
obj: tensor or a list of tensors to be marked
"""
if isinstance(obj, torch.Tensor):
obj._processed = True
elif isinstance(obj, list):
for x in obj:
x._processed = True
def _check_processed_flag_tensor(x: torch.Tensor):
"""
Checks if this gradient tensor has been previously used in optimization step.
See Also:
:meth:`~opacus.optimizers.optimizer._mark_as_processed`
Args:
x: gradient tensor
Raises:
ValueError
If tensor has attribute ``._processed`` previously set by
``_mark_as_processed`` method
"""
if hasattr(x, "_processed"):
raise ValueError(
"Gradients haven't been cleared since the last optimizer step. "
"In order to obtain privacy guarantees you must call optimizer.zero_grad()"
"on each step"
)
def _check_processed_flag(obj: Union[torch.Tensor, List[torch.Tensor]]):
"""
Checks if this gradient tensor (or a list of tensors) has been previously
used in optimization step.
See Also:
:meth:`~opacus.optimizers.optimizer._mark_as_processed`
Args:
x: gradient tensor or a list of tensors
Raises:
ValueError
If tensor (or at least one tensor from the list) has attribute
``._processed`` previously set by ``_mark_as_processed`` method
"""
if isinstance(obj, torch.Tensor):
_check_processed_flag_tensor(obj)
elif isinstance(obj, list):
for x in obj:
_check_processed_flag_tensor(x)
def _generate_noise(
std: float,
reference: torch.Tensor,
generator=None,
secure_mode: bool = False,
) -> torch.Tensor:
"""
Generates noise according to a Gaussian distribution with mean 0
Args:
std: Standard deviation of the noise
reference: The reference Tensor to get the appropriate shape and device
for generating the noise
generator: The PyTorch noise generator
secure_mode: boolean showing if "secure" noise need to be generated
(see the notes)
Notes:
If `secure_mode` is enabled, the generated noise is also secure
against the floating point representation attacks, such as the ones
in https://arxiv.org/abs/2107.10138 and https://arxiv.org/abs/2112.05307.
The attack for Opacus first appeared in https://arxiv.org/abs/2112.05307.
The implemented fix is based on https://arxiv.org/abs/2107.10138 and is
achieved through calling the Gaussian noise function 2*n times, when n=2
(see section 5.1 in https://arxiv.org/abs/2107.10138).
Reason for choosing n=2: n can be any number > 1. The bigger, the more
computation needs to be done (`2n` Gaussian samples will be generated).
The reason we chose `n=2` is that, `n=1` could be easy to break and `n>2`
is not really necessary. The complexity of the attack is `2^p(2n-1)`.
In PyTorch, `p=53` and so complexity is `2^53(2n-1)`. With `n=1`, we get
`2^53` (easy to break) but with `n=2`, we get `2^159`, which is hard
enough for an attacker to break.
"""
zeros = torch.zeros(reference.shape, device=reference.device)
if std == 0:
return zeros
# TODO: handle device transfers: generator and reference tensor
# could be on different devices
if secure_mode:
torch.normal(
mean=0,
std=std,
size=(1, 1),
device=reference.device,
generator=generator,
) # generate, but throw away first generated Gaussian sample
sum = zeros
for _ in range(4):
sum += torch.normal(
mean=0,
std=std,
size=reference.shape,
device=reference.device,
generator=generator,
)
return sum / 2
else:
return torch.normal(
mean=0,
std=std,
size=reference.shape,
device=reference.device,
generator=generator,
)
[docs]
class DPOptimizer(Optimizer):
"""
``torch.optim.Optimizer`` wrapper that adds additional functionality to clip per
sample gradients and add Gaussian noise.
Can be used with any ``torch.optim.Optimizer`` subclass as an underlying optimizer.
``DPOptimzer`` assumes that parameters over which it performs optimization belong
to GradSampleModule and therefore have the ``grad_sample`` attribute.
On a high level ``DPOptimizer``'s step looks like this:
1) Aggregate ``p.grad_sample`` over all parameters to calculate per sample norms
2) Clip ``p.grad_sample`` so that per sample norm is not above threshold
3) Aggregate clipped per sample gradients into ``p.grad``
4) Add Gaussian noise to ``p.grad`` calibrated to a given noise multiplier and
max grad norm limit (``std = noise_multiplier * max_grad_norm``).
5) Call underlying optimizer to perform optimization step
Examples:
>>> module = MyCustomModel()
>>> optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
>>> dp_optimizer = DPOptimizer(
... optimizer=optimizer,
... noise_multiplier=1.0,
... max_grad_norm=1.0,
... expected_batch_size=4,
... )
"""
def __init__(
self,
optimizer: Optimizer,
*,
noise_multiplier: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
):
"""
Args:
optimizer: wrapped optimizer.
noise_multiplier: noise multiplier
max_grad_norm: max grad norm used for gradient clipping
expected_batch_size: batch_size used for averaging gradients. When using
Poisson sampling averaging denominator can't be inferred from the
actual batch size. Required is ``loss_reduction="mean"``, ignored if
``loss_reduction="sum"``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
generator: torch.Generator() object used as a source of randomness for
the noise
secure_mode: if ``True`` uses noise generation approach robust to floating
point arithmetic attacks.
See :meth:`~opacus.optimizers.optimizer._generate_noise` for details
"""
if loss_reduction not in ("mean", "sum"):
raise ValueError(f"Unexpected value for loss_reduction: {loss_reduction}")
if loss_reduction == "mean" and expected_batch_size is None:
raise ValueError(
"You must provide expected batch size of the loss reduction is mean"
)
self.original_optimizer = optimizer
self.noise_multiplier = noise_multiplier
self.max_grad_norm = max_grad_norm
self.loss_reduction = loss_reduction
self.expected_batch_size = expected_batch_size
self.step_hook = None
self.generator = generator
self.secure_mode = secure_mode
self._step_skip_queue = []
self._is_last_step_skipped = False
for p in self.params:
p.summed_grad = None
def _get_flat_grad_sample(self, p: torch.Tensor):
"""
Return parameter's per sample gradients as a single tensor.
By default, per sample gradients (``p.grad_sample``) are stored as one tensor per
batch basis. Therefore, ``p.grad_sample`` is a single tensor if holds results from
only one batch, and a list of tensors if gradients are accumulated over multiple
steps. This is done to provide visibility into which sample belongs to which batch,
and how many batches have been processed.
This method returns per sample gradients as a single concatenated tensor, regardless
of how many batches have been accumulated
Args:
p: Parameter tensor. Must have ``grad_sample`` attribute
Returns:
``p.grad_sample`` if it's a tensor already, or a single tensor computed by
concatenating every tensor in ``p.grad_sample`` if it's a list
Raises:
ValueError
If ``p`` is missing ``grad_sample`` attribute
"""
if not hasattr(p, "grad_sample"):
raise ValueError(
"Per sample gradient not found. Are you using GradSampleModule?"
)
if p.grad_sample is None:
raise ValueError(
"Per sample gradient is not initialized. Not updated in backward pass?"
)
if isinstance(p.grad_sample, torch.Tensor):
ret = p.grad_sample
elif isinstance(p.grad_sample, list):
ret = torch.cat(p.grad_sample, dim=0)
else:
raise ValueError(f"Unexpected grad_sample type: {type(p.grad_sample)}")
return ret
[docs]
def signal_skip_step(self, do_skip=True):
"""
Signals the optimizer to skip an optimization step and only perform clipping and
per sample gradient accumulation.
On every call of ``.step()`` optimizer will check the queue of skipped step
signals. If non-empty and the latest flag is ``True``, optimizer will call
``self.clip_and_accumulate``, but won't proceed to adding noise and performing
the actual optimization step.
It also affects the behaviour of ``zero_grad()``. If the last step was skipped,
optimizer will clear per sample gradients accumulated by
``self.clip_and_accumulate`` (``p.grad_sample``), but won't touch aggregated
clipped gradients (``p.summed_grad``)
Used by :class:`~opacus.utils.batch_memory_manager.BatchMemoryManager` to
simulate large virtual batches with limited memory footprint.
Args:
do_skip: flag if next step should be skipped
"""
self._step_skip_queue.append(do_skip)
def _check_skip_next_step(self, pop_next=True):
"""
Checks if next step should be skipped by the optimizer.
This is for large Poisson batches that get split into smaller physical batches
to fit on the device. Batches that do not correspond to the end of a Poisson
batch or thus `skipped` as their gradient gets accumulated for one big step.
"""
if self._step_skip_queue:
if pop_next:
return self._step_skip_queue.pop(0)
else:
return self._step_skip_queue[0]
else:
return False
@property
def params(self) -> List[nn.Parameter]:
"""
Returns a flat list of ``nn.Parameter`` managed by the optimizer
"""
return params(self)
@property
def grad_samples(self) -> List[torch.Tensor]:
"""
Returns a flat list of per sample gradient tensors (one per parameter)
"""
ret = []
for p in self.params:
ret.append(self._get_flat_grad_sample(p))
return ret
@property
def accumulated_iterations(self) -> int:
"""
Returns number of batches currently accumulated and not yet processed.
In other words ``accumulated_iterations`` tracks the number of forward/backward
passed done in between two optimizer steps. The value would typically be 1,
but there are possible exceptions.
Used by privacy accountants to calculate real sampling rate.
"""
vals = []
for p in self.params:
if not hasattr(p, "grad_sample"):
raise ValueError(
"Per sample gradient not found. Are you using GradSampleModule?"
)
if isinstance(p.grad_sample, torch.Tensor):
vals.append(1)
elif isinstance(p.grad_sample, list):
vals.append(len(p.grad_sample))
else:
raise ValueError(f"Unexpected grad_sample type: {type(p.grad_sample)}")
if len(set(vals)) > 1:
raise ValueError(
"Number of accumulated steps is inconsistent across parameters"
)
return vals[0]
@property
def param_groups(self) -> List[dict]:
"""
Returns a list containing a dictionary of all parameters managed by the optimizer.
"""
return self.original_optimizer.param_groups
@param_groups.setter
def param_groups(self, param_groups: List[dict]):
"""
Updates the param_groups of the optimizer.
"""
self.original_optimizer.param_groups = param_groups
@property
def state(self) -> defaultdict:
"""
Returns a dictionary holding current optimization state.
"""
return self.original_optimizer.state
@state.setter
def state(self, state: defaultdict):
"""
Updates the state of the optimizer.
"""
self.original_optimizer.state = state
@property
def defaults(self) -> dict:
"""
Returns a dictionary containing default values for optimization.
"""
return self.original_optimizer.defaults
@defaults.setter
def defaults(self, defaults: dict):
"""
Updates the defaults of the optimizer.
"""
self.original_optimizer.defaults = defaults
[docs]
def attach_step_hook(self, fn: Callable[[DPOptimizer], None]):
"""
Attaches a hook to be executed after gradient clipping/noising, but before the
actual optimization step.
Most commonly used for privacy accounting.
Args:
fn: hook function. Expected signature: ``foo(optim: DPOptimizer)``
"""
self.step_hook = fn
[docs]
def clip_and_accumulate(self):
"""
Performs gradient clipping.
Stores clipped and aggregated gradients into `p.summed_grad```
"""
if len(self.grad_samples[0]) == 0:
# Empty batch
per_sample_clip_factor = torch.zeros(
(0,), device=self.grad_samples[0].device
)
else:
per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (
self.max_grad_norm / (per_sample_norms + 1e-6)
).clamp(max=1.0)
for p in self.params:
_check_processed_flag(p.grad_sample)
grad_sample = self._get_flat_grad_sample(p)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
if p.summed_grad is not None:
p.summed_grad += grad
else:
p.summed_grad = grad
_mark_as_processed(p.grad_sample)
[docs]
def add_noise(self):
"""
Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad``
"""
for p in self.params:
_check_processed_flag(p.summed_grad)
noise = _generate_noise(
std=self.noise_multiplier * self.max_grad_norm,
reference=p.summed_grad,
generator=self.generator,
secure_mode=self.secure_mode,
)
p.grad = (p.summed_grad + noise).view_as(p)
_mark_as_processed(p.summed_grad)
[docs]
def scale_grad(self):
"""
Applies given ``loss_reduction`` to ``p.grad``.
Does nothing if ``loss_reduction="sum"``. Divides gradients by
``self.expected_batch_size`` if ``loss_reduction="mean"``
"""
if self.loss_reduction == "mean":
for p in self.params:
p.grad /= self.expected_batch_size * self.accumulated_iterations
[docs]
def zero_grad(self, set_to_none: bool = False):
"""
Clear gradients.
Clears ``p.grad``, ``p.grad_sample`` and ``p.summed_grad`` for all of it's parameters
Notes:
``set_to_none`` argument only affects ``p.grad``. ``p.grad_sample`` and
``p.summed_grad`` is never zeroed out and always set to None.
Normal grads can do this, because their shape is always the same.
Grad samples do not behave like this, as we accumulate gradients from different
batches in a list
Args:
set_to_none: instead of setting to zero, set the grads to None. (only
affects regular gradients. Per sample gradients are always set to None)
"""
if set_to_none is False:
logger.debug(
"Despite set_to_none is set to False, "
"opacus will set p.grad_sample and p.summed_grad to None due to "
"non-trivial gradient accumulation behaviour"
)
for p in self.params:
p.grad_sample = None
if not self._is_last_step_skipped:
p.summed_grad = None
self.original_optimizer.zero_grad(set_to_none)
[docs]
def pre_step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
"""
Perform actions specific to ``DPOptimizer`` before calling
underlying ``optimizer.step()``
Args:
closure: A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
"""
# The corner case when the optimizer has no trainable parameters.
# Essentially the DPOptimizer act as a normal optimizer
if self.grad_samples is None or len(self.grad_samples) == 0:
return True
self.clip_and_accumulate()
if self._check_skip_next_step():
self._is_last_step_skipped = True
return False
self.add_noise()
self.scale_grad()
if self.step_hook:
self.step_hook(self)
self._is_last_step_skipped = False
return True
[docs]
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
if closure is not None:
with torch.enable_grad():
closure()
if self.pre_step():
return self.original_optimizer.step()
else:
return None
def __repr__(self):
return self.original_optimizer.__repr__()
[docs]
def state_dict(self):
return self.original_optimizer.state_dict()
[docs]
def load_state_dict(self, state_dict) -> None:
self.original_optimizer.load_state_dict(state_dict)