# Source code for opacus.utils.clipping

#!/usr/bin/env python3

import math
from abc import ABC
from enum import IntEnum
from itertools import cycle
from typing import Iterator, List, Union

import torch

try:
from skimage.filters import threshold_otsu as otsu
except ImportError:

def otsu(*args, **kwargs) -> float:
raise NotImplementedError("Install skimage!")

def _mean_plus_r_var(data: torch.Tensor, ratio: float = 0, **kwargs) -> float:
"""
Caclulates mean + ratio x standard_deviation of the provided tensor
and returns the larger of this value and the smallest element in
the tensor (can happen when ratio is negative).

Args:
data: Pytorch tensor containing the data on which the mean and stdv.
is evaluated.
ratio: Value of the scaling factor in the value calculated by the
function.

Returns:
The result of the function.

"""
return max(data.min().item(), data.mean().item() + ratio * data.std().item() + 1e-8)

def _pvalue(data: torch.Tensor, ratio: float = 0.25, **kwargs) -> torch.Tensor:
"""
Finds the pth largest value in the tensor, where p = ratio x len(data).

Args:
data: Pytorch tensor against which the function is evaluated.
ratio: Value of the scaling factor in the value calculated by
the function.

Returns:
Tensor of dimension (1,) with the result of the function.
"""
cut = max(1, int(data.numel() * (1 - ratio)))

def _static(data: torch.Tensor, current_thresh: float, **kwargs) -> float:
"""
Passes through the specified input current_threshold.

Args:
data: Pytorch tensor containing the data.
current_thresh: The threshold value.

Returns:
The threshold value.
"""
return current_thresh

def _otsu(data: torch.Tensor, **kwargs) -> float:
"""
Returns an intensity threshold for an image that separates it
into backgorund and foreground pixels.

The implementation uses Otsu's method, which assumes a GMM with
2 components but uses some heuristic to maximize the variance
differences. The input data is shaped into a 2D image for the
purpose of evaluating the threshold value.

Args:
data: Pytorch tensor containing the data.

Returns:
Threshold value determined via Otsu's method.
"""
h = 2 ** int(1 + math.log2(data.shape) / 2)
fake_img = data.view(h, -1).cpu().numpy()
return otsu(fake_img, h)

[docs]class ClippingMethod(IntEnum):
STATIC = 0
PVALUE = 1
MEAN = 2
GMM = 3
OTSU = 4

_thresh_ = {
ClippingMethod.STATIC: _static,
ClippingMethod.PVALUE: _pvalue,
ClippingMethod.MEAN: _mean_plus_r_var,
ClippingMethod.OTSU: _otsu,
}

def _calculate_thresh_value(
data: torch.Tensor,
current_thresh: float,
clipping_mehod: ClippingMethod = ClippingMethod.STATIC,
clipping_ratio: float = -1,
) -> float:
"""
Calculates a clipping threshold by looking at the layer norms
of each example.

Four methods are supported: static threshold, threshold calculated
based on mean and variance of the norms, and threshold calculated
based on percentile values of the norms.

Args:
data: Pytorch tensor containing the data
current_thresh: Value of the current threshold.
clipping_method: Enum value defining the clipping strategy. Current
options are STATIC, PVALUE, MEAN, and OTSU.
clipping_ratio: Value that has different meaning for differnet strategies, it
is the percentile parameter for PVALUE, and a multiplier for
standard deviation for MEAN. It has no significance for OTSU and
STATIC.

Returns:
Clipping threshold value
"""
return _thresh_[clipping_mehod](
data, ratio=clipping_ratio, current_thresh=current_thresh
)

[docs]class NormClipper(ABC):
"""
An abstract class to calculate the clipping factor
"""

[docs]    def calc_clipping_factors(
self, norms: List[torch.Tensor]
) -> Union[List[torch.Tensor], Iterator[torch.Tensor]]:
"""
Calculates the clipping factor(s) based on the given
parameters. A concrete subclass must implement this.

Returns:
The clipping factors
"""
pass

