| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Literal |
| |
|
| | from transformers.models.modernbert.configuration_modernbert import ModernBertConfig |
| |
|
| |
|
| | class ModChemBertConfig(ModernBertConfig): |
| | """ |
| | Configuration class for ModChemBert models. |
| | |
| | This configuration class extends ModernBertConfig with additional parameters specific to |
| | chemical molecule modeling and custom pooling strategies for classification/regression tasks. |
| | It accepts all arguments and keyword arguments from ModernBertConfig. |
| | |
| | Args: |
| | classifier_pooling (str, optional): Pooling strategy for sequence classification. |
| | Available options: |
| | - "cls": Use CLS token representation |
| | - "mean": Attention-weighted average pooling |
| | - "sum_mean": Sum all hidden states across layers, then mean pool over sequence (ChemLM approach) |
| | - "sum_sum": Sum all hidden states across layers, then sum pool over sequence |
| | - "mean_mean": Mean all hidden states across layers, then mean pool over sequence |
| | - "mean_sum": Mean all hidden states across layers, then sum pool over sequence |
| | - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token |
| | - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values |
| | - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query |
| | - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence |
| | Defaults to "sum_mean". |
| | classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention |
| | pooling strategies (cls_mha, max_seq_mha). Defaults to 4. |
| | classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention |
| | pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0. |
| | classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling |
| | strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8. |
| | *args: Variable length argument list passed to ModernBertConfig. |
| | **kwargs: Arbitrary keyword arguments passed to ModernBertConfig. |
| | |
| | Note: |
| | This class inherits all configuration parameters from ModernBertConfig including |
| | hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, etc. |
| | """ |
| |
|
| | model_type = "modchembert" |
| |
|
| | def __init__( |
| | self, |
| | *args, |
| | classifier_pooling: Literal[ |
| | "cls", |
| | "mean", |
| | "sum_mean", |
| | "sum_sum", |
| | "mean_mean", |
| | "mean_sum", |
| | "max_cls", |
| | "cls_mha", |
| | "max_seq_mha", |
| | "max_seq_mean", |
| | ] = "max_seq_mha", |
| | classifier_pooling_num_attention_heads: int = 4, |
| | classifier_pooling_attention_dropout: float = 0.0, |
| | classifier_pooling_last_k: int = 8, |
| | **kwargs, |
| | ): |
| | |
| | super().__init__(*args, classifier_pooling="cls", **kwargs) |
| | |
| | self.classifier_pooling = classifier_pooling |
| | self.classifier_pooling_num_attention_heads = classifier_pooling_num_attention_heads |
| | self.classifier_pooling_attention_dropout = classifier_pooling_attention_dropout |
| | self.classifier_pooling_last_k = classifier_pooling_last_k |
| |
|