# Source code for opacus.utils.tensor_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utils for generating stats from torch tensors.
"""
import math
from typing import Iterator, List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F

[docs]
def calc_sample_norms(
named_params: Iterator[Tuple[str, torch.Tensor]], *, flat: bool = True
) -> List[torch.Tensor]:
r"""
Calculates the norm of the given tensors for each sample.

This function calculates the overall norm of the given tensors for each sample,
assuming each batch's dim is zero.

Args:
named_params: An iterator of tuples <name, param> with name being a
string and param being a tensor of shape [B, ...] where B
is the size of the batch and is the 0th dimension.
flat: A flag, when set to True returns a flat norm over all
layers norms

Returns:
A list of tensor norms where length of the list is the number of layers

Example:
>>> t1 = torch.rand((2, 5))
>>> t2 = torch.rand((2, 5))
>>> norms = calc_sample_norms([("1", t1), ("2", t2)])
>>> norms, norms[0].shape
([tensor([...])], torch.Size([2]))
"""
norms = [param.view(len(param), -1).norm(2, dim=-1) for name, param in named_params]
# calc norm over all layer norms if flat = True
if flat:
norms = [torch.stack(norms, dim=0).norm(2, dim=0)]
return norms

[docs]
def calc_sample_norms_one_layer(param: torch.Tensor) -> torch.Tensor:
r"""
Calculates the norm of the given tensor (a single parameter) for each sample.

This function calculates the overall norm of the given tensor for each sample,
assuming each batch's dim is zero.

It is equivalent to:
calc_sample_norms(named_params=((None, param),))[0]

Args:
param: A tensor of shape [B, ...] where B
is the size of the batch and is the 0th dimension.

Returns:
A tensor of norms

Example:
>>> t1 = torch.rand((2, 5))
>>> norms = calc_sample_norms_one_layer(t1)
>>> norms, norms.shape
(tensor([...]), torch.Size([2]))
"""
norms = param.view(len(param), -1).norm(2, dim=-1)
return norms

[docs]
def sum_over_all_but_batch_and_last_n(
tensor: torch.Tensor, n_dims: int
) -> torch.Tensor:
r"""
Calculates the sum over all dimensions, except the first
(batch dimension), and excluding the last n_dims.

This function will ignore the first dimension, and it will
not aggregate over the last n_dims dimensions.

Args:
tensor: An input tensor of shape (B, ..., X[n_dims-1]).
n_dims: Number of dimensions to keep.

Returns:
A tensor of shape (B, ..., X[n_dims-1])

Example:
>>> tensor = torch.ones(1, 2, 3, 4, 5)
>>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape
torch.Size([1, 4, 5])
"""
if tensor.dim() == n_dims + 1:
return tensor
else:
dims = list(range(1, tensor.dim() - n_dims))
return tensor.sum(dim=dims)

[docs]
def unfold2d(
input,
*,
kernel_size: Tuple[int, int],
padding: Union[str, Tuple[int, int]],
stride: Tuple[int, int],
dilation: Tuple[int, int],
):
"""
See :meth:~torch.nn.functional.unfold
"""
*shape, H, W = input.shape
if padding == "same":
total_pad_H = dilation[0] * (kernel_size[0] - 1)
total_pad_W = dilation[1] * (kernel_size[1] - 1)

elif padding == "valid":
else:
)

H_effective = (
H
- (kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1))
) // stride[0] + 1
W_effective = (
W
+ -(kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1))
) // stride[1] + 1
# F.pad's first argument is the padding of the *last* dimension
strides = list(input.stride())
strides = strides[:-2] + [
dilation[1],
stride[1],
]
out = input.as_strided(
shape + [kernel_size[0], kernel_size[1], H_effective, W_effective], strides
)

return out.reshape(input.size(0), -1, H_effective * W_effective)

[docs]
def unfold3d(
tensor: torch.Tensor,
*,
kernel_size: Union[int, Tuple[int, int, int]],
padding: Union[int, Tuple[int, int, int]] = 0,
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
):
r"""
Extracts sliding local blocks from an batched input tensor.

