| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, TensorDataset |
| import numpy as np |
| from model import CustomBERTModel |
| from config import Config |
| import pandas as pd |
| from tqdm import tqdm |
|
|
| def load_data(file_path): |
| df = pd.read_csv(file_path, header=None) |
| return torch.tensor(df.values, dtype=torch.float32) |
|
|
| def create_mlm_data(data, mlm_probability): |
| labels = data.clone() |
| probability_matrix = torch.full(labels.shape, mlm_probability) |
| masked_indices = torch.bernoulli(probability_matrix).bool() |
| labels[~masked_indices] = -100 |
|
|
| |
| indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| data[indices_replaced] = 0 |
|
|
| |
| indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
| random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long) |
| data[indices_random] = random_words[indices_random].float() |
|
|
| return data, labels |
|
|
| def train(): |
| config = Config() |
| model = CustomBERTModel(config).to(config.device) |
| optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) |
| |
| train_data = load_data(config.train_file) |
| val_data = load_data(config.val_file) |
| |
| train_dataset = TensorDataset(train_data) |
| val_dataset = TensorDataset(val_data) |
| |
| train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=config.batch_size) |
| |
| for epoch in range(config.num_train_epochs): |
| model.train() |
| total_loss = 0 |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_train_epochs}"): |
| inputs = batch[0].to(config.device) |
| masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability) |
| |
| optimizer.zero_grad() |
| outputs = model(masked_inputs, labels=labels) |
| loss = outputs.loss |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| optimizer.step() |
| |
| total_loss += loss.item() |
| |
| avg_train_loss = total_loss / len(train_loader) |
| print(f"Epoch {epoch+1}/{config.num_train_epochs}, Average training loss: {avg_train_loss:.4f}") |
| |
| |
| model.eval() |
| total_val_loss = 0 |
| with torch.no_grad(): |
| for batch in val_loader: |
| inputs = batch[0].to(config.device) |
| masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability) |
| outputs = model(masked_inputs, labels=labels) |
| total_val_loss += outputs.loss.item() |
| |
| avg_val_loss = total_val_loss / len(val_loader) |
| print(f"Validation loss: {avg_val_loss:.4f}") |
| |
| |
| torch.save(model.state_dict(), "bert_mlm_model.pth") |
| print("Model saved as bert_mlm_model.pth") |
|
|
| if __name__ == "__main__": |
| train() |
|
|