File size: 17,082 Bytes
48975c7
e635ff3
 
 
5de336d
e635ff3
 
 
 
7c690ea
6aed840
0cd60d0
 
e635ff3
fc07314
 
e635ff3
 
676b1ac
e635ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cd60d0
 
 
 
 
 
 
 
 
71cd01f
 
0cd60d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e635ff3
 
 
 
ff74c34
 
e635ff3
 
491c0cd
e635ff3
 
 
7c690ea
 
589afd6
7c690ea
 
ff74c34
7c690ea
ff74c34
 
 
 
 
 
 
 
7c690ea
 
 
 
 
7076fa3
7c690ea
 
 
 
 
e635ff3
 
 
 
676b1ac
 
 
e635ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cdea10
e635ff3
 
 
 
 
 
b003e83
e635ff3
 
 
 
 
 
 
 
676b1ac
e635ff3
676b1ac
 
 
 
 
 
 
 
 
 
6aed840
676b1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5de336d
676b1ac
 
 
 
 
 
 
 
e635ff3
 
c91fc10
 
 
 
 
 
 
 
 
 
 
0cd60d0
e635ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c690ea
 
 
 
c91fc10
7c690ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c3455
7c690ea
 
 
 
 
 
 
 
 
491c0cd
e635ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff9c156
c91fc10
e635ff3
 
 
 
 
676b1ac
e635ff3
 
 
 
c91fc10
e635ff3
 
 
 
39c3455
 
 
 
 
e635ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503b56b
 
 
 
 
 
 
e635ff3
503b56b
 
 
 
 
 
 
491c0cd
e635ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9613a
e635ff3
 
676b1ac
8cdea10
 
 
 
676b1ac
8cdea10
676b1ac
 
 
 
e635ff3
 
 
 
 
 
39c3455
e635ff3
 
 
 
7e9613a
 
e635ff3
 
 
 
 
 
 
 
 
 
 
 
676b1ac
e635ff3
 
676b1ac
e635ff3
 
 
 
 
 
 
 
676b1ac
 
e635ff3
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
import spaces  # Must be imported first for ZeroGPU
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxKontextPipeline, FluxImg2ImgPipeline
from diffusers.utils import load_image
from PIL import Image
import os
import gc
import numpy as np
import peft  # Required for LoRA support
import io, base64, requests


os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN", "")

# Style LoRA mapping
STYLE_TYPE_LORA_DICT = {
    'None': "",
    "3D_Chibi": "3D_Chibi_lora_weights.safetensors",
    "American_Cartoon": "American_Cartoon_lora_weights.safetensors",
    "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
    "Clay_Toy": "Clay_Toy_lora_weights.safetensors",
    "Fabric": "Fabric_lora_weights.safetensors",
    "Ghibli": "Ghibli_lora_weights.safetensors",
    "Irasutoya": "Irasutoya_lora_weights.safetensors",
    "Jojo": "Jojo_lora_weights.safetensors",
    "Oil_Painting": "Oil_Painting_lora_weights.safetensors",
    "Pixel": "Pixel_lora_weights.safetensors",
    "Snoopy": "Snoopy_lora_weights.safetensors",
    "Poly": "Poly_lora_weights.safetensors",
    "LEGO": "LEGO_lora_weights.safetensors",
    "Origami": "Origami_lora_weights.safetensors",
    "Pop_Art": "Pop_Art_lora_weights.safetensors",
    "Van_Gogh": "Van_Gogh_lora_weights.safetensors",
    "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors",
    "Line": "Line_lora_weights.safetensors",
    "Vector": "Vector_lora_weights.safetensors",
    "Picasso": "Picasso_lora_weights.safetensors",
    "Macaron": "Macaron_lora_weights.safetensors",
    "Rick_Morty": "Rick_Morty_lora_weights.safetensors"
}

