| | import argparse |
| | import os |
| | import yaml |
| | import glob |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torch.utils.data as data |
| | from torch.utils.tensorboard import SummaryWriter |
| | import numpy as np |
| | from models.unet import DiffusionUNet |
| | from diff2flow import dict2namespace |
| | import utils.logging |
| |
|
| |
|
| | class ReflowDataset(data.Dataset): |
| | def __init__(self, data_dir): |
| | super().__init__() |
| | self.files = sorted(glob.glob(os.path.join(data_dir, "*.pth"))) |
| | print(f"Found {len(self.files)} files in {data_dir}") |
| |
|
| | def __len__(self): |
| | |
| | |
| | return len(self.files) |
| |
|
| | def __getitem__(self, index): |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | path = self.files[index] |
| | data_dict = torch.load(path) |
| | return data_dict |
| |
|
| |
|
| | def train_reflow(args, config): |
| | device = config.device |
| |
|
| | |
| | writer = SummaryWriter(log_dir=os.path.join(args.output, "logs")) |
| |
|
| | |
| | print("Loading model...") |
| | model = DiffusionUNet(config) |
| | model.to(device) |
| |
|
| | |
| | if args.resume: |
| | print(f"Loading pretrained weights from {args.resume}") |
| | checkpoint = torch.load(args.resume, map_location=device) |
| | if "state_dict" in checkpoint: |
| | state_dict = checkpoint["state_dict"] |
| | else: |
| | state_dict = checkpoint |
| |
|
| | |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if k.startswith("module."): |
| | new_state_dict[k[7:]] = v |
| | else: |
| | new_state_dict[k] = v |
| | model.load_state_dict(new_state_dict, strict=True) |
| |
|
| | optimizer = optim.Adam(model.parameters(), lr=config.optim.lr) |
| |
|
| | |
| | dataset = ReflowDataset(args.data_dir_reflow) |
| | |
| | loader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4) |
| |
|
| | model.train() |
| |
|
| | print("Starting training...") |
| |
|
| | step = 0 |
| | N = config.diffusion.num_diffusion_timesteps |
| |
|
| | for epoch in range(args.epochs): |
| | for i, batch_dict in enumerate(loader): |
| | |
| | x_0 = batch_dict["x_data"].squeeze(0).to(device) |
| | x_1 = batch_dict["x_noise"].squeeze(0).to(device) |
| | x_cond = batch_dict["x_cond"].squeeze(0).to(device) |
| |
|
| | B = x_0.shape[0] |
| |
|
| | |
| | t = torch.rand(B, device=device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | t_expand = t.view(B, 1, 1, 1) |
| | x_t = (1 - t_expand) * x_0 + t_expand * x_1 |
| | v_target = x_1 - x_0 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | t_input = t * (N - 1) |
| |
|
| | |
| | |
| | model_input = torch.cat([x_cond, x_t], dim=1) |
| | v_pred = model(model_input, t_input) |
| |
|
| | |
| | loss = torch.mean((v_pred - v_target) ** 2) |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | if step % 10 == 0: |
| | print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.6f}") |
| | writer.add_scalar("Loss/train", loss.item(), step) |
| |
|
| | step += 1 |
| |
|
| | |
| | if (epoch + 1) % 5 == 0 or epoch == 0: |
| | save_path = os.path.join(args.output, f"reflow_ckpt_{epoch}.pth") |
| | torch.save(model.state_dict(), save_path) |
| | print(f"Saved checkpoint to {save_path}") |
| |
|
| | writer.close() |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True) |
| | parser.add_argument("--resume", type=str, default="") |
| | parser.add_argument("--data_dir_reflow", type=str, required=True) |
| | parser.add_argument("--epochs", type=int, default=10) |
| | parser.add_argument("--output", type=str, default="results/reflow_train") |
| | parser.add_argument("--seed", type=int, default=61) |
| | parser.add_argument("--lr", type=float, default=1e-5) |
| | args = parser.parse_args() |
| |
|
| | with open(os.path.join("configs", args.config), "r") as f: |
| | config_dict = yaml.safe_load(f) |
| | config = dict2namespace(config_dict) |
| |
|
| | if args.lr: |
| | config.optim.lr = args.lr |
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| | config.device = device |
| |
|
| | torch.manual_seed(args.seed) |
| | np.random.seed(args.seed) |
| | os.makedirs(args.output, exist_ok=True) |
| |
|
| | train_reflow(args, config) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|