:class:torch.nn.Unfold only supports 4D inputs (batched image-like tensors).
This method implements the same action for 5D inputs

Args:
tensor: An input tensor of shape (B, C, D, H, W).
kernel_size: the size of the sliding blocks
padding: implicit zero padding to be added on both sides of input
stride: the stride of the sliding blocks in the input spatial dimensions
dilation: the spacing between the kernel points.

Returns:
A tensor of shape (B, C * np.product(kernel_size), L), where L - output spatial dimensions.
See :class:torch.nn.Unfold for more details

Example:
>>> B, C, D, H, W = 3, 4, 5, 6, 7
>>> tensor = torch.arange(1, B*C*D*H*W + 1.).view(B, C, D, H, W)
>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
torch.Size([3, 32, 120])
"""

if len(tensor.shape) != 5:
raise ValueError(
f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
)

if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)

if isinstance(stride, int):
stride = (stride, stride, stride)

if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)

if padding == "same":
total_pad_D = dilation[0] * (kernel_size[0] - 1)
total_pad_H = dilation[1] * (kernel_size[1] - 1)
total_pad_W = dilation[2] * (kernel_size[2] - 1)

elif padding == "valid":
0,
0,
0,
0,
0,
0,
)
else:
)

batch_size, channels, _, _, _ = tensor.shape

# Input shape: (B, C, D, H, W)
tensor,
)

dilated_kernel_size = (
kernel_size[0] + (kernel_size[0] - 1) * (dilation[0] - 1),
kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1),
kernel_size[2] + (kernel_size[2] - 1) * (dilation[2] - 1),
)

tensor = tensor.unfold(dimension=2, size=dilated_kernel_size[0], step=stride[0])
tensor = tensor.unfold(dimension=3, size=dilated_kernel_size[1], step=stride[1])
tensor = tensor.unfold(dimension=4, size=dilated_kernel_size[2], step=stride[2])

if dilation != (1, 1, 1):
tensor = filter_dilated_rows(tensor, dilation, dilated_kernel_size, kernel_size)

# Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
# For D_out, H_out, W_out definitions see :class:torch.nn.Unfold

tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
# Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose(
1, 2
)
# Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]

return tensor

[docs]
def filter_dilated_rows(
tensor: torch.Tensor,
dilation: Tuple[int, int, int],
dilated_kernel_size: Tuple[int, int, int],
kernel_size: Tuple[int, int, int],
):
"""
A helper function that removes extra rows created during the process of
implementing dilation.

Args:
tensor: A tensor containing the output slices resulting from unfolding
the input tensor to unfold3d().
Shape is (B, C, D_out, H_out, W_out, dilated_kernel_size[0],
dilated_kernel_size[1], dilated_kernel_size[2]).
dilation: The dilation given to unfold3d().
dilated_kernel_size: The size of the dilated kernel.
kernel_size: The size of the kernel given to unfold3d().

Returns:
A tensor of shape (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
For D_out, H_out, W_out definitions see :class:torch.nn.Unfold.

Example:
>>> tensor = torch.zeros([1, 1, 3, 3, 3, 5, 5, 5])
>>> dilation = (2, 2, 2)
>>> dilated_kernel_size = (5, 5, 5)
>>> kernel_size = (3, 3, 3)
>>> filter_dilated_rows(tensor, dilation, dilated_kernel_size, kernel_size).shape
torch.Size([1, 1, 3, 3, 3, 3, 3, 3])
"""

kernel_rank = len(kernel_size)

indices_to_keep = [
torch.arange(0, dilated_kernel_size[i], dilation[i], device=tensor.device)
for i in range(kernel_rank)
]

axis_offset = len(tensor.shape) - kernel_rank

for dim in range(kernel_rank):
tensor = torch.index_select(
tensor, dim=axis_offset + dim, index=indices_to_keep[dim]
)

return tensor