Spaces:
Running on Zero
Running on Zero
File size: 9,430 Bytes
1d5b5ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """
Camera Encoder for RT (Rotation-Translation) matrix injection.
CAM paper [2506.03141]: Maps camera pose to DiT hidden dimension for spatial attention conditioning.
Design notes:
- Primary purpose: dimension alignment (action/RT [12] -> DiT hidden D). One-layer MLP is sufficient; no need for deeper encoder.
- Per-frame MLP: each frame's RT [12] -> hidden_size independently (no temporal context).
- Optional zero-init scale: conditioning starts weak (scale=0) and grows with training for stability.
- shallow: single Linear(12, D) to match CAM "single-layer MLP" wording (ablation). When shallow=True
and separate_t_r=False, the encoder is exactly one layer (merged MLP): RT [12] -> D.
- separate_t_r: encode translation (t) and rotation (R) with separate MLPs then add, for scale balance.
Use shallow=True and separate_t_r=False for merged single-layer MLP (RT not split).
- explicit_yaw: add a signed yaw scalar branch (Z-only) so CW/CCW are explicitly encoded; helps when
the model is insensitive to rotation direction (Zhou et al. CVPR 2019, sign continuity).
- sincos_yaw: add [cos(yaw), sin(yaw)] branch (2D) for direction; sin carries sign explicitly.
Design limitations / caveats:
- No input normalization: t (translation) and R (rotation) have different scales; one Linear(12,D) may be
sensitive to units. Caller should use consistent RT scale or relative RT; optional input LayerNorm not implemented.
- Per-frame only: no temporal context (each frame encoded independently). Fine for CAM ablation; temporal modeling not supported.
- 16-dim input: layout for flattened 4x4 is unspecified; yaw branches are disabled. Prefer 12-dim in practice.
- explicit_yaw and sincos_yaw can both be True (redundant encoding of yaw); usually use one.
"""
import torch
import torch.nn as nn
from typing import Optional
# For Z-only rotation: yaw = atan2(R_21, R_11); R is row-major [R_11,R_12,R_13, R_21,...]
def _yaw_from_rt_12(rt: torch.Tensor) -> torch.Tensor:
"""rt [..., 12] -> yaw in [-1, 1] (normalized by pi)."""
R11 = rt[..., 3]
R21 = rt[..., 6]
yaw_rad = torch.atan2(R21, R11)
return yaw_rad / 3.141592653589793
def _sincos_yaw_from_rt_12(rt: torch.Tensor) -> torch.Tensor:
"""rt [..., 12] -> [..., 2] (cos(yaw), sin(yaw))."""
R11 = rt[..., 3]
R21 = rt[..., 6]
yaw_rad = torch.atan2(R21, R11)
return torch.stack([torch.cos(yaw_rad), torch.sin(yaw_rad)], dim=-1)
class CameraEncoder(nn.Module):
"""
Encode RT matrices (camera pose) to DiT hidden dimension.
Input: rt_matrices [B, F, 12] or [B, F, 16]
- 12: [t_x, t_y, t_z, R_11..R_33] (3 translation + 9 rotation), R row-major. No input normalization:
t and R often differ in scale (e.g. t in meters, R in [-1,1]); single Linear(12,D) may be sensitive to units.
- 16: 4x4 matrix flattened (layout/order unspecified; yaw branches disabled when rt_dim=16).
Output: camera_emb [B, F, D] where D = hidden_size (scaled by learnable scale, default 0-init).
"""
def __init__(
self,
rt_dim: int = 12,
hidden_size: int = 5120,
mlp_hidden_mult: int = 4,
eps: float = 1e-6,
zero_init_scale: bool = False,
full_zero_init: bool = False,
shallow: bool = False,
separate_t_r: bool = False,
explicit_yaw: bool = False,
sincos_yaw: bool = False,
conditioning_scale: float = 1.0,
r_mlp_no_layernorm: bool = False,
):
super().__init__()
self.rt_dim = rt_dim
self.hidden_size = hidden_size
self.zero_init_scale = zero_init_scale
self.full_zero_init = full_zero_init
self.shallow = shallow
self.separate_t_r = separate_t_r
self.explicit_yaw = explicit_yaw and rt_dim == 12
self.sincos_yaw = sincos_yaw and rt_dim == 12
self.conditioning_scale = float(conditioning_scale)
self.r_mlp_no_layernorm = r_mlp_no_layernorm and separate_t_r
dtype = torch.get_default_dtype()
if separate_t_r:
# Plan B: separate t (3) and R (9) encoders for scale balance; only for rt_dim=12.
assert rt_dim == 12, "separate_t_r only supported for rt_dim=12"
mid = max(hidden_size // 2, 256)
self.t_mlp = nn.Sequential(
nn.Linear(3, mid),
nn.LayerNorm(mid, eps=eps),
nn.GELU(),
nn.Linear(mid, hidden_size),
nn.LayerNorm(hidden_size, eps=eps),
)
if r_mlp_no_layernorm:
# No LayerNorm on R so sign of R_12/R_21 (yaw direction) is not normalized away.
self.r_mlp = nn.Sequential(
nn.Linear(9, mid),
nn.GELU(),
nn.Linear(mid, hidden_size),
)
else:
self.r_mlp = nn.Sequential(
nn.Linear(9, mid),
nn.LayerNorm(mid, eps=eps),
nn.GELU(),
nn.Linear(mid, hidden_size),
nn.LayerNorm(hidden_size, eps=eps),
)
self.mlp = None
elif shallow:
# Merged single-layer MLP: one Linear(rt_dim, hidden_size), no separate t/R.
self.mlp = nn.Linear(rt_dim, hidden_size)
assert isinstance(self.mlp, nn.Linear), "shallow path must be exactly one Linear layer"
if full_zero_init:
nn.init.zeros_(self.mlp.weight)
nn.init.zeros_(self.mlp.bias)
else:
mid_dim = hidden_size * mlp_hidden_mult
self.mlp = nn.Sequential(
nn.Linear(rt_dim, mid_dim),
nn.LayerNorm(mid_dim, eps=eps),
nn.GELU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim, eps=eps),
nn.GELU(),
nn.Linear(mid_dim, hidden_size),
nn.LayerNorm(hidden_size, eps=eps),
)
if self.explicit_yaw:
self.yaw_embed = nn.Linear(1, hidden_size)
else:
self.yaw_embed = None
if self.sincos_yaw:
self.sincos_embed = nn.Linear(2, hidden_size)
else:
self.sincos_embed = None
# Learnable scale: when zero_init_scale=True, init to 0 so conditioning grows with training (stable).
# When full_zero_init=True, skip scale (GF-ICL style: Linear output directly, no extra scale).
if full_zero_init:
self.scale = None # no scale, use 1.0 in forward
else:
self.scale = nn.Parameter(torch.zeros(1) if zero_init_scale else torch.ones(1))
def is_single_layer_merged(self) -> bool:
"""True if encoder is exactly one Linear(12, D) with no separate t/R (merged MLP)."""
return self.shallow and not self.separate_t_r and self.mlp is not None and isinstance(self.mlp, nn.Linear)
def forward(self, rt_matrices: torch.Tensor) -> torch.Tensor:
"""
Args:
rt_matrices: [B, F, 12] or [B, F, 16]
Returns:
camera_emb: [B, F, hidden_size], scaled by self.scale.
"""
d = rt_matrices.dtype
if self.separate_t_r:
t = rt_matrices[..., :3].to(d)
r = rt_matrices[..., 3:12].to(d)
out = self.t_mlp(t) + self.r_mlp(r)
else:
out = self.mlp(rt_matrices.to(d))
if self.yaw_embed is not None and rt_matrices.shape[-1] >= 12:
yaw_norm = _yaw_from_rt_12(rt_matrices[..., :12]).unsqueeze(-1).to(d)
out = out + self.yaw_embed(yaw_norm)
if self.sincos_embed is not None and rt_matrices.shape[-1] >= 12:
sincos = _sincos_yaw_from_rt_12(rt_matrices[..., :12]).to(d)
out = out + self.sincos_embed(sincos)
scale = self.scale.to(d) if self.scale is not None else torch.ones(1, device=out.device, dtype=out.dtype)
return out * scale * self.conditioning_scale
def expand_camera_emb_to_tokens(
camera_emb: torch.Tensor,
num_frames: int,
h: int,
w: int,
) -> torch.Tensor:
"""
Expand per-frame camera_emb [B, F, D] to per-token [B, N, D]
where N = F * h * w (tokens ordered as frame0_all_patches, frame1_all_patches, ...).
Dimension alignment (与 DiT patchify 一致):
- Encoder 输出: 每帧一个向量 [B, F, D],即相当于 [B, F, 1, D](F 帧每帧 1 个 embedding)。
- 对齐方式: 在空间维上把该 1 重复 H×W 次,得到 [B, F, h*w, D],再展平为 [B, F*h*w, D]。
- Token 顺序: frame0 的 h*w 个 token 共用 frame0 的 camera_emb,frame1 的 h*w 个 token 共用 frame1 的 camera_emb,与
wan_video_dit patchify 的 rearrange(..., 'b c f h w -> b (f h w) c') 顺序一致(帧优先,再空间)。
Args:
camera_emb: [B, F, D]
num_frames: F (must equal camera_emb.shape[1]; used for assertion only).
h, w: spatial grid (patches per frame)
Returns:
[B, F*h*w, D]
"""
B, F, D = camera_emb.shape
if F != num_frames:
raise ValueError(f"expand_camera_emb_to_tokens: camera_emb has F={F}, num_frames={num_frames}")
# [B, F, D] -> [B, F, 1, D] (每帧 1 个) -> expand 到 [B, F, h*w, D] (每帧重复 H×W 次) -> [B, F*h*w, D]
return camera_emb.unsqueeze(2).expand(B, F, h * w, D).reshape(B, F * h * w, D)
|