pranamanam commited on
Commit
a0e0ff1
·
verified ·
1 Parent(s): c9c7186

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +78 -3
  2. config.py +34 -0
  3. get_embeddings.py +44 -0
  4. model.py +36 -0
  5. requirements.txt +5 -0
  6. test.py +65 -0
  7. train.py +83 -0
README.md CHANGED
@@ -1,3 +1,78 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CULTURE: Curve Understanding and Learning Transformer for Unique Replication Estimation
2
+
3
+ CULTURE (Curve Understanding and Learning Transformer for Unique Replication Estimation) is a custom BERT-based model with a Masked Language Model (MLM) head for processing microbial growth curve data. The model is trained on 1x128 vector inputs representing microbial growth curves and can generate hidden state embeddings for downstream tasks in microbiology and bioinformatics.
4
+
5
+ ## Background
6
+
7
+ Microbial growth curves are time-series data that represent the growth of microorganisms over time. Each 1x128 vector in our dataset represents a single growth curve, with 128 time points measuring microbial population or density. By applying CULTURE to this data, we aim to capture complex patterns and relationships within these growth curves, potentially enabling better analysis and prediction in microbiology research.
8
+
9
+ ## Requirements
10
+
11
+ The project requires Python 3.7+ and the following packages:
12
+
13
+ - PyTorch 1.9.0
14
+ - Transformers 4.11.3
15
+ - pandas 1.3.3
16
+ - numpy 1.21.2
17
+ - tqdm 4.62.3
18
+
19
+ You can install the required packages using the provided `requirements.txt` file:
20
+
21
+ ```
22
+ pip install -r requirements.txt
23
+ ```
24
+
25
+ ## File Structure
26
+
27
+ - `config.py`: Contains all configurable parameters for the CULTURE model and training process.
28
+ - `model.py`: Defines the CULTURE model architecture.
29
+ - `train.py`: Script for training the CULTURE model.
30
+ - `test.py`: Script for evaluating the CULTURE model's performance on the MLM task.
31
+ - `get_embeddings.py`: Script for generating embeddings using the trained CULTURE model.
32
+ - `requirements.txt`: List of required Python packages.
33
+ - `train.csv`: Training data file containing microbial growth curves (not included in this repository).
34
+ - `val.csv`: Validation data file containing microbial growth curves (not included in this repository).
35
+ - `test.csv`: Test data file containing microbial growth curves (not included in this repository).
36
+
37
+ ## Usage
38
+
39
+ 1. Prepare your data:
40
+ - Ensure you have `train.csv`, `val.csv`, and `test.csv` files in the project directory.
41
+ - Each file should contain 1x128 vector data points representing microbial growth curves, without headers.
42
+
43
+ 2. Configure the model:
44
+ - Open `config.py` and adjust the hyperparameters as needed.
45
+ - You can modify the hidden size (`hidden_size`), number of encoder layers (`num_hidden_layers`), number of attention heads (`num_attention_heads`), and masking probability (`mlm_probability`) among other parameters.
46
+
47
+ 3. Train the model:
48
+ ```
49
+ python train.py
50
+ ```
51
+ This will train the CULTURE model on your microbial growth curve data and save it as `culture_model.pth`.
52
+
53
+ 4. Evaluate the model:
54
+ ```
55
+ python test.py
56
+ ```
57
+ This will load the trained CULTURE model and evaluate its performance on the MLM task using the test data.
58
+
59
+ 5. Generate embeddings:
60
+ ```
61
+ python get_embeddings.py input_file.csv output_embeddings.npy
62
+ ```
63
+ This will load the trained CULTURE model, process the input growth curve data, and save the embeddings as a NumPy file.
64
+
65
+ ## Customization
66
+
67
+ - To use different input dimensions (e.g., if your growth curves have a different number of time points), modify the `input_dim` parameter in `config.py`.
68
+ - Adjust learning rate, batch size, and other training parameters in `config.py`.
69
+ - For more advanced modifications, you can edit the model architecture in `model.py`.
70
+
71
+ ## Output
72
+
73
+ After running `get_embeddings.py`, you'll get a NumPy file containing the hidden state embeddings for your input data. These embeddings can be used for downstream tasks.
74
+
75
+ ## Note
76
+
77
+ CULTURE assumes that all growth curves are sampled at consistent time intervals and have been preprocessed to have the same length (128 time points). If your data differs significantly from this format, you may need to preprocess it or adjust the model architecture accordingly.
78
+
config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Config:
4
+ # Model parameters
5
+ vocab_size = 30522 # BERT vocabulary size
6
+ hidden_size = 128 # Adjusted to match input dimension
7
+ num_hidden_layers = 6 # Can be varied
8
+ num_attention_heads = 8 # Can be varied
9
+ intermediate_size = 512 # Adjusted based on hidden_size
10
+ hidden_act = "gelu"
11
+ hidden_dropout_prob = 0.1
12
+ attention_probs_dropout_prob = 0.1
13
+ max_position_embeddings = 512
14
+ type_vocab_size = 2
15
+ initializer_range = 0.02
16
+ layer_norm_eps = 1e-12
17
+
18
+ # Training parameters
19
+ batch_size = 32
20
+ learning_rate = 5e-5
21
+ num_train_epochs = 3
22
+ warmup_steps = 0
23
+ max_grad_norm = 1.0
24
+ weight_decay = 0.01
25
+
26
+ # Data parameters
27
+ train_file = "train.csv"
28
+ val_file = "val.csv"
29
+ test_file = "test.csv"
30
+ input_dim = 128
31
+ mlm_probability = 0.15 # Can be adjusted by the user
32
+
33
+ # Device
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
get_embeddings.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, TensorDataset
3
+ import numpy as np
4
+ from model import CustomBERTModel
5
+ from config import Config
6
+ import pandas as pd
7
+
8
+ def load_data(file_path):
9
+ df = pd.read_csv(file_path, header=None)
10
+ return torch.tensor(df.values, dtype=torch.float32)
11
+
12
+ def get_embeddings(input_file, output_file):
13
+ config = Config()
14
+ model = CustomBERTModel(config).to(config.device)
15
+ model.load_state_dict(torch.load("bert_mlm_model.pth"))
16
+ model.eval()
17
+
18
+ input_data = load_data(input_file)
19
+ dataset = TensorDataset(input_data)
20
+ data_loader = DataLoader(dataset, batch_size=config.batch_size)
21
+
22
+ all_embeddings = []
23
+
24
+ with torch.no_grad():
25
+ for batch in data_loader:
26
+ inputs = batch[0].to(config.device)
27
+ embeddings = model.get_encoder_output(inputs)
28
+ all_embeddings.append(embeddings.cpu().numpy())
29
+
30
+ all_embeddings = np.concatenate(all_embeddings, axis=0)
31
+ print(f"Generated embeddings shape: {all_embeddings.shape}")
32
+
33
+ # Save embeddings
34
+ np.save(output_file, all_embeddings)
35
+ print(f"Embeddings saved as {output_file}")
36
+
37
+ if __name__ == "__main__":
38
+ import argparse
39
+ parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves")
40
+ parser.add_argument("input_file", help="Path to the input CSV file containing growth curves")
41
+ parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)")
42
+ args = parser.parse_args()
43
+
44
+ get_embeddings(args.input_file, args.output_file)
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertConfig, BertForMaskedLM
4
+ from config import Config
5
+
6
+ class CustomBERTModel(nn.Module):
7
+ def __init__(self, config):
8
+ super(CustomBERTModel, self).__init__()
9
+ self.input_proj = nn.Linear(config.input_dim, config.hidden_size)
10
+
11
+ bert_config = BertConfig(
12
+ vocab_size=config.vocab_size,
13
+ hidden_size=config.hidden_size,
14
+ num_hidden_layers=config.num_hidden_layers,
15
+ num_attention_heads=config.num_attention_heads,
16
+ intermediate_size=config.intermediate_size,
17
+ hidden_act=config.hidden_act,
18
+ hidden_dropout_prob=config.hidden_dropout_prob,
19
+ attention_probs_dropout_prob=config.attention_probs_dropout_prob,
20
+ max_position_embeddings=config.max_position_embeddings,
21
+ type_vocab_size=config.type_vocab_size,
22
+ initializer_range=config.initializer_range,
23
+ layer_norm_eps=config.layer_norm_eps
24
+ )
25
+
26
+ self.bert = BertForMaskedLM(bert_config)
27
+
28
+ def forward(self, x, labels=None):
29
+ x = self.input_proj(x)
30
+ outputs = self.bert(inputs_embeds=x, labels=labels)
31
+ return outputs
32
+
33
+ def get_encoder_output(self, x):
34
+ x = self.input_proj(x)
35
+ outputs = self.bert.bert(inputs_embeds=x)
36
+ return outputs.last_hidden_state
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.9.0
2
+ transformers==4.11.3
3
+ pandas==1.3.3
4
+ numpy==1.21.2
5
+ tqdm==4.62.3
test.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, TensorDataset
3
+ import numpy as np
4
+ from model import CustomBERTModel
5
+ from config import Config
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+
9
+ def load_data(file_path):
10
+ df = pd.read_csv(file_path, header=None)
11
+ return torch.tensor(df.values, dtype=torch.float32)
12
+
13
+ def create_mlm_data(data, mlm_probability):
14
+ labels = data.clone()
15
+ probability_matrix = torch.full(labels.shape, mlm_probability)
16
+ masked_indices = torch.bernoulli(probability_matrix).bool()
17
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
18
+
19
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
20
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
21
+ data[indices_replaced] = 0 # Assume 0 is the representation of [MASK]
22
+
23
+ # 10% of the time, we replace masked input tokens with random word
24
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
25
+ random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long)
26
+ data[indices_random] = random_words[indices_random].float()
27
+
28
+ return data, labels
29
+
30
+ def test():
31
+ config = Config()
32
+ model = CustomBERTModel(config).to(config.device)
33
+ model.load_state_dict(torch.load("bert_mlm_model.pth"))
34
+ model.eval()
35
+
36
+ test_data = load_data(config.test_file)
37
+ test_dataset = TensorDataset(test_data)
38
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size)
39
+
40
+ total_loss = 0
41
+ total_correct = 0
42
+ total_predictions = 0
43
+
44
+ with torch.no_grad():
45
+ for batch in tqdm(test_loader, desc="Testing"):
46
+ inputs = batch[0].to(config.device)
47
+ masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
48
+
49
+ outputs = model(masked_inputs, labels=labels)
50
+ loss = outputs.loss
51
+ total_loss += loss.item()
52
+
53
+ predictions = outputs.logits.argmax(dim=-1)
54
+ mask = labels != -100
55
+ total_correct += (predictions[mask] == labels[mask]).sum().item()
56
+ total_predictions += mask.sum().item()
57
+
58
+ avg_loss = total_loss / len(test_loader)
59
+ accuracy = total_correct / total_predictions
60
+
61
+ print(f"Test Loss: {avg_loss:.4f}")
62
+ print(f"Test Accuracy: {accuracy:.4f}")
63
+
64
+ if __name__ == "__main__":
65
+ test()
train.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, TensorDataset
5
+ import numpy as np
6
+ from model import CustomBERTModel
7
+ from config import Config
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+ def load_data(file_path):
12
+ df = pd.read_csv(file_path, header=None)
13
+ return torch.tensor(df.values, dtype=torch.float32)
14
+
15
+ def create_mlm_data(data, mlm_probability):
16
+ labels = data.clone()
17
+ probability_matrix = torch.full(labels.shape, mlm_probability)
18
+ masked_indices = torch.bernoulli(probability_matrix).bool()
19
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
20
+
21
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
22
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
23
+ data[indices_replaced] = 0 # Assume 0 is the representation of [MASK]
24
+
25
+ # 10% of the time, we replace masked input tokens with random word
26
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
27
+ random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long)
28
+ data[indices_random] = random_words[indices_random].float()
29
+
30
+ return data, labels
31
+
32
+ def train():
33
+ config = Config()
34
+ model = CustomBERTModel(config).to(config.device)
35
+ optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
36
+
37
+ train_data = load_data(config.train_file)
38
+ val_data = load_data(config.val_file)
39
+
40
+ train_dataset = TensorDataset(train_data)
41
+ val_dataset = TensorDataset(val_data)
42
+
43
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
44
+ val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
45
+
46
+ for epoch in range(config.num_train_epochs):
47
+ model.train()
48
+ total_loss = 0
49
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_train_epochs}"):
50
+ inputs = batch[0].to(config.device)
51
+ masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
52
+
53
+ optimizer.zero_grad()
54
+ outputs = model(masked_inputs, labels=labels)
55
+ loss = outputs.loss
56
+ loss.backward()
57
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
58
+ optimizer.step()
59
+
60
+ total_loss += loss.item()
61
+
62
+ avg_train_loss = total_loss / len(train_loader)
63
+ print(f"Epoch {epoch+1}/{config.num_train_epochs}, Average training loss: {avg_train_loss:.4f}")
64
+
65
+ # Validation
66
+ model.eval()
67
+ total_val_loss = 0
68
+ with torch.no_grad():
69
+ for batch in val_loader:
70
+ inputs = batch[0].to(config.device)
71
+ masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability)
72
+ outputs = model(masked_inputs, labels=labels)
73
+ total_val_loss += outputs.loss.item()
74
+
75
+ avg_val_loss = total_val_loss / len(val_loader)
76
+ print(f"Validation loss: {avg_val_loss:.4f}")
77
+
78
+ # Save the model
79
+ torch.save(model.state_dict(), "bert_mlm_model.pth")
80
+ print("Model saved as bert_mlm_model.pth")
81
+
82
+ if __name__ == "__main__":
83
+ train()