Ekenayy's picture
change params order
8cdea10
import spaces # Must be imported first for ZeroGPU
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxKontextPipeline, FluxImg2ImgPipeline
from diffusers.utils import load_image
from PIL import Image
import os
import gc
import numpy as np
import peft # Required for LoRA support
import io, base64, requests
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN", "")
# Style LoRA mapping
STYLE_TYPE_LORA_DICT = {
'None': "",
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami": "Origami_lora_weights.safetensors",
"Pop_Art": "Pop_Art_lora_weights.safetensors",
"Van_Gogh": "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
"Line": "Line_lora_weights.safetensors",
"Vector": "Vector_lora_weights.safetensors",
"Picasso": "Picasso_lora_weights.safetensors",
"Macaron": "Macaron_lora_weights.safetensors",
"Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}
# Helper function to convert input image to PIL image
def _to_pil(x):
# Already a PIL image from the UI?
if isinstance(x, Image.Image):
return x.convert("RGB")
# Dict from API client: {"data": "<raw base64>", "name": "..."} or data URL
if isinstance(x, dict) and "data" in x:
b64 = x["data"]
# and b64.startswith("data:image")
if isinstance(b64, str):
b64 = b64.split(",", 1)[1]
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
# String: could be URL, data URL, or server-side filepath
if isinstance(x, str):
if x.startswith("http://") or x.startswith("https://"):
r = requests.get(x, timeout=20)
r.raise_for_status()
return Image.open(io.BytesIO(r.content)).convert("RGB")
if x.startswith("data:image"):
b64 = x.split(",", 1)[1]
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
if os.path.exists(x):
return Image.open(x).convert("RGB")
raise ValueError("Unsupported image input. Provide a PIL image, filepath/URL, or {data: <base64>}.")
# Global variables for pipeline management
pipeline = None
current_lora = None
BACKGROUND_LORA_REPO = "peteromallet/Flux-Kontext-InScene"
def load_pipeline():
"""Load the base FLUX Kontext pipeline"""
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
global pipeline
if pipeline is None:
print("Loading FLUX Kontext pipeline...")
try:
# Try FluxImg2ImgPipeline first
pipeline = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=dtype,
token=os.environ.get("HF_TOKEN"),
).to('cuda')
pipeline.load_lora_weights(
BACKGROUND_LORA_REPO,
token=os.environ.get("HF_TOKEN"),
adapter_name="background_lora"
)
pipeline.set_adapters(["background_lora"], adapter_weights=[0.6])
print("Pipeline loaded successfully with FluxImg2ImgPipeline!")
except Exception as e:
print(f"FluxImg2ImgPipeline failed: {e}")
print("Trying with regular FluxPipeline...")
# Fallback to regular FluxPipeline
pipeline = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=dtype,
token=os.environ.get("HF_TOKEN"),
).to('cuda')
print("Pipeline loaded successfully with FluxPipeline!")
return pipeline
def download_lora(style_name):
"""Download LoRA weights if not already cached"""
if style_name == "None":
return None
lora_filename = STYLE_TYPE_LORA_DICT[style_name]
local_path = f"./LoRAs/{lora_filename}"
if not os.path.exists(local_path):
print(f"Downloading LoRA for {style_name}...")
os.makedirs("./LoRAs", exist_ok=True)
hf_hub_download(
repo_id="Owen777/Kontext-Style-Loras",
filename=lora_filename,
local_dir="./LoRAs"
)
print(f"LoRA downloaded: {local_path}")
return local_path
@spaces.GPU
def generate_styled_image(
input_image,
input_base64,
style_name,
custom_prompt="",
num_inference_steps=24,
guidance_scale=7.5,
width=1024,
height=1024,
seed=-1,
):
"""Generate styled image using FLUX Kontext with LoRA"""
global pipeline, current_lora
try:
# Load pipeline if not loaded
pipeline = load_pipeline()
# Handle LoRA loading based on style selection
if current_lora != style_name:
# If switching to "None", just use background LoRA
if style_name == "None":
# Unload style LoRA if any was loaded
if current_lora is not None and current_lora != "None":
try:
pipeline.delete_adapters(["style_lora"])
except:
pass
# Set only background LoRA
pipeline.set_adapters(["background_lora"], adapter_weights=[0.6])
current_lora = style_name
print("Using only background LoRA (no style applied)")
else:
# Download and load style LoRA
lora_path = download_lora(style_name)
# Remove previous style LoRA if any
if current_lora is not None and current_lora != "None":
try:
pipeline.delete_adapters(["style_lora"])
except:
pass
# Load new style LoRA
try:
pipeline.load_lora_weights(lora_path, adapter_name="style_lora")
# Set both background and style LoRAs
pipeline.set_adapters(["background_lora", "style_lora"], adapter_weights=[0.4, 1.0])
current_lora = style_name
print(f"Loaded style LoRA: {style_name}")
except Exception as e:
print(f"Error loading LoRA {style_name}: {str(e)}")
# Fallback to just background LoRA
pipeline.set_adapters(["background_lora"], adapter_weights=[0.6])
raise e
# Note: When style hasn't changed, adapters are already set correctly from previous call
# Prepare input image
# Normalize the image
if input_image is not None:
img = _to_pil(input_image) # will receive PIL from UI
elif input_base64:
# accept either raw b64 or data URL
b64 = input_base64
if b64.startswith("data:image"):
b64 = b64.split(",", 1)[1]
img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
else:
raise ValueError("Please provide an image (upload) or input_base64.")
# input_image = input_image.convert("RGB")
# Prepare prompt
if custom_prompt.strip():
prompt = custom_prompt
else:
prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
# Set seed for reproducibility
if seed != -1:
torch.manual_seed(seed)
# Generate image
print(f"Generating image with style: {style_name}")
print(f"Prompt: {prompt}")
# Generate image with proper error handling
try:
with torch.autocast("cuda", dtype=torch.bfloat16):
result = pipeline(
image=img,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None
)
output_image = result.images[0]
# Validate the output image
if output_image is None:
raise ValueError("Generated image is None")
# Convert to RGB if needed and ensure it's valid
if output_image.mode != 'RGB':
output_image = output_image.convert('RGB')
# Additional validation - check if image has valid pixel values
img_array = np.array(output_image)
if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)):
print("Warning: Generated image contains invalid values, attempting to fix...")
# Clamp values to valid range
img_array = np.clip(img_array, 0, 255)
output_image = Image.fromarray(img_array.astype(np.uint8))
except Exception as generation_error:
print(f"Error during generation: {str(generation_error)}")
# Try with different parameters as fallback
print("Attempting generation with fallback parameters...")
with torch.autocast("cuda", dtype=torch.bfloat16):
result = pipeline(
image=img,
prompt=prompt,
height=height,
width=width,
num_inference_steps=max(15, num_inference_steps // 2),
guidance_scale=min(guidance_scale, 7.0),
)
output_image = result.images[0]
if output_image.mode != 'RGB':
output_image = output_image.convert('RGB')
# Clean up GPU memory
torch.cuda.empty_cache()
gc.collect()
return output_image
except Exception as e:
print(f"Error generating image: {str(e)}")
return None
# Custom CSS for better UI
css = """
.gradio-container {
font-family: 'Helvetica Neue', Arial, sans-serif;
}
.title {
text-align: center;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 1em;
color: #2c3e50;
}
.subtitle {
text-align: center;
font-size: 1.2em;
color: #7f8c8d;
margin-bottom: 2em;
}
"""
# Create Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.HTML('<div class="title">🎨 FLUX Kontext Style Transfer</div>')
gr.HTML('<div class="subtitle">Transform your images with 20+ artistic styles using LoRA adapters</div>')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
input_image = gr.Image(
label="Upload Image",
height=400,
type="pil",
)
style_dropdown = gr.Dropdown(
choices=list(STYLE_TYPE_LORA_DICT.keys()),
label="Choose Style",
value="None",
interactive=True
)
custom_prompt = gr.Textbox(
label="Custom Prompt (Discouraged when selecting a style)",
placeholder="Leave empty to use default style prompt",
lines=2
)
input_base64 = gr.Textbox(
label="Base64 / data URL (optional, for API callers)",
placeholder="data:image/png;base64,... or raw base64"
)
with gr.Accordion("Advanced Settings", open=False):
num_inference_steps = gr.Slider(
minimum=10,
maximum=50,
value=24,
step=1,
label="Inference Steps"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.1,
label="Guidance Scale"
)
# lora_strength = gr.Slider(
# minimum=0.1,
# maximum=2.0,
# value=1.0,
# step=0.1,
# label="LoRA Strength"
# )
# img2img_strength = gr.Slider(
# minimum=0.1,
# maximum=1.0,
# value=0.8,
# step=0.05,
# label="Transformation Strength"
# )
with gr.Row():
width = gr.Slider(
minimum=512,
maximum=1536,
value=1024,
step=64,
label="Width"
)
height = gr.Slider(
minimum=512,
maximum=1536,
value=1024,
step=64,
label="Height"
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("🎨 Generate Styled Image", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Output")
output_image = gr.Image(
label="Styled Image",
height=400
)
# Example images
gr.Markdown("### Examples")
gr.Examples(
examples=[
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "None", "", 24, 7.5, 1024, 1024, -1],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Ghibli", "", 24, 7.5, 1024, 1024, -1],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Pixel", "", 24, 7.5, 1024, 1024, -1],
["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Van_Gogh", "", 24, 7.5, 1024, 1024, -1],
],
inputs=[input_image, input_base64, style_dropdown, custom_prompt, num_inference_steps, guidance_scale, width, height, seed],
outputs=[output_image],
fn=generate_styled_image,
cache_examples=False,
)
# Event handlers
generate_btn.click(
fn=generate_styled_image,
inputs=[
input_image,
input_base64,
style_dropdown,
custom_prompt,
num_inference_steps,
guidance_scale,
# lora_strength,
# img2img_strength,
width,
height,
seed
],
outputs=[output_image]
)
# Information section
with gr.Accordion("About", open=False):
gr.Markdown("""
### FLUX Kontext Style Transfer
This application uses the FLUX.1 Kontext model with dual LoRA adapters: a background LoRA for scene understanding and optional style LoRAs for artistic transformation.
**Available Styles:**
- **None**: Uses only the background LoRA for natural scene enhancement
- 3D Chibi, American Cartoon, Chinese Ink, Clay Toy
- Fabric, Ghibli, Irasutoya, Jojo, Oil Painting
- Pixel, Snoopy, Poly, LEGO, Origami
- Pop Art, Van Gogh, Paper Cutting, Line, Vector
- Picasso, Macaron, Rick & Morty
**Tips:**
- Upload high-quality images for best results
- Select "None" for natural scene enhancement without style transfer
- Try different style LoRAs for various artistic effects
- Use custom prompts for more specific styling
- Higher inference steps generally produce better quality
**Model:** [Owen777/Kontext-Style-Loras](https://huggingface.co/Owen777/Kontext-Style-Loras)
**Training Code:** [GitHub Repository](https://github.com/Owen718/Kontext-Lora-Trainer)
""")
if __name__ == "__main__":
demo.launch()