mindchain commited on
Commit
74c1152
Β·
verified Β·
1 Parent(s): 3577db9

Upload train_arithmetic_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic_v2.py +1 -9
train_arithmetic_v2.py CHANGED
@@ -16,7 +16,7 @@ import re
16
  import random
17
  import torch
18
  from datasets import Dataset
19
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
20
  from trl import GRPOConfig, GRPOTrainer
21
 
22
  # ============================================================================
@@ -243,14 +243,6 @@ def main():
243
  save_steps=MAX_STEPS,
244
  push_to_hub=False,
245
  report_to="none",
246
- # Force EOS in generation
247
- generation_config=GenerationConfig(
248
- max_new_tokens=30,
249
- do_sample=True,
250
- temperature=0.7,
251
- pad_token_id=tokenizer.eos_token_id,
252
- eos_token_id=tokenizer.eos_token_id,
253
- ),
254
  )
255
 
256
  # Eval callback
 
16
  import random
17
  import torch
18
  from datasets import Dataset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
  from trl import GRPOConfig, GRPOTrainer
21
 
22
  # ============================================================================
 
243
  save_steps=MAX_STEPS,
244
  push_to_hub=False,
245
  report_to="none",
 
 
 
 
 
 
 
 
246
  )
247
 
248
  # Eval callback