Chengyue Wu commited on
Commit
da56081
·
1 Parent(s): 0417cd5
Files changed (2) hide show
  1. README.md +0 -1
  2. modeling.py +1 -3
README.md CHANGED
@@ -41,7 +41,6 @@ Our approach introduces a novel decoding recipe incorporating a complementary at
41
  - **Params**: 1.54B (non-embedding: 1.31B)
42
  - **Layers**: 28
43
  - **Attention Heads**: 12 (Q), 2 (KV, GQA)
44
- - **Context Window**: 32,768 tokens (generation length: 8,192)
45
  - **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching**
46
 
47
  ---
 
41
  - **Params**: 1.54B (non-embedding: 1.31B)
42
  - **Layers**: 28
43
  - **Attention Heads**: 12 (Q), 2 (KV, GQA)
 
44
  - **Key Feature**: Parallel **block-wise decoding** + **hierarchical caching**
45
 
46
  ---
modeling.py CHANGED
@@ -581,7 +581,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
581
  x_init = torch.cat([input_ids, x_init], dim=1)
582
 
583
  x_t = x_init.clone()
584
- step = 0
585
  block_past_key_values = None
586
  while True:
587
  if stop_token in x_t[:, prompt_length:]:
@@ -612,7 +611,7 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
612
  break
613
 
614
  if use_block_cache:
615
- if step % block_cache_refresh_interval == 0 or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
616
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
617
  logits, block_past_key_values = output.logits, output.block_past_key_values
618
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
@@ -638,7 +637,6 @@ class Fast_dLLM_QwenForCausalLM(Fast_dLLM_QwenPreTrainedModel, GenerationMixin):
638
 
639
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
640
 
641
- step += 1
642
  input_ids = x_t
643
  # Truncate stop_token
644
  if stop_token in input_ids[:, original_input_length:]:
 
581
  x_init = torch.cat([input_ids, x_init], dim=1)
582
 
583
  x_t = x_init.clone()
 
584
  block_past_key_values = None
585
  while True:
586
  if stop_token in x_t[:, prompt_length:]:
 
611
  break
612
 
613
  if use_block_cache:
614
+ if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
615
  output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
616
  logits, block_past_key_values = output.logits, output.block_past_key_values
617
  logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
 
637
 
638
  x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
639
 
 
640
  input_ids = x_t
641
  # Truncate stop_token
642
  if stop_token in input_ids[:, original_input_length:]: