| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | torch._C._jit_set_profiling_executor(False) |
| | torch._C._jit_set_profiling_mode(False) |
| |
|
| |
|
| | |
| | def weights_init(m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.kaiming_normal_(m.weight.data) |
| | if m.bias is not None: |
| | nn.init.zeros_(m.bias.data) |
| |
|
| |
|
| | @torch.jit.script |
| | def fused_mean_variance(x, weight): |
| | mean = torch.sum(x * weight, dim=2, keepdim=True) |
| | var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True) |
| | return mean, var |
| |
|
| |
|
| | class GeneralRenderingNetwork(nn.Module): |
| | """ |
| | This model is not sensitive to finetuning |
| | """ |
| |
|
| | def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True): |
| | super(GeneralRenderingNetwork, self).__init__() |
| |
|
| | self.in_geometry_feat_ch = in_geometry_feat_ch |
| | self.in_rendering_feat_ch = in_rendering_feat_ch |
| | self.anti_alias_pooling = anti_alias_pooling |
| |
|
| | if self.anti_alias_pooling: |
| | self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) |
| | activation_func = nn.ELU(inplace=True) |
| |
|
| | self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16), |
| | activation_func, |
| | nn.Linear(16, in_rendering_feat_ch + 3), |
| | activation_func) |
| |
|
| | self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64), |
| | activation_func, |
| | nn.Linear(64, 32), |
| | activation_func) |
| |
|
| | self.vis_fc = nn.Sequential(nn.Linear(32, 32), |
| | activation_func, |
| | nn.Linear(32, 33), |
| | activation_func, |
| | ) |
| |
|
| | self.vis_fc2 = nn.Sequential(nn.Linear(32, 32), |
| | activation_func, |
| | nn.Linear(32, 1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16), |
| | activation_func, |
| | nn.Linear(16, 8), |
| | activation_func, |
| | nn.Linear(8, 1)) |
| |
|
| | self.base_fc.apply(weights_init) |
| | self.vis_fc2.apply(weights_init) |
| | self.vis_fc.apply(weights_init) |
| | self.rgb_fc.apply(weights_init) |
| |
|
| | def forward(self, geometry_feat, rgb_feat, ray_diff, mask): |
| | ''' |
| | :param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat] |
| | :param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat] |
| | :param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions, |
| | last channel is inner product |
| | :param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples] |
| | :return: rgb and density output, [n_rays, n_samples, 4] |
| | ''' |
| |
|
| | rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous() |
| | ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous() |
| | mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous() |
| | num_views = rgb_feat.shape[2] |
| | geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1) |
| |
|
| | direction_feat = self.ray_dir_fc(ray_diff) |
| | rgb_in = rgb_feat[..., :3] |
| | rgb_feat = rgb_feat + direction_feat |
| |
|
| | if self.anti_alias_pooling: |
| | _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1) |
| | exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1)) |
| | weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask |
| | weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8) |
| | else: |
| | weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) |
| |
|
| | |
| | mean, var = fused_mean_variance(rgb_feat, weight) |
| | globalfeat = torch.cat([mean, var], dim=-1) |
| |
|
| | x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat], |
| | dim=-1) |
| | x = self.base_fc(x) |
| |
|
| | x_vis = self.vis_fc(x * weight) |
| | x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1) |
| | vis = F.sigmoid(vis) * mask |
| | x = x + x_res |
| | vis = self.vis_fc2(x * vis) * mask |
| |
|
| | |
| | x = torch.cat([x, vis, ray_diff], dim=-1) |
| | x = self.rgb_fc(x) |
| | x = x.masked_fill(mask == 0, -1e9) |
| | blending_weights_valid = F.softmax(x, dim=2) |
| | rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2) |
| |
|
| | mask = mask.detach().to(rgb_out.dtype) |
| | mask = torch.sum(mask, dim=2, keepdim=False) |
| | mask = mask >= 2 |
| | mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False) |
| | valid_mask = mask > 8 |
| | return rgb_out, valid_mask |
| |
|