CULTURE / train.py
pranamanam's picture
Upload 7 files
a0e0ff1 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from model import CustomBERTModel
from config import Config
import pandas as pd
from tqdm import tqdm
def load_data(file_path):
df = pd.read_csv(file_path, header=None)
return torch.tensor(df.values, dtype=torch.float32)
def create_mlm_data(data, mlm_probability):
labels = data.clone()
probability_matrix = torch.full(labels.shape, mlm_probability)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
data[indices_replaced] = 0 # Assume 0 is the representation of [MASK]
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long)
data[indices_random] = random_words[indices_random].float()
return data, labels
def train():
config = Config()
model = CustomBERTModel(config).to(config.device)
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
train_data = load_data(config.train_file)
val_data = load_data(config.val_file)
train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
for epoch in range(config.num_train_epochs):
model.train()
total_loss = 0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_train_epochs}"):
inputs = batch[0].to(config.device)
masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
optimizer.zero_grad()
outputs = model(masked_inputs, labels=labels)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}/{config.num_train_epochs}, Average training loss: {avg_train_loss:.4f}")
# Validation
model.eval()
total_val_loss = 0
with torch.no_grad():
for batch in val_loader:
inputs = batch[0].to(config.device)
masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
outputs = model(masked_inputs, labels=labels)
total_val_loss += outputs.loss.item()
avg_val_loss = total_val_loss / len(val_loader)
print(f"Validation loss: {avg_val_loss:.4f}")
# Save the model
torch.save(model.state_dict(), "bert_mlm_model.pth")
print("Model saved as bert_mlm_model.pth")
if __name__ == "__main__":
train()