| | |
| | |
| | |
| | |
| | """ |
| | Implementation of the following modules is borrowed from ml-cvnets repo: |
| | https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py |
| | https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py |
| | |
| | Please see ACKNOWLEDGEMENTS for license details. |
| | """ |
| |
|
| | from typing import List, Optional, Union |
| |
|
| | import torch |
| | from torch import Size, Tensor, nn |
| | from torch.nn import functional as F |
| | from torchvision.ops import StochasticDepth |
| |
|
| |
|
| | class LayerNormFP32(nn.LayerNorm): |
| | """ |
| | Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | normalized_shape: Union[int, List[int], Size], |
| | eps: Optional[float] = 1e-5, |
| | elementwise_affine: Optional[bool] = True, |
| | *args, |
| | **kwargs, |
| | ): |
| | super().__init__( |
| | normalized_shape=normalized_shape, |
| | eps=eps, |
| | elementwise_affine=elementwise_affine, |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | |
| | inp_dtype = x.dtype |
| | return super().forward(x.to(torch.float32)).to(inp_dtype) |
| |
|
| |
|
| | def get_normalization_layer(norm_type, num_features): |
| | if norm_type == "layer_norm": |
| | return nn.LayerNorm(num_features) |
| | elif norm_type == "layer_norm_fp32": |
| | return LayerNormFP32(num_features) |
| | else: |
| | raise NotImplementedError(f"Option: {norm_type} not supported.") |
| |
|
| |
|
| | class PositionalEmbedding(nn.Module): |
| | def __init__( |
| | self, |
| | num_embeddings: int, |
| | embedding_dim: int, |
| | padding_idx: Optional[int] = None, |
| | is_learnable: Optional[bool] = False, |
| | interpolation_mode: Optional[str] = "bilinear", |
| | *args, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | |
| | module = LearnablePositionalEmbedding |
| |
|
| | self.pos_embed = module( |
| | num_embeddings=num_embeddings, |
| | embedding_dim=embedding_dim, |
| | padding_idx=padding_idx, |
| | interpolation_mode=interpolation_mode, |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | def forward(self, seq_len: int, *args, **kwargs) -> Tensor: |
| | return self.pos_embed(seq_len, *args, **kwargs) |
| |
|
| | def __repr__(self): |
| | return self.pos_embed.__repr__() |
| |
|
| |
|
| | class LearnablePositionalEmbedding(nn.Module): |
| | """Learnable Positional embedding""" |
| |
|
| | def __init__( |
| | self, |
| | num_embeddings: int, |
| | embedding_dim: int, |
| | padding_idx: Optional[int] = None, |
| | interpolation_mode: Optional[str] = "bilinear", |
| | *args, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim)) |
| | self.embedding_dim = embedding_dim |
| | self.num_embeddings = num_embeddings |
| | self.padding_idx = padding_idx |
| | self.interpolation_mode = interpolation_mode |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self) -> None: |
| | nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5) |
| | if self.padding_idx is not None: |
| | with torch.no_grad(): |
| | self.pos_embed[:, :, self.padding_idx, ...] = 0.0 |
| |
|
| | def forward(self, seq_len: int, *args, **kwargs) -> Tensor: |
| | |
| | pos_embed = self.pos_embed |
| | if self.padding_idx is not None: |
| | with torch.no_grad(): |
| | pos_embed[:, :, self.padding_idx, ...] = 0.0 |
| |
|
| | if seq_len != self.num_embeddings: |
| | pos_embed = F.interpolate( |
| | pos_embed, |
| | size=(seq_len, self.embedding_dim), |
| | mode=self.interpolation_mode, |
| | ) |
| |
|
| | |
| | return pos_embed.reshape(1, seq_len, self.embedding_dim) |
| |
|
| | def __repr__(self): |
| | return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format( |
| | self.__class__.__name__, |
| | self.num_embeddings, |
| | self.embedding_dim, |
| | self.padding_idx, |
| | ) |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | """ |
| | This layer applies a multi-head self- or cross-attention as described in |
| | `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper |
| | |
| | Args: |
| | embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})` |
| | num_heads (int): Number of heads in multi-head attention |
| | attn_dropout (Optional[float]): Attention dropout. Default: 0.0 |
| | bias (Optional[bool]): Use bias or not. Default: ``True`` |
| | |
| | Shape: |
| | - Input: |
| | - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens, |
| | and :math:`C_{in}` is input embedding dim |
| | - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens |
| | - Output: same shape as the input |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | attn_dropout: Optional[float] = 0.0, |
| | bias: Optional[bool] = True, |
| | output_dim: Optional[int] = None, |
| | *args, |
| | **kwargs, |
| | ) -> None: |
| | if output_dim is None: |
| | output_dim = embed_dim |
| | super().__init__() |
| | if embed_dim % num_heads != 0: |
| | Warning( |
| | "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format( |
| | self.__class__.__name__, embed_dim, num_heads |
| | ) |
| | ) |
| |
|
| | self.qkv_proj = nn.Linear( |
| | in_features=embed_dim, out_features=3 * embed_dim, bias=bias |
| | ) |
| |
|
| | self.attn_dropout = nn.Dropout(p=attn_dropout) |
| | self.out_proj = nn.Linear( |
| | in_features=embed_dim, out_features=output_dim, bias=bias |
| | ) |
| |
|
| | self.head_dim = embed_dim // num_heads |
| | self.scaling = self.head_dim**-0.5 |
| | self.softmax = nn.Softmax(dim=-1) |
| | self.num_heads = num_heads |
| | self.embed_dim = embed_dim |
| | self.use_separate_proj_weight = embed_dim != output_dim |
| |
|
| | def __repr__(self): |
| | return "{}(head_dim={}, num_heads={}, attn_dropout={})".format( |
| | self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p |
| | ) |
| |
|
| | def _forward_impl( |
| | self, |
| | x_q: Tensor, |
| | x_kv: Optional[Tensor] = None, |
| | key_padding_mask: Optional[Tensor] = None, |
| | attn_mask: Optional[Tensor] = None, |
| | ) -> Tensor: |
| | |
| | b_sz, S_len, in_channels = x_q.shape |
| |
|
| | if x_kv is None: |
| | |
| | |
| | qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1) |
| | |
| | qkv = qkv.transpose(1, 3).contiguous() |
| |
|
| | |
| | query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] |
| | else: |
| | T_len = x_kv.shape[1] |
| |
|
| | |
| | |
| | query = F.linear( |
| | x_q, |
| | weight=self.qkv_proj.weight[: self.embed_dim, ...], |
| | bias=self.qkv_proj.bias[: self.embed_dim] |
| | if self.qkv_proj.bias is not None |
| | else None, |
| | ) |
| | |
| | query = ( |
| | query.reshape(b_sz, S_len, self.num_heads, self.head_dim) |
| | .transpose(1, 2) |
| | .contiguous() |
| | ) |
| |
|
| | |
| | kv = F.linear( |
| | x_kv, |
| | weight=self.qkv_proj.weight[self.embed_dim :, ...], |
| | bias=self.qkv_proj.bias[self.embed_dim :] |
| | if self.qkv_proj.bias is not None |
| | else None, |
| | ) |
| | |
| | kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim) |
| | |
| | kv = kv.transpose(1, 3).contiguous() |
| | key, value = kv[:, :, 0], kv[:, :, 1] |
| |
|
| | query = query * self.scaling |
| |
|
| | |
| | key = key.transpose(-1, -2) |
| |
|
| | |
| | |
| | attn = torch.matmul(query, key) |
| |
|
| | batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape |
| | if attn_mask is not None: |
| | |
| | assert list(attn_mask.shape) == [ |
| | batch_size, |
| | num_src_tokens, |
| | num_tgt_tokens, |
| | ], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format( |
| | batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape |
| | ) |
| | |
| | attn_mask = attn_mask.unsqueeze(1) |
| | attn = attn + attn_mask |
| |
|
| | if key_padding_mask is not None: |
| | |
| | |
| | assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ |
| | batch_size, |
| | num_tgt_tokens, |
| | ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format( |
| | batch_size, num_tgt_tokens, key_padding_mask.shape |
| | ) |
| | attn = attn.masked_fill( |
| | key_padding_mask.unsqueeze(1) |
| | .unsqueeze(2) |
| | .to(torch.bool), |
| | float("-inf"), |
| | ) |
| |
|
| | attn_dtype = attn.dtype |
| | attn_as_float = self.softmax(attn.float()) |
| | attn = attn_as_float.to(attn_dtype) |
| | attn = self.attn_dropout(attn) |
| |
|
| | |
| | |
| | out = torch.matmul(attn, value) |
| |
|
| | |
| | out = out.transpose(1, 2).reshape(b_sz, S_len, -1) |
| | out = self.out_proj(out) |
| |
|
| | return out |
| |
|
| | def forward( |
| | self, |
| | x_q: Tensor, |
| | x_kv: Optional[Tensor] = None, |
| | key_padding_mask: Optional[Tensor] = None, |
| | attn_mask: Optional[Tensor] = None, |
| | *args, |
| | **kwargs, |
| | ) -> Tensor: |
| | |
| | return self._forward_impl( |
| | x_q=x_q, |
| | x_kv=x_kv, |
| | key_padding_mask=key_padding_mask, |
| | attn_mask=attn_mask, |
| | ) |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| | """ |
| | This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_ |
| | Args: |
| | embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`. |
| | ffn_latent_dim: Inner dimension of the FFN. |
| | num_heads: Number of heads in multi-head attention. Default: 8. |
| | attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0 |
| | dropout: Dropout rate. Default: 0.0. |
| | ffn_dropout: Dropout between FFN layers. Default: 0.0. |
| | transformer_norm_layer: Normalization layer. Default: layer_norm. |
| | stochastic_dropout: Stochastic dropout setting. Default: 0.0. |
| | |
| | Shape: |
| | - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, |
| | and :math:`C_{in}` is input embedding dim |
| | - Output: same shape as the input |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | ffn_latent_dim: int, |
| | num_heads: Optional[int] = 8, |
| | attn_dropout: Optional[float] = 0.0, |
| | dropout: Optional[float] = 0.0, |
| | ffn_dropout: Optional[float] = 0.0, |
| | transformer_norm_layer: Optional[str] = "layer_norm", |
| | stochastic_dropout: Optional[float] = 0.0, |
| | *args, |
| | **kwargs, |
| | ) -> None: |
| |
|
| | super().__init__() |
| |
|
| | |
| | attn_unit = MultiHeadAttention( |
| | embed_dim, |
| | num_heads, |
| | attn_dropout=attn_dropout, |
| | bias=True, |
| | ) |
| |
|
| | self.pre_norm_mha = nn.Sequential( |
| | get_normalization_layer( |
| | norm_type=transformer_norm_layer, num_features=embed_dim |
| | ), |
| | attn_unit, |
| | nn.Dropout(p=dropout), |
| | ) |
| |
|
| | act_name = nn.GELU() |
| | self.pre_norm_ffn = nn.Sequential( |
| | get_normalization_layer( |
| | norm_type=transformer_norm_layer, num_features=embed_dim |
| | ), |
| | nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), |
| | act_name, |
| | nn.Dropout(p=ffn_dropout), |
| | nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), |
| | nn.Dropout(p=dropout), |
| | ) |
| |
|
| | self.drop_path = nn.Identity() |
| | if stochastic_dropout > 0.0: |
| | if dropout > 0.0: |
| | Warning( |
| | "Stochastic dropout and dropout are mutually exclusive. " |
| | "Use either of them, but not both." |
| | "Got: {} and {}".format(stochastic_dropout, dropout) |
| | ) |
| | self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row") |
| |
|
| | self.embed_dim = embed_dim |
| | self.ffn_dim = ffn_latent_dim |
| | self.ffn_dropout = ffn_dropout |
| | self.stochastic_dropout = stochastic_dropout |
| | self.std_dropout = dropout |
| | self.attn_fn_name = attn_unit.__class__.__name__ |
| | self.act_fn_name = act_name.__class__.__name__ |
| | self.norm_type = transformer_norm_layer |
| |
|
| | def __repr__(self) -> str: |
| | return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format( |
| | self.__class__.__name__, |
| | self.embed_dim, |
| | self.ffn_dim, |
| | self.std_dropout, |
| | self.ffn_dropout, |
| | self.stochastic_dropout, |
| | self.attn_fn_name, |
| | self.act_fn_name, |
| | self.norm_type, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: Tensor, |
| | x_prev: Optional[Tensor] = None, |
| | key_padding_mask: Optional[Tensor] = None, |
| | attn_mask: Optional[Tensor] = None, |
| | *args, |
| | **kwargs, |
| | ) -> Tensor: |
| |
|
| | |
| | res = x |
| | x = self.pre_norm_mha[0](x) |
| | x = self.pre_norm_mha[1]( |
| | x_q=x, |
| | x_kv=x_prev, |
| | key_padding_mask=key_padding_mask, |
| | attn_mask=attn_mask, |
| | *args, |
| | **kwargs, |
| | ) |
| |
|
| | x = self.drop_path(self.pre_norm_mha[2](x)) |
| | x = x + res |
| |
|
| | |
| | x = x + self.drop_path(self.pre_norm_ffn(x)) |
| | return x |
| |
|