import torch import torch.nn as nn import torch.nn.functional as F class ResidualRenderBlock(nn.Module): def __init__(self, dim): super().__init__() self.block = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=3, padding=1), nn.GroupNorm(8, dim), nn.SiLU(), nn.Conv2d(dim, dim, kernel_size=3, padding=1), nn.GroupNorm(8, dim) ) def forward(self, x): return x + self.block(x) class RenderEncoder(nn.Module): def __init__(self, encoder_type="1d", in_channels=768, out_channels=3): super().__init__() self.encoder_type = encoder_type if encoder_type == "1d": self.model = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1), nn.Sigmoid() ) elif encoder_type == "residual": self.model = ResidualBlockRender(in_channels, out_channels) elif encoder_type == "expressive": mid_channels = 256 self.model = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.GroupNorm(8, mid_channels), nn.SiLU(), ResidualRenderBlock(mid_channels), ResidualRenderBlock(mid_channels), ResidualRenderBlock(mid_channels), nn.Conv2d(mid_channels, out_channels, kernel_size=1), nn.Sigmoid() ) else: raise ValueError(f"Unknown encoder_type '{encoder_type}'. Use '1d', 'residual', or 'expressive'.") def forward(self, x): return self.model(x) class ResidualBlockRender(nn.Module): def __init__(self, in_channels=768, out_channels=3): super().__init__() self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.relu2 = nn.ReLU() self.conv3 = nn.Conv2d(256, out_channels, kernel_size=1) self.out = nn.Sigmoid() if in_channels != out_channels: self.residual_proj = nn.Conv2d(in_channels, out_channels, kernel_size=1) else: self.residual_proj = nn.Identity() def forward(self, x): residual = self.residual_proj(x) h = self.relu1(self.conv1(x)) h = self.relu2(self.conv2(h)) h = self.conv3(h) h = h + residual return self.out(h) def load_render_encoder(checkpoint_path, device='cpu'): """Load standalone RenderEncoder from checkpoint""" checkpoint = torch.load(checkpoint_path, map_location=device) config = checkpoint['model_config'] model = RenderEncoder( encoder_type=config['encoder_type'], in_channels=config['in_channels'], out_channels=config['out_channels'] ) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() print(f"Loaded RenderEncoder: {config}") return model