pstjohn commited on
Commit
76ff34a
·
verified ·
1 Parent(s): 2fcafda

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. config.json +3 -2
  2. esm_nv.py +145 -74
  3. special_tokens_map.json +42 -5
  4. tokenizer.json +176 -0
  5. tokenizer_config.json +8 -1
config.json CHANGED
@@ -8,7 +8,8 @@
8
  "auto_map": {
9
  "AutoConfig": "esm_nv.NVEsmConfig",
10
  "AutoModel": "esm_nv.NVEsmModel",
11
- "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM"
 
12
  },
13
  "classifier_dropout": null,
14
  "dtype": "float32",
@@ -35,7 +36,7 @@
35
  "position_embedding_type": "rotary",
36
  "qkv_weight_interleaved": true,
37
  "token_dropout": true,
38
- "transformers_version": "4.57.0",
39
  "use_cache": true,
40
  "vocab_list": null,
41
  "vocab_size": 33
 
8
  "auto_map": {
9
  "AutoConfig": "esm_nv.NVEsmConfig",
10
  "AutoModel": "esm_nv.NVEsmModel",
11
+ "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM",
12
+ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification"
13
  },
14
  "classifier_dropout": null,
15
  "dtype": "float32",
 
36
  "position_embedding_type": "rotary",
37
  "qkv_weight_interleaved": true,
38
  "token_dropout": true,
39
+ "transformers_version": "4.57.3",
40
  "use_cache": true,
41
  "vocab_list": null,
42
  "vocab_size": 33
esm_nv.py CHANGED
@@ -23,7 +23,7 @@
23
  Adapted from `modeling_esm.py` in huggingface/transformers.
24
  """
25
 
26
- from typing import Literal, Optional
27
 
28
  # TODO: put import guard around transformer_engine here, with an informative error message around
29
  # installation and the nvidia docker container.
@@ -36,15 +36,26 @@ from transformers.modeling_outputs import (
36
  BaseModelOutput,
37
  BaseModelOutputWithPooling,
38
  MaskedLMOutput,
 
39
  )
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.models.esm.configuration_esm import EsmConfig
42
  from transformers.models.esm.modeling_esm import EsmPooler
43
  from transformers.utils import logging
 
44
 
45
 
46
  logger = logging.get_logger(__name__)
47
 
 
 
 
 
 
 
 
 
 
48
 
49
  class NVEsmConfig(EsmConfig):
50
  """NVEsmConfig is a configuration for the NVEsm model."""
@@ -149,7 +160,9 @@ class NVEsmEncoder(nn.Module):
149
  for i in range(config.num_hidden_layers)
150
  ]
151
  )
152
- self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 
 
153
  if config.position_embedding_type == "rotary":
154
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
155
 
@@ -157,27 +170,28 @@ class NVEsmEncoder(nn.Module):
157
  self,
158
  hidden_states: torch.Tensor,
159
  attention_mask: Optional[torch.Tensor] = None,
160
- output_hidden_states: bool = False,
161
- cu_seq_lens_q: torch.IntTensor | None = None,
162
- cu_seq_lens_k: torch.IntTensor | None = None,
163
- max_length_q: int | None = None,
164
- max_length_k: int | None = None,
165
  ):
166
  """Forward pass of the NVEsmEncoder.
167
 
168
  Args:
169
  hidden_states (torch.Tensor): The hidden states.
170
  attention_mask (torch.Tensor): The attention mask.
