import copy from typing import Optional, Tuple, Union import torch from torch import nn from transformers.models.m2m_100.modeling_m2m_100 import ( M2M100Config, M2M100ScaledWordEmbedding, M2M100Decoder, M2M100PreTrainedModel, GenerationMixin, Seq2SeqLMOutput, BaseModelOutput, shift_tokens_right, Cache, CrossEntropyLoss, ) # override model type to register AutoModels class SonarDecoderConfig(M2M100Config): model_type = "SonarDecoderModel" class SonarDecoderModel(M2M100PreTrainedModel, GenerationMixin): # override confing class to register AutoModels config_class = SonarDecoderConfig _tied_weights_keys = { "decoder.embed_tokens.weight": "shared.weight", "lm_head.weight": "shared.weight", } _keys_to_ignore_on_load_unexpected = [r"encoder"] def __init__(self, config: M2M100Config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) decoder_config = copy.deepcopy(config) decoder_config.use_cache = False decoder_config.is_encoder_decoder = False self.decoder = M2M100Decoder(decoder_config) self.lm_head = nn.Linear(config.d_model, self.shared.num_embeddings, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, value): self.shared = value self.decoder.embed_tokens = self.shared def _tie_weights(self): if self.config.tie_word_embeddings: self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) # in SONAR models, input and output projections are tied (ideally, this should be configurable) self._tie_or_clone_weights(self.lm_head, self.shared) def get_decoder(self): return self.decoder def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, decoder_input_ids: torch.LongTensor | None = None, decoder_attention_mask: torch.LongTensor | None = None, encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, decoder_inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.Tensor | None = None, **kwargs, ) -> tuple[torch.Tensor] | Seq2SeqLMOutput: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) if encoder_outputs is None: raise ValueError("M2M100DecoderModel expects the `encoder_outputs` to be always present.") if return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) lm_logits = self.lm_head(decoder_outputs[0]) masked_lm_loss = None if labels is not None: # move labels to the correct device to enable PP labels = labels.to(lm_logits.device) loss_fct = CrossEntropyLoss() masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (lm_logits,) + decoder_outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past @classmethod def _can_set_experts_implementation(cls) -> bool: return False