Chengyue Wu
commited on
Commit
·
da56081
1
Parent(s):
0417cd5
update
Browse files- README.md +0 -1
- modeling.py +1 -3
README.md
CHANGED
|
@@ -41,7 +41,6 @@ Our approach introduces a novel decoding recipe incorporating a complementary at
|
|
| 41 |
- **Params**: 1.54B (non-embedding: 1.31B)
|
| 42 |
- **Layers**: 28
|
| 43 |
- **Attention Heads**: 12 (Q), 2 (KV, GQA)
|
| 44 |
-
- **Context Window**: 32,768 tokens (generation length: 8,192)
|
| 45 |
- **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching**
|
| 46 |
|
| 47 |
---
|
|
|
|
| 41 |
- **Params**: 1.54B (non-embedding: 1.31B)
|
| 42 |
- **Layers**: 28
|
| 43 |
- **Attention Heads**: 12 (Q), 2 (KV, GQA)
|
|
|
|
| 44 |
- **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching**
|
| 45 |
|
| 46 |
---
|
modeling.py
CHANGED
|
@@ -581,7 +581,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 581 |
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 582 |
|
| 583 |
x_t = x_init.clone()
|
| 584 |
-
step = 0
|
| 585 |
block_past_key_values = None
|
| 586 |
while True:
|
| 587 |
if stop_token in x_t[:, prompt_length:]:
|
|
@@ -612,7 +611,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 612 |
break
|
| 613 |
|
| 614 |
if use_block_cache:
|
| 615 |
-
if
|
| 616 |
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
|
| 617 |
logits, block_past_key_values = output.logits, output.block_past_key_values
|
| 618 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
|
@@ -638,7 +637,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
|
|
| 638 |
|
| 639 |
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 640 |
|
| 641 |
-
step += 1
|
| 642 |
input_ids = x_t
|
| 643 |
# Truncate stop_token
|
| 644 |
if stop_token in input_ids[:, original_input_length:]:
|
|
|
|
| 581 |
x_init = torch.cat([input_ids, x_init], dim=1)
|
| 582 |
|
| 583 |
x_t = x_init.clone()
|
|
|
|
| 584 |
block_past_key_values = None
|
| 585 |
while True:
|
| 586 |
if stop_token in x_t[:, prompt_length:]:
|
|
|
|
| 611 |
break
|
| 612 |
|
| 613 |
if use_block_cache:
|
| 614 |
+
if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
|
| 615 |
output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
|
| 616 |
logits, block_past_key_values = output.logits, output.block_past_key_values
|
| 617 |
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
|
|
|
|
| 637 |
|
| 638 |
x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
|
| 639 |
|
|
|
|
| 640 |
input_ids = x_t
|
| 641 |
# Truncate stop_token
|
| 642 |
if stop_token in input_ids[:, original_input_length:]:
|