| | from __future__ import annotations |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from monai.utils import optional_import |
| | from torch.cuda.amp import autocast |
| |
|
| | tqdm, has_tqdm = optional_import("tqdm", name="tqdm") |
| |
|
| |
|
| | class Sampler: |
| | def __init__(self) -> None: |
| | super().__init__() |
| |
|
| | @torch.no_grad() |
| | def sampling_fn( |
| | self, |
| | noise: torch.Tensor, |
| | autoencoder_model: nn.Module, |
| | diffusion_model: nn.Module, |
| | scheduler: nn.Module, |
| | prompt_embeds: torch.Tensor, |
| | guidance_scale: float = 7.0, |
| | scale_factor: float = 0.3, |
| | ) -> torch.Tensor: |
| | if has_tqdm: |
| | progress_bar = tqdm(scheduler.timesteps) |
| | else: |
| | progress_bar = iter(scheduler.timesteps) |
| |
|
| | for t in progress_bar: |
| | noise_input = torch.cat([noise] * 2) |
| | model_output = diffusion_model( |
| | noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds |
| | ) |
| | noise_pred_uncond, noise_pred_text = model_output.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| | noise, _ = scheduler.step(noise_pred, t, noise) |
| |
|
| | with autocast(): |
| | sample = autoencoder_model.decode_stage_2_outputs(noise / scale_factor) |
| |
|
| | return sample |
| |
|