from transformers import PretrainedConfig class NicheformerConfig(PretrainedConfig): model_type = "nicheformer" def __init__( self, dim_model=512, nheads=16, dim_feedforward=1024, nlayers=12, dropout=0.0, batch_first=True, masking_p=0.15, n_tokens=20340, context_length=1500, cls_classes=164, supervised_task=None, learnable_pe=True, specie=True, assay=True, modality=True, **kwargs ): """Initialize NicheformerConfig. Args: dim_model: Dimensionality of the model nheads: Number of attention heads dim_feedforward: Dimensionality of MLPs in attention blocks nlayers: Number of transformer layers dropout: Dropout probability batch_first: Whether batch dimension is first masking_p: Probability of masking tokens n_tokens: Total number of tokens (excluding auxiliary) context_length: Length of the context window cls_classes: Number of classification classes supervised_task: Type of supervised task learnable_pe: Whether to use learnable positional embeddings specie: Whether to add specie token assay: Whether to add assay token modality: Whether to add modality token """ super().__init__(**kwargs) self.dim_model = dim_model self.nheads = nheads self.dim_feedforward = dim_feedforward self.nlayers = nlayers self.dropout = dropout self.batch_first = batch_first self.masking_p = masking_p self.n_tokens = n_tokens self.context_length = context_length self.cls_classes = cls_classes self.supervised_task = supervised_task self.learnable_pe = learnable_pe self.specie = specie self.assay = assay self.modality = modality