@property
def thresholds(self) -> torch.Tensor:
"""
Depending on the type of clipper, returns threshold values.

Returns:
The threshold values
"""
pass

@property
def is_per_layer(self) -> bool:
"""
Depending on type of clipper, returns indicator as to whether
different clipping is applied to each layer in the model.

Returns:
Flag indicator as to whether different clipping is applied
to each layer in the model.
"""
pass

[docs]class ConstantFlatClipper(NormClipper):
"""
A clipper that clips all gradients in such a way that their norm is
at most equal to a specified value. This value is shared for all
layers in a model. Note that the process of clipping really involves
multiplying all gradients by a scaling factor. If this scaling factor
is > 1.0, it is instead capped at 1.0. The net effect is that the final
norm of the scaled gradients will be less than the specified value in
such a case. Thus it is better to think of the specified value as an
upper bound on the norm of final clipped gradients.
"""

def __init__(self, flat_value: float):
"""
Args:
flat_value: Constant value that is used to normalize gradients
such that their norm equals this value before clipping.
This threshold value is used for all layers.
"""
self.flat_value = float(flat_value)

[docs]    def calc_clipping_factors(
self, norms: List[torch.Tensor]
) -> Iterator[torch.Tensor]:
"""
Calculates the clipping factor based on the given
norm of gradients for all layers, so that the new
norm of clipped gradients is at most equal to
self.flat_value.

Args:
norms: List containing a single tensor of dimension (1,)
with the norm of all gradients.

Returns:
Tensor containing the single threshold value to be used
for all layers.
"""
# Expects a list of size one.
if len(norms) != 1:
raise ValueError(
"Waring: flat norm selected but "
)
per_sample_clip_factor = self.flat_value / (norms + 1e-6)
# We are *clipping* the gradient, so if the factor is ever >1 we set it to 1
per_sample_clip_factor = per_sample_clip_factor.clamp(max=1.0)  # pyre-ignore
# return this clipping factor for all layers
return cycle([per_sample_clip_factor])

@property
def thresholds(self) -> torch.Tensor:
"""
Returns singleton tensor of dimension (1,) containing
the common threshold value used for clipping all
layers in the model.

Returns:
Threshold values
"""

@property
def is_per_layer(self) -> bool:
"""
Returns indicator as to whether different clipping is applied
to each layer in the model. For this clipper, it is False.

Returns:
Flag with value False
"""
return False

[docs]class ConstantPerLayerClipper(NormClipper):
"""
A clipper that clips all gradients in such a way that their norm is
at most equal to a specified value. This value is specified for each
layer in a model. Note that the process of clipping really involves
multiplying all gradients by a scaling factor. If this scaling factor
is > 1.0, it is instead capped at 1.0. The net effect is that the final
norm of the scaled gradients will be less than the specified value in
such a case. Thus it is better to think of the specified value as an
upper bound on the norm of final clipped gradients.
"""

def __init__(self, flat_values: List[float]):
"""
Args:
flat_values: List of values that is used to normalize gradients
for each layer such that the norm equals the corresponding
value before clipping.
"""
self.flat_values = [float(fv) for fv in flat_values]