171
- output_hidden_states (bool): Whether to output the hidden states.
172
- cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
173
- cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
174
- max_length_q (int): The maximum length for the query state, if using THD inputs.
175
- max_length_k (int): The maximum length for the key state, if using THD inputs.
176
  """
177
  all_hidden_states: tuple[torch.Tensor, ...] = ()
 
 
 
 
 
 
 
 
 
178
 
179
  if self.config.attn_input_format == "thd":
180
- if any(x is None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
181
  raise ValueError(
182
  "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
183
  )
@@ -187,11 +201,10 @@ class NVEsmEncoder(nn.Module):
187
  hidden_states = hidden_states.squeeze(0)
188
  attention_mask = None
189
 
190
- elif self.config.attn_input_format == "bshd":
191
- if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
192
- raise ValueError(
193
- "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
194
- )
195
 
196
  # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context.
197
  with torch.autocast(device_type="cuda", enabled=False):
@@ -199,26 +212,33 @@ class NVEsmEncoder(nn.Module):
199
  if self.config.attn_input_format == "bshd":
200
  te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1])
201
  elif self.config.attn_input_format == "thd":
202
- te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1])
203
- te_rope_emb = te_rope_emb.to(hidden_states.device, dtype=hidden_states.dtype, non_blocking=True)
 
 
 
 
204
 
205
  for layer_module in self.layers:
206
- if output_hidden_states:
207
  all_hidden_states = (*all_hidden_states, hidden_states)
208
 
209
  hidden_states = layer_module(
210
  hidden_states,
211
  attention_mask,
212
  rotary_pos_emb=te_rope_emb,
213
- cu_seqlens_q=cu_seq_lens_q,
214
- cu_seqlens_kv=cu_seq_lens_k,
215
- max_seqlen_q=max_length_q,
216
- max_seqlen_kv=max_length_k,
 
 
 
217
  )
218
 
219
  hidden_states = self.emb_layer_norm_after(hidden_states)
220
 
221
- if output_hidden_states:
222
  all_hidden_states = (*all_hidden_states, hidden_states)
223
 
224
  return BaseModelOutput(
@@ -233,6 +253,7 @@ class NVEsmPreTrainedModel(PreTrainedModel):
233
  config_class = NVEsmConfig
234
  base_model_prefix = "esm"
235
  supports_gradient_checkpointing = False
 
236
  _no_split_modules = (
237
  "TransformerLayer",
238
  "EsmEmbeddings",
@@ -265,6 +286,11 @@ class NVEsmPreTrainedModel(PreTrainedModel):
265
  if module.layer_norm_bias is not None:
266
  module.layer_norm_bias.data.zero_()
267
 
 
 
 
 
 
268
 
269
  class NVEsmModel(NVEsmPreTrainedModel):
270
  """The ESM Encoder-only protein language model.
@@ -310,11 +336,7 @@ class NVEsmModel(NVEsmPreTrainedModel):
310
  attention_mask: Optional[torch.Tensor] = None,
311
  position_ids: Optional[torch.Tensor] = None,
312
  inputs_embeds: Optional[torch.Tensor] = None,
313
- output_hidden_states: Optional[bool] = None,
314
- cu_seq_lens_q: torch.IntTensor | None = None,
315
- cu_seq_lens_k: torch.IntTensor | None = None,
316
- max_length_q: int | None = None,
317
- max_length_k: int | None = None,
318
  ) -> BaseModelOutputWithPooling:
319
  """Forward pass of the NVEsmModel.
320
 
@@ -323,19 +345,11 @@ class NVEsmModel(NVEsmPreTrainedModel):
323
  attention_mask (torch.Tensor): The attention mask.
324
  position_ids (torch.Tensor): The position ids.
325
  inputs_embeds (torch.Tensor): The input embeddings.
326
- output_hidden_states (bool): Whether to output the hidden states.
327
- cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
328
- cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
329
- max_length_q (int): The maximum length for the query state, if using THD inputs.
330
- max_length_k (int): The maximum length for the key state, if using THD inputs.
331
 
332
  Returns:
333
  BaseModelOutputWithPooling: The output of the model.
334
  """
335
- output_hidden_states = (
336
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
337
- )
338
-
339
  if input_ids is not None and inputs_embeds is not None:
340
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
341
  elif input_ids is not None:
@@ -363,19 +377,12 @@ class NVEsmModel(NVEsmPreTrainedModel):
363
  input_ids=input_ids,
364
  attention_mask=attention_mask,
365
  inputs_embeds=inputs_embeds,
366
- cu_seq_lens_q=cu_seq_lens_q,
367
- cu_seq_lens_k=cu_seq_lens_k,
368
- max_length_q=max_length_q,
369
- max_length_k=max_length_k,
370
  )
371
  encoder_outputs = self.encoder(
372
  embedding_output,
373
  attention_mask=extended_attention_mask,
374
- output_hidden_states=output_hidden_states,
375
- cu_seq_lens_q=cu_seq_lens_q,
376
- cu_seq_lens_k=cu_seq_lens_k,
377
- max_length_q=max_length_q,
378
- max_length_k=max_length_k,
379
  )
380
  sequence_output = encoder_outputs[0]
381
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@@ -427,11 +434,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
427
  position_ids: Optional[torch.LongTensor] = None,
428
  inputs_embeds: Optional[torch.FloatTensor] = None,
429
  labels: Optional[torch.LongTensor] = None,
430
- output_hidden_states: Optional[bool] = None,
431
- cu_seq_lens_q: torch.IntTensor | None = None,
432
- cu_seq_lens_k: torch.IntTensor | None = None,
433
- max_length_q: int | None = None,
434
- max_length_k: int | None = None,
435
  ) -> MaskedLMOutput:
436
  """Forward pass of the NVEsmForMaskedLM.
437
 
@@ -441,11 +444,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
441
  position_ids (torch.LongTensor): The position ids.
442
  inputs_embeds (torch.FloatTensor): The input embeddings.
443
  labels (torch.LongTensor): The labels.
444
- output_hidden_states (bool): Whether to output the hidden states.
445
- cu_seq_lens_q (torch.IntTensor): The cumulative sequence lengths for the query state, if using THD inputs.
446
- cu_seq_lens_k (torch.IntTensor): The cumulative sequence lengths for the key state, if using THD inputs.
447
- max_length_q (int): The maximum length for the query state, if using THD inputs.
448
- max_length_k (int): The maximum length for the key state, if using THD inputs.
449
 
450
  Returns:
451
  MaskedLMOutput: The output of the model.
@@ -455,11 +454,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
455
  attention_mask=attention_mask,
456
  position_ids=position_ids,
457
  inputs_embeds=inputs_embeds,
458
- output_hidden_states=output_hidden_states,
459
- cu_seq_lens_q=cu_seq_lens_q,
460
- cu_seq_lens_k=cu_seq_lens_k,
461
- max_length_q=max_length_q,
462
- max_length_k=max_length_k,
463
  )
464
  sequence_output = outputs[0]
465
  prediction_scores = self.lm_head(sequence_output)
@@ -493,13 +488,18 @@ class NVEsmLMHead(nn.Module):
493
  config (NVEsmConfig): The configuration of the model.
494
  """
495
  super().__init__()
496
- self.dense = transformer_engine.pytorch.Linear(config.hidden_size, config.hidden_size)
 
 
 
 
497
 
498
  self.decoder = transformer_engine.pytorch.LayerNormLinear(
499
  config.hidden_size,
500
  config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
501
  bias=True,
502
  eps=config.layer_norm_eps,
 
503
  )
504
 
505
  def forward(self, features, **kwargs):
@@ -522,11 +522,16 @@ class NVEsmEmbeddings(nn.Module):
522
  """Initialize a NVEsmEmbeddings."""
523
  super().__init__()
524
  self.word_embeddings = nn.Embedding(
525
- config.padded_vocab_size, config.hidden_size, padding_idx=config.pad_token_id
 
 
 
526
  )
527
 
528
  self.layer_norm = (
529
- nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.emb_layer_norm_before else None
 
 
530
  )
531
 
532
  if config.position_embedding_type != "rotary":
@@ -544,10 +549,7 @@ class NVEsmEmbeddings(nn.Module):
544
  input_ids=None,
545
  attention_mask=None,
546
  inputs_embeds=None,
547
- cu_seq_lens_q: torch.IntTensor | None = None,
548
- cu_seq_lens_k: torch.IntTensor | None = None,
549
- max_length_q: int | None = None,
550
- max_length_k: int | None = None,
551
  ):
552
  """Forward pass of the NVEsmEmbeddings."""
553
  if inputs_embeds is None:
@@ -557,7 +559,12 @@ class NVEsmEmbeddings(nn.Module):
557
  # embedding_scale factor here.
558
  embeddings = inputs_embeds
559
 
560
- if all(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
 
 
 
 
 
561
  using_thd = True
562
  attention_mask = None
563
  else:
@@ -583,10 +590,12 @@ class NVEsmEmbeddings(nn.Module):
583
  embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
584
 
585
  else:
586
- src_lengths = torch.diff(cu_seq_lens_q)
587
  # We need to find the number of masked tokens in each sequence in the padded batch.
588
  is_masked = (input_ids == self.mask_token_id).squeeze(0)
589
- n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=cu_seq_lens_q).sum(1)
 
 
590
  mask_ratio_observed = n_masked_per_seq.float() / src_lengths
591
  scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
592
  reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0)
@@ -599,3 +608,65 @@ class NVEsmEmbeddings(nn.Module):
599
  embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
600
 
601
  return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  Adapted from `modeling_esm.py` in huggingface/transformers.
24
  """
25
 
26
+ from typing import Literal, Optional, Unpack
27
 
28
  # TODO: put import guard around transformer_engine here, with an informative error message around
29
  # installation and the nvidia docker container.
 
36
  BaseModelOutput,
37
  BaseModelOutputWithPooling,
38
  MaskedLMOutput,
39
+ TokenClassifierOutput,
40
  )
41
  from transformers.modeling_utils import PreTrainedModel
42
  from transformers.models.esm.configuration_esm import EsmConfig
43
  from transformers.models.esm.modeling_esm import EsmPooler
44
  from transformers.utils import logging
45
+ from transformers.utils.generic import TransformersKwargs
46
 
47
 
48
  logger = logging.get_logger(__name__)
49
 
50
+ # Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below.
51
+ # These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints.
52
+ AUTO_MAP = {
53
+ "AutoConfig": "esm_nv.NVEsmConfig",
54
+ "AutoModel": "esm_nv.NVEsmModel",
55
+ "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM",
56
+ "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification",
57
+ }
58
+
59
 
60
  class NVEsmConfig(EsmConfig):
61
  """NVEsmConfig is a configuration for the NVEsm model."""
 
160
  for i in range(config.num_hidden_layers)
161
  ]
162
  )
163
+ self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
164
+ config.hidden_size, eps=config.layer_norm_eps, params_dtype=config.dtype
165
+ )
166
  if config.position_embedding_type == "rotary":
167
  self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
168
 
 
170
  self,
171
  hidden_states: torch.Tensor,
172
  attention_mask: Optional[torch.Tensor] = None,
173
+ **kwargs: Unpack[TransformersKwargs],
 
 
 
 
174
  ):
175
  """Forward pass of the NVEsmEncoder.
176
 
177
  Args:
178
  hidden_states (torch.Tensor): The hidden states.
179
  attention_mask (torch.Tensor): The attention mask.
180
+ **kwargs: Additional arguments, see TransformersKwargs for more details.
 
 
 
 
181
  """
182
  all_hidden_states: tuple[torch.Tensor, ...] = ()
183
+ has_thd_input = [
184
+ x is not None
185
+ for x in [
186
+ kwargs.get("cu_seq_lens_q", None),
187
+ kwargs.get("cu_seq_lens_k", None),
188
+ kwargs.get("max_length_q", None),
189
+ kwargs.get("max_length_k", None),
190
+ ]
191
+ ]
192
 
193
  if self.config.attn_input_format == "thd":
194
+ if not all(has_thd_input):
195
  raise ValueError(
196
  "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
197
  )
 
201
  hidden_states = hidden_states.squeeze(0)
202
  attention_mask = None
203
 
204
+ elif self.config.attn_input_format == "bshd" and any(has_thd_input):
205
+ raise ValueError(
206
+ "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
207
+ )
 
208
 
209
  # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context.
210
  with torch.autocast(device_type="cuda", enabled=False):
 
212
  if self.config.attn_input_format == "bshd":
213
  te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1])
214
  elif self.config.attn_input_format == "thd":
215
+ te_rope_emb = self.rotary_embeddings(
216
+ max_seq_len=kwargs["cu_seq_lens_q_padded"][-1]
217
+ if "cu_seq_lens_q_padded" in kwargs
218
+ else kwargs["cu_seq_lens_q"][-1]
219
+ )
220
+ te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
221
 
222
  for layer_module in self.layers:
223
+ if kwargs.get("output_hidden_states", False):
224
  all_hidden_states = (*all_hidden_states, hidden_states)
225
 
226
  hidden_states = layer_module(
227
  hidden_states,
228
  attention_mask,
229
  rotary_pos_emb=te_rope_emb,
230
+ cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
231
+ cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
232
+ cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
233
+ cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
234
+ max_seqlen_q=kwargs.get("max_length_q", None),
235
+ max_seqlen_kv=kwargs.get("max_length_k", None),
236
+ pad_between_seqs=kwargs.get("pad_between_seqs", None),
237
  )
238
 
239
  hidden_states = self.emb_layer_norm_after(hidden_states)
240
 
241
+ if kwargs.get("output_hidden_states", False):
242
  all_hidden_states = (*all_hidden_states, hidden_states)
243
 
244
  return BaseModelOutput(
 
253
  config_class = NVEsmConfig
254
  base_model_prefix = "esm"
255
  supports_gradient_checkpointing = False
256
+ accepts_loss_kwargs = False
257
  _no_split_modules = (
258
  "TransformerLayer",
259
  "EsmEmbeddings",
 
286
  if module.layer_norm_bias is not None:
287
  module.layer_norm_bias.data.zero_()
288
 
289
+ @classmethod
290
+ def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
291
+ """Override the default get_init_context method to allow for fp8 model initialization."""
292
+ return []
293
+
294
 
295
  class NVEsmModel(NVEsmPreTrainedModel):
296
  """The ESM Encoder-only protein language model.
 
