| import os |
| import sys |
| import torch |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| sys.path.append(os.getcwd()) |
|
|
| from infer.lib.predictors.DJCM.encoder import ResEncoderBlock |
| from infer.lib.predictors.DJCM.utils import ResConvBlock, BiGRU, init_bn, init_layer |
|
|
| class ResDecoderBlock(nn.Module): |
| def __init__(self, in_channels, out_channels, n_blocks, stride): |
| super(ResDecoderBlock, self).__init__() |
| self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, stride, stride, (0, 0), bias=False) |
| self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.01) |
| self.conv = nn.ModuleList([ResConvBlock(out_channels * 2, out_channels)]) |
|
|
| for _ in range(n_blocks - 1): |
| self.conv.append(ResConvBlock(out_channels, out_channels)) |
|
|
| self.init_weights() |
|
|
| def init_weights(self): |
| init_bn(self.bn1) |
| init_layer(self.conv1) |
|
|
| def forward(self, x, concat): |
| x = self.conv1(F.relu_(self.bn1(x))) |
| x = torch.cat((x, concat), dim=1) |
| |
| for each_layer in self.conv: |
| x = each_layer(x) |
| |
| return x |
|
|
| class Decoder(nn.Module): |
| def __init__(self, n_blocks): |
| super(Decoder, self).__init__() |
| self.de_blocks = nn.ModuleList([ |
| ResDecoderBlock(384, 384, n_blocks, (1, 2)), |
| ResDecoderBlock(384, 384, n_blocks, (1, 2)), |
| ResDecoderBlock(384, 256, n_blocks, (1, 2)), |
| ResDecoderBlock(256, 128, n_blocks, (1, 2)), |
| ResDecoderBlock(128, 64, n_blocks, (1, 2)), |
| ResDecoderBlock(64, 32, n_blocks, (1, 2)) |
| ]) |
|
|
| def forward(self, x, concat_tensors): |
| for i, layer in enumerate(self.de_blocks): |
| x = layer(x, concat_tensors[-1 - i]) |
|
|
| return x |
|
|
| class PE_Decoder(nn.Module): |
| def __init__(self, n_blocks, seq_layers=1, window_length = 1024, n_class = 360): |
| super(PE_Decoder, self).__init__() |
| self.de_blocks = Decoder(n_blocks) |
| self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None) |
| self.after_conv2 = nn.Conv2d(32, 1, (1, 1)) |
| self.fc = nn.Sequential( |
| BiGRU( |
| (1, window_length // 2), |
| 1, |
| seq_layers |
| ), |
| nn.Linear( |
| window_length // 2, |
| n_class |
| ), |
| nn.Sigmoid() |
| ) |
| init_layer(self.after_conv2) |
|
|
| def forward(self, x, concat_tensors): |
| return self.fc(self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))).squeeze(1) |
| |
| class SVS_Decoder(nn.Module): |
| def __init__(self, in_channels, n_blocks): |
| super(SVS_Decoder, self).__init__() |
| self.de_blocks = Decoder(n_blocks) |
| self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None) |
| self.after_conv2 = nn.Conv2d(32, in_channels * 4, (1, 1)) |
| self.init_weights() |
|
|
| def init_weights(self): |
| init_layer(self.after_conv2) |
|
|
| def forward(self, x, concat_tensors): |
| return self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors))) |