File size: 3,618 Bytes
1d5b5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import torch.nn.functional as F


class SpatialGridMemory(nn.Module):
    def __init__(self, dim: int, grid_size: int = 8, num_tokens: int = 64):
        super().__init__()
        self.dim = int(dim)
        self.grid_size = int(grid_size)
        self.num_tokens = int(num_tokens)
        g2 = self.grid_size * self.grid_size
        # Keep key name aligned with ckpt loading in loop_utils.py (spatial_to_tokens).
        self.spatial_to_tokens = nn.Parameter(torch.zeros(g2, self.num_tokens))
        nn.init.normal_(self.spatial_to_tokens, std=0.02)

    @property
    def mix(self):
        # Backward compatibility for code that referenced the old attribute name.
        return self.spatial_to_tokens

    def forward(self, x_context: torch.Tensor, num_context_frames: int, h: int, w: int):
        # x_context: (B, K*H*W, D)
        if x_context is None or x_context.ndim != 3:
            return x_context
        b, n, d = x_context.shape
        if d != self.dim:
            raise ValueError(f"SpatialGridMemory dim mismatch: x={d} module={self.dim}")
        k = max(int(num_context_frames), 1)
        spatial = int(h) * int(w)
        if n != k * spatial:
            # Best effort fallback: treat x as a flat token map and pool directly.
            x_mean = x_context
        else:
            x_mean = x_context.reshape(b, k, spatial, d).mean(dim=1)  # (B, S, D)

        g2 = self.grid_size * self.grid_size
        pooled = F.adaptive_avg_pool1d(x_mean.transpose(1, 2), g2).transpose(1, 2)  # (B, G2, D)
        mix = torch.softmax(self.spatial_to_tokens, dim=0)  # (G2, M)
        mem = torch.einsum("bgd,gm->bmd", pooled, mix)  # (B, M, D)
        return mem

    def load_state_dict(self, state_dict, strict: bool = True):
        # Compatibility:
        # - old local key: mix
        # - current/baseline key: spatial_to_tokens
        sd = dict(state_dict)
        if "mix" in sd and "spatial_to_tokens" not in sd:
            sd["spatial_to_tokens"] = sd.pop("mix")
        # Ignore deprecated projection keys from prior experiments.
        sd.pop("out.weight", None)
        sd.pop("out.bias", None)
        return super().load_state_dict(sd, strict=False if not strict else strict)


class SpatialCrossAttnReadout(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=int(dim), num_heads=int(num_heads), batch_first=True)
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, x_target: torch.Tensor, mem_tokens: torch.Tensor):
        if x_target is None or mem_tokens is None:
            return x_target
        if x_target.numel() == 0 or mem_tokens.numel() == 0:
            return x_target
        delta, _ = self.attn(x_target, mem_tokens, mem_tokens, need_weights=False)
        return x_target + torch.tanh(self.gate) * delta


def apply_spatial_cross_attn_readout(x_target: torch.Tensor, mem_tokens: torch.Tensor, module: nn.Module = None):
    if module is None:
        module = SpatialCrossAttnReadout(dim=int(x_target.shape[-1]), num_heads=8).to(device=x_target.device, dtype=x_target.dtype)
    return module(x_target, mem_tokens)


def inject_spatial_memory(context: torch.Tensor, mem_tokens: torch.Tensor, mode: str = "concat_text"):
    mode = str(mode or "concat_text").lower()
    if mem_tokens is None or mode == "none":
        return context
    if context is None:
        return mem_tokens
    if mode in ("concat_text", "cross_attn_readout"):
        return torch.cat([context, mem_tokens], dim=1)
    return context