336
  attention_mask: Optional[torch.Tensor] = None,
337
  position_ids: Optional[torch.Tensor] = None,
338
  inputs_embeds: Optional[torch.Tensor] = None,
339
+ **kwargs: Unpack[TransformersKwargs],
 
 
 
 
340
  ) -> BaseModelOutputWithPooling:
341
  """Forward pass of the NVEsmModel.
342
 
 
345
  attention_mask (torch.Tensor): The attention mask.
346
  position_ids (torch.Tensor): The position ids.
347
  inputs_embeds (torch.Tensor): The input embeddings.
348
+ **kwargs: Additional arguments, see TransformersKwargs for more details.
 
 
 
 
349
 
350
  Returns:
351
  BaseModelOutputWithPooling: The output of the model.
352
  """
 
 
 
 
353
  if input_ids is not None and inputs_embeds is not None:
354
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
355
  elif input_ids is not None:
 
377
  input_ids=input_ids,
378
  attention_mask=attention_mask,
379
  inputs_embeds=inputs_embeds,
380
+ **kwargs,
 
 
 
381
  )
382
  encoder_outputs = self.encoder(
383
  embedding_output,
384
  attention_mask=extended_attention_mask,
385
+ **kwargs,
 
 
 
 
386
  )
387
  sequence_output = encoder_outputs[0]
388
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
 
434
  position_ids: Optional[torch.LongTensor] = None,
435
  inputs_embeds: Optional[torch.FloatTensor] = None,
436
  labels: Optional[torch.LongTensor] = None,
437
+ **kwargs: Unpack[TransformersKwargs],
 
 
 
 
438
  ) -> MaskedLMOutput:
439
  """Forward pass of the NVEsmForMaskedLM.
440
 
 
444
  position_ids (torch.LongTensor): The position ids.
445
  inputs_embeds (torch.FloatTensor): The input embeddings.
446
  labels (torch.LongTensor): The labels.
447
+ **kwargs: Additional arguments, see TransformersKwargs for more details.
 
 
 
 
448
 
449
  Returns:
450
  MaskedLMOutput: The output of the model.
 
454
  attention_mask=attention_mask,
455
  position_ids=position_ids,
456
  inputs_embeds=inputs_embeds,
457
+ **kwargs,
 
 
 
 
458
  )
459
  sequence_output = outputs[0]
460
  prediction_scores = self.lm_head(sequence_output)
 
488
  config (NVEsmConfig): The configuration of the model.
489
  """
490
  super().__init__()
491
+ self.dense = transformer_engine.pytorch.Linear(
492
+ config.hidden_size,
493
+ config.hidden_size,
494
+ params_dtype=config.dtype,
495
+ )
496
 
497
  self.decoder = transformer_engine.pytorch.LayerNormLinear(
498
  config.hidden_size,
499
  config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
500
  bias=True,
501
  eps=config.layer_norm_eps,
502
+ params_dtype=config.dtype,
503
  )
504
 
505
  def forward(self, features, **kwargs):
 
522
  """Initialize a NVEsmEmbeddings."""
523
  super().__init__()
524
  self.word_embeddings = nn.Embedding(
525
+ config.padded_vocab_size,
526
+ config.hidden_size,
527
+ padding_idx=config.pad_token_id,
528
+ dtype=config.dtype,
529
  )
530
 
531
  self.layer_norm = (
532
+ transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
533
+ if config.emb_layer_norm_before
534
+ else None
535
  )
536
 
537
  if config.position_embedding_type != "rotary":
 
549
  input_ids=None,
550
  attention_mask=None,
551
  inputs_embeds=None,
552
+ **kwargs: Unpack[TransformersKwargs],
 
 
 
