BiliSakura commited on
Commit
f461189
·
verified ·
1 Parent(s): e278359

Update SiT-L-2-256/pipeline.py

Browse files
Files changed (1) hide show
  1. SiT-L-2-256/pipeline.py +17 -1
SiT-L-2-256/pipeline.py CHANGED
@@ -4,6 +4,8 @@ Load with native Hugging Face diffusers and trust_remote_code=True.
4
 
5
  from __future__ import annotations
6
 
 
 
7
  # Copyright 2026 The HuggingFace Team. All rights reserved.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,7 +22,7 @@ from __future__ import annotations
20
 
21
  import json
22
  from pathlib import Path
23
- from typing import Dict, List, Optional, Tuple, Union
24
 
25
  import torch
26
 
@@ -76,6 +78,20 @@ class SiTPipeline(DiffusionPipeline):
76
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
77
  """
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  model_cpu_offload_seq = "transformer->vae"
80
 
81
  def __init__(
 
4
 
5
  from __future__ import annotations
6
 
7
+ import inspect
8
+
9
  # Copyright 2026 The HuggingFace Team. All rights reserved.
10
  #
11
  # Licensed under the Apache License, Version 2.0 (the "License");
 
22
 
23
  import json
24
  from pathlib import Path
25
+ from typing import Dict, List, Optional, Tuple, Union, Any
26
 
27
  import torch
28
 
 
78
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
79
  """
80
 
81
+ @staticmethod
82
+ def prepare_extra_step_kwargs(
83
+ scheduler,
84
+ generator=None,
85
+ eta: float | None = None,
86
+ ):
87
+ kwargs = {}
88
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
89
+ if "generator" in step_params:
90
+ kwargs["generator"] = generator
91
+ if eta is not None and "eta" in step_params:
92
+ kwargs["eta"] = eta
93
+ return kwargs
94
+
95
  model_cpu_offload_seq = "transformer->vae"
96
 
97
  def __init__(