| | import argparse, os, sys, glob |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from tqdm import tqdm |
| | import numpy as np |
| | import torch |
| | from main import instantiate_from_config |
| | from ldm.models.diffusion.ddim import DDIMSampler |
| |
|
| |
|
| | def make_batch(image, mask, device): |
| | image = np.array(Image.open(image).convert("RGB")) |
| | image = image.astype(np.float32)/255.0 |
| | image = image[None].transpose(0,3,1,2) |
| | image = torch.from_numpy(image) |
| |
|
| | mask = np.array(Image.open(mask).convert("L")) |
| | mask = mask.astype(np.float32)/255.0 |
| | mask = mask[None,None] |
| | mask[mask < 0.5] = 0 |
| | mask[mask >= 0.5] = 1 |
| | mask = torch.from_numpy(mask) |
| |
|
| | masked_image = (1-mask)*image |
| |
|
| | batch = {"image": image, "mask": mask, "masked_image": masked_image} |
| | for k in batch: |
| | batch[k] = batch[k].to(device=device) |
| | batch[k] = batch[k]*2.0-1.0 |
| | return batch |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--indir", |
| | type=str, |
| | nargs="?", |
| | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", |
| | ) |
| | parser.add_argument( |
| | "--outdir", |
| | type=str, |
| | nargs="?", |
| | help="dir to write results to", |
| | ) |
| | parser.add_argument( |
| | "--steps", |
| | type=int, |
| | default=50, |
| | help="number of ddim sampling steps", |
| | ) |
| | opt = parser.parse_args() |
| |
|
| | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) |
| | images = [x.replace("_mask.png", ".png") for x in masks] |
| | print(f"Found {len(masks)} inputs.") |
| |
|
| | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") |
| | model = instantiate_from_config(config.model) |
| | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], |
| | strict=False) |
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| | model = model.to(device) |
| | sampler = DDIMSampler(model) |
| |
|
| | os.makedirs(opt.outdir, exist_ok=True) |
| | with torch.no_grad(): |
| | with model.ema_scope(): |
| | for image, mask in tqdm(zip(images, masks)): |
| | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) |
| | batch = make_batch(image, mask, device=device) |
| |
|
| | |
| | c = model.cond_stage_model.encode(batch["masked_image"]) |
| | cc = torch.nn.functional.interpolate(batch["mask"], |
| | size=c.shape[-2:]) |
| | c = torch.cat((c, cc), dim=1) |
| |
|
| | shape = (c.shape[1]-1,)+c.shape[2:] |
| | samples_ddim, _ = sampler.sample(S=opt.steps, |
| | conditioning=c, |
| | batch_size=c.shape[0], |
| | shape=shape, |
| | verbose=False) |
| | x_samples_ddim = model.decode_first_stage(samples_ddim) |
| |
|
| | image = torch.clamp((batch["image"]+1.0)/2.0, |
| | min=0.0, max=1.0) |
| | mask = torch.clamp((batch["mask"]+1.0)/2.0, |
| | min=0.0, max=1.0) |
| | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, |
| | min=0.0, max=1.0) |
| |
|
| | inpainted = (1-mask)*image+mask*predicted_image |
| | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 |
| | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) |
| |
|