#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
[docs]
class SequenceBias(nn.Module):
r"""
Adds one bias element to the end of the sequence.
so if the input has a shape ``(L, N, E)``, (``batch_first = False``),
where ``L`` is the sequence length, ``N`` is the batch size, and ``E`` is
the embedding dimension, the output will have a shape
``(L+1, N, E)``. When ``batch_first = True``, input has a shape ``(N, L, E)``
and the output will have a shape ``(N, L+1, E)``
Attributes:
bias (:class:`torch.nn.parameter.Parameter`): the learnable bias of
the module of shape ``(E)``, where ``E`` is the embedding dimension.
Example:
>>> m = SequenceBias(16, batch_first=False)
>>> input = torch.randn(20, 4, 16)
>>> output = m(input)
>>> output.size()
torch.Size([21, 4, 16])
"""
def __init__(self, embed_dim: int, batch_first: bool = False):
r"""
Args:
embed_dim: Embedding dimension
"""
super(SequenceBias, self).__init__()
self.batch_first = batch_first
self.bias = Parameter(torch.empty(embed_dim))
self._reset_parameters()
def _reset_parameters(self):
r"""
assigns Normally distributed random values to bias.
"""
nn.init.normal_(self.bias)
[docs]
def forward(self, x):
if self.batch_first:
bsz, _, _ = x.shape
return torch.cat([x, self.bias.repeat(bsz, 1, 1)], 1)
else:
_, bsz, _ = x.shape
return torch.cat([x, self.bias.repeat(1, bsz, 1)])
[docs]
class DPMultiheadAttention(nn.Module):
r"""
This is DP-friendly implementation of nn.MultiheadAttention.
For full reference see original module refer to
:class:`torch.nn.MultiheadAttention`.
Current implementation leverages pytorch modules as building blocks
to allow DP engine to calculate per-sample gradients.
This is in contrast with original implementation based on nn.functional.
"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
):
super(DPMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
# when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False.
self._qkv_same_embed_dim = False
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.qlinear = nn.Linear(embed_dim, embed_dim, bias=bias)
self.klinear = nn.Linear(self.kdim, embed_dim, bias=bias)
self.vlinear = nn.Linear(self.vdim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.add_bias_kv = add_bias_kv
if self.add_bias_kv:
self.seq_bias_k = SequenceBias(embed_dim)
self.seq_bias_v = SequenceBias(embed_dim)
self.add_zero_attn = add_zero_attn
self.dropout = nn.Dropout(dropout)
# to avoid null pointers in Transformer.forward
self.in_proj_weight = None
self.in_proj_bias = None
[docs]
def load_state_dict(self, state_dict):
r"""
Loads module from previously saved state.
Supports loading from both :class:`torch.nn.MultiheadAttention` and
:class:`opacus.layers.dp_multihead_attention.DPMultiheadAttention`.
Args:
state_dict: Please refer to
https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html.
"""
if "in_proj_weight" in state_dict:
qweight, kweight, vweight = state_dict["in_proj_weight"].chunk(3, dim=0)
state_dict["qlinear.weight"] = qweight
state_dict["klinear.weight"] = kweight
state_dict["vlinear.weight"] = vweight
del state_dict["in_proj_weight"]
if "in_proj_bias" in state_dict:
qbias, kbias, vbias = state_dict["in_proj_bias"].chunk(3, dim=0)
state_dict["qlinear.bias"] = qbias
state_dict["klinear.bias"] = kbias
state_dict["vlinear.bias"] = vbias
del state_dict["in_proj_bias"]
if "bias_k" in state_dict:
state_dict["seq_bias_k.bias"] = state_dict["bias_k"].squeeze()
del state_dict["bias_k"]
if "bias_v" in state_dict:
state_dict["seq_bias_v.bias"] = state_dict["bias_v"].squeeze()
del state_dict["bias_v"]
if "q_proj_weight" in state_dict:
state_dict["qlinear.weight"] = state_dict["q_proj_weight"]
del state_dict["q_proj_weight"]
if "k_proj_weight" in state_dict:
state_dict["klinear.weight"] = state_dict["k_proj_weight"]
del state_dict["k_proj_weight"]
if "v_proj_weight" in state_dict:
state_dict["vlinear.weight"] = state_dict["v_proj_weight"]
del state_dict["v_proj_weight"]
super(DPMultiheadAttention, self).load_state_dict(state_dict)
# flake8: noqa C901
[docs]
def forward(
self,
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
is_causal=False,
):
is_batched = query.dim() == 3
assert is_batched == True, "The query must have a dimension of 3."
r"""
As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,
since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.
"""
assert (
is_causal == False
), "We currently do not support causal mask. Will fix it in the future."
r"""
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
"""
if not self.batch_first:
tgt_len, bsz, embed_dim = query.size()
else:
bsz, tgt_len, embed_dim = query.size()
if embed_dim != self.embed_dim:
raise ValueError(
f"query has as size of {embed_dim} while the embedding"
" size is {self.embed_dim}"
)
head_dim = embed_dim // self.num_heads
if head_dim * self.num_heads != embed_dim:
raise ValueError(
f"embedding dimension {embed_dim} not divisible "
"by number of heads {num_heads}"
)
scaling = float(head_dim) ** -0.5
q = self.qlinear(query)
k = self.klinear(key)
v = self.vlinear(value)
q = q * scaling
if self.batch_first:
q, k, v = [x.transpose(0, 1) for x in (q, k, v)]
if attn_mask is not None:
if attn_mask.dtype not in (
torch.float32,
torch.float64,
torch.uint8,
torch.bool,
):
raise ValueError(
f"Only float, byte, and bool types are supported for attn_mask, "
f"not {attn_mask.dtype}."
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated."
"Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise ValueError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * self.num_heads,
query.size(0),
key.size(0),
]:
raise ValueError("The size of the 3D attn_mask is not correct.")
else:
raise ValueError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask in nn.MultiheadAttention"
"is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
if self.add_bias_kv:
k = self.seq_bias_k(k)
v = self.seq_bias_v(v)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k = torch.cat(
[
k,
torch.zeros(
(k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
),
],
dim=1,
)
v = torch.cat(
[
v,
torch.zeros(
(v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
),
],
dim=1,
)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [
bsz * self.num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
attn_output_weights = attn_output_weights.view(
bsz * self.num_heads, tgt_len, src_len
)
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = self.dropout(attn_output_weights)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
if self.batch_first:
attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim)
else:
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = self.out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / self.num_heads
else:
return attn_output, None
def unsqueeze_0_2(self, t):
return torch.unsqueeze(torch.unsqueeze(t, 0), 0)
[docs]
def state_dict(self, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
destination_alter = OrderedDict()
if len(prefix.split(".")) > 2:
alter_key = ".".join(prefix.split(".")[:-2]) + ".emb.weight"
else:
alter_key = "emb.weight"
if alter_key in destination:
destination_alter[alter_key] = destination[alter_key]
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(
destination=destination,
prefix=prefix + name + ".",
keep_vars=keep_vars,
)
if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim):
destination_alter[prefix + "in_proj_weight"] = torch.cat(
(
destination[prefix + "qlinear.weight"],
destination[prefix + "klinear.weight"],
destination[prefix + "vlinear.weight"],
),
0,
)
else:
destination_alter[prefix + "q_proj_weight"] = destination[
prefix + "qlinear.weight"
]
destination_alter[prefix + "k_proj_weight"] = destination[
prefix + "klinear.weight"
]
destination_alter[prefix + "v_proj_weight"] = destination[
prefix + "vlinear.weight"
]
if (
(prefix + "qlinear.bias") in destination
and (prefix + "klinear.bias") in destination
and (prefix + "vlinear.bias") in destination
):
destination_alter[prefix + "in_proj_bias"] = torch.cat(
(
destination[prefix + "qlinear.bias"],
destination[prefix + "klinear.bias"],
destination[prefix + "vlinear.bias"],
),
0,
)
if self.add_bias_kv:
destination_alter[prefix + "bias_k"] = self.unsqueeze_0_2(
destination[prefix + "seq_bias_k.bias"]
)
destination_alter[prefix + "bias_v"] = self.unsqueeze_0_2(
destination[prefix + "seq_bias_v.bias"]
)
destination_alter[prefix + "out_proj.weight"] = destination[
prefix + "out_proj.weight"
]
if (prefix + "out_proj.bias") in destination:
destination_alter[prefix + "out_proj.bias"] = destination[
prefix + "out_proj.bias"
]
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination_alter, prefix, local_metadata)
if hook_result is not None:
destination_alter = hook_result
return destination_alter