[docs]    def calc_clipping_factors(self, norms: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Calculates separate clipping factors for each layer based on
its corresponding norm of gradients, such that its new norm is
at most equal to the flat value specified for that layer when
instantiating the object of
:class:~opacus.utils.clipping.ConstantPerLayerClipper.

Args:
norms: List containing the desired norm of gradients for each layer.

Returns:
List of tensors, each containing a single value specifying the
clipping factor per layer.
"""
if len(norms) != len(self.flat_values) and len(self.flat_values) != 1:
raise ValueError(
f"{len(norms)} layers have provided norms but the "
f"number of clipping thresholds is {len(self.flat_values)}"
)

self.flat_values = self.flat_values * (
len(norms) if len(self.flat_values) == 1 else 1
)

clipping_factor = []
for norm, threshold in zip(norms, self.flat_values):
per_sample_clip_factor = threshold / (norm + 1e-6)
clipping_factor.append(per_sample_clip_factor.clamp(max=1.0))
return clipping_factor

@property
def thresholds(self) -> torch.Tensor:
"""
Returns a tensor of values that are used to normalize gradients for
each layer such that the norm at most equals the corresponding
value before clipping.

Returns:
Tensor of thresholds
"""

@property
def is_per_layer(self) -> bool:
"""
Returns indicator as to whether different clipping is applied
to each layer in the model. For this clipper, it is True.

Returns:
Flag with value True
"""
return True

class _Dynamic_Clipper_(NormClipper):
"""
This is a generic clipper, that is in an experimental phase.
The clipper uses different stats to find a clipping threshold
based on the given per sample norms.

Notes:
This clipper breaks DP guarantees [use only for experimentation]
"""

def __init__(
self,
flat_values: List[float],
clip_per_layer: bool = False,
clipping_method: ClippingMethod = ClippingMethod.STATIC,
clipping_ratio: float = 0.0,
clipping_momentum: float = 0.9,
):
"""
Args:
flat_value: List of float values that is used to normalize gradients
for each layer such that the norm equals the corresponding
value before clipping.
clip_per_layer: Flag indicating if a separate desired norm value is
specified per layer or if a single value is shared for all.
clipping_method: Value in the enum ClippingMethod that specifies one
of the currently supported clipping types.
clipping_ratio: Value that can be used to evaluate the clipping threshold
for certain clipping types.
clipping_momentum: value defines the decaing factor of an ubiased estimator
of exponential averaging of clipping thresholds, i.e. weight used to
combine the threshold from the current batch and the previous one.
"""
self.flat_values = [float(float_value) for float_value in flat_values]
self.clip_per_layer = clip_per_layer
if clipping_method != ClippingMethod.STATIC:
print(
"Warning! Current implementations of dynamic clipping "
"are not privacy safe; Caclulated privacy loss is not "
"indicative of a proper bound."
)
self.clipping_method = clipping_method
self.clipping_ratio = clipping_ratio
self.clipping_momentum = clipping_momentum
self.thresh = []

def calc_clipping_factors(
self, norms: List[torch.Tensor]
) -> Union[List[torch.Tensor], Iterator[torch.Tensor]]:
"""
Calculates separate clipping factors for each layer based on
stats such as a threshold determined by Otsu's method, combinations
of mean and std. deviation, kth median value etc.

This is experimental and does not guarantee privacy and is not recommended
for production use.

Args:
norms: List containing the desired norm of gradients for each layer.

Returns:
Singleton list specifying a common clippng factor for all layers,
or an iterator of tensors specifying a clipping factor per layer
"""

if len(self.thresh) == 0:
current_threshs = self.flat_values
if len(self.flat_values) == 1 and self.clip_per_layer:
# a constant clipping factor applied to all non-rozen layers
# need to replicate it by the number of number of those layers
# (= number of norms).
current_threshs *= len(norms)
else:
current_threshs = self.thresh

clipping_factor = []
self.thresh = []

if len(norms) != len(current_threshs):
raise ValueError(
f"Provided grad norm max's size {len(current_threshs)}"
f" does not match the number of layers {len(norms)}"
)

for norm, current_thresh in zip(norms, current_threshs):
thresh = _calculate_thresh_value(
norm, current_thresh, self.clipping_method, self.clipping_ratio
)
thresh = float(
(1 - self.clipping_momentum) * thresh
+ self.clipping_momentum * current_thresh
)
self.thresh.append(thresh)
per_sample_clip_factor = thresh / (norm + 1e-6)
clipping_factor.append(per_sample_clip_factor.clamp(max=1.0))  # pyre-ignore
return clipping_factor if self.is_per_layer else cycle(clipping_factor)

@property
def thresholds(self) -> torch.Tensor:
"""
Returns a tensor of values that are used to normalize gradients
for each layer such that the norm at most equals the corresponding
value before clipping.

Returns:
Tensor of thresholds
"""