MeganEFlynn commited on
Commit
a076cd6
·
verified ·
1 Parent(s): aa22b6d

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +57 -0
  2. eagle3.py +466 -0
  3. generation_config.json +4 -0
  4. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Eagle3DraftModel"
4
+ ],
5
+ "auto_map": {
6
+ "": "eagle3.Eagle3SpeculatorConfig"
7
+ },
8
+ "base_model_ep_plan": null,
9
+ "draft_vocab_size": 64000,
10
+ "dtype": "float32",
11
+ "eagle_aux_hidden_state_layer_ids": null,
12
+ "has_no_defaults_at_init": false,
13
+ "norm_before_residual": true,
14
+ "speculators_config": {
15
+ "algorithm": "eagle3",
16
+ "default_proposal_method": "greedy",
17
+ "proposal_methods": [
18
+ {
19
+ "accept_tolerance": 0.0,
20
+ "proposal_type": "greedy",
21
+ "speculative_tokens": 3,
22
+ "verifier_accept_k": 1
23
+ }
24
+ ],
25
+ "verifier": {
26
+ "architectures": [
27
+ "LlamaForCausalLM"
28
+ ],
29
+ "name_or_path": "Qwen/Qwen3-235B-A22B-Instruct-2507"
30
+ }
31
+ },
32
+ "speculators_model_type": "eagle3",
33
+ "speculators_version": "0.3.0.dev13",
34
+ "target_hidden_size": null,
35
+ "transformer_layer_config": {
36
+ "attention_bias": false,
37
+ "attention_dropout": 0.0,
38
+ "head_dim": 128,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 4096,
41
+ "initializer_range": 0.02,
42
+ "intermediate_size": 12288,
43
+ "max_position_embeddings": 262144,
44
+ "mlp_bias": false,
45
+ "model_type": "llama",
46
+ "num_attention_heads": 64,
47
+ "num_hidden_layers": 1,
48
+ "num_key_value_heads": 4,
49
+ "pretraining_tp": 1,
50
+ "rms_norm_eps": 1e-06,
51
+ "rope_scaling": null,
52
+ "rope_theta": 10000.0,
53
+ "use_cache": true,
54
+ "vocab_size": 151936
55
+ },
56
+ "transformers_version": "4.57.1"
57
+ }
eagle3.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculators implementation of EAGLE-3:
3
+ - https://arxiv.org/abs/2503.01840
4
+
5
+ Classes:
6
+ Eagle3SpeculatorConfig: Configuration class for EAGLE-3 speculator model
7
+ EagleSpeculator3: Main model implementation for EAGLE-3 speculators
8
+ Eagle3Attention: Custom attention layer for EAGLE-3, processes
9
+ concatenated embeddings and hidden states
10
+ Eagle3DecoderLayer: Custom decoder layer for EAGLE-3, processes
11
+ concatenated embeddings and hidden states with Eagle3Attention
12
+ and support for moving hidden layernorm before residual
13
+ """
14
+
15
+ import os
16
+ from typing import Any, ClassVar, Literal
17
+
18
+ import torch
19
+ from pydantic import Field, field_serializer, field_validator
20
+ from torch import nn
21
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
22
+ from transformers.models.llama.configuration_llama import LlamaConfig
23
+ from transformers.models.llama.modeling_llama import (
24
+ LlamaMLP,
25
+ LlamaRMSNorm,
26
+ apply_rotary_pos_emb,
27
+ repeat_kv,
28
+ )
29
+
30
+ from speculators import SpeculatorModel, SpeculatorModelConfig
31
+
32
+ __all__ = [
33
+ "Eagle3Attention",
34
+ "Eagle3DecoderLayer",
35
+ "Eagle3Speculator",
36
+ "Eagle3SpeculatorConfig",
37
+ ]
38
+
39
+
40
+ @SpeculatorModelConfig.register("eagle3")
41
+ class Eagle3SpeculatorConfig(SpeculatorModelConfig):
42
+ """
43
+ Configuration for EAGLE-3 speculator with vocabulary mapping.
44
+
45
+ EAGLE-3 features vocabulary mapping between draft (32K) and target (128K)
46
+ vocabularies, enabling cross-tokenizer speculation.
47
+
48
+ :param transformer_layer_config: Configuration for the transformer decoder layer
49
+ :param draft_vocab_size: Size of draft model vocabulary for speculation
50
+ :param norm_before_residual: Apply hidden_norm before storing residual
51
+ """
52
+
53
+ speculators_model_type: Literal["eagle3"] = "eagle3"
54
+ architectures: list[str] = Field(
55
+ default_factory=lambda: ["Eagle3Speculator"],
56
+ description="Model architectures that can load these weights",
57
+ )
58
+
59
+ transformer_layer_config: PretrainedConfig = Field(
60
+ default_factory=LlamaConfig,
61
+ description="Configuration for the transformer decoder layer",
62
+ )
63
+
64
+ draft_vocab_size: int = Field(
65
+ default=32000,
66
+ description="Size of draft model vocabulary for speculation",
67
+ )
68
+
69
+ norm_before_residual: bool = Field(
70
+ default=False,
71
+ description="Apply hidden_norm before storing residual",
72
+ )
73
+
74
+ target_hidden_size: int | None = Field(
75
+ default=None,
76
+ description="Hidden size of the target model (if different from draft model)",
77
+ )
78
+
79
+ eagle_aux_hidden_state_layer_ids: list[int] | None = Field(
80
+ default=None,
81
+ description="Layer IDs of the Eagle auxiliary hidden state layers",
82
+ )
83
+
84
+ @property
85
+ def target_vocab_size(self) -> int:
86
+ """Get target vocabulary size from transformer config."""
87
+ return self.transformer_layer_config.vocab_size
88
+
89
+ @field_serializer("transformer_layer_config")
90
+ def serialize_transformer_config(self, value: PretrainedConfig) -> dict:
91
+ """Serialize transformer config to dict."""
92
+ return value.to_diff_dict()
93
+
94
+ @field_validator("transformer_layer_config", mode="before")
95
+ @classmethod
96
+ def validate_transformer_config(cls, value: Any) -> PretrainedConfig:
97
+ """Validate and convert transformer config."""
98
+ if isinstance(value, dict):
99
+ config_class: type[PretrainedConfig] = LlamaConfig
100
+ if "model_type" in value:
101
+ config_class = AutoConfig.for_model(
102
+ model_type=value["model_type"]
103
+ ).__class__
104
+ return config_class(**value)
105
+ return value
106
+
107
+
108
+ class Eagle3Attention(nn.Module):
109
+ """
110
+ Eagle-3 attention module that processes concatenated embeddings and hidden states.
111
+
112
+ Modified from standard Llama attention to accept 2x hidden_size input
113
+ for Q/K/V projections while maintaining standard output size.
114
+ """
115
+
116
+ def __init__(self, config: PretrainedConfig, layer_idx: int):
117
+ super().__init__()
118
+ self.config = config
119
+ self.layer_idx = layer_idx
120
+
121
+ self.num_heads = config.num_attention_heads
122
+ self.num_key_value_heads = config.num_key_value_heads
123
+ self.hidden_size = config.hidden_size
124
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
125
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
126
+
127
+ input_size = 2 * self.hidden_size
128
+ self.q_proj = nn.Linear(
129
+ input_size, self.num_heads * self.head_dim, bias=config.attention_bias
130
+ )
131
+ self.k_proj = nn.Linear(
132
+ input_size,
133
+ self.num_key_value_heads * self.head_dim,
134
+ bias=config.attention_bias,
135
+ )
136
+ self.v_proj = nn.Linear(
137
+ input_size,
138
+ self.num_key_value_heads * self.head_dim,
139
+ bias=config.attention_bias,
140
+ )
141
+ self.o_proj = nn.Linear(
142
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
143
+ )
144
+
145
+ def forward(
146
+ self,
147
+ hidden_states: torch.Tensor,
148
+ attention_mask: torch.Tensor | None = None,
149
+ position_ids: torch.LongTensor | None = None,
150
+ past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
151
+ output_attentions: bool = False,
152
+ use_cache: bool = False,
153
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
154
+ **kwargs, # noqa: ARG002
155
+ ) -> tuple:
156
+ """
157
+ Forward pass for Eagle-3 attention.
158
+ Taken from Llama Attention but modified to accept 2x hidden_size input.
159
+
160
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
161
+ :param attention_mask: Optional attention mask
162
+ :param position_ids: Optional position IDs for rotary embeddings
163
+ :param past_key_value: Optional cached key-value pairs
164
+ :param output_attentions: Whether to return attention weights
165
+ :param use_cache: Whether to cache key-value pairs
166
+ :param position_embeddings: Optional precomputed rotary embeddings
167
+ :return: Tuple of (hidden_states, [attention_weights], [past_key_value])
168
+ """
169
+ bsz, q_len, _ = hidden_states.size()
170
+
171
+ query_states = self.q_proj(hidden_states)
172
+ key_states = self.k_proj(hidden_states)
173
+ value_states = self.v_proj(hidden_states)
174
+
175
+ query_states = query_states.view(
176
+ bsz, q_len, self.num_heads, self.head_dim
177
+ ).transpose(1, 2)
178
+ key_states = key_states.view(
179
+ bsz, q_len, self.num_key_value_heads, self.head_dim
180
+ ).transpose(1, 2)
181
+ value_states = value_states.view(
182
+ bsz, q_len, self.num_key_value_heads, self.head_dim
183
+ ).transpose(1, 2)
184
+
185
+ if position_embeddings is not None:
186
+ cos, sin = position_embeddings
187
+ query_states, key_states = apply_rotary_pos_emb(
188
+ query_states, key_states, cos, sin, position_ids
189
+ )
190
+
191
+ past_key_value_out = None
192
+ if past_key_value is not None:
193
+ past_key = past_key_value[0]
194
+ past_value = past_key_value[1]
195
+ key_states = torch.cat([past_key, key_states], dim=2)
196
+ value_states = torch.cat([past_value, value_states], dim=2)
197
+
198
+ if use_cache:
199
+ past_key_value_out = (key_states, value_states)
200
+
201
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
202
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
203
+
204
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / (
205
+ self.head_dim**0.5
206
+ )
207
+
208
+ if attention_mask is not None:
209
+ attn_weights = attn_weights + attention_mask
210
+
211
+ attn_weights = nn.functional.softmax(
212
+ attn_weights, dim=-1, dtype=torch.float32
213
+ ).to(query_states.dtype)
214
+
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+ attn_output = attn_output.transpose(1, 2).contiguous()
217
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
218
+
219
+ attn_output = self.o_proj(attn_output)
220
+
221
+ if not output_attentions:
222
+ attn_weights = None
223
+
224
+ return attn_output, attn_weights, past_key_value_out
225
+
226
+
227
+ class Eagle3DecoderLayer(nn.Module):
228
+ """
229
+ Eagle-3 decoder layer that processes concatenated embeddings and hidden states.
230
+
231
+ Accepts 2x hidden_size input from concatenated embeddings and fused hidden states.
232
+ Uses Eagle3Attention for the self-attention computation.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ config: PretrainedConfig,
238
+ layer_idx: int,
239
+ norm_before_residual: bool = False,
240
+ ):
241
+ super().__init__()
242
+ self.hidden_size = config.hidden_size
243
+ self.norm_before_residual = norm_before_residual
244
+
245
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+ self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
247
+ self.post_attention_layernorm = LlamaRMSNorm(
248
+ config.hidden_size, eps=config.rms_norm_eps
249
+ )
250
+
251
+ self.self_attn = Eagle3Attention(config, layer_idx)
252
+
253
+ self.mlp = LlamaMLP(config)
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.Tensor,
258
+ attention_mask: torch.Tensor | None = None,
259
+ position_ids: torch.LongTensor | None = None,
260
+ past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None,
261
+ output_attentions: bool | None = False,
262
+ use_cache: bool | None = False,
263
+ cache_position: torch.LongTensor | None = None, # noqa: ARG002
264
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
265
+ **kwargs, # noqa: ARG002
266
+ ) -> tuple:
267
+ """
268
+ Process concatenated embeddings and hidden states through modified decoder
269
+ layer.
270
+
271
+ :param hidden_states: Input tensor of shape [batch, seq_len, 2*hidden_size]
272
+ :return: Tuple of layer outputs
273
+ """
274
+ embeds = hidden_states[:, :, : self.hidden_size]
275
+ hidden = hidden_states[:, :, self.hidden_size : 2 * self.hidden_size]
276
+
277
+ if self.norm_before_residual:
278
+ hidden = self.hidden_norm(hidden)
279
+ residual = hidden
280
+ else:
281
+ residual = hidden
282
+ hidden = self.hidden_norm(hidden)
283
+
284
+ embeds = self.input_layernorm(embeds)
285
+
286
+ attn_input = torch.cat([embeds, hidden], dim=-1)
287
+
288
+ attn_output, attn_weights, past_key_value_out = self.self_attn(
289
+ hidden_states=attn_input,
290
+ attention_mask=attention_mask,
291
+ position_ids=position_ids,
292
+ past_key_value=past_key_value,
293
+ output_attentions=output_attentions,
294
+ use_cache=use_cache,
295
+ position_embeddings=position_embeddings,
296
+ )
297
+
298
+ hidden_states = residual + attn_output
299
+
300
+ residual = hidden_states
301
+ hidden_states = self.post_attention_layernorm(hidden_states)
302
+ hidden_states = self.mlp(hidden_states)
303
+ hidden_states = residual + hidden_states
304
+
305
+ outputs = (hidden_states,)
306
+
307
+ if output_attentions:
308
+ outputs += (attn_weights,) # type: ignore[assignment]
309
+
310
+ if use_cache:
311
+ outputs += (past_key_value_out,) # type: ignore[assignment]
312
+
313
+ return outputs
314
+
315
+
316
+ @SpeculatorModel.register("eagle3")
317
+ class Eagle3Speculator(SpeculatorModel):
318
+ """
319
+ EAGLE-3 speculator with vocabulary mapping and multi-layer fusion.
320
+
321
+ EAGLE-3 processes concatenated hidden states from multiple verifier layers
322
+ through a fusion layer, then combines with embeddings for a custom decoder
323
+ layer that accepts 2x hidden_size input.
324
+ """
325
+
326
+ config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc]
327
+ _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
328
+ "verifier*",
329
+ ]
330
+ _keys_to_ignore_on_save: ClassVar[list[str]] = [] # type: ignore[misc,assignment]
331
+
332
+ def __init__(
333
+ self,
334
+ config: Eagle3SpeculatorConfig,
335
+ verifier: str | os.PathLike | PreTrainedModel | None = None,
336
+ verifier_attachment_mode: Literal["detached", "full", "train_only"]
337
+ | None = None,
338
+ reduce_vocab_size: bool = True,
339
+ has_drafter_embedding: bool = True,
340
+ ):
341
+ """
342
+ Initialize Eagle3 speculator.
343
+
344
+ :param config: Eagle3SpeculatorConfig instance
345
+ :param verifier: Optional verifier model
346
+ :param verifier_attachment_mode: How to attach the verifier
347
+ :param reduce_vocab_size: Whether to reduce vocabulary size with mapping
348
+ :param has_drafter_embedding: Whether drafter embedding weights are provided
349
+ """
350
+ if not isinstance(config, Eagle3SpeculatorConfig):
351
+ raise ValueError(
352
+ f"config must be Eagle3SpeculatorConfig, got {type(config)}"
353
+ )
354
+
355
+ self.config: Eagle3SpeculatorConfig = config
356
+
357
+ self.hidden_size = config.transformer_layer_config.hidden_size
358
+ self.draft_vocab_size = config.draft_vocab_size
359
+ self.target_vocab_size = config.target_vocab_size
360
+
361
+ # Use target_hidden_size if specified, otherwise use draft model's hidden_size
362
+ self.target_hidden_size = (
363
+ config.target_hidden_size
364
+ if config.target_hidden_size is not None
365
+ else self.hidden_size
366
+ )
367
+
368
+ super().__init__(
369
+ config=config,
370
+ verifier=verifier,
371
+ verifier_attachment_mode=verifier_attachment_mode,
372
+ )
373
+
374
+ if has_drafter_embedding:
375
+ self.embed_tokens = nn.Embedding(
376
+ self.target_vocab_size,
377
+ self.hidden_size,
378
+ padding_idx=config.transformer_layer_config.pad_token_id
379
+ if hasattr(config.transformer_layer_config, "pad_token_id")
380
+ else None,
381
+ )
382
+
383
+ self.fc = nn.Linear(
384
+ 3 * self.target_hidden_size, # Use target model's hidden size
385
+ self.hidden_size,
386
+ bias=False,
387
+ )
388
+
389
+ self.layers = nn.ModuleList(
390
+ [
391
+ Eagle3DecoderLayer(
392
+ config.transformer_layer_config,
393
+ layer_idx=0,
394
+ norm_before_residual=config.norm_before_residual,
395
+ )
396
+ ]
397
+ )
398
+
399
+ self.norm = LlamaRMSNorm(
400
+ self.hidden_size,
401
+ eps=config.transformer_layer_config.rms_norm_eps,
402
+ )
403
+
404
+ self.lm_head = nn.Linear(
405
+ self.hidden_size,
406
+ self.draft_vocab_size,
407
+ bias=False,
408
+ )
409
+ if reduce_vocab_size:
410
+ self.register_buffer( # type: ignore[attr-defined]
411
+ "d2t",
412
+ torch.zeros(self.draft_vocab_size, dtype=torch.long),
413
+ )
414
+ self.register_buffer( # type: ignore[attr-defined]
415
+ "t2d",
416
+ torch.zeros(self.target_vocab_size, dtype=torch.bool),
417
+ )
418
+
419
+ # Type hints for buffers
420
+ self.d2t: torch.Tensor
421
+ self.t2d: torch.Tensor
422
+ self.post_init() # type: ignore[attr-defined]
423
+
424
+ def tie_weights(self):
425
+ """
426
+ Override tie_weights to prevent vocabulary corruption in transformers 4.54.1+
427
+
428
+ Eagle3 intentionally uses different vocabulary sizes:
429
+ - Input embeddings (embed_tokens): 128256 (full vocabulary)
430
+ - Output embeddings (lm_head): 32000 (draft vocabulary)
431
+
432
+ The default tie_weights() tries to make them identical, breaking Eagle3.
433
+ This override preserves the intentional vocabulary size difference.
434
+ """
435
+ # Don't call super().tie_weights() - this prevents vocabulary corruption
436
+ # that occurs when _tie_or_clone_weights replaces lm_head.weight with
437
+ # embed_tokens.weight
438
+
439
+ def forward(
440
+ self,
441
+ input_ids: torch.LongTensor,
442
+ hidden_states: torch.FloatTensor,
443
+ attention_mask: torch.Tensor | None = None,
444
+ position_ids: torch.LongTensor | None = None,
445
+ past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
446
+ use_cache: bool | None = None,
447
+ output_attentions: bool | None = None,
448
+ output_hidden_states: bool | None = None, # noqa: ARG002
449
+ return_dict: bool | None = None,
450
+ ) -> torch.FloatTensor:
451
+ """
452
+ Forward pass for EAGLE-3 speculation.
453
+
454
+ :param input_ids: Input token IDs from draft vocabulary
455
+ :param hidden_states: Concatenated hidden states from 3 verifier layers
456
+ [B, L, 3*target_H] where target_H is the target model's hidden size
457
+ :param attention_mask: Optional attention mask
458
+ :param position_ids: Optional position IDs
459
+ :param past_key_values: Optional cached key-values
460
+ :param use_cache: Whether to cache key-values
461
+ :param output_attentions: Return attention weights
462
+ :param output_hidden_states: Return hidden states
463
+ :param return_dict: Return dict output
464
+ :return: Model outputs with draft vocabulary logits
465
+ """
466
+ raise NotImplementedError("Eagle3Speculator.forward is not implemented yet.")
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.57.1"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2234959f4ebf015a7a79c2706e9483676748e7eee74752a66195554de65e4c7
3
+ size 2390403048