| | import torch |
| | import torchvision |
| | from torchvision.utils import save_image, make_grid |
| | import os |
| | import argparse |
| | from datetime import datetime |
| | from config import Config |
| | from model import SmoothDiffusionUNet |
| | from noise_scheduler_simple import FrequencyAwareNoise |
| | from sample_simple import simple_sample |
| |
|
| | def load_model(checkpoint_path, device): |
| | """Load model from checkpoint""" |
| | print(f"Loading model from: {checkpoint_path}") |
| | |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | |
| | |
| | if 'config' in checkpoint: |
| | config = checkpoint['config'] |
| | else: |
| | config = Config() |
| | |
| | model = SmoothDiffusionUNet(config).to(device) |
| | noise_scheduler = FrequencyAwareNoise(config) |
| | |
| | |
| | if 'model_state_dict' in checkpoint: |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | epoch = checkpoint.get('epoch', 'unknown') |
| | loss = checkpoint.get('loss', 'unknown') |
| | print(f"Loaded model from epoch {epoch}, loss: {loss}") |
| | else: |
| | |
| | model.load_state_dict(checkpoint) |
| | print("Loaded model state dict") |
| | |
| | return model, noise_scheduler, config |
| |
|
| | def test_checkpoint(checkpoint_path, device, n_samples=16): |
| | """Test a single checkpoint with working sampler""" |
| | model, noise_scheduler, config = load_model(checkpoint_path, device) |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | save_path = f"test_samples_simple_{timestamp}.png" |
| | |
| | print(f"Testing checkpoint with {n_samples} samples...") |
| | samples, grid = simple_sample(model, noise_scheduler, device, n_samples=n_samples) |
| | |
| | |
| | save_image(grid, save_path, normalize=False) |
| | print(f"Samples saved to: {save_path}") |
| | |
| | return samples, grid |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Test trained diffusion model (simple version)') |
| | parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file') |
| | parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate') |
| | parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if args.device == 'auto': |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | else: |
| | device = torch.device(args.device) |
| | |
| | print(f"Using device: {device}") |
| | |
| | |
| | print("=== Testing Checkpoint with Simple DDPM ===") |
| | test_checkpoint(args.checkpoint, device, args.n_samples) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|