Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja
Added "{% generation %}" markers enables the TRL SFTTrainer's assistant_only_loss config option. assistant_only_loss tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around {% generation %} by this PR. I confirmed that this behaves as expected by using this custom template for the gpt-oss-20b tokenizer as the processing_class for SFTTrainer.
See this transformers PR that introduced this change
See also how trl/trainer/sft_trainer.py uses this marker in transformers/utils/chat_template_utils.py.
Code segment to verify the masking is done correctly, where assistant tokens are printed in green:
tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE
templated_output = tokenizer.apply_chat_template(
sample['messages'],
tokenize=True,
add_generation_prompt=False,
return_assistant_tokens_mask=True,
return_dict=True,
)
print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"
input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']
if len(input_ids) != len(assistant_mask):
raise ValueError("Mismatch between input_ids and assistant_masks length.")
current_chunk_tokens = []
current_mask_status = None
for token_id, is_assistant in zip(input_ids, assistant_mask):
mask_status = bool(is_assistant)
if current_mask_status is None:
current_mask_status = mask_status
if mask_status != current_mask_status:
# Decode and print the completed chunk
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
# Start a new chunk
current_chunk_tokens = [token_id]
current_mask_status = mask_status
else:
current_chunk_tokens.append(token_id)
# Print the final chunk after the loop
if current_chunk_tokens:
decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
if current_mask_status:
print(f"{GREEN}{decoded_text}{RESET}", end="")
else:
print(decoded_text, end="")
Prints something like:
<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]
Hi,
Good point. To solve this issue, in my case, I developed a custom data collator. This one masks everything except the assistant's final channel response.
I found it useful to preserve the native CoT capability of GPT-OSS 20B, while fine tuning only on final channel.
Here's the code, I'm happy to share:
class DataCollatorForAssistantOnlyLM(DataCollatorForLanguageModeling):
"""Custom data collator that only computes loss on assistant responses."""
RESPONSE_TEMPLATE: str = "<|start|>assistant<|channel|>final<|message|>"
response_template_ids: torch.Tensor
def __init__(self, tokenizer):
super().__init__(tokenizer=tokenizer, mlm=False)
self.response_template_ids = tokenizer.encode(
self.RESPONSE_TEMPLATE, add_special_tokens=False
)
def torch_call(self, examples):
batch = super().torch_call(examples)
# For each example in the batch, mask everything except assistant responses
labels = batch["labels"].clone()
for i, label in enumerate(labels):
# Find where assistant responses start
response_token_ids_start_idx = []
# Search for the response template in the token sequence
for idx in range(len(label) - len(self.response_template_ids) + 1):
if (
label[idx : idx + len(self.response_template_ids)].tolist()
== self.response_template_ids
):
response_token_ids_start_idx.append(
idx + len(self.response_template_ids)
)
# Mask all tokens except those after response templates
if len(response_token_ids_start_idx) > 0:
# Start with all tokens masked
mask = torch.ones_like(label, dtype=torch.bool)
# Unmask assistant response regions
for start_idx in response_token_ids_start_idx:
# Find the next response template or end of sequence
next_idx = len(label)
for next_start in response_token_ids_start_idx:
if next_start > start_idx:
next_idx = next_start - len(self.response_template_ids)
break
mask[start_idx:next_idx] = False
# Apply mask (set masked tokens to -100)
labels[i] = torch.where(mask, -100, label)
batch["labels"] = labels
return batch
I'm available to contribute, if needed.
Best wishes,
Manuel