File size: 9,430 Bytes
1d5b5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
Camera Encoder for RT (Rotation-Translation) matrix injection.
CAM paper [2506.03141]: Maps camera pose to DiT hidden dimension for spatial attention conditioning.

Design notes:
- Primary purpose: dimension alignment (action/RT [12] -> DiT hidden D). One-layer MLP is sufficient; no need for deeper encoder.
- Per-frame MLP: each frame's RT [12] -> hidden_size independently (no temporal context).
- Optional zero-init scale: conditioning starts weak (scale=0) and grows with training for stability.
- shallow: single Linear(12, D) to match CAM "single-layer MLP" wording (ablation). When shallow=True
  and separate_t_r=False, the encoder is exactly one layer (merged MLP): RT [12] -> D.
- separate_t_r: encode translation (t) and rotation (R) with separate MLPs then add, for scale balance.
  Use shallow=True and separate_t_r=False for merged single-layer MLP (RT not split).
- explicit_yaw: add a signed yaw scalar branch (Z-only) so CW/CCW are explicitly encoded; helps when
  the model is insensitive to rotation direction (Zhou et al. CVPR 2019, sign continuity).
- sincos_yaw: add [cos(yaw), sin(yaw)] branch (2D) for direction; sin carries sign explicitly.

Design limitations / caveats:
- No input normalization: t (translation) and R (rotation) have different scales; one Linear(12,D) may be
  sensitive to units. Caller should use consistent RT scale or relative RT; optional input LayerNorm not implemented.
