Upload 7 files
Browse files- README.md +78 -3
- config.py +34 -0
- get_embeddings.py +44 -0
- model.py +36 -0
- requirements.txt +5 -0
- test.py +65 -0
- train.py +83 -0
README.md
CHANGED
|
@@ -1,3 +1,78 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CULTURE: Curve Understanding and Learning Transformer for Unique Replication Estimation
|
| 2 |
+
|
| 3 |
+
CULTURE (Curve Understanding and Learning Transformer for Unique Replication Estimation) is a custom BERT-based model with a Masked Language Model (MLM) head for processing microbial growth curve data. The model is trained on 1x128 vector inputs representing microbial growth curves and can generate hidden state embeddings for downstream tasks in microbiology and bioinformatics.
|
| 4 |
+
|
| 5 |
+
## Background
|
| 6 |
+
|
| 7 |
+
Microbial growth curves are time-series data that represent the growth of microorganisms over time. Each 1x128 vector in our dataset represents a single growth curve, with 128 time points measuring microbial population or density. By applying CULTURE to this data, we aim to capture complex patterns and relationships within these growth curves, potentially enabling better analysis and prediction in microbiology research.
|
| 8 |
+
|
| 9 |
+
## Requirements
|
| 10 |
+
|
| 11 |
+
The project requires Python 3.7+ and the following packages:
|
| 12 |
+
|
| 13 |
+
- PyTorch 1.9.0
|
| 14 |
+
- Transformers 4.11.3
|
| 15 |
+
- pandas 1.3.3
|
| 16 |
+
- numpy 1.21.2
|
| 17 |
+
- tqdm 4.62.3
|
| 18 |
+
|
| 19 |
+
You can install the required packages using the provided `requirements.txt` file:
|
| 20 |
+
|
| 21 |
+
```
|
| 22 |
+
pip install -r requirements.txt
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## File Structure
|
| 26 |
+
|
| 27 |
+
- `config.py`: Contains all configurable parameters for the CULTURE model and training process.
|
| 28 |
+
- `model.py`: Defines the CULTURE model architecture.
|
| 29 |
+
- `train.py`: Script for training the CULTURE model.
|
| 30 |
+
- `test.py`: Script for evaluating the CULTURE model's performance on the MLM task.
|
| 31 |
+
- `get_embeddings.py`: Script for generating embeddings using the trained CULTURE model.
|
| 32 |
+
- `requirements.txt`: List of required Python packages.
|
| 33 |
+
- `train.csv`: Training data file containing microbial growth curves (not included in this repository).
|
| 34 |
+
- `val.csv`: Validation data file containing microbial growth curves (not included in this repository).
|
| 35 |
+
- `test.csv`: Test data file containing microbial growth curves (not included in this repository).
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
1. Prepare your data:
|
| 40 |
+
- Ensure you have `train.csv`, `val.csv`, and `test.csv` files in the project directory.
|
| 41 |
+
- Each file should contain 1x128 vector data points representing microbial growth curves, without headers.
|
| 42 |
+
|
| 43 |
+
2. Configure the model:
|
| 44 |
+
- Open `config.py` and adjust the hyperparameters as needed.
|
| 45 |
+
- You can modify the hidden size (`hidden_size`), number of encoder layers (`num_hidden_layers`), number of attention heads (`num_attention_heads`), and masking probability (`mlm_probability`) among other parameters.
|
| 46 |
+
|
| 47 |
+
3. Train the model:
|
| 48 |
+
```
|
| 49 |
+
python train.py
|
| 50 |
+
```
|
| 51 |
+
This will train the CULTURE model on your microbial growth curve data and save it as `culture_model.pth`.
|
| 52 |
+
|
| 53 |
+
4. Evaluate the model:
|
| 54 |
+
```
|
| 55 |
+
python test.py
|
| 56 |
+
```
|
| 57 |
+
This will load the trained CULTURE model and evaluate its performance on the MLM task using the test data.
|
| 58 |
+
|
| 59 |
+
5. Generate embeddings:
|
| 60 |
+
```
|
| 61 |
+
python get_embeddings.py input_file.csv output_embeddings.npy
|
| 62 |
+
```
|
| 63 |
+
This will load the trained CULTURE model, process the input growth curve data, and save the embeddings as a NumPy file.
|
| 64 |
+
|
| 65 |
+
## Customization
|
| 66 |
+
|
| 67 |
+
- To use different input dimensions (e.g., if your growth curves have a different number of time points), modify the `input_dim` parameter in `config.py`.
|
| 68 |
+
- Adjust learning rate, batch size, and other training parameters in `config.py`.
|
| 69 |
+
- For more advanced modifications, you can edit the model architecture in `model.py`.
|
| 70 |
+
|
| 71 |
+
## Output
|
| 72 |
+
|
| 73 |
+
After running `get_embeddings.py`, you'll get a NumPy file containing the hidden state embeddings for your input data. These embeddings can be used for downstream tasks.
|
| 74 |
+
|
| 75 |
+
## Note
|
| 76 |
+
|
| 77 |
+
CULTURE assumes that all growth curves are sampled at consistent time intervals and have been preprocessed to have the same length (128 time points). If your data differs significantly from this format, you may need to preprocess it or adjust the model architecture accordingly.
|
| 78 |
+
|
config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class Config:
|
| 4 |
+
# Model parameters
|
| 5 |
+
vocab_size = 30522 # BERT vocabulary size
|
| 6 |
+
hidden_size = 128 # Adjusted to match input dimension
|
| 7 |
+
num_hidden_layers = 6 # Can be varied
|
| 8 |
+
num_attention_heads = 8 # Can be varied
|
| 9 |
+
intermediate_size = 512 # Adjusted based on hidden_size
|
| 10 |
+
hidden_act = "gelu"
|
| 11 |
+
hidden_dropout_prob = 0.1
|
| 12 |
+
attention_probs_dropout_prob = 0.1
|
| 13 |
+
max_position_embeddings = 512
|
| 14 |
+
type_vocab_size = 2
|
| 15 |
+
initializer_range = 0.02
|
| 16 |
+
layer_norm_eps = 1e-12
|
| 17 |
+
|
| 18 |
+
# Training parameters
|
| 19 |
+
batch_size = 32
|
| 20 |
+
learning_rate = 5e-5
|
| 21 |
+
num_train_epochs = 3
|
| 22 |
+
warmup_steps = 0
|
| 23 |
+
max_grad_norm = 1.0
|
| 24 |
+
weight_decay = 0.01
|
| 25 |
+
|
| 26 |
+
# Data parameters
|
| 27 |
+
train_file = "train.csv"
|
| 28 |
+
val_file = "val.csv"
|
| 29 |
+
test_file = "test.csv"
|
| 30 |
+
input_dim = 128
|
| 31 |
+
mlm_probability = 0.15 # Can be adjusted by the user
|
| 32 |
+
|
| 33 |
+
# Device
|
| 34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
get_embeddings.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
from model import CustomBERTModel
|
| 5 |
+
from config import Config
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
def load_data(file_path):
|
| 9 |
+
df = pd.read_csv(file_path, header=None)
|
| 10 |
+
return torch.tensor(df.values, dtype=torch.float32)
|
| 11 |
+
|
| 12 |
+
def get_embeddings(input_file, output_file):
|
| 13 |
+
config = Config()
|
| 14 |
+
model = CustomBERTModel(config).to(config.device)
|
| 15 |
+
model.load_state_dict(torch.load("bert_mlm_model.pth"))
|
| 16 |
+
model.eval()
|
| 17 |
+
|
| 18 |
+
input_data = load_data(input_file)
|
| 19 |
+
dataset = TensorDataset(input_data)
|
| 20 |
+
data_loader = DataLoader(dataset, batch_size=config.batch_size)
|
| 21 |
+
|
| 22 |
+
all_embeddings = []
|
| 23 |
+
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
for batch in data_loader:
|
| 26 |
+
inputs = batch[0].to(config.device)
|
| 27 |
+
embeddings = model.get_encoder_output(inputs)
|
| 28 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
| 29 |
+
|
| 30 |
+
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
| 31 |
+
print(f"Generated embeddings shape: {all_embeddings.shape}")
|
| 32 |
+
|
| 33 |
+
# Save embeddings
|
| 34 |
+
np.save(output_file, all_embeddings)
|
| 35 |
+
print(f"Embeddings saved as {output_file}")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
import argparse
|
| 39 |
+
parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves")
|
| 40 |
+
parser.add_argument("input_file", help="Path to the input CSV file containing growth curves")
|
| 41 |
+
parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)")
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
get_embeddings(args.input_file, args.output_file)
|
model.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import BertConfig, BertForMaskedLM
|
| 4 |
+
from config import Config
|
| 5 |
+
|
| 6 |
+
class CustomBERTModel(nn.Module):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super(CustomBERTModel, self).__init__()
|
| 9 |
+
self.input_proj = nn.Linear(config.input_dim, config.hidden_size)
|
| 10 |
+
|
| 11 |
+
bert_config = BertConfig(
|
| 12 |
+
vocab_size=config.vocab_size,
|
| 13 |
+
hidden_size=config.hidden_size,
|
| 14 |
+
num_hidden_layers=config.num_hidden_layers,
|
| 15 |
+
num_attention_heads=config.num_attention_heads,
|
| 16 |
+
intermediate_size=config.intermediate_size,
|
| 17 |
+
hidden_act=config.hidden_act,
|
| 18 |
+
hidden_dropout_prob=config.hidden_dropout_prob,
|
| 19 |
+
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
| 20 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 21 |
+
type_vocab_size=config.type_vocab_size,
|
| 22 |
+
initializer_range=config.initializer_range,
|
| 23 |
+
layer_norm_eps=config.layer_norm_eps
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.bert = BertForMaskedLM(bert_config)
|
| 27 |
+
|
| 28 |
+
def forward(self, x, labels=None):
|
| 29 |
+
x = self.input_proj(x)
|
| 30 |
+
outputs = self.bert(inputs_embeds=x, labels=labels)
|
| 31 |
+
return outputs
|
| 32 |
+
|
| 33 |
+
def get_encoder_output(self, x):
|
| 34 |
+
x = self.input_proj(x)
|
| 35 |
+
outputs = self.bert.bert(inputs_embeds=x)
|
| 36 |
+
return outputs.last_hidden_state
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.9.0
|
| 2 |
+
transformers==4.11.3
|
| 3 |
+
pandas==1.3.3
|
| 4 |
+
numpy==1.21.2
|
| 5 |
+
tqdm==4.62.3
|
test.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
from model import CustomBERTModel
|
| 5 |
+
from config import Config
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
def load_data(file_path):
|
| 10 |
+
df = pd.read_csv(file_path, header=None)
|
| 11 |
+
return torch.tensor(df.values, dtype=torch.float32)
|
| 12 |
+
|
| 13 |
+
def create_mlm_data(data, mlm_probability):
|
| 14 |
+
labels = data.clone()
|
| 15 |
+
probability_matrix = torch.full(labels.shape, mlm_probability)
|
| 16 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 17 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 18 |
+
|
| 19 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 20 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 21 |
+
data[indices_replaced] = 0 # Assume 0 is the representation of [MASK]
|
| 22 |
+
|
| 23 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 24 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 25 |
+
random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long)
|
| 26 |
+
data[indices_random] = random_words[indices_random].float()
|
| 27 |
+
|
| 28 |
+
return data, labels
|
| 29 |
+
|
| 30 |
+
def test():
|
| 31 |
+
config = Config()
|
| 32 |
+
model = CustomBERTModel(config).to(config.device)
|
| 33 |
+
model.load_state_dict(torch.load("bert_mlm_model.pth"))
|
| 34 |
+
model.eval()
|
| 35 |
+
|
| 36 |
+
test_data = load_data(config.test_file)
|
| 37 |
+
test_dataset = TensorDataset(test_data)
|
| 38 |
+
test_loader = DataLoader(test_dataset, batch_size=config.batch_size)
|
| 39 |
+
|
| 40 |
+
total_loss = 0
|
| 41 |
+
total_correct = 0
|
| 42 |
+
total_predictions = 0
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
for batch in tqdm(test_loader, desc="Testing"):
|
| 46 |
+
inputs = batch[0].to(config.device)
|
| 47 |
+
masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
|
| 48 |
+
|
| 49 |
+
outputs = model(masked_inputs, labels=labels)
|
| 50 |
+
loss = outputs.loss
|
| 51 |
+
total_loss += loss.item()
|
| 52 |
+
|
| 53 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 54 |
+
mask = labels != -100
|
| 55 |
+
total_correct += (predictions[mask] == labels[mask]).sum().item()
|
| 56 |
+
total_predictions += mask.sum().item()
|
| 57 |
+
|
| 58 |
+
avg_loss = total_loss / len(test_loader)
|
| 59 |
+
accuracy = total_correct / total_predictions
|
| 60 |
+
|
| 61 |
+
print(f"Test Loss: {avg_loss:.4f}")
|
| 62 |
+
print(f"Test Accuracy: {accuracy:.4f}")
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
test()
|
train.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 5 |
+
import numpy as np
|
| 6 |
+
from model import CustomBERTModel
|
| 7 |
+
from config import Config
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
def load_data(file_path):
|
| 12 |
+
df = pd.read_csv(file_path, header=None)
|
| 13 |
+
return torch.tensor(df.values, dtype=torch.float32)
|
| 14 |
+
|
| 15 |
+
def create_mlm_data(data, mlm_probability):
|
| 16 |
+
labels = data.clone()
|
| 17 |
+
probability_matrix = torch.full(labels.shape, mlm_probability)
|
| 18 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 19 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 20 |
+
|
| 21 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 22 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 23 |
+
data[indices_replaced] = 0 # Assume 0 is the representation of [MASK]
|
| 24 |
+
|
| 25 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 26 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 27 |
+
random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long)
|
| 28 |
+
data[indices_random] = random_words[indices_random].float()
|
| 29 |
+
|
| 30 |
+
return data, labels
|
| 31 |
+
|
| 32 |
+
def train():
|
| 33 |
+
config = Config()
|
| 34 |
+
model = CustomBERTModel(config).to(config.device)
|
| 35 |
+
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
|
| 36 |
+
|
| 37 |
+
train_data = load_data(config.train_file)
|
| 38 |
+
val_data = load_data(config.val_file)
|
| 39 |
+
|
| 40 |
+
train_dataset = TensorDataset(train_data)
|
| 41 |
+
val_dataset = TensorDataset(val_data)
|
| 42 |
+
|
| 43 |
+
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
|
| 44 |
+
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
|
| 45 |
+
|
| 46 |
+
for epoch in range(config.num_train_epochs):
|
| 47 |
+
model.train()
|
| 48 |
+
total_loss = 0
|
| 49 |
+
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_train_epochs}"):
|
| 50 |
+
inputs = batch[0].to(config.device)
|
| 51 |
+
masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
|
| 52 |
+
|
| 53 |
+
optimizer.zero_grad()
|
| 54 |
+
outputs = model(masked_inputs, labels=labels)
|
| 55 |
+
loss = outputs.loss
|
| 56 |
+
loss.backward()
|
| 57 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
| 58 |
+
optimizer.step()
|
| 59 |
+
|
| 60 |
+
total_loss += loss.item()
|
| 61 |
+
|
| 62 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 63 |
+
print(f"Epoch {epoch+1}/{config.num_train_epochs}, Average training loss: {avg_train_loss:.4f}")
|
| 64 |
+
|
| 65 |
+
# Validation
|
| 66 |
+
model.eval()
|
| 67 |
+
total_val_loss = 0
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
for batch in val_loader:
|
| 70 |
+
inputs = batch[0].to(config.device)
|
| 71 |
+
masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
|
| 72 |
+
outputs = model(masked_inputs, labels=labels)
|
| 73 |
+
total_val_loss += outputs.loss.item()
|
| 74 |
+
|
| 75 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 76 |
+
print(f"Validation loss: {avg_val_loss:.4f}")
|
| 77 |
+
|
| 78 |
+
# Save the model
|
| 79 |
+
torch.save(model.state_dict(), "bert_mlm_model.pth")
|
| 80 |
+
print("Model saved as bert_mlm_model.pth")
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
train()
|