| | import torch |
| | from torch import Tensor |
| | from torch import nn |
| | from typing import Union, Tuple, List, Iterable, Dict |
| | import os |
| | import json |
| |
|
| |
|
| | class Pooling(nn.Module): |
| | """Performs pooling (max or mean) on the token embeddings. |
| | |
| | Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. |
| | You can concatenate multiple poolings together. |
| | |
| | :param word_embedding_dimension: Dimensions for the word embeddings |
| | :param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings |
| | :param pooling_mode_cls_token: Use the first token (CLS token) as text representations |
| | :param pooling_mode_max_tokens: Use max in each dimension over all tokens. |
| | :param pooling_mode_mean_tokens: Perform mean-pooling |
| | :param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but devide by sqrt(input_length). |
| | :param pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling, see https://arxiv.org/abs/2202.08904 |
| | :param pooling_mode_lasttoken: Perform last token pooling, see https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005 |
| | """ |
| | def __init__(self, |
| | word_embedding_dimension: int, |
| | pooling_mode: str = None, |
| | pooling_mode_cls_token: bool = False, |
| | pooling_mode_max_tokens: bool = False, |
| | pooling_mode_mean_tokens: bool = True, |
| | pooling_mode_mean_sqrt_len_tokens: bool = False, |
| | pooling_mode_weightedmean_tokens: bool = False, |
| | pooling_mode_lasttoken: bool = False, |
| | ): |
| | super(Pooling, self).__init__() |
| |
|
| | self.config_keys = ['word_embedding_dimension', 'pooling_mode_cls_token', 'pooling_mode_mean_tokens', 'pooling_mode_max_tokens', |
| | 'pooling_mode_mean_sqrt_len_tokens', 'pooling_mode_weightedmean_tokens', 'pooling_mode_lasttoken'] |
| |
|
| | if pooling_mode is not None: |
| | pooling_mode = pooling_mode.lower() |
| | assert pooling_mode in ['mean', 'max', 'cls', 'weightedmean', 'lasttoken'] |
| | pooling_mode_cls_token = (pooling_mode == 'cls') |
| | pooling_mode_max_tokens = (pooling_mode == 'max') |
| | pooling_mode_mean_tokens = (pooling_mode == 'mean') |
| | pooling_mode_weightedmean_tokens = (pooling_mode == 'weightedmean') |
| | pooling_mode_lasttoken = (pooling_mode == 'lasttoken') |
| |
|
| | self.word_embedding_dimension = word_embedding_dimension |
| | self.pooling_mode_cls_token = pooling_mode_cls_token |
| | self.pooling_mode_mean_tokens = pooling_mode_mean_tokens |
| | self.pooling_mode_max_tokens = pooling_mode_max_tokens |
| | self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens |
| | self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens |
| | self.pooling_mode_lasttoken = pooling_mode_lasttoken |
| |
|
| | pooling_mode_multiplier = sum([pooling_mode_cls_token, pooling_mode_max_tokens, pooling_mode_mean_tokens, |
| | pooling_mode_mean_sqrt_len_tokens, pooling_mode_weightedmean_tokens, pooling_mode_lasttoken]) |
| | self.pooling_output_dimension = (pooling_mode_multiplier * word_embedding_dimension) |
| |
|
| |
|
| | def __repr__(self): |
| | return "Pooling({})".format(self.get_config_dict()) |
| |
|
| | def get_pooling_mode_str(self) -> str: |
| | """ |
| | Returns the pooling mode as string |
| | """ |
| | modes = [] |
| | if self.pooling_mode_cls_token: |
| | modes.append('cls') |
| | if self.pooling_mode_mean_tokens: |
| | modes.append('mean') |
| | if self.pooling_mode_max_tokens: |
| | modes.append('max') |
| | if self.pooling_mode_mean_sqrt_len_tokens: |
| | modes.append('mean_sqrt_len_tokens') |
| | if self.pooling_mode_weightedmean_tokens: |
| | modes.append('weightedmean') |
| | if self.pooling_mode_lasttoken: |
| | modes.append('lasttoken') |
| |
|
| | return "+".join(modes) |
| |
|
| | def forward(self, features: Dict[str, Tensor]): |
| | token_embeddings = features['token_embeddings'] |
| | attention_mask = features['attention_mask'] |
| |
|
| | |
| | output_vectors = [] |
| | if self.pooling_mode_cls_token: |
| | cls_token = features.get('cls_token_embeddings', token_embeddings[:, 0]) |
| | output_vectors.append(cls_token) |
| | if self.pooling_mode_max_tokens: |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | token_embeddings[input_mask_expanded == 0] = -1e9 |
| | max_over_time = torch.max(token_embeddings, 1)[0] |
| | output_vectors.append(max_over_time) |
| | if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
| |
|
| | |
| | if 'token_weights_sum' in features: |
| | sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size()) |
| | else: |
| | sum_mask = input_mask_expanded.sum(1) |
| |
|
| | sum_mask = torch.clamp(sum_mask, min=1e-9) |
| |
|
| | if self.pooling_mode_mean_tokens: |
| | output_vectors.append(sum_embeddings / sum_mask) |
| | if self.pooling_mode_mean_sqrt_len_tokens: |
| | output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) |
| | if self.pooling_mode_weightedmean_tokens: |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | |
| | weights = ( |
| | torch.arange(start=1, end=token_embeddings.shape[1] + 1) |
| | .unsqueeze(0) |
| | .unsqueeze(-1) |
| | .expand(token_embeddings.size()) |
| | .float().to(token_embeddings.device) |
| | ) |
| | assert weights.shape == token_embeddings.shape == input_mask_expanded.shape |
| | input_mask_expanded = input_mask_expanded * weights |
| | |
| | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
| |
|
| | |
| | if 'token_weights_sum' in features: |
| | sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size()) |
| | else: |
| | sum_mask = input_mask_expanded.sum(1) |
| |
|
| | sum_mask = torch.clamp(sum_mask, min=1e-9) |
| | output_vectors.append(sum_embeddings / sum_mask) |
| | if self.pooling_mode_lasttoken: |
| | bs, seq_len, hidden_dim = token_embeddings.shape |
| | |
| | |
| | |
| | gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 |
| |
|
| | |
| | gather_indices = torch.clamp(gather_indices, min=0) |
| | |
| | |
| | gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) |
| | gather_indices = gather_indices.unsqueeze(1) |
| | assert gather_indices.shape == (bs, 1, hidden_dim) |
| |
|
| | |
| | |
| | |
| | |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1) |
| | output_vectors.append(embedding) |
| |
|
| | output_vector = torch.cat(output_vectors, 1) |
| | features.update({'sentence_embedding': output_vector}) |
| | return features |
| |
|
| | def get_sentence_embedding_dimension(self): |
| | return self.pooling_output_dimension |
| |
|
| | def get_config_dict(self): |
| | return {key: self.__dict__[key] for key in self.config_keys} |
| |
|
| | def save(self, output_path): |
| | with open(os.path.join(output_path, 'config.json'), 'w') as fOut: |
| | json.dump(self.get_config_dict(), fOut, indent=2) |
| |
|
| | @staticmethod |
| | def load(input_path): |
| | with open(os.path.join(input_path, 'config.json')) as fIn: |
| | config = json.load(fIn) |
| |
|
| | return Pooling(**config) |
| |
|