| | |
| | |
| | import math |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from diffusers.configuration_utils import ConfigMixin |
| | from diffusers.loaders.single_file_model import FromOriginalModelMixin |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| |
|
| | def fp16_clamp(x): |
| | if x.dtype == torch.float16 and torch.isinf(x).any(): |
| | clamp = torch.finfo(x.dtype).max - 1000 |
| | x = torch.clamp(x, min=-clamp, max=clamp) |
| | return x |
| |
|
| |
|
| | def init_weights(m): |
| | if isinstance(m, T5LayerNorm): |
| | nn.init.ones_(m.weight) |
| | elif isinstance(m, T5FeedForward): |
| | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) |
| | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) |
| | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) |
| | elif isinstance(m, T5Attention): |
| | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) |
| | nn.init.normal_(m.k.weight, std=m.dim**-0.5) |
| | nn.init.normal_(m.v.weight, std=m.dim**-0.5) |
| | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) |
| | elif isinstance(m, T5RelativeEmbedding): |
| | nn.init.normal_( |
| | m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) |
| |
|
| |
|
| | class GELU(nn.Module): |
| | def forward(self, x): |
| | return 0.5 * x * (1.0 + torch.tanh( |
| | math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) |
| |
|
| |
|
| | class T5LayerNorm(nn.Module): |
| | def __init__(self, dim, eps=1e-6): |
| | super(T5LayerNorm, self).__init__() |
| | self.dim = dim |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + |
| | self.eps) |
| | if self.weight.dtype in [torch.float16, torch.bfloat16]: |
| | x = x.type_as(self.weight) |
| | return self.weight * x |
| |
|
| |
|
| | class T5Attention(nn.Module): |
| | def __init__(self, dim, dim_attn, num_heads, dropout=0.1): |
| | assert dim_attn % num_heads == 0 |
| | super(T5Attention, self).__init__() |
| | self.dim = dim |
| | self.dim_attn = dim_attn |
| | self.num_heads = num_heads |
| | self.head_dim = dim_attn // num_heads |
| |
|
| | |
| | self.q = nn.Linear(dim, dim_attn, bias=False) |
| | self.k = nn.Linear(dim, dim_attn, bias=False) |
| | self.v = nn.Linear(dim, dim_attn, bias=False) |
| | self.o = nn.Linear(dim_attn, dim, bias=False) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x, context=None, mask=None, pos_bias=None): |
| | """ |
| | x: [B, L1, C]. |
| | context: [B, L2, C] or None. |
| | mask: [B, L2] or [B, L1, L2] or None. |
| | """ |
| | |
| | context = x if context is None else context |
| | b, n, c = x.size(0), self.num_heads, self.head_dim |
| |
|
| | |
| | q = self.q(x).view(b, -1, n, c) |
| | k = self.k(context).view(b, -1, n, c) |
| | v = self.v(context).view(b, -1, n, c) |
| |
|
| | |
| | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) |
| | if pos_bias is not None: |
| | attn_bias += pos_bias |
| | if mask is not None: |
| | assert mask.ndim in [2, 3] |
| | mask = mask.view(b, 1, 1, |
| | -1) if mask.ndim == 2 else mask.unsqueeze(1) |
| | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) |
| |
|
| | |
| | attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias |
| | attn = F.softmax(attn.float(), dim=-1).type_as(attn) |
| | x = torch.einsum('bnij,bjnc->binc', attn, v) |
| |
|
| | |
| | x = x.reshape(b, -1, n * c) |
| | x = self.o(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class T5FeedForward(nn.Module): |
| |
|
| | def __init__(self, dim, dim_ffn, dropout=0.1): |
| | super(T5FeedForward, self).__init__() |
| | self.dim = dim |
| | self.dim_ffn = dim_ffn |
| |
|
| | |
| | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) |
| | self.fc1 = nn.Linear(dim, dim_ffn, bias=False) |
| | self.fc2 = nn.Linear(dim_ffn, dim, bias=False) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) * self.gate(x) |
| | x = self.dropout(x) |
| | x = self.fc2(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class T5SelfAttention(nn.Module): |
| | def __init__(self, |
| | dim, |
| | dim_attn, |
| | dim_ffn, |
| | num_heads, |
| | num_buckets, |
| | shared_pos=True, |
| | dropout=0.1): |
| | super(T5SelfAttention, self).__init__() |
| | self.dim = dim |
| | self.dim_attn = dim_attn |
| | self.dim_ffn = dim_ffn |
| | self.num_heads = num_heads |
| | self.num_buckets = num_buckets |
| | self.shared_pos = shared_pos |
| |
|
| | |
| | self.norm1 = T5LayerNorm(dim) |
| | self.attn = T5Attention(dim, dim_attn, num_heads, dropout) |
| | self.norm2 = T5LayerNorm(dim) |
| | self.ffn = T5FeedForward(dim, dim_ffn, dropout) |
| | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( |
| | num_buckets, num_heads, bidirectional=True) |
| |
|
| | def forward(self, x, mask=None, pos_bias=None): |
| | e = pos_bias if self.shared_pos else self.pos_embedding( |
| | x.size(1), x.size(1)) |
| | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) |
| | x = fp16_clamp(x + self.ffn(self.norm2(x))) |
| | return x |
| |
|
| |
|
| | class T5CrossAttention(nn.Module): |
| | def __init__(self, |
| | dim, |
| | dim_attn, |
| | dim_ffn, |
| | num_heads, |
| | num_buckets, |
| | shared_pos=True, |
| | dropout=0.1): |
| | super(T5CrossAttention, self).__init__() |
| | self.dim = dim |
| | self.dim_attn = dim_attn |
| | self.dim_ffn = dim_ffn |
| | self.num_heads = num_heads |
| | self.num_buckets = num_buckets |
| | self.shared_pos = shared_pos |
| |
|
| | |
| | self.norm1 = T5LayerNorm(dim) |
| | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) |
| | self.norm2 = T5LayerNorm(dim) |
| | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) |
| | self.norm3 = T5LayerNorm(dim) |
| | self.ffn = T5FeedForward(dim, dim_ffn, dropout) |
| | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( |
| | num_buckets, num_heads, bidirectional=False) |
| |
|
| | def forward(self, |
| | x, |
| | mask=None, |
| | encoder_states=None, |
| | encoder_mask=None, |
| | pos_bias=None): |
| | e = pos_bias if self.shared_pos else self.pos_embedding( |
| | x.size(1), x.size(1)) |
| | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) |
| | x = fp16_clamp(x + self.cross_attn( |
| | self.norm2(x), context=encoder_states, mask=encoder_mask)) |
| | x = fp16_clamp(x + self.ffn(self.norm3(x))) |
| | return x |
| |
|
| |
|
| | class T5RelativeEmbedding(nn.Module): |
| | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): |
| | super(T5RelativeEmbedding, self).__init__() |
| | self.num_buckets = num_buckets |
| | self.num_heads = num_heads |
| | self.bidirectional = bidirectional |
| | self.max_dist = max_dist |
| |
|
| | |
| | self.embedding = nn.Embedding(num_buckets, num_heads) |
| |
|
| | def forward(self, lq, lk): |
| | device = self.embedding.weight.device |
| | |
| | |
| | if torch.device(type="meta") != device: |
| | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ |
| | torch.arange(lq, device=device).unsqueeze(1) |
| | else: |
| | rel_pos = torch.arange(lk).unsqueeze(0) - \ |
| | torch.arange(lq).unsqueeze(1) |
| | rel_pos = self._relative_position_bucket(rel_pos) |
| | rel_pos_embeds = self.embedding(rel_pos) |
| | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( |
| | 0) |
| | return rel_pos_embeds.contiguous() |
| |
|
| | def _relative_position_bucket(self, rel_pos): |
| | |
| | if self.bidirectional: |
| | num_buckets = self.num_buckets // 2 |
| | rel_buckets = (rel_pos > 0).long() * num_buckets |
| | rel_pos = torch.abs(rel_pos) |
| | else: |
| | num_buckets = self.num_buckets |
| | rel_buckets = 0 |
| | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) |
| |
|
| | |
| | max_exact = num_buckets // 2 |
| | rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / |
| | math.log(self.max_dist / max_exact) * |
| | (num_buckets - max_exact)).long() |
| | rel_pos_large = torch.min( |
| | rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) |
| | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) |
| | return rel_buckets |
| |
|
| | class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): |
| | def __init__(self, |
| | vocab, |
| | dim, |
| | dim_attn, |
| | dim_ffn, |
| | num_heads, |
| | num_layers, |
| | num_buckets, |
| | shared_pos=True, |
| | dropout=0.1): |
| | super(WanT5EncoderModel, self).__init__() |
| | self.dim = dim |
| | self.dim_attn = dim_attn |
| | self.dim_ffn = dim_ffn |
| | self.num_heads = num_heads |
| | self.num_layers = num_layers |
| | self.num_buckets = num_buckets |
| | self.shared_pos = shared_pos |
| |
|
| | |
| | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ |
| | else nn.Embedding(vocab, dim) |
| | self.pos_embedding = T5RelativeEmbedding( |
| | num_buckets, num_heads, bidirectional=True) if shared_pos else None |
| | self.dropout = nn.Dropout(dropout) |
| | self.blocks = nn.ModuleList([ |
| | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, |
| | shared_pos, dropout) for _ in range(num_layers) |
| | ]) |
| | self.norm = T5LayerNorm(dim) |
| |
|
| | |
| | self.apply(init_weights) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| | x = self.token_embedding(input_ids) |
| | x = self.dropout(x) |
| | e = self.pos_embedding(x.size(1), |
| | x.size(1)) if self.shared_pos else None |
| | for block in self.blocks: |
| | x = block(x, attention_mask, pos_bias=e) |
| | x = self.norm(x) |
| | x = self.dropout(x) |
| | return (x, ) |
| | |
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16): |
| | def filter_kwargs(cls, kwargs): |
| | import inspect |
| | sig = inspect.signature(cls.__init__) |
| | valid_params = set(sig.parameters.keys()) - {'self', 'cls'} |
| | filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} |
| | return filtered_kwargs |
| | |
| | |
| | import os |
| | from huggingface_hub import hf_hub_download |
| | |
| | |
| | if not os.path.exists(pretrained_model_path): |
| | try: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print(f"Downloading text encoder from {pretrained_model_path}...") |
| | pretrained_model_path = hf_hub_download(repo_id=pretrained_model_path, filename="models_t5_umt5-xxl-enc-bf16.pth") |
| | except Exception as e: |
| | print(f"Failed to download Text Encoder from HF: {e}") |
| | |
| | |
| | if low_cpu_mem_usage: |
| | try: |
| | import re |
| |
|
| | from diffusers import __version__ as diffusers_version |
| | if diffusers_version >= "0.33.0": |
| | from diffusers.models.model_loading_utils import \ |
| | load_model_dict_into_meta |
| | else: |
| | from diffusers.models.modeling_utils import \ |
| | load_model_dict_into_meta |
| | from diffusers.utils import is_accelerate_available |
| | if is_accelerate_available(): |
| | import accelerate |
| | |
| | |
| | with accelerate.init_empty_weights(): |
| | model = cls(**filter_kwargs(cls, additional_kwargs)) |
| |
|
| | param_device = "cpu" |
| | if pretrained_model_path.endswith(".safetensors"): |
| | from safetensors.torch import load_file |
| | state_dict = load_file(pretrained_model_path) |
| | else: |
| | state_dict = torch.load(pretrained_model_path, map_location="cpu") |
| |
|
| | if diffusers_version >= "0.33.0": |
| | |
| | |
| | load_model_dict_into_meta( |
| | model, |
| | state_dict, |
| | dtype=torch_dtype, |
| | model_name_or_path=pretrained_model_path, |
| | ) |
| | else: |
| | |
| | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |
| | if len(missing_keys) > 0: |
| | raise ValueError( |
| | f"Cannot load {cls} from {pretrained_model_path} because the following keys are" |
| | f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" |
| | " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" |
| | " those weights or else make sure your checkpoint file is correct." |
| | ) |
| |
|
| | unexpected_keys = load_model_dict_into_meta( |
| | model, |
| | state_dict, |
| | device=param_device, |
| | dtype=torch_dtype, |
| | model_name_or_path=pretrained_model_path, |
| | ) |
| |
|
| | if cls._keys_to_ignore_on_load_unexpected is not None: |
| | for pat in cls._keys_to_ignore_on_load_unexpected: |
| | unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] |
| |
|
| | if len(unexpected_keys) > 0: |
| | print( |
| | f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" |
| | ) |
| | |
| | return model |
| | except Exception as e: |
| | print( |
| | f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." |
| | ) |
| | |
| | model = cls(**filter_kwargs(cls, additional_kwargs)) |
| | if pretrained_model_path.endswith(".safetensors"): |
| | from safetensors.torch import load_file, safe_open |
| | state_dict = load_file(pretrained_model_path) |
| | else: |
| | state_dict = torch.load(pretrained_model_path, map_location="cpu") |
| | m, u = model.load_state_dict(state_dict, strict=False) |
| | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") |
| | print(m, u) |
| |
|
| | model = model.to(torch_dtype) |
| | return model |