| | from typing import Any, Callable, Dict, List, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| | from enformer_pytorch import Enformer |
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | def get_activation_fn(activation_name: str) -> Callable: |
| | """ |
| | Returns torch activation function |
| | |
| | Args: |
| | activation_name (str): Name of the activation function. Possible values are |
| | 'swish', 'relu', 'gelu', 'sin' |
| | |
| | Raises: |
| | ValueError: If activation_name is not supported |
| | |
| | Returns: |
| | Callable: Activation function |
| | """ |
| | if activation_name == "swish": |
| | return nn.functional.silu |
| | elif activation_name == "relu": |
| | return nn.functional.relu |
| | elif activation_name == "gelu": |
| | return nn.functional.gelu |
| | elif activation_name == "sin": |
| | return torch.sin |
| | else: |
| | raise ValueError(f"Unsupported activation function: {activation_name}") |
| |
|
| |
|
| | class TorchDownSample1D(nn.Module): |
| | """ |
| | Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_channels: int, |
| | output_channels: int, |
| | activation_fn: str = "swish", |
| | num_layers: int = 2, |
| | ): |
| | """ |
| | Args: |
| | input_channels: number of input channels |
| | output_channels: number of output channels. |
| | activation_fn: name of the activation function to use. |
| | Should be one of "gelu", |
| | "gelu-no-approx", "relu", "swish", "silu", "sin". |
| | num_layers: number of convolution layers. |
| | """ |
| | super().__init__() |
| |
|
| | self.conv_layers = nn.ModuleList( |
| | [ |
| | nn.Conv1d( |
| | in_channels=input_channels if i == 0 else output_channels, |
| | out_channels=output_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0) |
| |
|
| | self.activation_fn: Callable = get_activation_fn(activation_fn) |
| |
|
| | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | for conv_layer in self.conv_layers: |
| | x = self.activation_fn(conv_layer(x)) |
| | hidden = x |
| | x = self.avg_pool(hidden) |
| | return x, hidden |
| |
|
| |
|
| | class TorchUpSample1D(nn.Module): |
| | """ |
| | Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_channels: int, |
| | output_channels: int, |
| | activation_fn: str = "swish", |
| | num_layers: int = 2, |
| | interpolation_method: str = "nearest", |
| | ): |
| | """ |
| | Args: |
| | input_channels: number of input channels. |
| | output_channels: number of output channels. |
| | activation_fn: name of the activation function to use. |
| | Should be one of "gelu", |
| | "gelu-no-approx", "relu", "swish", "silu", "sin". |
| | interpolation_method: Method to be used for upsampling interpolation. |
| | Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5". |
| | num_layers: number of convolution layers. |
| | """ |
| | super().__init__() |
| |
|
| | self.conv_transpose_layers = nn.ModuleList( |
| | [ |
| | nn.ConvTranspose1d( |
| | in_channels=input_channels if i == 0 else output_channels, |
| | out_channels=output_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | self.interpolation_mode = interpolation_method |
| | self.activation_fn: Callable = get_activation_fn(activation_fn) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | for conv_layer in self.conv_transpose_layers: |
| | x = self.activation_fn(conv_layer(x)) |
| | x = nn.functional.interpolate( |
| | x, |
| | scale_factor=2, |
| | mode=self.interpolation_mode, |
| | align_corners=False if self.interpolation_mode != "nearest" else None, |
| | ) |
| | return x |
| |
|
| |
|
| | class TorchFinalConv1D(nn.Module): |
| | """ |
| | Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_channels: int, |
| | output_channels: int, |
| | activation_fn: str = "swish", |
| | num_layers: int = 2, |
| | ): |
| | """ |
| | Args: |
| | input_channels: number of input channels |
| | output_channels: number of output channels. |
| | activation_fn: name of the activation function to use. |
| | Should be one of "gelu", |
| | "gelu-no-approx", "relu", "swish", "silu", "sin". |
| | num_layers: number of convolution layers. |
| | name: module name. |
| | """ |
| | super().__init__() |
| |
|
| | self.conv_layers = nn.ModuleList( |
| | [ |
| | nn.Conv1d( |
| | in_channels=input_channels if i == 0 else output_channels, |
| | out_channels=output_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | self.activation_fn: Callable = get_activation_fn(activation_fn) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | for i, conv_layer in enumerate(self.conv_layers): |
| | x = conv_layer(x) |
| | if i < len(self.conv_layers) - 1: |
| | x = self.activation_fn(x) |
| | return x |
| |
|
| |
|
| | class TorchUNET1DSegmentationHead(nn.Module): |
| | """ |
| | Torch adaptation of UNET1DSegmentationHead in |
| | trix.layers.heads.unet_segmentation_head.py |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_classes: int, |
| | input_embed_dim: int, |
| | output_channels_list: Tuple[int, ...] = (64, 128, 256), |
| | activation_fn: str = "swish", |
| | num_conv_layers_per_block: int = 2, |
| | upsampling_interpolation_method: str = "nearest", |
| | ): |
| | """ |
| | Args: |
| | num_classes: number of classes to segment |
| | output_channels_list: list of the number of output channel at each level of |
| | the UNET |
| | activation_fn: name of the activation function to use. |
| | Should be one of "gelu", |
| | "gelu-no-approx", "relu", "swish", "silu", "sin". |
| | num_conv_layers_per_block: number of convolution layers per block. |
| | upsampling_interpolation_method: Method to be used for |
| | interpolation in upsampling blocks. Should be one of "nearest", |
| | "linear", "cubic", "lanczos3", "lanczos5". |
| | """ |
| | super().__init__() |
| |
|
| | input_channels_list = (input_embed_dim,) + output_channels_list[:-1] |
| |
|
| | self.num_pooling_layers = len(output_channels_list) |
| | self.downsample_blocks = nn.ModuleList( |
| | [ |
| | TorchDownSample1D( |
| | input_channels=input_channels, |
| | output_channels=output_channels, |
| | activation_fn=activation_fn, |
| | num_layers=num_conv_layers_per_block, |
| | ) |
| | for input_channels, output_channels in zip( |
| | input_channels_list, output_channels_list |
| | ) |
| | ] |
| | ) |
| |
|
| | input_channels_list = (output_channels_list[-1],) + tuple( |
| | list(reversed(output_channels_list))[:-1] |
| | ) |
| |
|
| | self.upsample_blocks = nn.ModuleList( |
| | [ |
| | TorchUpSample1D( |
| | input_channels=input_channels, |
| | output_channels=output_channels, |
| | activation_fn=activation_fn, |
| | num_layers=num_conv_layers_per_block, |
| | interpolation_method=upsampling_interpolation_method, |
| | ) |
| | for input_channels, output_channels in zip( |
| | input_channels_list, reversed(output_channels_list) |
| | ) |
| | ] |
| | ) |
| |
|
| | self.final_block = TorchFinalConv1D( |
| | activation_fn=activation_fn, |
| | input_channels=output_channels_list[0], |
| | output_channels=num_classes * 2, |
| | num_layers=num_conv_layers_per_block, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if x.shape[-1] % 2**self.num_pooling_layers: |
| | raise ValueError( |
| | "Input length must be divisible by 2 to the power of " |
| | "the number of pooling layers." |
| | ) |
| |
|
| | hiddens = [] |
| | for downsample_block in self.downsample_blocks: |
| | x, hidden = downsample_block(x) |
| | hiddens.append(hidden) |
| |
|
| | for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)): |
| | x = upsample_block(x) + hidden |
| |
|
| | x = self.final_block(x) |
| | return x |
| |
|
| |
|
| | class TorchUNetHead(nn.Module): |
| | """ |
| | Torch adaptation of UNetHead in |
| | genomics_research/segmentnt/layers/segmentation_head.py |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | features: List[str], |
| | num_classes: int = 2, |
| | embed_dimension: int = 1024, |
| | nucl_per_token: int = 6, |
| | num_layers: int = 2, |
| | remove_cls_token: bool = True, |
| | ): |
| | """ |
| | Args: |
| | features (List[str]): List of features names. |
| | num_classes (int): Number of classes. |
| | embed_dimension (int): Embedding dimension. |
| | nucl_per_token (int): Number of nucleotides per token. |
| | num_layers (int): Number of layers. |
| | remove_cls_token (bool): Whether to remove the CLS token. |
| | name: Name the layer. Defaults to None. |
| | """ |
| | super().__init__() |
| | self._num_features = len(features) |
| | self._num_classes = num_classes |
| | self.nucl_per_token = nucl_per_token |
| | self.remove_cls_token = remove_cls_token |
| |
|
| | self.unet = TorchUNET1DSegmentationHead( |
| | num_classes=embed_dimension // 2, |
| | output_channels_list=tuple( |
| | embed_dimension * (2**i) for i in range(num_layers) |
| | ), |
| | input_embed_dim=embed_dimension, |
| | ) |
| |
|
| | self.fc = nn.Linear( |
| | embed_dimension, |
| | self.nucl_per_token * self._num_classes * self._num_features, |
| | ) |
| |
|
| | def forward( |
| | self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None |
| | ) -> Dict[str, torch.Tensor]: |
| | if self.remove_cls_token: |
| | x = x[:, 1:] |
| |
|
| | x = self.unet(x) |
| | x = nn.functional.silu(x) |
| |
|
| | x = x.transpose(2, 1) |
| | logits = self.fc(x) |
| |
|
| | batch_size, seq_len, _ = x.shape |
| | logits = logits.view( |
| | batch_size, |
| | seq_len * self.nucl_per_token, |
| | self._num_features, |
| | self._num_classes, |
| | ) |
| |
|
| | return {"logits": logits} |
| |
|
| |
|
| | FEATURES = [ |
| | "protein_coding_gene", |
| | "lncRNA", |
| | "exon", |
| | "intron", |
| | "splice_donor", |
| | "splice_acceptor", |
| | "5UTR", |
| | "3UTR", |
| | "CTCF-bound", |
| | "polyA_signal", |
| | "enhancer_Tissue_specific", |
| | "enhancer_Tissue_invariant", |
| | "promoter_Tissue_specific", |
| | "promoter_Tissue_invariant", |
| | ] |
| |
|
| |
|
| | class SegmentEnformerConfig(PretrainedConfig): |
| | model_type = "segment_enformer" |
| |
|
| | def __init__( |
| | self, |
| | features: List[str] = FEATURES, |
| | embed_dim: int = 1536, |
| | dim_divisible_by: int = 128, |
| | **kwargs: Dict[str, Any], |
| | ) -> None: |
| | self.features = features |
| | self.embed_dim = embed_dim |
| | self.dim_divisible_by = dim_divisible_by |
| |
|
| | super().__init__(**kwargs) |
| |
|
| |
|
| | class SegmentEnformer(PreTrainedModel): |
| | config_class = SegmentEnformerConfig |
| |
|
| | def __init__(self, config: SegmentEnformerConfig) -> None: |
| | super().__init__(config=config) |
| |
|
| | enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough") |
| |
|
| | self.stem = enformer.stem |
| | self.conv_tower = enformer.conv_tower |
| | self.transformer = enformer.transformer |
| |
|
| | self.unet_head = TorchUNetHead( |
| | features=config.features, |
| | embed_dimension=config.embed_dim, |
| | nucl_per_token=config.dim_divisible_by, |
| | remove_cls_token=False, |
| | ) |
| |
|
| | def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| | x = rearrange(x, "b n d -> b d n") |
| | x = self.stem(x) |
| |
|
| | x = self.conv_tower(x) |
| |
|
| | x = rearrange(x, "b d n -> b n d") |
| | x = self.transformer(x) |
| |
|
| | x = rearrange(x, "b n d -> b d n") |
| | x = self.unet_head(x) |
| |
|
| | return x |
| |
|