| | import torch |
| | from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer |
| | from diffusers import AutoencoderKL |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| |
|
| | |
| | t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") |
| | t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16) |
| |
|
| | clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| | clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16) |
| |
|
| | |
| | vae = AutoencoderKL.from_pretrained( |
| | "black-forest-labs/FLUX.1-schnell", |
| | subfolder="vae", |
| | torch_dtype=torch.bfloat16 |
| | ).to("cuda") |
| |
|
| | |
| | model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py") |
| | exec(open(model_py).read()) |
| |
|
| | config = TinyFluxConfig( |
| | use_sol_prior=False, |
| | use_t5_vec=False, |
| | ) |
| | model = TinyFluxDeep(config).to("cuda", torch.bfloat16) |
| | weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors")) |
| | model.load_state_dict(weights, strict=False) |
| | model.eval() |
| |
|
| | def encode_prompt(prompt): |
| | """Encode prompt with both T5 and CLIP.""" |
| | |
| | t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to("cuda") |
| | with torch.no_grad(): |
| | t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16) |
| | |
| | |
| | clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length", |
| | max_length=77, truncation=True).to("cuda") |
| | with torch.no_grad(): |
| | clip_out = clip_model(**clip_tokens) |
| | clip_pooled = clip_out.pooler_output.to(torch.bfloat16) |
| | |
| | return t5_emb, clip_pooled |
| |
|
| |
|
| | def flux_shift(t, s=3.0): |
| | """Flux-style timestep shift.""" |
| | return s * t / (1 + (s - 1) * t) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None): |
| | """ |
| | Euler sampling for rectified flow. |
| | |
| | Flow matching formulation: |
| | x_t = (1 - t) * noise + t * data |
| | At t=0: pure noise |
| | At t=1: pure data |
| | Velocity v = data - noise (constant) |
| | |
| | Sampling: Integrate from t=0 (noise) → t=1 (data) |
| | """ |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | |
| | t5_emb, clip_pooled = encode_prompt(prompt) |
| | t5_null, clip_null = encode_prompt("") |
| | |
| | |
| | x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16) |
| | img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda") |
| | |
| | |
| | t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32) |
| | timesteps = flux_shift(t_linear, s=3.0) |
| | |
| | for i in range(num_steps): |
| | t_curr = timesteps[i] |
| | t_next = timesteps[i + 1] |
| | dt = t_next - t_curr |
| | |
| | t_batch = t_curr.unsqueeze(0) |
| | |
| | |
| | v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids) |
| | v_uncond = model(x, t5_null, clip_null, t_batch, img_ids) |
| | |
| | |
| | v = v_uncond + cfg_scale * (v_cond - v_uncond) |
| | |
| | |
| | x = x + v * dt |
| | |
| | |
| | x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) |
| | x = x / vae.config.scaling_factor |
| | image = vae.decode(x).sample |
| | |
| | |
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | image = image[0].permute(1, 2, 0).cpu().float().numpy() |
| | image = (image * 255).astype("uint8") |
| | |
| | from PIL import Image |
| | return Image.fromarray(image) |
| |
|
| |
|
| | |
| | image = generate_image("a photograph of a tiger in natural habitat", seed=42) |
| | image.save("tiger.png") |
| | image |