multimodalart HF Staff commited on
Commit
025a849
·
verified ·
1 Parent(s): ca637ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -26
app.py CHANGED
@@ -56,23 +56,29 @@ Rules:
56
 
57
  Output only the final instruction in plain text and nothing else."""
58
 
59
- # Model repository IDs
60
- REPO_ID_4B = "diffusers-internal-dev/dummy-1015-4b" # 4b model
61
- REPO_ID_9B = "diffusers-internal-dev/dummy-1015-9b" # 9b model
62
 
63
- # Load both models
64
- print("Loading 4B model...")
65
- pipe_4b = Flux2KleinPipeline.from_pretrained(REPO_ID_4B, torch_dtype=dtype)
66
- pipe_4b.to("cuda")
67
 
68
- print("Loading 9B model...")
69
- pipe_9b = Flux2KleinPipeline.from_pretrained(REPO_ID_9B, torch_dtype=dtype)
70
- pipe_9b.to("cuda")
71
 
72
  # Dictionary for easy access
73
  pipes = {
74
- "4B": pipe_4b,
75
- "9B": pipe_9b,
 
 
 
 
 
 
76
  }
77
 
78
 
@@ -153,14 +159,19 @@ def update_dimensions_from_image(image_list):
153
  return new_width, new_height
154
 
155
 
 
 
 
 
 
156
  @spaces.GPU(duration=85)
157
- def infer(prompt, input_images=None, model_choice="4B", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
158
 
159
  if randomize_seed:
160
  seed = random.randint(0, MAX_SEED)
161
 
162
- # Select the appropriate pipeline based on model choice
163
- pipe = pipes[model_choice]
164
 
165
  # Prepare image list (convert None or empty gallery to None)
166
  image_list = None
@@ -178,7 +189,7 @@ def infer(prompt, input_images=None, model_choice="4B", seed=42, randomize_seed=
178
  print(f"Upsampled Prompt: {final_prompt}")
179
 
180
  # 2. Image Generation
181
- progress(0.2, desc=f"Generating image with {model_choice} model...")
182
 
183
  generator = torch.Generator(device=device).manual_seed(seed)
184
 
@@ -221,11 +232,11 @@ css = """
221
  }
222
  """
223
 
224
- with gr.Blocks() as demo:
225
 
226
  with gr.Column(elem_id="col-container"):
227
- gr.Markdown(f"""# FLUX.2 [Klein]
228
- FLUX.2 [Klein] is a distilled model capable of generating, editing and combining images based on text instructions [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
229
  """)
230
  with gr.Row():
231
  with gr.Column():
@@ -249,10 +260,10 @@ FLUX.2 [Klein] is a distilled model capable of generating, editing and combining
249
  rows=1,
250
  )
251
 
252
- model_choice = gr.Radio(
253
- label="Model Size",
254
- choices=["4B", "9B"],
255
- value="4B",
256
  )
257
 
258
  with gr.Accordion("Advanced Settings", open=False):
@@ -298,7 +309,7 @@ FLUX.2 [Klein] is a distilled model capable of generating, editing and combining
298
  minimum=1,
299
  maximum=100,
300
  step=1,
301
- value=50,
302
  )
303
 
304
  guidance_scale = gr.Slider(
@@ -338,12 +349,19 @@ FLUX.2 [Klein] is a distilled model capable of generating, editing and combining
338
  inputs=[input_images],
339
  outputs=[width, height]
340
  )
 
 
 
 
 
 
 
341
 
342
  gr.on(
343
  triggers=[run_button.click, prompt.submit],
344
  fn=infer,
345
- inputs=[prompt, input_images, model_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
346
  outputs=[result, seed]
347
  )
348
 
349
- demo.launch(css=css)
 
56
 
57
  Output only the final instruction in plain text and nothing else."""
58
 
59
+ # Model repository IDs for 4B
60
+ REPO_ID_REGULAR = "diffusers-internal-dev/dummy-1015-4b"
61
+ REPO_ID_DISTILLED = "diffusers-internal-dev/dummy-1015-4b-distilled"
62
 
63
+ # Load both 4B models
64
+ print("Loading 4B Regular model...")
65
+ pipe_regular = Flux2KleinPipeline.from_pretrained(REPO_ID_REGULAR, torch_dtype=dtype)
66
+ pipe_regular.to("cuda")
67
 
68
+ print("Loading 4B Distilled model...")
69
+ pipe_distilled = Flux2KleinPipeline.from_pretrained(REPO_ID_DISTILLED, torch_dtype=dtype)
70
+ pipe_distilled.to("cuda")
71
 
72
  # Dictionary for easy access
73
  pipes = {
74
+ "Distilled (4 steps)": pipe_distilled,
75
+ "Regular (30 steps)": pipe_regular,
76
+ }
77
+
78
+ # Default steps for each mode
79
+ DEFAULT_STEPS = {
80
+ "Distilled (4 steps)": 4,
81
+ "Regular (30 steps)": 30,
82
  }
83
 
84
 
 
159
  return new_width, new_height
160
 
161
 
162
+ def update_steps_from_mode(mode_choice):
163
+ """Update the number of inference steps based on the selected mode."""
164
+ return DEFAULT_STEPS[mode_choice]
165
+
166
+
167
  @spaces.GPU(duration=85)
168
+ def infer(prompt, input_images=None, mode_choice="Distilled (4 steps)", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
169
 
170
  if randomize_seed:
171
  seed = random.randint(0, MAX_SEED)
172
 
173
+ # Select the appropriate pipeline based on mode choice
174
+ pipe = pipes[mode_choice]
175
 
176
  # Prepare image list (convert None or empty gallery to None)
177
  image_list = None
 
189
  print(f"Upsampled Prompt: {final_prompt}")
190
 
191
  # 2. Image Generation
192
+ progress(0.2, desc=f"Generating image with 4B {mode_choice}...")
193
 
194
  generator = torch.Generator(device=device).manual_seed(seed)
195
 
 
232
  }
233
  """
234
 
235
+ with gr.Blocks(css=css) as demo:
236
 
237
  with gr.Column(elem_id="col-container"):
238
+ gr.Markdown(f"""# FLUX.2 [Klein] - 4B (Apache 2.0)
239
+ FLUX.2 [Klein] is a... [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
240
  """)
241
  with gr.Row():
242
  with gr.Column():
 
260
  rows=1,
261
  )
262
 
263
+ mode_choice = gr.Radio(
264
+ label="Mode",
265
+ choices=["Distilled (4 steps)", "Regular (30 steps)"],
266
+ value="Distilled (4 steps)",
267
  )
268
 
269
  with gr.Accordion("Advanced Settings", open=False):
 
309
  minimum=1,
310
  maximum=100,
311
  step=1,
312
+ value=4,
313
  )
314
 
315
  guidance_scale = gr.Slider(
 
349
  inputs=[input_images],
350
  outputs=[width, height]
351
  )
352
+
353
+ # Auto-update steps when mode changes
354
+ mode_choice.change(
355
+ fn=update_steps_from_mode,
356
+ inputs=[mode_choice],
357
+ outputs=[num_inference_steps]
358
+ )
359
 
360
  gr.on(
361
  triggers=[run_button.click, prompt.submit],
362
  fn=infer,
363
+ inputs=[prompt, input_images, mode_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
364
  outputs=[result, seed]
365
  )
366
 
367
+ demo.launch()