from collections.abc import Generator, Iterable from dataclasses import dataclass from enum import StrEnum import pprint import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, BatchEncoding, ModernBertModel, PreTrainedConfig, PreTrainedModel, PreTrainedTokenizer, ) from transformers.modeling_outputs import TokenClassifierOutput BATCH_SIZE = 16 class ModelURI(StrEnum): BASE = "answerdotai/ModernBERT-base" LARGE = "answerdotai/ModernBERT-large" @dataclass(slots=True, frozen=True) class LexicalExample: concept: str definition: str @dataclass(slots=True, frozen=True) class PaddedBatch: input_ids: torch.Tensor attention_mask: torch.Tensor class DisamBertSingleSense(PreTrainedModel): def __init__(self, config: PreTrainedConfig): super().__init__(config) if config.init_basemodel: self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") self.config.vocab_size += 2 self.BaseModel.resize_token_embeddings(self.config.vocab_size) else: self.BaseModel = ModernBertModel(config) config.init_basemodel = False self.loss = nn.CrossEntropyLoss() self.post_init() @classmethod def from_base(cls, base_id: ModelURI): config = AutoConfig.from_pretrained(base_id) config.init_basemodel = True return cls(config) def add_special_tokens(self, start: int, end: int): self.config.start_token = start self.config.end_token = end def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, candidate_tokens: torch.Tensor, candidate_attention_masks: torch.Tensor, candidate_mapping: torch.Tensor, labels: Iterable[int] | None = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> TokenClassifierOutput: base_model_output = self.BaseModel( input_ids, attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) token_vectors = base_model_output.last_hidden_state selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) starts = (input_ids == self.config.start_token).nonzero() ends = (input_ids == self.config.end_token).nonzero() for startpos, endpos in zip(starts, ends, strict=True): selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) gloss_vectors = self.gloss_vectors( candidate_tokens, candidate_attention_masks, candidate_mapping ) logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) return TokenClassifierOutput( logits=logits, loss=self.loss(logits, labels) if labels is not None else None, hidden_states=base_model_output.hidden_states if output_hidden_states else None, attentions=base_model_output.attentions if output_attentions else None, ) def gloss_vectors(self, candidates, candidate_attention_masks, candidate_mapping): with self.device: vectors = self.BaseModel(candidates, candidate_attention_masks).last_hidden_state[:, 0] chunks = [ torch.squeeze(vectors[(candidate_mapping == sentence_index).nonzero()], dim=1) for sentence_index in torch.unique(candidate_mapping) ] maxlen = max(chunk.shape[0] for chunk in chunks) return torch.stack( [ torch.cat( [chunk, torch.zeros((maxlen - chunk.shape[0], self.config.hidden_size))] ) for chunk in chunks ] ) class CandidateLabeller: def __init__( self, tokenizer: PreTrainedTokenizer, ontology: Generator[LexicalExample], device: torch.device, retain_candidates: bool = False, ): self.tokenizer = tokenizer self.device = device self.gloss_tokens = { example.concept: self.tokenizer(example.definition, padding=True) for example in ontology } self.retain_candidates = retain_candidates def __call__(self, batch: dict) -> dict: with self.device: encoded = [ BatchEncoding( {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} ) for example in batch ] tokens = self.tokenizer.pad(encoded, padding=True, return_tensors="pt") candidate_tokens = self.tokenizer.pad( [ self.gloss_tokens[concept] for example in batch for concept in example["candidates"] ], padding=True, return_attention_mask=True, return_tensors="pt", ) result = { "input_ids": tokens.input_ids, "attention_mask": tokens.attention_mask, "candidate_tokens": candidate_tokens.input_ids, "candidate_attention_masks": candidate_tokens.attention_mask, "candidate_mapping": torch.cat( [ torch.tensor([i] * len(example["candidates"])) for (i, example) in enumerate(batch) ] ), } if "label" in batch[0]: result["labels"] = torch.tensor( [example["candidates"].index(example["label"]) for example in batch] ) if self.retain_candidates: result["candidates"] = [example["candidates"] for example in batch] return result