553
  ):
554
  """Forward pass of the NVEsmEmbeddings."""
555
  if inputs_embeds is None:
 
559
  # embedding_scale factor here.
560
  embeddings = inputs_embeds
561
 
562
+ if (
563
+ kwargs.get("cu_seq_lens_q") is not None
564
+ and kwargs.get("cu_seq_lens_k") is not None
565
+ and kwargs.get("max_length_q") is not None
566
+ and kwargs.get("max_length_k") is not None
567
+ ):
568
  using_thd = True
569
  attention_mask = None
570
  else:
 
590
  embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
591
 
592
  else:
593
+ src_lengths = torch.diff(kwargs["cu_seq_lens_q"])
594
  # We need to find the number of masked tokens in each sequence in the padded batch.
595
  is_masked = (input_ids == self.mask_token_id).squeeze(0)
596
+ n_masked_per_seq = torch.nested.nested_tensor_from_jagged(
597
+ is_masked, offsets=kwargs["cu_seq_lens_q"]
598
+ ).sum(1)
599
  mask_ratio_observed = n_masked_per_seq.float() / src_lengths
600
  scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
601
  reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0)
 
608
  embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
609
 
610
  return embeddings
611
+
612
+
613
+ class NVEsmForTokenClassification(NVEsmPreTrainedModel):
614
+ """Adds a token classification head to the model.
615
+
616
+ Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`.
617
+ """
618
+
619
+ def __init__(self, config):
620
+ """Initialize NVEsmForTokenClassification."""
621
+ super().__init__(config)
622
+ self.num_labels = config.num_labels
623
+
624
+ self.esm = NVEsmModel(config, add_pooling_layer=False)
625
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
626
+ self.classifier = transformer_engine.pytorch.Linear(
627
+ config.hidden_size, config.num_labels, params_dtype=config.dtype
628
+ )
629
+
630
+ self.init_weights()
631
+ self.post_init()
632
+
633
+ def forward(
634
+ self,
635
+ input_ids: Optional[torch.LongTensor] = None,
636
+ attention_mask: Optional[torch.Tensor] = None,
637
+ position_ids: Optional[torch.LongTensor] = None,
638
+ inputs_embeds: Optional[torch.FloatTensor] = None,
639
+ labels: Optional[torch.LongTensor] = None,
640
+ **kwargs: Unpack[TransformersKwargs],
641
+ ) -> TokenClassifierOutput:
642
+ """Forward pass for the token classification head.
643
+
644
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
645
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
646
+ """
647
+ outputs = self.esm(
648
+ input_ids,
649
+ attention_mask=attention_mask,
650
+ position_ids=position_ids,
651
+ inputs_embeds=inputs_embeds,
652
+ **kwargs,
653
+ )
654
+
655
+ sequence_output = outputs[0]
656
+
657
+ sequence_output = self.dropout(sequence_output)
658
+ logits = self.classifier(sequence_output)
659
+
660
+ loss = None
661
+ if labels is not None:
662
+ loss_fct = CrossEntropyLoss()
663
+
664
+ labels = labels.to(logits.device)
665
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
666
+
667
+ return TokenClassifierOutput(
668
+ loss=loss,
669
+ logits=logits,
670
+ hidden_states=outputs.hidden_states,
671
+ attentions=outputs.attentions,
672
+ )
special_tokens_map.json CHANGED
@@ -1,7 +1,44 @@
1
  {
2
- "cls_token": "<cls>",
3
- "eos_token": "<eos>",
4
- "mask_token": "<mask>",
5
- "pad_token": "<pad>",
6
- "unk_token": "<unk>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  }
 
1
  {
2
+ "bos_token": {
3
+ "content": "<cls>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<cls>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "<eos>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "unk_token": {
38
+ "content": "<unk>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ }
44
  }
tokenizer.json ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<cls>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<pad>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<eos>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "<unk>",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 32,
44
+ "content": "<mask>",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ }
51
+ ],
52
+ "normalizer": null,
53
+ "pre_tokenizer": {
54
+ "type": "Split",
55
+ "pattern": {
56
+ "String": ""
57
+ },
58
+ "behavior": "Isolated",
59
+ "invert": false
60
+ },
61
+ "post_processor": {
62
+ "type": "TemplateProcessing",
63
+ "single": [
64
+ {
65
+ "SpecialToken": {
66
+ "id": "<cls>",
67
+ "type_id": 0
68
+ }
69
+ },
70
+ {
71
+ "Sequence": {
72
+ "id": "A",
73
+ "type_id": 0
74
+ }
75
+ },
76
+ {
77
+ "SpecialToken": {
78
+ "id": "<eos>",
79
+ "type_id": 0
80
+ }
81
+ }
82
+ ],
83
+ "pair": [
84
+ {
85
+ "SpecialToken": {
86
+ "id": "<cls>",
87
+ "type_id": 0
88
+ }
89
+ },
90
+ {
91
+ "Sequence": {
92
+ "id": "A",
93
+ "type_id": 0
94
+ }
95
+ },
96
+ {
97
+ "SpecialToken": {
98
+ "id": "<eos>",
99
+ "type_id": 0
100
+ }
101
+ },
102
+ {
103
+ "Sequence": {
104
+ "id": "B",
105
+ "type_id": 1
106
+ }
107
+ },
108
+ {
109
+ "SpecialToken": {
110
+ "id": "<eos>",
111
+ "type_id": 1
112
+ }
113
+ }
114
+ ],
115
+ "special_tokens": {
116
+ "<cls>": {
117
+ "id": "<cls>",
118
+ "ids": [
119
+ 0
120
+ ],
121
+ "tokens": [
122
+ "<cls>"
123
+ ]
124
+ },
125
+ "<eos>": {
126
+ "id": "<eos>",
127
+ "ids": [
128
+ 2
129
+ ],
130
+ "tokens": [
131
+ "<eos>"
132
+ ]
133
+ }
134
+ }
135
+ },
136
+ "decoder": null,
137
+ "model": {
138
+ "type": "WordLevel",
139
+ "vocab": {
140
+ "<cls>": 0,
141
+ "<pad>": 1,
142
+ "<eos>": 2,
143
+ "<unk>": 3,
144
+ "L": 4,
145
+ "A": 5,
146
+ "G": 6,
147
+ "V": 7,
148
+ "S": 8,
149
+ "E": 9,
150
+ "R": 10,
151
+ "T": 11,
152
+ "I": 12,
153
+ "D": 13,
154
+ "P": 14,
155
+ "K": 15,
156
+ "Q": 16,
157
+ "N": 17,
158
+ "F": 18,
159
+ "Y": 19,
160
+ "M": 20,
161
+ "H": 21,
162
+ "W": 22,
163
+ "C": 23,
164
+ "X": 24,
165
+ "B": 25,
166
+ "U": 26,
167
+ "Z": 27,
168
+ "O": 28,
169
+ ".": 29,
170
+ "-": 30,
171
+ "<null_1>": 31,
172
+ "<mask>": 32
173
+ },
174
+ "unk_token": "<unk>"
175
+ }
176
+ }
tokenizer_config.json CHANGED
@@ -1,4 +1,6 @@
1
  {
 
 
2
  "added_tokens_decoder": {
3
  "0": {
4
  "content": "<cls>",
@@ -41,13 +43,18 @@
41
  "special": true
42
  }
43
  },
 
44
  "clean_up_tokenization_spaces": false,
45
  "cls_token": "<cls>",
46
  "eos_token": "<eos>",
47
  "extra_special_tokens": {},
48
  "mask_token": "<mask>",
 
 
 
 
49
  "model_max_length": 1000000000000000019884624838656,
50
  "pad_token": "<pad>",
51
- "tokenizer_class": "EsmTokenizer",
52
  "unk_token": "<unk>"
53
  }
 
1
  {
2
+ "add_bos_token": true,
3
+ "add_eos_token": true,
4
  "added_tokens_decoder": {
5
  "0": {
6
  "content": "<cls>",
 
43
  "special": true
44
  }
45
  },
46
+ "bos_token": "<cls>",
47
  "clean_up_tokenization_spaces": false,
48
  "cls_token": "<cls>",
49
  "eos_token": "<eos>",
50
  "extra_special_tokens": {},
51
  "mask_token": "<mask>",
52
+ "model_input_names": [
53
+ "input_ids",
54
+ "attention_mask"
55
+ ],
56
  "model_max_length": 1000000000000000019884624838656,
57
  "pad_token": "<pad>",
58
+ "tokenizer_class": "PreTrainedTokenizerFast",
59
  "unk_token": "<unk>"
60
  }