Wrap assistant messages inside "{% generation %}" markers in chat_template.jinja

#126

Added "{% generation %}" markers enables the TRL SFTTrainer's assistant_only_loss config option. assistant_only_loss tells the SFTTrainer to only enable gradients on the assistant messages, which are wrapped around {% generation %} by this PR. I confirmed that this behaves as expected by using this custom template for the gpt-oss-20b tokenizer as the processing_class for SFTTrainer.

See this transformers PR that introduced this change
See also how trl/trainer/sft_trainer.py uses this marker in transformers/utils/chat_template_utils.py.

Code segment to verify the masking is done correctly, where assistant tokens are printed in green:

tokenizer = AutoTokenizer.from_pretrained('openai/gpt-oss-20b', trust_remote_code=True)
tokenizer.chat_template = CORRECTED_JINJA_TEMPLATE

templated_output = tokenizer.apply_chat_template(
    sample['messages'],
    tokenize=True, 
    add_generation_prompt=False,
    return_assistant_tokens_mask=True,
    return_dict=True,
)

print("Visualizing token masks. Green text is used for loss calculation.\n")
GREEN = "\033[92m"
RESET = "\033[0m"

input_ids = templated_output['input_ids']
assistant_mask = templated_output['assistant_masks']

if len(input_ids) != len(assistant_mask):
    raise ValueError("Mismatch between input_ids and assistant_masks length.")

current_chunk_tokens = []
current_mask_status = None

for token_id, is_assistant in zip(input_ids, assistant_mask):
    mask_status = bool(is_assistant)
    if current_mask_status is None:
        current_mask_status = mask_status
    
    if mask_status != current_mask_status:
        # Decode and print the completed chunk
        decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
        if current_mask_status:
            print(f"{GREEN}{decoded_text}{RESET}", end="")
        else:
            print(decoded_text, end="")
        
        # Start a new chunk
        current_chunk_tokens = [token_id]
        current_mask_status = mask_status
    else:
        current_chunk_tokens.append(token_id)

# Print the final chunk after the loop
if current_chunk_tokens:
    decoded_text = tokenizer.decode(current_chunk_tokens, skip_special_tokens=False)
    if current_mask_status:
        print(f"{GREEN}{decoded_text}{RESET}", end="")
    else:
        print(decoded_text, end="")

Prints something like:

<|start|>user<|message|>USER_MESSAGE<|end|>[GREEN_STARTS]<|start|>assistant<|channel|>analysis<|message|>...<|call|>[GREEN_ENDS]

Hi,

Good point. To solve this issue, in my case, I developed a custom data collator. This one masks everything except the assistant's final channel response.
I found it useful to preserve the native CoT capability of GPT-OSS 20B, while fine tuning only on final channel.

Here's the code, I'm happy to share:

class DataCollatorForAssistantOnlyLM(DataCollatorForLanguageModeling):
    """Custom data collator that only computes loss on assistant responses."""

    RESPONSE_TEMPLATE: str = "<|start|>assistant<|channel|>final<|message|>"
    response_template_ids: torch.Tensor

    def __init__(self, tokenizer):
        super().__init__(tokenizer=tokenizer, mlm=False)

        self.response_template_ids = tokenizer.encode(
            self.RESPONSE_TEMPLATE, add_special_tokens=False
        )

    def torch_call(self, examples):
        batch = super().torch_call(examples)

        # For each example in the batch, mask everything except assistant responses
        labels = batch["labels"].clone()

        for i, label in enumerate(labels):
            # Find where assistant responses start
            response_token_ids_start_idx = []

            # Search for the response template in the token sequence
            for idx in range(len(label) - len(self.response_template_ids) + 1):
                if (
                    label[idx : idx + len(self.response_template_ids)].tolist()
                    == self.response_template_ids
                ):
                    response_token_ids_start_idx.append(
                        idx + len(self.response_template_ids)
                    )

            # Mask all tokens except those after response templates
            if len(response_token_ids_start_idx) > 0:
                # Start with all tokens masked
                mask = torch.ones_like(label, dtype=torch.bool)

                # Unmask assistant response regions
                for start_idx in response_token_ids_start_idx:
                    # Find the next response template or end of sequence
                    next_idx = len(label)
                    for next_start in response_token_ids_start_idx:
                        if next_start > start_idx:
                            next_idx = next_start - len(self.response_template_ids)
                            break

                    mask[start_idx:next_idx] = False

                # Apply mask (set masked tokens to -100)
                labels[i] = torch.where(mask, -100, label)

        batch["labels"] = labels
        return batch

I'm available to contribute, if needed.

Best wishes,
Manuel

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment