Source code for opacus.utils.tensor_utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utils for generating stats from torch tensors.
"""
from typing import Iterator, List, Tuple

import torch


[docs]def calc_sample_norms( named_params: Iterator[Tuple[str, torch.Tensor]], flat: bool = True ) -> List[torch.Tensor]: r""" Calculates the norm of the given tensors for each sample. This function calculates the overall norm of the given tensors for each sample, assuming the each batch's dim is zero. Args: named_params: An iterator of tuples <name, param> with name being a string and param being a tensor of shape ``[B, ...]`` where ``B`` is the size of the batch and is the 0th dimension. flat: A flag, when set to `True` returns a flat norm over all layers norms Example: >>> t1 = torch.rand((2, 5)) >>> t2 = torch.rand((2, 5)) >>> calc_sample_norms([("1", t1), ("2", t2)]) [tensor([1.5117, 1.0618])] Returns: A list of tensor norms where length of the list is the number of layers """ norms = [param.view(len(param), -1).norm(2, dim=-1) for name, param in named_params] # calc norm over all layer norms if flat = True if flat: # pyre-fixme[6]: Expected `Union[List[torch.Tensor], # typing.Tuple[torch.Tensor, ...]]` for 1st param but got # `List[torch.FloatTensor]`. norms = [torch.stack(norms, dim=0).norm(2, dim=0)] # pyre-fixme[7]: Expected `Tuple[List[torch.Tensor], Dict[str, float]]` but got # `List[torch.FloatTensor]`. return norms
[docs]def sum_over_all_but_batch_and_last_n( tensor: torch.Tensor, n_dims: int ) -> torch.Tensor: r""" Calculates the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims. This function will ignore the first dimension and it will not aggregate over the last n_dims dimensions. Args: tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``. n_dims: Number of dimensions to keep. Example: >>> tensor = torch.ones(1, 2, 3, 4, 5) >>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape torch.Size([1, 4, 5]) Returns: A tensor of shape ``(B, ..., X[n_dims-1])`` """ if tensor.dim() == n_dims + 1: return tensor else: dims = list(range(1, tensor.dim() - n_dims)) return tensor.sum(dim=dims)