| | """ |
| | pytorch grid_sample doesn't support second-order derivative |
| | implement custom version |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| |
|
| |
|
| | def grid_sample_2d(image, optical): |
| | N, C, IH, IW = image.shape |
| | _, H, W, _ = optical.shape |
| |
|
| | ix = optical[..., 0] |
| | iy = optical[..., 1] |
| |
|
| | ix = ((ix + 1) / 2) * (IW - 1); |
| | iy = ((iy + 1) / 2) * (IH - 1); |
| | with torch.no_grad(): |
| | ix_nw = torch.floor(ix); |
| | iy_nw = torch.floor(iy); |
| | ix_ne = ix_nw + 1; |
| | iy_ne = iy_nw; |
| | ix_sw = ix_nw; |
| | iy_sw = iy_nw + 1; |
| | ix_se = ix_nw + 1; |
| | iy_se = iy_nw + 1; |
| |
|
| | nw = (ix_se - ix) * (iy_se - iy) |
| | ne = (ix - ix_sw) * (iy_sw - iy) |
| | sw = (ix_ne - ix) * (iy - iy_ne) |
| | se = (ix - ix_nw) * (iy - iy_nw) |
| |
|
| | with torch.no_grad(): |
| | torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) |
| | torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) |
| |
|
| | torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) |
| | torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) |
| |
|
| | torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) |
| | torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) |
| |
|
| | torch.clamp(ix_se, 0, IW - 1, out=ix_se) |
| | torch.clamp(iy_se, 0, IH - 1, out=iy_se) |
| |
|
| | image = image.view(N, C, IH * IW) |
| |
|
| | nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) |
| | ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) |
| | sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) |
| | se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) |
| |
|
| | out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + |
| | ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + |
| | sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + |
| | se_val.view(N, C, H, W) * se.view(N, 1, H, W)) |
| |
|
| | return out_val |
| |
|
| |
|
| | |
| | def grid_sample_3d(volume, optical): |
| | """ |
| | bilinear sampling cannot guarantee continuous first-order gradient |
| | mimic pytorch grid_sample function |
| | The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view) |
| | fnw (front north west) point |
| | bse (back south east) point |
| | :param volume: [B, C, X, Y, Z] |
| | :param optical: [B, x, y, z, 3] |
| | :return: |
| | """ |
| | N, C, ID, IH, IW = volume.shape |
| | _, D, H, W, _ = optical.shape |
| |
|
| | ix = optical[..., 0] |
| | iy = optical[..., 1] |
| | iz = optical[..., 2] |
| |
|
| | ix = ((ix + 1) / 2) * (IW - 1) |
| | iy = ((iy + 1) / 2) * (IH - 1) |
| | iz = ((iz + 1) / 2) * (ID - 1) |
| |
|
| | mask_x = (ix > 0) & (ix < IW) |
| | mask_y = (iy > 0) & (iy < IH) |
| | mask_z = (iz > 0) & (iz < ID) |
| |
|
| | mask = mask_x & mask_y & mask_z |
| | mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) |
| |
|
| | with torch.no_grad(): |
| | |
| | ix_bnw = torch.floor(ix) |
| | iy_bnw = torch.floor(iy) |
| | iz_bnw = torch.floor(iz) |
| |
|
| | ix_bne = ix_bnw + 1 |
| | iy_bne = iy_bnw |
| | iz_bne = iz_bnw |
| |
|
| | ix_bsw = ix_bnw |
| | iy_bsw = iy_bnw + 1 |
| | iz_bsw = iz_bnw |
| |
|
| | ix_bse = ix_bnw + 1 |
| | iy_bse = iy_bnw + 1 |
| | iz_bse = iz_bnw |
| |
|
| | |
| | ix_fnw = ix_bnw |
| | iy_fnw = iy_bnw |
| | iz_fnw = iz_bnw + 1 |
| |
|
| | ix_fne = ix_bnw + 1 |
| | iy_fne = iy_bnw |
| | iz_fne = iz_bnw + 1 |
| |
|
| | ix_fsw = ix_bnw |
| | iy_fsw = iy_bnw + 1 |
| | iz_fsw = iz_bnw + 1 |
| |
|
| | ix_fse = ix_bnw + 1 |
| | iy_fse = iy_bnw + 1 |
| | iz_fse = iz_bnw + 1 |
| |
|
| | |
| | bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) |
| | bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz) |
| | bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz) |
| | bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz) |
| |
|
| | |
| | fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) |
| | fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw) |
| | fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne) |
| | fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw) |
| |
|
| | with torch.no_grad(): |
| | |
| | torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) |
| | torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) |
| | torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) |
| |
|
| | torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) |
| | torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) |
| | torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) |
| |
|
| | torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) |
| | torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) |
| | torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) |
| |
|
| | torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) |
| | torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) |
| | torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) |
| |
|
| | |
| | torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw) |
| | torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw) |
| | torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw) |
| |
|
| | torch.clamp(ix_fne, 0, IW - 1, out=ix_fne) |
| | torch.clamp(iy_fne, 0, IH - 1, out=iy_fne) |
| | torch.clamp(iz_fne, 0, ID - 1, out=iz_fne) |
| |
|
| | torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw) |
| | torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw) |
| | torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw) |
| |
|
| | torch.clamp(ix_fse, 0, IW - 1, out=ix_fse) |
| | torch.clamp(iy_fse, 0, IH - 1, out=iy_fse) |
| | torch.clamp(iz_fse, 0, ID - 1, out=iz_fse) |
| |
|
| | |
| | volume = volume.view(N, C, ID * IH * IW) |
| | |
| |
|
| | |
| | bnw_val = torch.gather(volume, 2, |
| | (iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| | bne_val = torch.gather(volume, 2, |
| | (iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| | bsw_val = torch.gather(volume, 2, |
| | (iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| | bse_val = torch.gather(volume, 2, |
| | (iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| |
|
| | |
| | fnw_val = torch.gather(volume, 2, |
| | (iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| | fne_val = torch.gather(volume, 2, |
| | (iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| | fsw_val = torch.gather(volume, 2, |
| | (iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| | fse_val = torch.gather(volume, 2, |
| | (iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1)) |
| |
|
| | out_val = ( |
| | |
| | bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + |
| | bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + |
| | bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + |
| | bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) + |
| | |
| | fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) + |
| | fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) + |
| | fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) + |
| | fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W) |
| |
|
| | ) |
| |
|
| | |
| | out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device)) |
| |
|
| | return out_val |
| |
|
| |
|
| | |
| | def get_weight(s, a=-0.5): |
| | mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1) |
| | mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2) |
| | mask_2 = torch.abs(s) > 2 |
| |
|
| | weight = torch.zeros_like(s).to(s.device) |
| | weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight) |
| | weight = torch.where(mask_1, |
| | a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a, |
| | weight) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | return weight |
| |
|
| |
|
| | def cubic_interpolate(p, x): |
| | """ |
| | one dimensional cubic interpolation |
| | :param p: [N, 4] (4) should be in order |
| | :param x: [N] |
| | :return: |
| | """ |
| | return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * ( |
| | 2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * ( |
| | 3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0]))) |
| |
|
| |
|
| | def bicubic_interpolate(p, x, y, if_batch=True): |
| | """ |
| | two dimensional cubic interpolation |
| | :param p: [N, 4, 4] |
| | :param x: [N] |
| | :param y: [N] |
| | :return: |
| | """ |
| | num = p.shape[0] |
| |
|
| | if not if_batch: |
| | arr0 = cubic_interpolate(p[:, 0, :], x) |
| | arr1 = cubic_interpolate(p[:, 1, :], x) |
| | arr2 = cubic_interpolate(p[:, 2, :], x) |
| | arr3 = cubic_interpolate(p[:, 3, :], x) |
| | return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) |
| | else: |
| | x = x[:, None].repeat(1, 4).view(-1) |
| | p = p.contiguous().view(num * 4, 4) |
| | arr = cubic_interpolate(p, x) |
| | arr = arr.view(num, 4) |
| |
|
| | return cubic_interpolate(arr, y) |
| |
|
| |
|
| | def tricubic_interpolate(p, x, y, z): |
| | """ |
| | three dimensional cubic interpolation |
| | :param p: [N,4,4,4] |
| | :param x: [N] |
| | :param y: [N] |
| | :param z: [N] |
| | :return: |
| | """ |
| | num = p.shape[0] |
| |
|
| | arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) |
| | arr1 = bicubic_interpolate(p[:, 1, :, :], x, y) |
| | arr2 = bicubic_interpolate(p[:, 2, :, :], x, y) |
| | arr3 = bicubic_interpolate(p[:, 3, :, :], x, y) |
| |
|
| | return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) |
| |
|
| |
|
| | def cubic_interpolate_batch(p, x): |
| | """ |
| | one dimensional cubic interpolation |
| | :param p: [B, N, 4] (4) should be in order |
| | :param x: [B, N] |
| | :return: |
| | """ |
| | return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * ( |
| | 2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * ( |
| | 3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0]))) |
| |
|
| |
|
| | def bicubic_interpolate_batch(p, x, y): |
| | """ |
| | two dimensional cubic interpolation |
| | :param p: [B, N, 4, 4] |
| | :param x: [B, N] |
| | :param y: [B, N] |
| | :return: |
| | """ |
| | B, N, _, _ = p.shape |
| |
|
| | x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) |
| | arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x) |
| | arr = arr.view(B, N, 4) |
| | return cubic_interpolate_batch(arr, y) |
| |
|
| |
|
| | |
| | def tricubic_interpolate_batch(p, x, y, z): |
| | """ |
| | three dimensional cubic interpolation |
| | :param p: [N,4,4,4] |
| | :param x: [N] |
| | :param y: [N] |
| | :param z: [N] |
| | :return: |
| | """ |
| | N = p.shape[0] |
| |
|
| | x = x[None, :].repeat(4, 1) |
| | y = y[None, :].repeat(4, 1) |
| |
|
| | p = p.permute(1, 0, 2, 3).contiguous() |
| |
|
| | arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) |
| |
|
| | arr = arr.permute(1, 0).contiguous() |
| |
|
| | return cubic_interpolate(arr, z) |
| |
|
| |
|
| | def tricubic_sample_3d(volume, optical): |
| | """ |
| | tricubic sampling; can guarantee continuous gradient (interpolation border) |
| | :param volume: [B, C, ID, IH, IW] |
| | :param optical: [B, D, H, W, 3] |
| | :param sample_num: |
| | :return: |
| | """ |
| |
|
| | @torch.no_grad() |
| | def get_shifts(x): |
| | x1 = -1 * (1 + x - torch.floor(x)) |
| | x2 = -1 * (x - torch.floor(x)) |
| | x3 = torch.floor(x) + 1 - x |
| | x4 = torch.floor(x) + 2 - x |
| |
|
| | return torch.stack([x1, x2, x3, x4], dim=-1) |
| |
|
| | N, C, ID, IH, IW = volume.shape |
| | _, D, H, W, _ = optical.shape |
| |
|
| | device = volume.device |
| |
|
| | ix = optical[..., 0] |
| | iy = optical[..., 1] |
| | iz = optical[..., 2] |
| |
|
| | ix = ((ix + 1) / 2) * (IW - 1) |
| | iy = ((iy + 1) / 2) * (IH - 1) |
| | iz = ((iz + 1) / 2) * (ID - 1) |
| |
|
| | ix = ix.view(-1) |
| | iy = iy.view(-1) |
| | iz = iz.view(-1) |
| |
|
| | with torch.no_grad(): |
| | shifts_x = get_shifts(ix).view(-1, 4) |
| | shifts_y = get_shifts(iy).view(-1, 4) |
| | shifts_z = get_shifts(iz).view(-1, 4) |
| |
|
| | perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device) |
| | perm = torch.cumsum(perm_weights, dim=-1) - 1 |
| |
|
| | perm_z = perm // 16 |
| | perm_y = (perm - perm_z * 16) // 4 |
| | perm_x = (perm - perm_z * 16 - perm_y * 4) |
| |
|
| | shifts_x = torch.gather(shifts_x, 1, perm_x) |
| | shifts_y = torch.gather(shifts_y, 1, perm_y) |
| | shifts_z = torch.gather(shifts_z, 1, perm_z) |
| |
|
| | ix_target = (ix[:, None] + shifts_x).long() |
| | iy_target = (iy[:, None] + shifts_y).long() |
| | iz_target = (iz[:, None] + shifts_z).long() |
| |
|
| | torch.clamp(ix_target, 0, IW - 1, out=ix_target) |
| | torch.clamp(iy_target, 0, IH - 1, out=iy_target) |
| | torch.clamp(iz_target, 0, ID - 1, out=iz_target) |
| |
|
| | local_dist_x = ix - ix_target[:, 1] |
| | local_dist_y = iy - iy_target[:, 1 + 4] |
| | local_dist_z = iz - iz_target[:, 1 + 16] |
| |
|
| | local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) |
| | local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) |
| | local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) |
| |
|
| | |
| | idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target |
| |
|
| | volume = volume.view(N, C, ID * IH * IW) |
| |
|
| | out = torch.gather(volume, 2, |
| | idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1)) |
| | out = out.view(N * C * D * H * W, 4, 4, 4) |
| |
|
| | |
| | final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) |
| |
|
| | return final |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from ops.generate_grids import generate_grid |
| |
|
| | p = torch.tensor([x for x in range(4)]).view(1, 4).float() |
| |
|
| | v = cubic_interpolate(p, torch.tensor([0.5]).view(1)) |
| | |
| |
|
| | vsize = 9 |
| | volume = generate_grid([vsize, vsize, vsize], 1) |
| | |
| | X, Y, Z = 0, 0, 6 |
| | x = 2 * X / (vsize - 1) - 1 |
| | y = 2 * Y / (vsize - 1) - 1 |
| | z = 2 * Z / (vsize - 1) - 1 |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3) |
| |
|
| | print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True)) |
| | print(grid_sample_3d(volume, optical)) |
| | print(tricubic_sample_3d(volume, optical)) |
| | |
| | |
| |
|