| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from models.embedder import get_embedder |
| |
|
| |
|
| | class SDFNetwork(nn.Module): |
| | def __init__(self, |
| | d_in, |
| | d_out, |
| | d_hidden, |
| | n_layers, |
| | skip_in=(4,), |
| | multires=0, |
| | bias=0.5, |
| | scale=1, |
| | geometric_init=True, |
| | weight_norm=True, |
| | activation='softplus', |
| | conditional_type='multiply'): |
| | super(SDFNetwork, self).__init__() |
| |
|
| | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] |
| |
|
| | self.embed_fn_fine = None |
| |
|
| | if multires > 0: |
| | embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) |
| | self.embed_fn_fine = embed_fn |
| | dims[0] = input_ch |
| |
|
| | self.num_layers = len(dims) |
| | self.skip_in = skip_in |
| | self.scale = scale |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | if l + 1 in self.skip_in: |
| | out_dim = dims[l + 1] - dims[0] |
| | else: |
| | out_dim = dims[l + 1] |
| |
|
| | lin = nn.Linear(dims[l], out_dim) |
| |
|
| | if geometric_init: |
| | if l == self.num_layers - 2: |
| | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) |
| | torch.nn.init.constant_(lin.bias, -bias) |
| | elif multires > 0 and l == 0: |
| | torch.nn.init.constant_(lin.bias, 0.0) |
| | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) |
| | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| | elif multires > 0 and l in self.skip_in: |
| | torch.nn.init.constant_(lin.bias, 0.0) |
| | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) |
| | else: |
| | torch.nn.init.constant_(lin.bias, 0.0) |
| | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| |
|
| | if weight_norm: |
| | lin = nn.utils.weight_norm(lin) |
| |
|
| | setattr(self, "lin" + str(l), lin) |
| |
|
| | if activation == 'softplus': |
| | self.activation = nn.Softplus(beta=100) |
| | else: |
| | assert activation == 'relu' |
| | self.activation = nn.ReLU() |
| |
|
| | def forward(self, inputs): |
| | inputs = inputs * self.scale |
| | if self.embed_fn_fine is not None: |
| | inputs = self.embed_fn_fine(inputs) |
| |
|
| | x = inputs |
| | for l in range(0, self.num_layers - 1): |
| | lin = getattr(self, "lin" + str(l)) |
| |
|
| | if l in self.skip_in: |
| | x = torch.cat([x, inputs], 1) / np.sqrt(2) |
| |
|
| | x = lin(x) |
| |
|
| | if l < self.num_layers - 2: |
| | x = self.activation(x) |
| | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) |
| |
|
| | def sdf(self, x): |
| | return self.forward(x)[:, :1] |
| |
|
| | def sdf_hidden_appearance(self, x): |
| | return self.forward(x) |
| |
|
| | def gradient(self, x): |
| | x.requires_grad_(True) |
| | y = self.sdf(x) |
| | d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
| | gradients = torch.autograd.grad( |
| | outputs=y, |
| | inputs=x, |
| | grad_outputs=d_output, |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True)[0] |
| | return gradients.unsqueeze(1) |
| |
|
| |
|
| | class VarianceNetwork(nn.Module): |
| | def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0): |
| | super(VarianceNetwork, self).__init__() |
| |
|
| | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] |
| |
|
| | self.embed_fn_fine = None |
| |
|
| | if multires > 0: |
| | embed_fn, input_ch = get_embedder(multires, normalize=False) |
| | self.embed_fn_fine = embed_fn |
| | dims[0] = input_ch |
| |
|
| | self.num_layers = len(dims) |
| | self.skip_in = skip_in |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | if l + 1 in self.skip_in: |
| | out_dim = dims[l + 1] - dims[0] |
| | else: |
| | out_dim = dims[l + 1] |
| |
|
| | lin = nn.Linear(dims[l], out_dim) |
| | setattr(self, "lin" + str(l), lin) |
| |
|
| | self.relu = nn.ReLU() |
| | self.softplus = nn.Softplus(beta=100) |
| |
|
| | def forward(self, inputs): |
| | if self.embed_fn_fine is not None: |
| | inputs = self.embed_fn_fine(inputs) |
| |
|
| | x = inputs |
| | for l in range(0, self.num_layers - 1): |
| | lin = getattr(self, "lin" + str(l)) |
| |
|
| | if l in self.skip_in: |
| | x = torch.cat([x, inputs], 1) / np.sqrt(2) |
| |
|
| | x = lin(x) |
| |
|
| | if l < self.num_layers - 2: |
| | x = self.relu(x) |
| |
|
| | |
| | return 1.0 / (self.softplus(x + 0.5) + 1e-3) |
| |
|
| | def coarse(self, inputs): |
| | return self.forward(inputs)[:, :1] |
| |
|
| | def fine(self, inputs): |
| | return self.forward(inputs)[:, 1:] |
| |
|
| |
|
| | class FixVarianceNetwork(nn.Module): |
| | def __init__(self, base): |
| | super(FixVarianceNetwork, self).__init__() |
| | self.base = base |
| | self.iter_step = 0 |
| |
|
| | def set_iter_step(self, iter_step): |
| | self.iter_step = iter_step |
| |
|
| | def forward(self, x): |
| | return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base) |
| |
|
| |
|
| | class SingleVarianceNetwork(nn.Module): |
| | def __init__(self, init_val=1.0): |
| | super(SingleVarianceNetwork, self).__init__() |
| | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) |
| |
|
| | def forward(self, x): |
| | return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) |
| |
|
| |
|
| |
|
| | class RenderingNetwork(nn.Module): |
| | def __init__( |
| | self, |
| | d_feature, |
| | mode, |
| | d_in, |
| | d_out, |
| | d_hidden, |
| | n_layers, |
| | weight_norm=True, |
| | multires_view=0, |
| | squeeze_out=True, |
| | d_conditional_colors=0 |
| | ): |
| | super().__init__() |
| |
|
| | self.mode = mode |
| | self.squeeze_out = squeeze_out |
| | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] |
| |
|
| | self.embedview_fn = None |
| | if multires_view > 0: |
| | embedview_fn, input_ch = get_embedder(multires_view) |
| | self.embedview_fn = embedview_fn |
| | dims[0] += (input_ch - 3) |
| |
|
| | self.num_layers = len(dims) |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | out_dim = dims[l + 1] |
| | lin = nn.Linear(dims[l], out_dim) |
| |
|
| | if weight_norm: |
| | lin = nn.utils.weight_norm(lin) |
| |
|
| | setattr(self, "lin" + str(l), lin) |
| |
|
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, points, normals, view_dirs, feature_vectors): |
| | if self.embedview_fn is not None: |
| | view_dirs = self.embedview_fn(view_dirs) |
| |
|
| | rendering_input = None |
| |
|
| | if self.mode == 'idr': |
| | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) |
| | elif self.mode == 'no_view_dir': |
| | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) |
| | elif self.mode == 'no_normal': |
| | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) |
| | elif self.mode == 'no_points': |
| | rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) |
| | elif self.mode == 'no_points_no_view_dir': |
| | rendering_input = torch.cat([normals, feature_vectors], dim=-1) |
| |
|
| | x = rendering_input |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | lin = getattr(self, "lin" + str(l)) |
| |
|
| | x = lin(x) |
| |
|
| | if l < self.num_layers - 2: |
| | x = self.relu(x) |
| |
|
| | if self.squeeze_out: |
| | x = torch.sigmoid(x) |
| | return x |
| |
|
| |
|
| | |
| | class NeRF(nn.Module): |
| | def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4], |
| | use_viewdirs=False): |
| | """ |
| | """ |
| | super(NeRF, self).__init__() |
| | self.D = D |
| | self.W = W |
| | self.d_in = d_in |
| | self.d_in_view = d_in_view |
| | self.input_ch = 3 |
| | self.input_ch_view = 3 |
| | self.embed_fn = None |
| | self.embed_fn_view = None |
| |
|
| | if multires > 0: |
| | embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) |
| | self.embed_fn = embed_fn |
| | self.input_ch = input_ch |
| |
|
| | if multires_view > 0: |
| | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False) |
| | self.embed_fn_view = embed_fn_view |
| | self.input_ch_view = input_ch_view |
| |
|
| | self.skips = skips |
| | self.use_viewdirs = use_viewdirs |
| |
|
| | self.pts_linears = nn.ModuleList( |
| | [nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) |
| | for i in |
| | range(D - 1)]) |
| |
|
| | |
| | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) |
| |
|
| | |
| | |
| | |
| |
|
| | if use_viewdirs: |
| | self.feature_linear = nn.Linear(W, W) |
| | self.alpha_linear = nn.Linear(W, 1) |
| | self.rgb_linear = nn.Linear(W // 2, 3) |
| | else: |
| | self.output_linear = nn.Linear(W, output_ch) |
| |
|
| | def forward(self, input_pts, input_views): |
| | if self.embed_fn is not None: |
| | input_pts = self.embed_fn(input_pts) |
| | if self.embed_fn_view is not None: |
| | input_views = self.embed_fn_view(input_views) |
| |
|
| | h = input_pts |
| | for i, l in enumerate(self.pts_linears): |
| | h = self.pts_linears[i](h) |
| | h = F.relu(h) |
| | if i in self.skips: |
| | h = torch.cat([input_pts, h], -1) |
| |
|
| | if self.use_viewdirs: |
| | alpha = self.alpha_linear(h) |
| | feature = self.feature_linear(h) |
| | h = torch.cat([feature, input_views], -1) |
| |
|
| | for i, l in enumerate(self.views_linears): |
| | h = self.views_linears[i](h) |
| | h = F.relu(h) |
| |
|
| | rgb = self.rgb_linear(h) |
| | return alpha + 1.0, rgb |
| | else: |
| | assert False |
| |
|