# Helper function to convert input image to PIL image
def _to_pil(x):
    # Already a PIL image from the UI?
    if isinstance(x, Image.Image):
        return x.convert("RGB")

    # Dict from API client: {"data": "<raw base64>", "name": "..."} or data URL
    if isinstance(x, dict) and "data" in x:
        b64 = x["data"]
        #  and b64.startswith("data:image")
        if isinstance(b64, str):
            b64 = b64.split(",", 1)[1]
        return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")

    # String: could be URL, data URL, or server-side filepath
    if isinstance(x, str):
        if x.startswith("http://") or x.startswith("https://"):
            r = requests.get(x, timeout=20)
            r.raise_for_status()
            return Image.open(io.BytesIO(r.content)).convert("RGB")
        if x.startswith("data:image"):
            b64 = x.split(",", 1)[1]
            return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
        if os.path.exists(x):
            return Image.open(x).convert("RGB")

    raise ValueError("Unsupported image input. Provide a PIL image, filepath/URL, or {data: <base64>}.")

# Global variables for pipeline management
pipeline = None
current_lora = None

BACKGROUND_LORA_REPO = "peteromallet/Flux-Kontext-InScene"

def load_pipeline():
    """Load the base FLUX Kontext pipeline"""
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    global pipeline
    if pipeline is None:
        print("Loading FLUX Kontext pipeline...")
        try:
            # Try FluxImg2ImgPipeline first
            pipeline = FluxKontextPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-Kontext-dev", 
                torch_dtype=dtype,
                token=os.environ.get("HF_TOKEN"),
            ).to('cuda')

            pipeline.load_lora_weights(
                BACKGROUND_LORA_REPO,
                token=os.environ.get("HF_TOKEN"),
                adapter_name="background_lora"
            )
            pipeline.set_adapters(["background_lora"], adapter_weights=[0.6])

            print("Pipeline loaded successfully with FluxImg2ImgPipeline!")
        except Exception as e:
            print(f"FluxImg2ImgPipeline failed: {e}")
            print("Trying with regular FluxPipeline...")
            # Fallback to regular FluxPipeline
            pipeline = FluxImg2ImgPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-Kontext-dev", 
                torch_dtype=dtype,
                token=os.environ.get("HF_TOKEN"),
            ).to('cuda')
            print("Pipeline loaded successfully with FluxPipeline!")
    return pipeline

def download_lora(style_name):
    """Download LoRA weights if not already cached"""
    if style_name == "None":
        return None
        
    lora_filename = STYLE_TYPE_LORA_DICT[style_name]
    local_path = f"./LoRAs/{lora_filename}"
    
    if not os.path.exists(local_path):
        print(f"Downloading LoRA for {style_name}...")
        os.makedirs("./LoRAs", exist_ok=True)
        hf_hub_download(
            repo_id="Owen777/Kontext-Style-Loras", 
            filename=lora_filename, 
            local_dir="./LoRAs"
        )
        print(f"LoRA downloaded: {local_path}")
    return local_path

