Update README.md
Browse files
README.md
CHANGED
|
@@ -58,6 +58,108 @@ Source code is available at https://github.com/NVlabs/Sana.
|
|
| 58 |
|
| 59 |
Refer to: https://github.com/NVlabs/Sana/blob/main/asset/docs/sana_video.md#1-inference-with-txt-file
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
### Model Description
|
| 62 |
|
| 63 |
- **Developed by:** NVIDIA, Sana
|
|
|
|
| 58 |
|
| 59 |
Refer to: https://github.com/NVlabs/Sana/blob/main/asset/docs/sana_video.md#1-inference-with-txt-file
|
| 60 |
|
| 61 |
+
# diffusers pipeline
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
"""Sana Video + LTX2 Refiner: Stage 1 generate latent β Stage 2 refine (3 steps)."""
|
| 65 |
+
|
| 66 |
+
import gc
|
| 67 |
+
import torch
|
| 68 |
+
from diffusers import SanaVideoPipeline, FlowMatchEulerDiscreteScheduler
|
| 69 |
+
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
|
| 70 |
+
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
| 71 |
+
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
|
| 72 |
+
from diffusers.pipelines.ltx2.export_utils import encode_video
|
| 73 |
+
|
| 74 |
+
device = "cuda"
|
| 75 |
+
dtype = torch.bfloat16
|
| 76 |
+
prompt = "A cat walking on the grass, facing the camera."
|
| 77 |
+
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
| 78 |
+
motion_score = 30
|
| 79 |
+
height, width, frames, frame_rate = 704, 1280, 81, 16.0
|
| 80 |
+
seed = 42
|
| 81 |
+
|
| 82 |
+
# ββ Load all models ββ
|
| 83 |
+
sana_pipe = SanaVideoPipeline.from_pretrained(
|
| 84 |
+
"Efficient-Large-Model/SANA-Video_2B_720p_diffusers", torch_dtype=dtype,
|
| 85 |
+
)
|
| 86 |
+
sana_pipe.text_encoder.to(dtype)
|
| 87 |
+
sana_pipe.enable_model_cpu_offload()
|
| 88 |
+
|
| 89 |
+
ltx_pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=dtype)
|
| 90 |
+
ltx_pipe.load_lora_weights(
|
| 91 |
+
"Lightricks/LTX-2", adapter_name="stage_2_distilled",
|
| 92 |
+
weight_name="ltx-2-19b-distilled-lora-384.safetensors",
|
| 93 |
+
)
|
| 94 |
+
ltx_pipe.set_adapters("stage_2_distilled", 1.0)
|
| 95 |
+
ltx_pipe.vae.enable_tiling()
|
| 96 |
+
ltx_pipe.enable_model_cpu_offload()
|
| 97 |
+
|
| 98 |
+
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
| 99 |
+
"Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=dtype,
|
| 100 |
+
)
|
| 101 |
+
upsample_pipe = LTX2LatentUpsamplePipeline(vae=ltx_pipe.vae, latent_upsampler=latent_upsampler)
|
| 102 |
+
upsample_pipe.enable_model_cpu_offload(device=device)
|
| 103 |
+
|
| 104 |
+
# ββ Stage 1: Sana Video ββ
|
| 105 |
+
video_latent = sana_pipe(
|
| 106 |
+
prompt=prompt + f" motion score: {motion_score}.", negative_prompt=negative_prompt,
|
| 107 |
+
height=height, width=width, frames=frames,
|
| 108 |
+
guidance_scale=6.0, num_inference_steps=50,
|
| 109 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
| 110 |
+
output_type="latent", return_dict=True,
|
| 111 |
+
).frames
|
| 112 |
+
|
| 113 |
+
del sana_pipe; gc.collect(); torch.cuda.empty_cache()
|
| 114 |
+
|
| 115 |
+
# ββ Stage 1.5: Latent Upsample (2x spatial) ββ
|
| 116 |
+
video_latent = upsample_pipe(
|
| 117 |
+
latents=video_latent.to(device=device, dtype=dtype),
|
| 118 |
+
latents_normalized=True,
|
| 119 |
+
height=height, width=width, num_frames=frames,
|
| 120 |
+
output_type="latent", return_dict=False,
|
| 121 |
+
)[0]
|
| 122 |
+
latents_mean = ltx_pipe.vae.latents_mean.view(1, -1, 1, 1, 1).to(video_latent.device, video_latent.dtype)
|
| 123 |
+
latents_std = ltx_pipe.vae.latents_std.view(1, -1, 1, 1, 1).to(video_latent.device, video_latent.dtype)
|
| 124 |
+
video_latent = (video_latent - latents_mean) * ltx_pipe.vae.config.scaling_factor / latents_std
|
| 125 |
+
|
| 126 |
+
# ββ Stage 2: LTX2 Refine ββ
|
| 127 |
+
packed = LTX2Pipeline._pack_latents(
|
| 128 |
+
video_latent.to(device=device, dtype=dtype),
|
| 129 |
+
patch_size=ltx_pipe.transformer_spatial_patch_size,
|
| 130 |
+
patch_size_t=ltx_pipe.transformer_temporal_patch_size,
|
| 131 |
+
)
|
| 132 |
+
_, _, lF, lH, lW = video_latent.shape
|
| 133 |
+
pH, pW, pT = lH * ltx_pipe.vae_spatial_compression_ratio, lW * ltx_pipe.vae_spatial_compression_ratio, (lF - 1) * ltx_pipe.vae_temporal_compression_ratio + 1
|
| 134 |
+
|
| 135 |
+
dur = pT / frame_rate
|
| 136 |
+
audio_frames = round(dur * ltx_pipe.audio_sampling_rate / ltx_pipe.audio_hop_length / ltx_pipe.audio_vae_temporal_compression_ratio)
|
| 137 |
+
nch = ltx_pipe.audio_vae.config.latent_channels
|
| 138 |
+
mel = ltx_pipe.audio_vae.config.mel_bins // ltx_pipe.audio_vae_mel_compression_ratio
|
| 139 |
+
audio_latent = (
|
| 140 |
+
ltx_pipe.audio_vae.latents_mean.unsqueeze(0).unsqueeze(0)
|
| 141 |
+
.expand(1, audio_frames, nch * mel).to(dtype=dtype, device=device).contiguous()
|
| 142 |
+
.unflatten(2, (nch, mel)).permute(0, 2, 1, 3).contiguous()
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
del video_latent; gc.collect(); torch.cuda.empty_cache()
|
| 146 |
+
|
| 147 |
+
video, _ = ltx_pipe(
|
| 148 |
+
latents=packed, audio_latents=audio_latent,
|
| 149 |
+
prompt=prompt, negative_prompt=negative_prompt,
|
| 150 |
+
height=pH, width=pW, num_frames=pT,
|
| 151 |
+
num_inference_steps=3,
|
| 152 |
+
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
|
| 153 |
+
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
|
| 154 |
+
guidance_scale=1.0, frame_rate=frame_rate,
|
| 155 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
| 156 |
+
output_type="np", return_dict=False,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
video = torch.from_numpy((video * 255).round().astype("uint8"))
|
| 160 |
+
encode_video(video[0], fps=frame_rate, audio=None, audio_sample_rate=None, output_path="sana_ltx2_refined.mp4")
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
### Model Description
|
| 164 |
|
| 165 |
- **Developed by:** NVIDIA, Sana
|