Spaces:
Runtime error
Runtime error
| import spaces # Must be imported first for ZeroGPU | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import FluxKontextPipeline, FluxImg2ImgPipeline | |
| from diffusers.utils import load_image | |
| from PIL import Image | |
| import os | |
| import gc | |
| import numpy as np | |
| import peft # Required for LoRA support | |
| import io, base64, requests | |
| os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN", "") | |
| # Style LoRA mapping | |
| STYLE_TYPE_LORA_DICT = { | |
| 'None': "", | |
| "3D_Chibi": "3D_Chibi_lora_weights.safetensors", | |
| "American_Cartoon": "American_Cartoon_lora_weights.safetensors", | |
| "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", | |
| "Clay_Toy": "Clay_Toy_lora_weights.safetensors", | |
| "Fabric": "Fabric_lora_weights.safetensors", | |
| "Ghibli": "Ghibli_lora_weights.safetensors", | |
| "Irasutoya": "Irasutoya_lora_weights.safetensors", | |
| "Jojo": "Jojo_lora_weights.safetensors", | |
| "Oil_Painting": "Oil_Painting_lora_weights.safetensors", | |
| "Pixel": "Pixel_lora_weights.safetensors", | |
| "Snoopy": "Snoopy_lora_weights.safetensors", | |
| "Poly": "Poly_lora_weights.safetensors", | |
| "LEGO": "LEGO_lora_weights.safetensors", | |
| "Origami": "Origami_lora_weights.safetensors", | |
| "Pop_Art": "Pop_Art_lora_weights.safetensors", | |
| "Van_Gogh": "Van_Gogh_lora_weights.safetensors", | |
| "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", | |
| "Line": "Line_lora_weights.safetensors", | |
| "Vector": "Vector_lora_weights.safetensors", | |
| "Picasso": "Picasso_lora_weights.safetensors", | |
| "Macaron": "Macaron_lora_weights.safetensors", | |
| "Rick_Morty": "Rick_Morty_lora_weights.safetensors" | |
| } | |
| # Helper function to convert input image to PIL image | |
| def _to_pil(x): | |
| # Already a PIL image from the UI? | |
| if isinstance(x, Image.Image): | |
| return x.convert("RGB") | |
| # Dict from API client: {"data": "<raw base64>", "name": "..."} or data URL | |
| if isinstance(x, dict) and "data" in x: | |
| b64 = x["data"] | |
| # and b64.startswith("data:image") | |
| if isinstance(b64, str): | |
| b64 = b64.split(",", 1)[1] | |
| return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") | |
| # String: could be URL, data URL, or server-side filepath | |
| if isinstance(x, str): | |
| if x.startswith("http://") or x.startswith("https://"): | |
| r = requests.get(x, timeout=20) | |
| r.raise_for_status() | |
| return Image.open(io.BytesIO(r.content)).convert("RGB") | |
| if x.startswith("data:image"): | |
| b64 = x.split(",", 1)[1] | |
| return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") | |
| if os.path.exists(x): | |
| return Image.open(x).convert("RGB") | |
| raise ValueError("Unsupported image input. Provide a PIL image, filepath/URL, or {data: <base64>}.") | |
| # Global variables for pipeline management | |
| pipeline = None | |
| current_lora = None | |
| BACKGROUND_LORA_REPO = "peteromallet/Flux-Kontext-InScene" | |
| def load_pipeline(): | |
| """Load the base FLUX Kontext pipeline""" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| global pipeline | |
| if pipeline is None: | |
| print("Loading FLUX Kontext pipeline...") | |
| try: | |
| # Try FluxImg2ImgPipeline first | |
| pipeline = FluxKontextPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Kontext-dev", | |
| torch_dtype=dtype, | |
| token=os.environ.get("HF_TOKEN"), | |
| ).to('cuda') | |
| pipeline.load_lora_weights( | |
| BACKGROUND_LORA_REPO, | |
| token=os.environ.get("HF_TOKEN"), | |
| adapter_name="background_lora" | |
| ) | |
| pipeline.set_adapters(["background_lora"], adapter_weights=[0.6]) | |
| print("Pipeline loaded successfully with FluxImg2ImgPipeline!") | |
| except Exception as e: | |
| print(f"FluxImg2ImgPipeline failed: {e}") | |
| print("Trying with regular FluxPipeline...") | |
| # Fallback to regular FluxPipeline | |
| pipeline = FluxImg2ImgPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Kontext-dev", | |
| torch_dtype=dtype, | |
| token=os.environ.get("HF_TOKEN"), | |
| ).to('cuda') | |
| print("Pipeline loaded successfully with FluxPipeline!") | |
| return pipeline | |
| def download_lora(style_name): | |
| """Download LoRA weights if not already cached""" | |
| if style_name == "None": | |
| return None | |
| lora_filename = STYLE_TYPE_LORA_DICT[style_name] | |
| local_path = f"./LoRAs/{lora_filename}" | |
| if not os.path.exists(local_path): | |
| print(f"Downloading LoRA for {style_name}...") | |
| os.makedirs("./LoRAs", exist_ok=True) | |
| hf_hub_download( | |
| repo_id="Owen777/Kontext-Style-Loras", | |
| filename=lora_filename, | |
| local_dir="./LoRAs" | |
| ) | |
| print(f"LoRA downloaded: {local_path}") | |
| return local_path | |
| def generate_styled_image( | |
| input_image, | |
| input_base64, | |
| style_name, | |
| custom_prompt="", | |
| num_inference_steps=24, | |
| guidance_scale=7.5, | |
| width=1024, | |
| height=1024, | |
| seed=-1, | |
| ): | |
| """Generate styled image using FLUX Kontext with LoRA""" | |
| global pipeline, current_lora | |
| try: | |
| # Load pipeline if not loaded | |
| pipeline = load_pipeline() | |
| # Handle LoRA loading based on style selection | |
| if current_lora != style_name: | |
| # If switching to "None", just use background LoRA | |
| if style_name == "None": | |
| # Unload style LoRA if any was loaded | |
| if current_lora is not None and current_lora != "None": | |
| try: | |
| pipeline.delete_adapters(["style_lora"]) | |
| except: | |
| pass | |
| # Set only background LoRA | |
| pipeline.set_adapters(["background_lora"], adapter_weights=[0.6]) | |
| current_lora = style_name | |
| print("Using only background LoRA (no style applied)") | |
| else: | |
| # Download and load style LoRA | |
| lora_path = download_lora(style_name) | |
| # Remove previous style LoRA if any | |
| if current_lora is not None and current_lora != "None": | |
| try: | |
| pipeline.delete_adapters(["style_lora"]) | |
| except: | |
| pass | |
| # Load new style LoRA | |
| try: | |
| pipeline.load_lora_weights(lora_path, adapter_name="style_lora") | |
| # Set both background and style LoRAs | |
| pipeline.set_adapters(["background_lora", "style_lora"], adapter_weights=[0.4, 1.0]) | |
| current_lora = style_name | |
| print(f"Loaded style LoRA: {style_name}") | |
| except Exception as e: | |
| print(f"Error loading LoRA {style_name}: {str(e)}") | |
| # Fallback to just background LoRA | |
| pipeline.set_adapters(["background_lora"], adapter_weights=[0.6]) | |
| raise e | |
| # Note: When style hasn't changed, adapters are already set correctly from previous call | |
| # Prepare input image | |
| # Normalize the image | |
| if input_image is not None: | |
| img = _to_pil(input_image) # will receive PIL from UI | |
| elif input_base64: | |
| # accept either raw b64 or data URL | |
| b64 = input_base64 | |
| if b64.startswith("data:image"): | |
| b64 = b64.split(",", 1)[1] | |
| img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") | |
| else: | |
| raise ValueError("Please provide an image (upload) or input_base64.") | |
| # input_image = input_image.convert("RGB") | |
| # Prepare prompt | |
| if custom_prompt.strip(): | |
| prompt = custom_prompt | |
| else: | |
| prompt = f"Turn this image into the {style_name.replace('_', ' ')} style." | |
| # Set seed for reproducibility | |
| if seed != -1: | |
| torch.manual_seed(seed) | |
| # Generate image | |
| print(f"Generating image with style: {style_name}") | |
| print(f"Prompt: {prompt}") | |
| # Generate image with proper error handling | |
| try: | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| result = pipeline( | |
| image=img, | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None | |
| ) | |
| output_image = result.images[0] | |
| # Validate the output image | |
| if output_image is None: | |
| raise ValueError("Generated image is None") | |
| # Convert to RGB if needed and ensure it's valid | |
| if output_image.mode != 'RGB': | |
| output_image = output_image.convert('RGB') | |
| # Additional validation - check if image has valid pixel values | |
| img_array = np.array(output_image) | |
| if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)): | |
| print("Warning: Generated image contains invalid values, attempting to fix...") | |
| # Clamp values to valid range | |
| img_array = np.clip(img_array, 0, 255) | |
| output_image = Image.fromarray(img_array.astype(np.uint8)) | |
| except Exception as generation_error: | |
| print(f"Error during generation: {str(generation_error)}") | |
| # Try with different parameters as fallback | |
| print("Attempting generation with fallback parameters...") | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| result = pipeline( | |
| image=img, | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=max(15, num_inference_steps // 2), | |
| guidance_scale=min(guidance_scale, 7.0), | |
| ) | |
| output_image = result.images[0] | |
| if output_image.mode != 'RGB': | |
| output_image = output_image.convert('RGB') | |
| # Clean up GPU memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return output_image | |
| except Exception as e: | |
| print(f"Error generating image: {str(e)}") | |
| return None | |
| # Custom CSS for better UI | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Helvetica Neue', Arial, sans-serif; | |
| } | |
| .title { | |
| text-align: center; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 1em; | |
| color: #2c3e50; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| font-size: 1.2em; | |
| color: #7f8c8d; | |
| margin-bottom: 2em; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| gr.HTML('<div class="title">🎨 FLUX Kontext Style Transfer</div>') | |
| gr.HTML('<div class="subtitle">Transform your images with 20+ artistic styles using LoRA adapters</div>') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| height=400, | |
| type="pil", | |
| ) | |
| style_dropdown = gr.Dropdown( | |
| choices=list(STYLE_TYPE_LORA_DICT.keys()), | |
| label="Choose Style", | |
| value="None", | |
| interactive=True | |
| ) | |
| custom_prompt = gr.Textbox( | |
| label="Custom Prompt (Discouraged when selecting a style)", | |
| placeholder="Leave empty to use default style prompt", | |
| lines=2 | |
| ) | |
| input_base64 = gr.Textbox( | |
| label="Base64 / data URL (optional, for API callers)", | |
| placeholder="data:image/png;base64,... or raw base64" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| num_inference_steps = gr.Slider( | |
| minimum=10, | |
| maximum=50, | |
| value=24, | |
| step=1, | |
| label="Inference Steps" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=20.0, | |
| value=7.5, | |
| step=0.1, | |
| label="Guidance Scale" | |
| ) | |
| # lora_strength = gr.Slider( | |
| # minimum=0.1, | |
| # maximum=2.0, | |
| # value=1.0, | |
| # step=0.1, | |
| # label="LoRA Strength" | |
| # ) | |
| # img2img_strength = gr.Slider( | |
| # minimum=0.1, | |
| # maximum=1.0, | |
| # value=0.8, | |
| # step=0.05, | |
| # label="Transformation Strength" | |
| # ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| minimum=512, | |
| maximum=1536, | |
| value=1024, | |
| step=64, | |
| label="Width" | |
| ) | |
| height = gr.Slider( | |
| minimum=512, | |
| maximum=1536, | |
| value=1024, | |
| step=64, | |
| label="Height" | |
| ) | |
| seed = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=-1, | |
| precision=0 | |
| ) | |
| generate_btn = gr.Button("🎨 Generate Styled Image", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Output") | |
| output_image = gr.Image( | |
| label="Styled Image", | |
| height=400 | |
| ) | |
| # Example images | |
| gr.Markdown("### Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "None", "", 24, 7.5, 1024, 1024, -1], | |
| ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Ghibli", "", 24, 7.5, 1024, 1024, -1], | |
| ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Pixel", "", 24, 7.5, 1024, 1024, -1], | |
| ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Van_Gogh", "", 24, 7.5, 1024, 1024, -1], | |
| ], | |
| inputs=[input_image, input_base64, style_dropdown, custom_prompt, num_inference_steps, guidance_scale, width, height, seed], | |
| outputs=[output_image], | |
| fn=generate_styled_image, | |
| cache_examples=False, | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_styled_image, | |
| inputs=[ | |
| input_image, | |
| input_base64, | |
| style_dropdown, | |
| custom_prompt, | |
| num_inference_steps, | |
| guidance_scale, | |
| # lora_strength, | |
| # img2img_strength, | |
| width, | |
| height, | |
| seed | |
| ], | |
| outputs=[output_image] | |
| ) | |
| # Information section | |
| with gr.Accordion("About", open=False): | |
| gr.Markdown(""" | |
| ### FLUX Kontext Style Transfer | |
| This application uses the FLUX.1 Kontext model with dual LoRA adapters: a background LoRA for scene understanding and optional style LoRAs for artistic transformation. | |
| **Available Styles:** | |
| - **None**: Uses only the background LoRA for natural scene enhancement | |
| - 3D Chibi, American Cartoon, Chinese Ink, Clay Toy | |
| - Fabric, Ghibli, Irasutoya, Jojo, Oil Painting | |
| - Pixel, Snoopy, Poly, LEGO, Origami | |
| - Pop Art, Van Gogh, Paper Cutting, Line, Vector | |
| - Picasso, Macaron, Rick & Morty | |
| **Tips:** | |
| - Upload high-quality images for best results | |
| - Select "None" for natural scene enhancement without style transfer | |
| - Try different style LoRAs for various artistic effects | |
| - Use custom prompts for more specific styling | |
| - Higher inference steps generally produce better quality | |
| **Model:** [Owen777/Kontext-Style-Loras](https://huggingface.co/Owen777/Kontext-Style-Loras) | |
| **Training Code:** [GitHub Repository](https://github.com/Owen718/Kontext-Lora-Trainer) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |