CULTURE / test.py
pranamanam's picture
Upload 7 files
a0e0ff1 verified
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from model import CustomBERTModel
from config import Config
import pandas as pd
from tqdm import tqdm
def load_data(file_path):
df = pd.read_csv(file_path, header=None)
return torch.tensor(df.values, dtype=torch.float32)
def create_mlm_data(data, mlm_probability):
labels = data.clone()
probability_matrix = torch.full(labels.shape, mlm_probability)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
data[indices_replaced] = 0 # Assume 0 is the representation of [MASK]
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long)
data[indices_random] = random_words[indices_random].float()
return data, labels
def test():
config = Config()
model = CustomBERTModel(config).to(config.device)
model.load_state_dict(torch.load("bert_mlm_model.pth"))
model.eval()
test_data = load_data(config.test_file)
test_dataset = TensorDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size)
total_loss = 0
total_correct = 0
total_predictions = 0
with torch.no_grad():
for batch in tqdm(test_loader, desc="Testing"):
inputs = batch[0].to(config.device)
masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
outputs = model(masked_inputs, labels=labels)
loss = outputs.loss
total_loss += loss.item()
predictions = outputs.logits.argmax(dim=-1)
mask = labels != -100
total_correct += (predictions[mask] == labels[mask]).sum().item()
total_predictions += mask.sum().item()
avg_loss = total_loss / len(test_loader)
accuracy = total_correct / total_predictions
print(f"Test Loss: {avg_loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")
if __name__ == "__main__":
test()