Other
Transformers
Safetensors
ldf_motion
feature-extraction
text-to-motion
motion-generation
diffusion-forcing
humanml3d
computer-animation
custom_code
Instructions to use ShandaAI/FloodDiffusionTiny with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ShandaAI/FloodDiffusionTiny with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ShandaAI/FloodDiffusionTiny", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from .tools.wan_vae_1d import WanVAE_ | |
| class VAEWanModel(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| mean_path=None, | |
| std_path=None, | |
| z_dim=256, | |
| dim=160, | |
| dec_dim=512, | |
| num_res_blocks=1, | |
| dropout=0.0, | |
| dim_mult=[1, 1, 1], | |
| temperal_downsample=[True, True], | |
| vel_window=[0, 0], | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.mean_path = mean_path | |
| self.std_path = std_path | |
| self.input_dim = input_dim | |
| self.z_dim = z_dim | |
| self.dim = dim | |
| self.dec_dim = dec_dim | |
| self.num_res_blocks = num_res_blocks | |
| self.dropout = dropout | |
| self.dim_mult = dim_mult | |
| self.temperal_downsample = temperal_downsample | |
| self.vel_window = vel_window | |
| self.RECONS_LOSS = nn.SmoothL1Loss() | |
| self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0) | |
| self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5) | |
| self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6) | |
| if self.mean_path is not None: | |
| self.register_buffer( | |
| "mean", torch.from_numpy(np.load(self.mean_path)).float() | |
| ) | |
| else: | |
| self.register_buffer("mean", torch.zeros(input_dim)) | |
| if self.std_path is not None: | |
| self.register_buffer( | |
| "std", torch.from_numpy(np.load(self.std_path)).float() | |
| ) | |
| else: | |
| self.register_buffer("std", torch.ones(input_dim)) | |
| self.model = WanVAE_( | |
| input_dim=self.input_dim, | |
| dim=self.dim, | |
| dec_dim=self.dec_dim, | |
| z_dim=self.z_dim, | |
| dim_mult=self.dim_mult, | |
| num_res_blocks=self.num_res_blocks, | |
| temperal_downsample=self.temperal_downsample, | |
| dropout=self.dropout, | |
| ) | |
| downsample_factor = 1 | |
| for flag in self.temperal_downsample: | |
| if flag: | |
| downsample_factor *= 2 | |
| self.downsample_factor = downsample_factor | |
| def preprocess(self, x): | |
| # (bs, T, C) -> (bs, C, T) | |
| x = x.permute(0, 2, 1) | |
| return x | |
| def postprocess(self, x): | |
| # (bs, C, T) -> (bs, T, C) | |
| x = x.permute(0, 2, 1) | |
| return x | |
| def forward(self, x): | |
| features = x["feature"] | |
| feature_length = x["feature_length"] | |
| features = (features - self.mean) / self.std | |
| # create mask based on feature_length | |
| batch_size, seq_len = features.shape[:2] | |
| mask = torch.zeros( | |
| batch_size, seq_len, dtype=torch.bool, device=features.device | |
| ) | |
| for i in range(batch_size): | |
| mask[i, : feature_length[i]] = True | |
| x_in = self.preprocess(features) # (bs, input_dim, T) | |
| mu, log_var = self.model.encode( | |
| x_in, scale=[0, 1], return_dist=True | |
| ) # (bs, z_dim, T) | |
| z = self.model.reparameterize(mu, log_var) | |
| x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T) | |
| x_out = self.postprocess(x_decoder) # (bs, T, input_dim) | |
| if x_out.size(1) != features.size(1): | |
| min_len = min(x_out.size(1), features.size(1)) | |
| x_out = x_out[:, :min_len, :] | |
| features = features[:, :min_len, :] | |
| mask = mask[:, :min_len] | |
| mask_expanded = mask.unsqueeze(-1) | |
| x_out_masked = x_out * mask_expanded | |
| features_masked = features * mask_expanded | |
| loss_recons = self.RECONS_LOSS(x_out_masked, features_masked) | |
| vel_start = self.vel_window[0] | |
| vel_end = self.vel_window[1] | |
| loss_vel = self.RECONS_LOSS( | |
| x_out_masked[..., vel_start:vel_end], | |
| features_masked[..., vel_start:vel_end], | |
| ) | |
| # Compute KL divergence loss | |
| # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | |
| # log_var = log(sigma^2), so we can use it directly | |
| # Build mask for latent space | |
| T_latent = mu.size(2) | |
| mask_downsampled = torch.zeros( | |
| batch_size, T_latent, dtype=torch.bool, device=features.device | |
| ) | |
| for i in range(batch_size): | |
| latent_length = ( | |
| feature_length[i] + self.downsample_factor - 1 | |
| ) // self.downsample_factor | |
| mask_downsampled[i, :latent_length] = True | |
| mask_latent = mask_downsampled.unsqueeze(1) # (B, 1, T_latent) | |
| # Compute KL loss per element | |
| kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) | |
| # Apply mask: only compute KL loss for valid timesteps | |
| kl_masked = kl_per_element * mask_latent | |
| # Sum over all dimensions and normalize by the number of valid elements | |
| kl_loss = torch.sum(kl_masked) / ( | |
| torch.sum(mask_downsampled) * mu.size(1) | |
| ) # normalize by valid timesteps * latent_dim | |
| # Total loss | |
| total_loss = ( | |
| self.LAMBDA_FEATURE * loss_recons | |
| + self.LAMBDA_VELOCITY * loss_vel | |
| + self.LAMBDA_KL * kl_loss | |
| ) | |
| loss_dict = {} | |
| loss_dict["total"] = total_loss | |
| loss_dict["recons"] = loss_recons | |
| loss_dict["velocity"] = loss_vel | |
| loss_dict["kl"] = kl_loss | |
| return loss_dict | |
| def encode(self, x): | |
| x = (x - self.mean) / self.std | |
| x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T) | |
| mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T) | |
| mu = self.postprocess(mu) # (bs, T, z_dim) | |
| return mu | |
| def decode(self, mu): | |
| mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T) | |
| x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T) | |
| x_out = self.postprocess(x_decoder) # (bs, T, input_dim) | |
| x_out = x_out * self.std + self.mean | |
| return x_out | |
| def stream_encode(self, x, first_chunk=True): | |
| x = (x - self.mean) / self.std | |
| x_in = self.preprocess(x) # (bs, input_dim, T) | |
| mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1]) | |
| mu = self.postprocess(mu) # (bs, T, z_dim) | |
| return mu | |
| def stream_decode(self, mu, first_chunk=True): | |
| mu_in = self.preprocess(mu) # (bs, z_dim, T) | |
| x_decoder = self.model.stream_decode( | |
| mu_in, first_chunk=first_chunk, scale=[0, 1] | |
| ) | |
| x_out = self.postprocess(x_decoder) # (bs, T, input_dim) | |
| x_out = x_out * self.std + self.mean | |
| return x_out | |
| def clear_cache(self): | |
| self.model.clear_cache() | |
| def generate(self, x): | |
| features = x["feature"] | |
| feature_length = x["feature_length"] | |
| y_hat = self.decode(self.encode(features)) | |
| y_hat_out = [] | |
| for i in range(y_hat.shape[0]): | |
| # cut off the padding and align lengths | |
| valid_len = ( | |
| feature_length[i] - 1 | |
| ) // self.downsample_factor * self.downsample_factor + 1 | |
| # Make sure both have the same length (take minimum) | |
| y_hat_out.append(y_hat[i, :valid_len, :]) | |
| out = {} | |
| out["generated"] = y_hat_out | |
| return out | |