- Per-frame only: no temporal context (each frame encoded independently). Fine for CAM ablation; temporal modeling not supported.
- 16-dim input: layout for flattened 4x4 is unspecified; yaw branches are disabled. Prefer 12-dim in practice.
- explicit_yaw and sincos_yaw can both be True (redundant encoding of yaw); usually use one.
"""

import torch
import torch.nn as nn
from typing import Optional

# For Z-only rotation: yaw = atan2(R_21, R_11); R is row-major [R_11,R_12,R_13, R_21,...]
def _yaw_from_rt_12(rt: torch.Tensor) -> torch.Tensor:
    """rt [..., 12] -> yaw in [-1, 1] (normalized by pi)."""
    R11 = rt[..., 3]
    R21 = rt[..., 6]
    yaw_rad = torch.atan2(R21, R11)
    return yaw_rad / 3.141592653589793


def _sincos_yaw_from_rt_12(rt: torch.Tensor) -> torch.Tensor:
    """rt [..., 12] -> [..., 2] (cos(yaw), sin(yaw))."""
    R11 = rt[..., 3]
    R21 = rt[..., 6]
    yaw_rad = torch.atan2(R21, R11)
    return torch.stack([torch.cos(yaw_rad), torch.sin(yaw_rad)], dim=-1)


class CameraEncoder(nn.Module):
    """
    Encode RT matrices (camera pose) to DiT hidden dimension.
    
    Input: rt_matrices [B, F, 12] or [B, F, 16]
      - 12: [t_x, t_y, t_z, R_11..R_33] (3 translation + 9 rotation), R row-major. No input normalization:
        t and R often differ in scale (e.g. t in meters, R in [-1,1]); single Linear(12,D) may be sensitive to units.
      - 16: 4x4 matrix flattened (layout/order unspecified; yaw branches disabled when rt_dim=16).
    Output: camera_emb [B, F, D] where D = hidden_size (scaled by learnable scale, default 0-init).
    """
    
    def __init__(
        self,
        rt_dim: int = 12,
        hidden_size: int = 5120,
        mlp_hidden_mult: int = 4,
        eps: float = 1e-6,
        zero_init_scale: bool = False,
        full_zero_init: bool = False,
        shallow: bool = False,
        separate_t_r: bool = False,
        explicit_yaw: bool = False,
        sincos_yaw: bool = False,
        conditioning_scale: float = 1.0,
        r_mlp_no_layernorm: bool = False,
    ):
        super().__init__()
        self.rt_dim = rt_dim
        self.hidden_size = hidden_size
        self.zero_init_scale = zero_init_scale
        self.full_zero_init = full_zero_init
        self.shallow = shallow
        self.separate_t_r = separate_t_r
        self.explicit_yaw = explicit_yaw and rt_dim == 12
        self.sincos_yaw = sincos_yaw and rt_dim == 12
        self.conditioning_scale = float(conditioning_scale)
        self.r_mlp_no_layernorm = r_mlp_no_layernorm and separate_t_r
        dtype = torch.get_default_dtype()
        
        if separate_t_r:
            # Plan B: separate t (3) and R (9) encoders for scale balance; only for rt_dim=12.
            assert rt_dim == 12, "separate_t_r only supported for rt_dim=12"
            mid = max(hidden_size // 2, 256)
            self.t_mlp = nn.Sequential(
                nn.Linear(3, mid),
                nn.LayerNorm(mid, eps=eps),
                nn.GELU(),
                nn.Linear(mid, hidden_size),
                nn.LayerNorm(hidden_size, eps=eps),
            )
            if r_mlp_no_layernorm:
                # No LayerNorm on R so sign of R_12/R_21 (yaw direction) is not normalized away.
                self.r_mlp = nn.Sequential(
                    nn.Linear(9, mid),
                    nn.GELU(),
                    nn.Linear(mid, hidden_size),
                )
            else:
                self.r_mlp = nn.Sequential(
                    nn.Linear(9, mid),
                    nn.LayerNorm(mid, eps=eps),
                    nn.GELU(),
                    nn.Linear(mid, hidden_size),
                    nn.LayerNorm(hidden_size, eps=eps),
                )
            self.mlp = None
        elif shallow:
            # Merged single-layer MLP: one Linear(rt_dim, hidden_size), no separate t/R.
            self.mlp = nn.Linear(rt_dim, hidden_size)
            assert isinstance(self.mlp, nn.Linear), "shallow path must be exactly one Linear layer"
            if full_zero_init:
                nn.init.zeros_(self.mlp.weight)
                nn.init.zeros_(self.mlp.bias)
        else:
            mid_dim = hidden_size * mlp_hidden_mult
            self.mlp = nn.Sequential(
                nn.Linear(rt_dim, mid_dim),
                nn.LayerNorm(mid_dim, eps=eps),
                nn.GELU(),
                nn.Linear(mid_dim, mid_dim),
                nn.LayerNorm(mid_dim, eps=eps),
                nn.GELU(),
                nn.Linear(mid_dim, hidden_size),
                nn.LayerNorm(hidden_size, eps=eps),
            )
        
        if self.explicit_yaw:
            self.yaw_embed = nn.Linear(1, hidden_size)
        else:
            self.yaw_embed = None
        if self.sincos_yaw:
            self.sincos_embed = nn.Linear(2, hidden_size)
        else:
            self.sincos_embed = None
        
        # Learnable scale: when zero_init_scale=True, init to 0 so conditioning grows with training (stable).
        # When full_zero_init=True, skip scale (GF-ICL style: Linear output directly, no extra scale).
        if full_zero_init:
            self.scale = None  # no scale, use 1.0 in forward
        else:
            self.scale = nn.Parameter(torch.zeros(1) if zero_init_scale else torch.ones(1))

    def is_single_layer_merged(self) -> bool:
        """True if encoder is exactly one Linear(12, D) with no separate t/R (merged MLP)."""
        return self.shallow and not self.separate_t_r and self.mlp is not None and isinstance(self.mlp, nn.Linear)

    def forward(self, rt_matrices: torch.Tensor) -> torch.Tensor:
        """
        Args:
            rt_matrices: [B, F, 12] or [B, F, 16]
        Returns:
            camera_emb: [B, F, hidden_size], scaled by self.scale.
        """
        d = rt_matrices.dtype
        if self.separate_t_r:
            t = rt_matrices[..., :3].to(d)
            r = rt_matrices[..., 3:12].to(d)
            out = self.t_mlp(t) + self.r_mlp(r)
        else:
            out = self.mlp(rt_matrices.to(d))
        if self.yaw_embed is not None and rt_matrices.shape[-1] >= 12:
            yaw_norm = _yaw_from_rt_12(rt_matrices[..., :12]).unsqueeze(-1).to(d)
            out = out + self.yaw_embed(yaw_norm)
        if self.sincos_embed is not None and rt_matrices.shape[-1] >= 12:
            sincos = _sincos_yaw_from_rt_12(rt_matrices[..., :12]).to(d)
            out = out + self.sincos_embed(sincos)
        scale = self.scale.to(d) if self.scale is not None else torch.ones(1, device=out.device, dtype=out.dtype)
        return out * scale * self.conditioning_scale


def expand_camera_emb_to_tokens(
    camera_emb: torch.Tensor,
    num_frames: int,
    h: int,
    w: int,
) -> torch.Tensor:
    """
    Expand per-frame camera_emb [B, F, D] to per-token [B, N, D]
    where N = F * h * w (tokens ordered as frame0_all_patches, frame1_all_patches, ...).

    Dimension alignment (与 DiT patchify 一致):
    - Encoder 输出: 每帧一个向量 [B, F, D],即相当于 [B, F, 1, D](F 帧每帧 1 个 embedding)。
    - 对齐方式: 在空间维上把该 1 重复 H×W 次,得到 [B, F, h*w, D],再展平为 [B, F*h*w, D]。
    - Token 顺序: frame0 的 h*w 个 token 共用 frame0 的 camera_emb,frame1 的 h*w 个 token 共用 frame1 的 camera_emb,与
      wan_video_dit patchify 的 rearrange(..., 'b c f h w -> b (f h w) c') 顺序一致(帧优先,再空间)。

    Args:
        camera_emb: [B, F, D]
        num_frames: F (must equal camera_emb.shape[1]; used for assertion only).
        h, w: spatial grid (patches per frame)
    Returns:
        [B, F*h*w, D]
    """
    B, F, D = camera_emb.shape
    if F != num_frames:
        raise ValueError(f"expand_camera_emb_to_tokens: camera_emb has F={F}, num_frames={num_frames}")
    # [B, F, D] -> [B, F, 1, D] (每帧 1 个) -> expand 到 [B, F, h*w, D] (每帧重复 H×W 次) -> [B, F*h*w, D]
    return camera_emb.unsqueeze(2).expand(B, F, h * w, D).reshape(B, F * h * w, D)