| import torch |
| from torch import nn |
| from transformers import PreTrainedModel |
| from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder |
|
|
| from .configuration_radzero import AlignTransformerConfig |
|
|
|
|
| def build_align_transformer(config): |
| if config.model_type == "align_transformer": |
| model = AlignTransformer(config) |
| else: |
| raise NotImplementedError() |
|
|
| return model |
|
|
|
|
| class AlignTransformer(PreTrainedModel): |
| def __init__(self, config: AlignTransformerConfig): |
| super().__init__(config) |
|
|
| self.projector = None |
|
|
| if config.num_hidden_layers: |
| self.transformer_layers = Dinov2Encoder(config) |
| else: |
| self.transformer_layers = None |
|
|
| if config.use_layer_norm: |
| self.layer_norm = nn.LayerNorm(config.hidden_size) |
| else: |
| self.layer_norm = None |
|
|
| def forward(self, vision_tokens): |
|
|
| if self.projector is not None: |
|
|
| cls_token = vision_tokens[:, :1] |
| patch_tokens = vision_tokens[:, 1:] |
|
|
| patch_tokens = self.projector(patch_tokens)["last_hidden_state"] |
| vision_tokens = torch.cat([cls_token, patch_tokens], dim=1) |
|
|
| if self.transformer_layers is not None: |
| vision_tokens = self.transformer_layers(vision_tokens)["last_hidden_state"] |
|
|
| if self.layer_norm is not None: |
| vision_tokens = self.layer_norm(vision_tokens) |
|
|
| return vision_tokens |
|
|