| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | from transformers import (RobertaConfig, RobertaModel, RobertaTokenizer, |
| | BartConfig, BartForConditionalGeneration, BartTokenizer, |
| | T5Config, T5ForConditionalGeneration, T5Tokenizer) |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), |
| | 't5': (T5Config, T5ForConditionalGeneration, T5Tokenizer), |
| | 'codet5': (T5Config, T5ForConditionalGeneration, RobertaTokenizer), |
| | 'bart': (BartConfig, BartForConditionalGeneration, BartTokenizer)} |
| |
|
| |
|
| | def get_model_size(model): |
| | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
| | model_size = sum([np.prod(p.size()) for p in model_parameters]) |
| | return "{}M".format(round(model_size / 1e+6)) |
| |
|
| |
|
| | def build_or_load_gen_model(args): |
| | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] |
| | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) |
| | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name) |
| | if args.model_type == 'roberta': |
| | encoder = model_class.from_pretrained(args.model_name_or_path, config=config) |
| | decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) |
| | decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) |
| | model = Seq2Seq(encoder=encoder, decoder=decoder, config=config, |
| | beam_size=args.beam_size, max_length=args.max_target_length, |
| | sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id) |
| | else: |
| | model = model_class.from_pretrained(args.model_name_or_path) |
| |
|
| | logger.info("Finish loading model [%s] from %s", get_model_size(model), args.model_name_or_path) |
| |
|
| | if args.load_model_path is not None: |
| | logger.info("Reload model from {}".format(args.load_model_path)) |
| | model.load_state_dict(torch.load(args.load_model_path)) |
| |
|
| | return config, model, tokenizer |
| |
|
| |
|
| | class RobertaClassificationHead(nn.Module): |
| | """Head for sentence-level classification tasks.""" |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size * 2, config.hidden_size) |
| | self.out_proj = nn.Linear(config.hidden_size, 2) |
| |
|
| | def forward(self, x, **kwargs): |
| | x = x.reshape(-1, x.size(-1) * 2) |
| | x = self.dense(x) |
| | x = torch.tanh(x) |
| | x = self.out_proj(x) |
| | return x |
| |
|
| |
|
| | class CloneModel(nn.Module): |
| | def __init__(self, encoder, config, tokenizer, args): |
| | super(CloneModel, self).__init__() |
| | self.encoder = encoder |
| | self.config = config |
| | self.tokenizer = tokenizer |
| | self.classifier = RobertaClassificationHead(config) |
| | self.args = args |
| |
|
| | def get_t5_vec(self, source_ids): |
| | attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| | outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| | labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| | hidden_states = outputs['decoder_hidden_states'][-1] |
| | eos_mask = source_ids.eq(self.config.eos_token_id) |
| |
|
| | if len(torch.unique(eos_mask.sum(1))) > 1: |
| | raise ValueError("All examples must have the same number of <eos> tokens.") |
| | vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| | hidden_states.size(-1))[:, -1, :] |
| | return vec |
| |
|
| | def get_bart_vec(self, source_ids): |
| | attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| | outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| | labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| | hidden_states = outputs['decoder_hidden_states'][-1] |
| | eos_mask = source_ids.eq(self.config.eos_token_id) |
| |
|
| | if len(torch.unique(eos_mask.sum(1))) > 1: |
| | raise ValueError("All examples must have the same number of <eos> tokens.") |
| | vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| | hidden_states.size(-1))[:, -1, :] |
| | return vec |
| |
|
| | def get_roberta_vec(self, source_ids): |
| | attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| | vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][:, 0, :] |
| | return vec |
| |
|
| | def forward(self, source_ids=None, labels=None): |
| | source_ids = source_ids.view(-1, self.args.max_source_length) |
| |
|
| | if self.args.model_type == 'codet5': |
| | vec = self.get_t5_vec(source_ids) |
| | elif self.args.model_type == 'bart': |
| | vec = self.get_bart_vec(source_ids) |
| | elif self.args.model_type == 'roberta': |
| | vec = self.get_roberta_vec(source_ids) |
| |
|
| | logits = self.classifier(vec) |
| | prob = nn.functional.softmax(logits) |
| |
|
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits, labels) |
| | return loss, prob |
| | else: |
| | return prob |
| |
|
| |
|
| | class DefectModel(nn.Module): |
| | def __init__(self, encoder, config, tokenizer, args): |
| | super(DefectModel, self).__init__() |
| | self.encoder = encoder |
| | self.config = config |
| | self.tokenizer = tokenizer |
| | self.classifier = nn.Linear(config.hidden_size, 2) |
| | self.args = args |
| |
|
| | def get_t5_vec(self, source_ids): |
| | attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| | outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| | labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| | hidden_states = outputs['decoder_hidden_states'][-1] |
| | eos_mask = source_ids.eq(self.config.eos_token_id) |
| |
|
| | if len(torch.unique(eos_mask.sum(1))) > 1: |
| | raise ValueError("All examples must have the same number of <eos> tokens.") |
| | vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| | hidden_states.size(-1))[:, -1, :] |
| | return vec |
| |
|
| | def get_bart_vec(self, source_ids): |
| | attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| | outputs = self.encoder(input_ids=source_ids, attention_mask=attention_mask, |
| | labels=source_ids, decoder_attention_mask=attention_mask, output_hidden_states=True) |
| | hidden_states = outputs['decoder_hidden_states'][-1] |
| | eos_mask = source_ids.eq(self.config.eos_token_id) |
| |
|
| | if len(torch.unique(eos_mask.sum(1))) > 1: |
| | raise ValueError("All examples must have the same number of <eos> tokens.") |
| | vec = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, |
| | hidden_states.size(-1))[:, -1, :] |
| | return vec |
| |
|
| | def get_roberta_vec(self, source_ids): |
| | attention_mask = source_ids.ne(self.tokenizer.pad_token_id) |
| | vec = self.encoder(input_ids=source_ids, attention_mask=attention_mask)[0][:, 0, :] |
| | return vec |
| |
|
| | def forward(self, source_ids=None, labels=None): |
| | source_ids = source_ids.view(-1, self.args.max_source_length) |
| |
|
| | if self.args.model_type == 'codet5': |
| | vec = self.get_t5_vec(source_ids) |
| | elif self.args.model_type == 'bart': |
| | vec = self.get_bart_vec(source_ids) |
| | elif self.args.model_type == 'roberta': |
| | vec = self.get_roberta_vec(source_ids) |
| |
|
| | logits = self.classifier(vec) |
| | prob = nn.functional.softmax(logits) |
| |
|
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits, labels) |
| | return loss, prob |
| | else: |
| | return prob |
| |
|
| |
|
| | |
| | class Seq2Seq(nn.Module): |
| | """ |
| | Build Seqence-to-Sequence. |
| | |
| | Parameters: |
| | |
| | * `encoder`- encoder of seq2seq model. e.g. roberta |
| | * `decoder`- decoder of seq2seq model. e.g. transformer |
| | * `config`- configuration of encoder model. |
| | * `beam_size`- beam size for beam search. |
| | * `max_length`- max length of target for beam search. |
| | * `sos_id`- start of symbol ids in target for beam search. |
| | * `eos_id`- end of symbol ids in target for beam search. |
| | """ |
| |
|
| | def __init__(self, encoder, decoder, config, beam_size=None, max_length=None, sos_id=None, eos_id=None): |
| | super(Seq2Seq, self).__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.config = config |
| | self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.lsm = nn.LogSoftmax(dim=-1) |
| | self.tie_weights() |
| |
|
| | self.beam_size = beam_size |
| | self.max_length = max_length |
| | self.sos_id = sos_id |
| | self.eos_id = eos_id |
| |
|
| | def _tie_or_clone_weights(self, first_module, second_module): |
| | """ Tie or clone module weights depending of weither we are using TorchScript or not |
| | """ |
| | if self.config.torchscript: |
| | first_module.weight = nn.Parameter(second_module.weight.clone()) |
| | else: |
| | first_module.weight = second_module.weight |
| |
|
| | def tie_weights(self): |
| | """ Make sure we are sharing the input and output embeddings. |
| | Export to TorchScript can't handle parameter sharing so we are cloning them instead. |
| | """ |
| | self._tie_or_clone_weights(self.lm_head, |
| | self.encoder.embeddings.word_embeddings) |
| |
|
| | def forward(self, source_ids=None, source_mask=None, target_ids=None, target_mask=None, args=None): |
| | outputs = self.encoder(source_ids, attention_mask=source_mask) |
| | encoder_output = outputs[0].permute([1, 0, 2]).contiguous() |
| | if target_ids is not None: |
| | attn_mask = -1e4 * (1 - self.bias[:target_ids.shape[1], :target_ids.shape[1]]) |
| | tgt_embeddings = self.encoder.embeddings(target_ids).permute([1, 0, 2]).contiguous() |
| | out = self.decoder(tgt_embeddings, encoder_output, tgt_mask=attn_mask, |
| | memory_key_padding_mask=~source_mask) |
| | |
| | hidden_states = torch.tanh(self.dense(out)).permute([1, 0, 2]).contiguous() |
| | lm_logits = self.lm_head(hidden_states) |
| | |
| | active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 |
| | shift_logits = lm_logits[..., :-1, :].contiguous() |
| | shift_labels = target_ids[..., 1:].contiguous() |
| | |
| | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
| | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], |
| | shift_labels.view(-1)[active_loss]) |
| |
|
| | outputs = loss, loss * active_loss.sum(), active_loss.sum() |
| | return outputs |
| | else: |
| | |
| | preds = [] |
| | zero = torch.cuda.LongTensor(1).fill_(0) |
| | for i in range(source_ids.shape[0]): |
| | context = encoder_output[:, i:i + 1] |
| | context_mask = source_mask[i:i + 1, :] |
| | beam = Beam(self.beam_size, self.sos_id, self.eos_id) |
| | input_ids = beam.getCurrentState() |
| | context = context.repeat(1, self.beam_size, 1) |
| | context_mask = context_mask.repeat(self.beam_size, 1) |
| | for _ in range(self.max_length): |
| | if beam.done(): |
| | break |
| | attn_mask = -1e4 * (1 - self.bias[:input_ids.shape[1], :input_ids.shape[1]]) |
| | tgt_embeddings = self.encoder.embeddings(input_ids).permute([1, 0, 2]).contiguous() |
| | out = self.decoder(tgt_embeddings, context, tgt_mask=attn_mask, |
| | memory_key_padding_mask=~context_mask) |
| | |
| | out = torch.tanh(self.dense(out)) |
| | hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :] |
| | out = self.lsm(self.lm_head(hidden_states)).data |
| | beam.advance(out) |
| | input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) |
| | input_ids = torch.cat((input_ids, beam.getCurrentState()), -1) |
| | hyp = beam.getHyp(beam.getFinal()) |
| | pred = beam.buildTargetTokens(hyp)[:self.beam_size] |
| | pred = [torch.cat([x.view(-1) for x in p] + [zero] * (self.max_length - len(p))).view(1, -1) for p in |
| | pred] |
| | preds.append(torch.cat(pred, 0).unsqueeze(0)) |
| |
|
| | preds = torch.cat(preds, 0) |
| | return preds |
| |
|
| |
|
| | class Beam(object): |
| | def __init__(self, size, sos, eos): |
| | self.size = size |
| | self.tt = torch.cuda |
| | |
| | self.scores = self.tt.FloatTensor(size).zero_() |
| | |
| | self.prevKs = [] |
| | |
| | self.nextYs = [self.tt.LongTensor(size) |
| | .fill_(0)] |
| | self.nextYs[0][0] = sos |
| | |
| | self._eos = eos |
| | self.eosTop = False |
| | |
| | self.finished = [] |
| |
|
| | def getCurrentState(self): |
| | "Get the outputs for the current timestep." |
| | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) |
| | return batch |
| |
|
| | def getCurrentOrigin(self): |
| | "Get the backpointers for the current timestep." |
| | return self.prevKs[-1] |
| |
|
| | def advance(self, wordLk): |
| | """ |
| | Given prob over words for every last beam `wordLk` and attention |
| | `attnOut`: Compute and update the beam search. |
| | |
| | Parameters: |
| | |
| | * `wordLk`- probs of advancing from the last step (K x words) |
| | * `attnOut`- attention at the last step |
| | |
| | Returns: True if beam search is complete. |
| | """ |
| | numWords = wordLk.size(1) |
| |
|
| | |
| | if len(self.prevKs) > 0: |
| | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) |
| |
|
| | |
| | for i in range(self.nextYs[-1].size(0)): |
| | if self.nextYs[-1][i] == self._eos: |
| | beamLk[i] = -1e20 |
| | else: |
| | beamLk = wordLk[0] |
| | flatBeamLk = beamLk.view(-1) |
| | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) |
| |
|
| | self.scores = bestScores |
| |
|
| | |
| | |
| | prevK = bestScoresId // numWords |
| | self.prevKs.append(prevK) |
| | self.nextYs.append((bestScoresId - prevK * numWords)) |
| |
|
| | for i in range(self.nextYs[-1].size(0)): |
| | if self.nextYs[-1][i] == self._eos: |
| | s = self.scores[i] |
| | self.finished.append((s, len(self.nextYs) - 1, i)) |
| |
|
| | |
| | if self.nextYs[-1][0] == self._eos: |
| | self.eosTop = True |
| |
|
| | def done(self): |
| | return self.eosTop and len(self.finished) >= self.size |
| |
|
| | def getFinal(self): |
| | if len(self.finished) == 0: |
| | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) |
| | self.finished.sort(key=lambda a: -a[0]) |
| | if len(self.finished) != self.size: |
| | unfinished = [] |
| | for i in range(self.nextYs[-1].size(0)): |
| | if self.nextYs[-1][i] != self._eos: |
| | s = self.scores[i] |
| | unfinished.append((s, len(self.nextYs) - 1, i)) |
| | unfinished.sort(key=lambda a: -a[0]) |
| | self.finished += unfinished[:self.size - len(self.finished)] |
| | return self.finished[:self.size] |
| |
|
| | def getHyp(self, beam_res): |
| | """ |
| | Walk back to construct the full hypothesis. |
| | """ |
| | hyps = [] |
| | for _, timestep, k in beam_res: |
| | hyp = [] |
| | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): |
| | hyp.append(self.nextYs[j + 1][k]) |
| | k = self.prevKs[j][k] |
| | hyps.append(hyp[::-1]) |
| | return hyps |
| |
|
| | def buildTargetTokens(self, preds): |
| | sentence = [] |
| | for pred in preds: |
| | tokens = [] |
| | for tok in pred: |
| | if tok == self._eos: |
| | break |
| | tokens.append(tok) |
| | sentence.append(tokens) |
| | return sentence |
| |
|