@spaces.GPU
def generate_styled_image(
    input_image, 
    input_base64, 
    style_name, 
    custom_prompt="", 
    num_inference_steps=24, 
    guidance_scale=7.5,
    width=1024,
    height=1024,
    seed=-1,
):
    """Generate styled image using FLUX Kontext with LoRA"""
    global pipeline, current_lora
    
    try:
        # Load pipeline if not loaded
        pipeline = load_pipeline()
        
        # Handle LoRA loading based on style selection
        if current_lora != style_name:
            # If switching to "None", just use background LoRA
            if style_name == "None":
                # Unload style LoRA if any was loaded
                if current_lora is not None and current_lora != "None":
                    try:
                        pipeline.delete_adapters(["style_lora"])
                    except:
                        pass
                # Set only background LoRA
                pipeline.set_adapters(["background_lora"], adapter_weights=[0.6])
                current_lora = style_name
                print("Using only background LoRA (no style applied)")
            else:
                # Download and load style LoRA
                lora_path = download_lora(style_name)
                
                # Remove previous style LoRA if any
                if current_lora is not None and current_lora != "None":
                    try:
                        pipeline.delete_adapters(["style_lora"])
                    except:
                        pass
                
                # Load new style LoRA
                try:
                    pipeline.load_lora_weights(lora_path, adapter_name="style_lora")
                    # Set both background and style LoRAs
                    pipeline.set_adapters(["background_lora", "style_lora"], adapter_weights=[0.4, 1.0])
                    current_lora = style_name
                    print(f"Loaded style LoRA: {style_name}")
                except Exception as e:
                    print(f"Error loading LoRA {style_name}: {str(e)}")
                    # Fallback to just background LoRA
                    pipeline.set_adapters(["background_lora"], adapter_weights=[0.6])
                    raise e
        # Note: When style hasn't changed, adapters are already set correctly from previous call
        
        # Prepare input image
        # Normalize the image
        if input_image is not None:
            img = _to_pil(input_image)  # will receive PIL from UI
        elif input_base64:
            # accept either raw b64 or data URL
            b64 = input_base64
            if b64.startswith("data:image"):
                b64 = b64.split(",", 1)[1]
            img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
        else:
            raise ValueError("Please provide an image (upload) or input_base64.")
        # input_image = input_image.convert("RGB")
        
        # Prepare prompt
        if custom_prompt.strip():
            prompt = custom_prompt
        else:
            prompt = f"Turn this image into the {style_name.replace('_', ' ')} style."
        
        # Set seed for reproducibility
        if seed != -1:
            torch.manual_seed(seed)
        
        # Generate image
        print(f"Generating image with style: {style_name}")
        print(f"Prompt: {prompt}")
        
        # Generate image with proper error handling
        try:
            with torch.autocast("cuda", dtype=torch.bfloat16):
                result = pipeline(
                    image=img,
                    prompt=prompt,
                    height=height,
                    width=width,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    generator=torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None
                )
            
            output_image = result.images[0]
            
            # Validate the output image
            if output_image is None:
                raise ValueError("Generated image is None")
            
            # Convert to RGB if needed and ensure it's valid
            if output_image.mode != 'RGB':
                output_image = output_image.convert('RGB')
                
            # Additional validation - check if image has valid pixel values
            img_array = np.array(output_image)
            if np.any(np.isnan(img_array)) or np.any(np.isinf(img_array)):
                print("Warning: Generated image contains invalid values, attempting to fix...")
                # Clamp values to valid range
                img_array = np.clip(img_array, 0, 255)
                output_image = Image.fromarray(img_array.astype(np.uint8))
                
        except Exception as generation_error:
            print(f"Error during generation: {str(generation_error)}")
            # Try with different parameters as fallback
            print("Attempting generation with fallback parameters...")
            with torch.autocast("cuda", dtype=torch.bfloat16):
                result = pipeline(
                    image=img,
                    prompt=prompt,
                    height=height,
                    width=width,
                    num_inference_steps=max(15, num_inference_steps // 2),
                    guidance_scale=min(guidance_scale, 7.0),
                )
            output_image = result.images[0]
            if output_image.mode != 'RGB':
                output_image = output_image.convert('RGB')
        
        # Clean up GPU memory
        torch.cuda.empty_cache()
        gc.collect()
        
        return output_image
        
    except Exception as e:
        print(f"Error generating image: {str(e)}")
        return None

# Custom CSS for better UI
css = """
.gradio-container {
    font-family: 'Helvetica Neue', Arial, sans-serif;
}
.title {
    text-align: center;
    font-size: 2.5em;
    font-weight: bold;
    margin-bottom: 1em;
    color: #2c3e50;
}
.subtitle {
    text-align: center;
    font-size: 1.2em;
    color: #7f8c8d;
    margin-bottom: 2em;
}
"""

# Create Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
    gr.HTML('<div class="title">🎨 FLUX Kontext Style Transfer</div>')
    gr.HTML('<div class="subtitle">Transform your images with 20+ artistic styles using LoRA adapters</div>')
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Input")
            input_image = gr.Image(
                label="Upload Image",
                height=400,
                type="pil",
            )
            
            style_dropdown = gr.Dropdown(
                choices=list(STYLE_TYPE_LORA_DICT.keys()),
                label="Choose Style",
                value="None",
                interactive=True
            )
            
            custom_prompt = gr.Textbox(
                label="Custom Prompt (Discouraged when selecting a style)",
                placeholder="Leave empty to use default style prompt",
                lines=2
            )
            
            input_base64 = gr.Textbox(
                label="Base64 / data URL (optional, for API callers)",
                placeholder="data:image/png;base64,... or raw base64"
            )

            with gr.Accordion("Advanced Settings", open=False):
                num_inference_steps = gr.Slider(
                    minimum=10,
                    maximum=50,
                    value=24,
                    step=1,
                    label="Inference Steps"
                )
                
                guidance_scale = gr.Slider(
                    minimum=1.0,
                    maximum=20.0,
                    value=7.5,
                    step=0.1,
                    label="Guidance Scale"
                )
                
                # lora_strength = gr.Slider(
                #     minimum=0.1,
                #     maximum=2.0,
                #     value=1.0,
                #     step=0.1,
                #     label="LoRA Strength"
                # )
                
                # img2img_strength = gr.Slider(
                #     minimum=0.1,
                #     maximum=1.0,
                #     value=0.8,
                #     step=0.05,
                #     label="Transformation Strength"
                # )
                
                with gr.Row():
                    width = gr.Slider(
                        minimum=512,
                        maximum=1536,
                        value=1024,
                        step=64,
                        label="Width"
                    )
                    height = gr.Slider(
                        minimum=512,
                        maximum=1536,
                        value=1024,
                        step=64,
                        label="Height"
                    )
                
                seed = gr.Number(
                    label="Seed (-1 for random)",
                    value=-1,
                    precision=0
                )
            
            generate_btn = gr.Button("🎨 Generate Styled Image", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            gr.Markdown("### Output")
            output_image = gr.Image(
                label="Styled Image",
                height=400
            )
    
    # Example images
    gr.Markdown("### Examples")
    gr.Examples(
         examples=[
             ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "None", "", 24, 7.5, 1024, 1024, -1],
             ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Ghibli", "", 24, 7.5, 1024, 1024, -1],
             ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Pixel", "", 24, 7.5, 1024, 1024, -1],
             ["https://huggingface.co/datasets/black-forest-labs/kontext-bench/resolve/main/test/images/0003.jpg", "", "Van_Gogh", "", 24, 7.5, 1024, 1024, -1],
         ],
         inputs=[input_image, input_base64, style_dropdown, custom_prompt, num_inference_steps, guidance_scale, width, height, seed],
         outputs=[output_image],
         fn=generate_styled_image,
         cache_examples=False,
     )
    
    # Event handlers
    generate_btn.click(
        fn=generate_styled_image,
        inputs=[
            input_image, 
            input_base64,
            style_dropdown, 
            custom_prompt, 
            num_inference_steps, 
            guidance_scale,
            # lora_strength,
            # img2img_strength,
            width,
            height,
            seed
        ],
        outputs=[output_image]
    )
    
    # Information section
    with gr.Accordion("About", open=False):
        gr.Markdown("""
        ### FLUX Kontext Style Transfer
        
        This application uses the FLUX.1 Kontext model with dual LoRA adapters: a background LoRA for scene understanding and optional style LoRAs for artistic transformation.
        
        **Available Styles:**
        - **None**: Uses only the background LoRA for natural scene enhancement
        - 3D Chibi, American Cartoon, Chinese Ink, Clay Toy
        - Fabric, Ghibli, Irasutoya, Jojo, Oil Painting
        - Pixel, Snoopy, Poly, LEGO, Origami
        - Pop Art, Van Gogh, Paper Cutting, Line, Vector
        - Picasso, Macaron, Rick & Morty
        
        **Tips:**
        - Upload high-quality images for best results
        - Select "None" for natural scene enhancement without style transfer
        - Try different style LoRAs for various artistic effects
        - Use custom prompts for more specific styling
        - Higher inference steps generally produce better quality
        
        **Model:** [Owen777/Kontext-Style-Loras](https://huggingface.co/Owen777/Kontext-Style-Loras)
        
        **Training Code:** [GitHub Repository](https://github.com/Owen718/Kontext-Lora-Trainer)
        """)

if __name__ == "__main__":
    demo.launch()