aaljabari commited on
Commit
0b2b5cc
·
verified ·
1 Parent(s): 4462f0f

Create BertTrainer.py

Browse files
Files changed (1) hide show
  1. Nested/trainers/BertTrainer.py +163 -0
Nested/trainers/BertTrainer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from Nested.trainers import BaseTrainer
6
+ from Nested.utils.metrics import compute_single_label_metrics
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class BertTrainer(BaseTrainer):
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ def train(self):
16
+ best_val_loss, test_loss = np.inf, np.inf
17
+ num_train_batch = len(self.train_dataloader)
18
+ patience = self.patience
19
+
20
+ for epoch_index in range(self.max_epochs):
21
+ self.current_epoch = epoch_index
22
+ train_loss = 0
23
+
24
+ for batch_index, (_, gold_tags, _, _, logits) in enumerate(self.tag(
25
+ self.train_dataloader, is_train=True
26
+ ), 1):
27
+ self.current_timestep += 1
28
+ batch_loss = self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
29
+ batch_loss.backward()
30
+
31
+ # Avoid exploding gradient by doing gradient clipping
32
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
33
+
34
+ self.optimizer.step()
35
+ self.scheduler.step()
36
+ train_loss += batch_loss.item()
37
+
38
+ if self.current_timestep % self.log_interval == 0:
39
+ logger.info(
40
+ "Epoch %d | Batch %d/%d | Timestep %d | LR %.10f | Loss %f",
41
+ epoch_index,
42
+ batch_index,
43
+ num_train_batch,
44
+ self.current_timestep,
45
+ self.optimizer.param_groups[0]['lr'],
46
+ batch_loss.item()
47
+ )
48
+
49
+ train_loss /= num_train_batch
50
+
51
+ logger.info("** Evaluating on validation dataset **")
52
+ val_preds, segments, valid_len, val_loss = self.eval(self.val_dataloader)
53
+ val_metrics = compute_single_label_metrics(segments)
54
+
55
+ epoch_summary_loss = {
56
+ "train_loss": train_loss,
57
+ "val_loss": val_loss
58
+ }
59
+ epoch_summary_metrics = {
60
+ "val_micro_f1": val_metrics.micro_f1,
61
+ "val_precision": val_metrics.precision,
62
+ "val_recall": val_metrics.recall
63
+ }
64
+
65
+ logger.info(
66
+ "Epoch %d | Timestep %d | Train Loss %f | Val Loss %f | F1 %f",
67
+ epoch_index,
68
+ self.current_timestep,
69
+ train_loss,
70
+ val_loss,
71
+ val_metrics.micro_f1
72
+ )
73
+
74
+ if val_loss < best_val_loss:
75
+ patience = self.patience
76
+ best_val_loss = val_loss
77
+ logger.info("** Validation improved, evaluating test data **")
78
+ test_preds, segments, valid_len, test_loss = self.eval(self.test_dataloader)
79
+ self.segments_to_file(segments, os.path.join(self.output_path, "predictions.txt"))
80
+ test_metrics = compute_single_label_metrics(segments)
81
+
82
+ epoch_summary_loss["test_loss"] = test_loss
83
+ epoch_summary_metrics["test_micro_f1"] = test_metrics.micro_f1
84
+ epoch_summary_metrics["test_precision"] = test_metrics.precision
85
+ epoch_summary_metrics["test_recall"] = test_metrics.recall
86
+
87
+ logger.info(
88
+ f"Epoch %d | Timestep %d | Test Loss %f | F1 %f",
89
+ epoch_index,
90
+ self.current_timestep,
91
+ test_loss,
92
+ test_metrics.micro_f1
93
+ )
94
+
95
+ self.save()
96
+ else:
97
+ patience -= 1
98
+
99
+ # No improvements, terminating early
100
+ if patience == 0:
101
+ logger.info("Early termination triggered")
102
+ break
103
+
104
+ self.summary_writer.add_scalars("Loss", epoch_summary_loss, global_step=self.current_timestep)
105
+ self.summary_writer.add_scalars("Metrics", epoch_summary_metrics, global_step=self.current_timestep)
106
+
107
+ def eval(self, dataloader):
108
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
109
+ loss = 0
110
+
111
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
112
+ dataloader, is_train=False
113
+ ):
114
+ loss += self.loss(logits.view(-1, logits.shape[-1]), gold_tags.view(-1))
115
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
116
+ segments += tokens
117
+ valid_lens += list(valid_len)
118
+
119
+ loss /= len(dataloader)
120
+
121
+ # Update segments, attach predicted tags to each token
122
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
123
+
124
+ return preds, segments, valid_lens, loss.item()
125
+
126
+ def infer(self, dataloader):
127
+ golds, preds, segments, valid_lens = list(), list(), list(), list()
128
+
129
+ for _, gold_tags, tokens, valid_len, logits in self.tag(
130
+ dataloader, is_train=False
131
+ ):
132
+ preds += torch.argmax(logits, dim=2).detach().cpu().numpy().tolist()
133
+ segments += tokens
134
+ valid_lens += list(valid_len)
135
+
136
+ segments = self.to_segments(segments, preds, valid_lens, dataloader.dataset.vocab)
137
+ return segments
138
+
139
+ def to_segments(self, segments, preds, valid_lens, vocab):
140
+ if vocab is None:
141
+ vocab = self.vocab
142
+
143
+ tagged_segments = list()
144
+ tokens_stoi = vocab.tokens.get_stoi()
145
+ tags_itos = vocab.tags[0].get_itos()
146
+ unk_id = tokens_stoi["UNK"]
147
+
148
+ for segment, pred, valid_len in zip(segments, preds, valid_lens):
149
+ # First, the token at 0th index [CLS] and token at nth index [SEP]
150
+ # Combine the tokens with their corresponding predictions
151
+ segment_pred = zip(segment[1:valid_len-1], pred[1:valid_len-1])
152
+
153
+ # Ignore the sub-tokens/subwords, which are identified with text being UNK
154
+ segment_pred = list(filter(lambda t: tokens_stoi[t[0].text] != unk_id, segment_pred))
155
+
156
+ # Attach the predicted tags to each token
157
+ list(map(lambda t: setattr(t[0], 'pred_tag', [{"tag": tags_itos[t[1]]}]), segment_pred))
158
+
159
+ # We are only interested in the tagged tokens, we do no longer need raw model predictions
160
+ tagged_segment = [t for t, _ in segment_pred]
161
+ tagged_segments.append(tagged_segment)
162
+
163
+ return tagged_segments