| import torch |
|
|
| from torch import nn |
| from einops.layers.torch import Rearrange |
|
|
| def init_layer(layer): |
| nn.init.xavier_uniform_(layer.weight) |
|
|
| if hasattr(layer, "bias") and layer.bias is not None: |
| layer.bias.data.fill_(0.0) |
|
|
| def init_bn(bn): |
| bn.bias.data.fill_(0.0) |
| bn.weight.data.fill_(1.0) |
| bn.running_mean.data.fill_(0.0) |
| bn.running_var.data.fill_(1.0) |
|
|
| class BiGRU(nn.Module): |
| def __init__( |
| self, |
| patch_size, |
| channels, |
| depth |
| ): |
| super(BiGRU, self).__init__() |
| patch_width, patch_height = patch_size |
| patch_dim = channels * patch_height * patch_width |
|
|
| self.to_patch_embedding = nn.Sequential( |
| Rearrange( |
| 'b c (w p1) (h p2) -> b (w h) (p1 p2 c)', |
| p1=patch_width, |
| p2=patch_height |
| ) |
| ) |
|
|
| self.gru = nn.GRU( |
| patch_dim, |
| patch_dim // 2, |
| num_layers=depth, |
| batch_first=True, |
| bidirectional=True |
| ) |
|
|
| def forward(self, x): |
| x = self.to_patch_embedding(x) |
|
|
| try: |
| return self.gru(x)[0] |
| except: |
| torch.backends.cudnn.enabled = False |
| return self.gru(x)[0] |
|
|
| class ResConvBlock(nn.Module): |
| def __init__( |
| self, |
| in_planes, |
| out_planes |
| ): |
| super(ResConvBlock, self).__init__() |
| self.bn1 = nn.BatchNorm2d( |
| in_planes, |
| momentum=0.01 |
| ) |
| self.bn2 = nn.BatchNorm2d( |
| out_planes, |
| momentum=0.01 |
| ) |
| self.act1 = nn.PReLU() |
| self.act2 = nn.PReLU() |
| self.conv1 = nn.Conv2d( |
| in_planes, |
| out_planes, |
| (3, 3), |
| padding=(1, 1), |
| bias=False |
| ) |
| self.conv2 = nn.Conv2d( |
| out_planes, |
| out_planes, |
| (3, 3), |
| padding=(1, 1), |
| bias=False |
| ) |
| self.is_shortcut = False |
|
|
| if in_planes != out_planes: |
| self.shortcut = nn.Conv2d( |
| in_planes, |
| out_planes, |
| (1, 1) |
| ) |
| self.is_shortcut = True |
|
|
| self.init_weights() |
|
|
| def init_weights(self): |
| init_bn(self.bn1) |
| init_bn(self.bn2) |
|
|
| init_layer(self.conv1) |
| init_layer(self.conv2) |
|
|
| if self.is_shortcut: init_layer(self.shortcut) |
|
|
| def forward(self, x): |
| out = self.conv1( |
| self.act1(self.bn1(x)) |
| ) |
| out = self.conv2( |
| self.act2(self.bn2(out)) |
| ) |
|
|
| if self.is_shortcut: return self.shortcut(x) + out |
| else: return out + x |