| import argparse |
| import logging |
| import csv |
| import random |
| import warnings |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Any, Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import albumentations as A |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import ( |
| accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix |
| ) |
| from rasterio.errors import NotGeoreferencedWarning |
| import terramind |
|
|
| |
| from methane_classification_datamodule import MethaneClassificationDataModule |
|
|
| |
| from terratorch.tasks import ClassificationTask |
|
|
|
|
| |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| logging.getLogger("rasterio._env").setLevel(logging.ERROR) |
| warnings.simplefilter("ignore", NotGeoreferencedWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| def set_seed(seed: int = 42): |
| """Sets the seed for reproducibility across random, numpy, and torch.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def get_training_transforms() -> A.Compose: |
| """Returns the albumentations training pipeline.""" |
| return A.Compose([ |
| A.ElasticTransform(p=0.25), |
| A.RandomRotate90(p=0.5), |
| A.Flip(p=0.5), |
| A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) |
| ]) |
|
|
| |
|
|
| class MetricTracker: |
| """Accumulates targets and predictions to calculate epoch-level metrics.""" |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.all_targets = [] |
| self.all_predictions = [] |
| self.total_loss = 0.0 |
| self.steps = 0 |
|
|
| def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor): |
| self.total_loss += loss |
| self.steps += 1 |
| |
| self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy()) |
| self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy()) |
|
|
| def compute(self) -> Dict[str, float]: |
| """Calculates aggregate metrics for the accumulated data.""" |
| if not self.all_targets: |
| return {} |
| |
| |
| tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel() |
| |
| return { |
| "Loss": self.total_loss / max(self.steps, 1), |
| "Accuracy": accuracy_score(self.all_targets, self.all_predictions), |
| "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0, |
| "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| "MCC": matthews_corrcoef(self.all_targets, self.all_predictions), |
| } |
|
|
| class MethaneTrainer: |
| """ |
| Handles the training lifecycle: Model setup, Training loop, Validation, and Checkpointing. |
| """ |
| def __init__(self, args: argparse.Namespace): |
| self.args = args |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}' |
| self.save_dir.mkdir(parents=True, exist_ok=True) |
| |
| self.model = self._init_model() |
| self.optimizer, self.scheduler = self._init_optimizer() |
| self.criterion = self.task.criterion |
| |
| self.best_val_loss = float('inf') |
| |
| logger.info(f"Trainer initialized on device: {self.device}") |
|
|
| def _init_model(self) -> nn.Module: |
| """Initializes the TerraTorch Classification Task and Model.""" |
| model_config = dict( |
| backbone="terramind_v1_base", |
| backbone_pretrained=True, |
| backbone_modalities=["S2L2A"], |
| backbone_merge_method="mean", |
| decoder="UperNetDecoder", |
| decoder_scale_modules=True, |
| decoder_channels=256, |
| num_classes=2, |
| head_dropout=0.3, |
| necks=[ |
| {"name": "ReshapeTokensToImage", "remove_cls_token": False}, |
| {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, |
| ], |
| ) |
|
|
| self.task = ClassificationTask( |
| model_args=model_config, |
| model_factory="EncoderDecoderFactory", |
| loss="ce", |
| lr=self.args.lr, |
| ignore_index=-1, |
| optimizer="AdamW", |
| optimizer_hparams={"weight_decay": self.args.weight_decay}, |
| ) |
| self.task.configure_models() |
| self.task.configure_losses() |
| return self.task.model.to(self.device) |
|
|
| def _init_optimizer(self): |
| optimizer = optim.AdamW( |
| self.model.parameters(), |
| lr=self.args.lr, |
| weight_decay=self.args.weight_decay |
| ) |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', patience=5, verbose=True |
| ) |
| return optimizer, scheduler |
|
|
| def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]: |
| """Runs a single epoch for either training or validation.""" |
| is_train = stage == "train" |
| self.model.train() if is_train else self.model.eval() |
| |
| tracker = MetricTracker() |
| |
| |
| with torch.set_grad_enabled(is_train): |
| pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False) |
| |
| for batch in pbar: |
| inputs = batch['S2L2A'].to(self.device) |
| targets = batch['label'].to(self.device) |
|
|
| |
| outputs = self.model(x={"S2L2A": inputs}) |
| probabilities = torch.softmax(outputs.output, dim=1) |
| loss = self.criterion(probabilities, targets) |
|
|
| if is_train: |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
|
|
| |
| tracker.update(loss.item(), targets, probabilities) |
| |
| |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
|
|
| return tracker.compute() |
|
|
| def save_checkpoint(self, filename: str): |
| path = self.save_dir / filename |
| torch.save(self.model.state_dict(), path) |
| logger.info(f"Saved model to {path}") |
|
|
| def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict): |
| """Appends metrics to the CSV log file.""" |
| csv_path = self.save_dir / 'train_val_metrics.csv' |
| file_exists = csv_path.exists() |
| |
| |
| headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()] |
| |
| with open(csv_path, mode='a', newline='') as f: |
| writer = csv.writer(f) |
| if not file_exists: |
| writer.writerow(headers) |
| |
| row = [epoch] + list(train_metrics.values()) + list(val_metrics.values()) |
| writer.writerow(row) |
|
|
| def fit(self, train_loader: DataLoader, val_loader: DataLoader): |
| """Main training entry point.""" |
| logger.info(f"Starting training for {self.args.epochs} epochs...") |
| start_time = time.time() |
|
|
| for epoch in range(1, self.args.epochs + 1): |
| logger.info(f"Epoch {epoch}/{self.args.epochs}") |
| |
| |
| train_metrics = self.run_epoch(train_loader, stage="train") |
| val_metrics = self.run_epoch(val_loader, stage="validate") |
| |
| |
| self.scheduler.step(val_metrics['Loss']) |
| |
| |
| self.log_to_csv(epoch, train_metrics, val_metrics) |
| logger.info( |
| f"Train Loss: {train_metrics['Loss']:.4f} | " |
| f"Val Loss: {val_metrics['Loss']:.4f} | " |
| f"Val F1: {val_metrics['F1']:.4f}" |
| ) |
|
|
| |
| if val_metrics['Loss'] < self.best_val_loss: |
| self.best_val_loss = val_metrics['Loss'] |
| self.save_checkpoint("best_model.pth") |
| logger.info(f"--> New best model (Val Loss: {self.best_val_loss:.4f})") |
|
|
| |
| self.save_checkpoint("final_model.pth") |
| logger.info(f"Training finished in {time.time() - start_time:.2f}s") |
|
|
|
|
| |
|
|
| def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]: |
| """Prepares DataModule and returns Train/Val loaders.""" |
| |
| |
| try: |
| df = pd.read_csv(args.excel_file) if args.excel_file.endswith('.csv') else pd.read_excel(args.excel_file) |
| except Exception as e: |
| logger.error(f"Failed to load summary file: {e}") |
| raise |
|
|
| |
| all_folds = range(1, args.num_folds + 1) |
| train_pool_folds = [f for f in all_folds if f != args.test_fold] |
| |
| |
| df_filtered = df[df['Fold'].isin(train_pool_folds)] |
| if df_filtered.empty: |
| raise ValueError(f"No data found for folds {train_pool_folds}. Check 'Fold' column in Excel.") |
| |
| paths = df_filtered['Filename'].tolist() |
| |
| |
| train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed) |
| |
| logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})") |
|
|
| |
| datamodule = MethaneClassificationDataModule( |
| data_root=args.root_dir, |
| excel_file=args.excel_file, |
| batch_size=args.batch_size, |
| paths=train_paths, |
| train_transform=get_training_transforms(), |
| val_transform=None, |
| ) |
| |
| |
| datamodule.paths = train_paths |
| datamodule.setup(stage="fit") |
| train_loader = datamodule.train_dataloader() |
| |
| datamodule.paths = val_paths |
| datamodule.setup(stage="validate") |
| val_loader = datamodule.val_dataloader() |
| |
| return train_loader, val_loader |
|
|
|
|
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Methane Classification Training with TerraTorch") |
| |
| |
| parser.add_argument('--root_dir', type=str, required=True, help='Root directory for satellite images') |
| parser.add_argument('--excel_file', type=str, required=True, help='Path to summary Excel/CSV file') |
| parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory to save outputs') |
| |
| |
| parser.add_argument('--epochs', type=int, default=100) |
| parser.add_argument('--batch_size', type=int, default=8) |
| parser.add_argument('--lr', type=float, default=1e-5) |
| parser.add_argument('--weight_decay', type=float, default=0.05) |
| parser.add_argument('--num_folds', type=int, default=5) |
| parser.add_argument('--test_fold', type=int, default=2, help='Fold ID to hold out for testing') |
| parser.add_argument('--seed', type=int, default=42) |
| |
| return parser.parse_args() |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| set_seed(args.seed) |
| |
| |
| train_loader, val_loader = get_data_loaders(args) |
| |
| |
| trainer = MethaneTrainer(args) |
| trainer.fit(train_loader, val_loader) |