financial-task-env / train_grpo.py
bpHigh's picture
Add blog and final data
18d1157
# ruff: noqa: E402
"""
GRPO training notebook for the office-document task env.
Continues from `bpHigh/qwen3b-office-sft-kimi` (the SFT'd Qwen2.5-Coder-3B
LoRA) with GRPO on the 97 train tasks, using TRL's `environment_factory`
integration. The env is loaded as a remote OpenEnv Space via WebSocket —
no env code runs in the training process.
Designed for **Modal notebooks** on a single A100 40GB ($2.50/hr).
Designed budget: ~$25–30 (3-5h training + 1-2h eval).
Format: each `# %%` is a cell — opens in Modal, Jupyter, or runs as a
plain script with `python train_grpo.py`.
References
----------
* TRL OpenEnv guide: https://huggingface.co/docs/trl/openenv
* TRL Trackio guide: https://huggingface.co/docs/trl/trackio_integration
* SFT base model: https://huggingface.co/bpHigh/qwen3b-office-sft-kimi
* Env Space: https://huggingface.co/spaces/bpHigh/financial-task-env
Pipeline
--------
1. Install deps + clone repo (for the env client + manifest)
2. Imports + auth check
3. Duplicate the Space to your account, set FINANCIAL_ENV_GOLD_STASH=copy
4. Define the OpenEnv tool wrapper (one tool per env action_type)
5. Build the train dataset (one prompt per train task)
6. Reward function (env.reward → list[float])
7. Load model + SFT adapter for trainable continuation
8. GRPO config + trainer (with Trackio logging)
9. Train (~3-5 hr) — live reward / loss curves on Trackio Space
10. Eval on 22-task held-out split (~1-2 hr)
11. Save + summarize results
Required env vars / Modal secrets
---------------------------------
* HF_TOKEN — pull SFT adapter, push GRPO adapter to your Hub repo,
and host Trackio logs on a Space
* TRACKIO_SPACE_ID — (recommended) HF Space ID where Trackio runs are
hosted live, e.g. 'bpHigh/trackio-office-grpo'.
If unset, logs land locally only (still inspectable
after the run from /tmp/grpo_qwen3b_office/).
* TRACKIO_PROJECT — project group name (default 'office-doc-grpo')
IMPORTANT: before running, duplicate the env Space to your account and set
the env-var `FINANCIAL_ENV_GOLD_STASH=copy` in its Settings → Variables.
This switches gold-file stashing to per-session COPY mode so concurrent
GRPO rollouts don't race on the same source's rename.
"""
# %% [markdown]
# # GRPO training: SFT'd Qwen3B → GRPO on the office-document env
#
# **Starting point:** [bpHigh/qwen3b-office-sft-kimi](https://huggingface.co/bpHigh/qwen3b-office-sft-kimi)
# (SFT'd on 53 Kimi-K2.5 trajectories, train_loss 0.196).
#
# **Goal:** continue with GRPO on the 97 train tasks via TRL's
# `environment_factory`, evaluate on the held-out 22-task split, push
# the trained adapter to HF Hub.
#
# **Budget:** ~$25–30 on Modal A100 40GB.
# %% [markdown]
# ## 1. Install deps + clone the env repo
# %%
import os
import subprocess
import sys
import time
from pathlib import Path
REPO_URL = "https://github.com/bp-high/openenv_financial_task_env.git"
REPO_DIR = Path("/work/openenv_financial_task_env")
RUN_START = time.time()
if not REPO_DIR.exists():
REPO_DIR.parent.mkdir(parents=True, exist_ok=True)
subprocess.run(["git", "clone", "--depth=1", REPO_URL, str(REPO_DIR)], check=True)
subprocess.run(["git", "-C", str(REPO_DIR), "fetch", "origin", "main"], check=True)
subprocess.run(["git", "-C", str(REPO_DIR), "reset", "--hard", "origin/main"], check=True)
# Pip install the stack. vLLM is required for `use_vllm=True` in colocate mode.
subprocess.run([
sys.executable, "-m", "pip", "install", "-q", "-U",
# typing_extensions>=4.15 is required by recent pydantic_core — the
# Sentinel symbol pydantic_core imports was first added in 4.15.0.
# Pin both first so the rest of the install resolves cleanly.
"typing_extensions>=4.15",
"pydantic>=2.9", "pydantic_core>=2.23",
"trl>=0.11", "peft>=0.13", "accelerate>=1.0",
"datasets>=3.0", "bitsandbytes>=0.43",
"vllm>=0.6.0",
"openenv-core>=0.2.0",
"trackio",
], check=True)
# Add repo to path so we can import the env CLIENT + manifest helpers
sys.path.insert(0, str(REPO_DIR))
sys.path.insert(0, str(REPO_DIR / "server"))
os.chdir(REPO_DIR)
print(f"✓ Repo + deps ready at {REPO_DIR}")
# %% [markdown]
# ## 2. Imports + auth check
# %%
import json
from typing import List, Optional
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from trl import GRPOConfig, GRPOTrainer
# Env client + helpers (from the cloned repo)
from client import FinancialTaskEnv
from models import FinancialAction
from tasks import TASKS, split_ids
assert os.environ.get("HF_TOKEN"), \
"Set HF_TOKEN env var (Modal: add as a Secret) — needed for SFT-adapter pull and GRPO-adapter push"
# Trackio config — TRL picks these up via report_to='trackio'.
# TRACKIO_SPACE_ID is optional; if unset, runs are logged locally only.
os.environ.setdefault("TRACKIO_PROJECT", "office-doc-grpo")
if os.environ.get("TRACKIO_SPACE_ID"):
print(f"✓ Trackio Space: {os.environ['TRACKIO_SPACE_ID']} (project='{os.environ['TRACKIO_PROJECT']}')")
else:
print(f"⚠ TRACKIO_SPACE_ID unset — runs will log locally only "
f"(project='{os.environ['TRACKIO_PROJECT']}')")
print(f"✓ Imports OK · CUDA: {torch.cuda.is_available()} · "
f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
# %% [markdown]
# ## 3. Configure the env Space URL
#
# **Important:** before training, duplicate `bpHigh/financial-task-env`
# to your own HF account so you have a private Space with concurrency
# capacity for your training run. Set the env-var
# `FINANCIAL_ENV_GOLD_STASH=copy` in the Space's Settings → Variables
# so concurrent GRPO rollouts don't race on the same gold file's rename.
#
# Then update `ENV_URL` below to your duplicate.
# %%
# Default points at the public Space — duplicate this for training.
ENV_URL = os.environ.get("ENV_URL", "https://bphigh-financial-task-env.hf.space")
print(f"Using env Space: {ENV_URL}")
# Quick health check — fail fast if the Space is sleeping or wrong URL
import urllib.request
try:
with urllib.request.urlopen(f"{ENV_URL}/health", timeout=15) as r:
if r.status != 200:
raise RuntimeError(f"Space /health returned {r.status}")
print("✓ Env Space is alive")
except Exception as e:
print(f"⚠ Env Space health check failed: {e}")
print(f" Wake it up by visiting {ENV_URL} once, then re-run.")
raise
# %% [markdown]
# ## 4. The OpenEnv tool wrapper
#
# This class is what TRL's `environment_factory` instantiates per generation.
# Each public method becomes a tool the model can call. We expose three
# tools matching the env's three action types (`code`, `submit`, `submit_file`).
#
# The model sees these as function-callable tools with typed args + docstrings;
# TRL handles parsing + multi-turn rollout automatically.
# %%
class OfficeDocumentEnv:
"""OpenEnv wrapper for the cross-format office-document env.
Exposes three tools matching the env's three action types:
- run_python_code(code) → action_type='code'
- submit_text_answer(answer) → action_type='submit' (QA tasks)
- submit_file(path) → action_type='submit_file' (MODIFY tasks)
The model picks tools based on task_type (returned in reset). Tool
methods raise ValueError when called after the episode is done; the
trainer catches these and feeds the message back to the model.
"""
def __init__(self):
# FinancialTaskEnv is async-by-default; wrap with .sync() and open
# the WebSocket connection up-front so per-call latency is just RPC.
self.client = FinancialTaskEnv(base_url=ENV_URL).sync()
self.client.connect()
self.reward = 0.0
self.done = False
self._task_type = "MODIFY"
def __del__(self):
# Best-effort WS cleanup — TRL doesn't manage env lifecycle explicitly.
try:
self.client.close()
except Exception:
pass
def reset(self, task_id: Optional[str] = None, **kwargs) -> Optional[str]:
"""Receives task_id (and any other dataset columns) as kwargs."""
result = self.client.reset(task_id=task_id) if task_id else self.client.reset()
obs = result.observation
self.reward = 0.0
self.done = False
self._task_type = obs.task_type
# Initial observation as a single string the model sees
return (
f"{obs.task_description}\n\n"
f"Source file: {obs.source_file}\n"
f"Family: {getattr(obs, 'task_type', 'MODIFY')}\n\n"
f"{obs.feedback}"
)
def run_python_code(self, code: str) -> str:
"""Execute Python code in the env's sandbox.
Use this to read or modify the source file. Variables do NOT
persist between calls — each call runs in a fresh subprocess.
Available libs: openpyxl, python-docx, python-pptx, Pillow.
Args:
code: Python source to execute.
Returns:
stdout/stderr from the code, plus per-step reward decomposition.
"""
if self.done:
raise ValueError("Episode already finished — submit your answer.")
result = self.client.step(FinancialAction(action_type="code", content=code))
self.reward = result.reward
self.done = result.done
return result.observation.feedback
def submit_file(self, path: str) -> str:
"""Submit the modified file as the final answer (MODIFY tasks).
Args:
path: Absolute filesystem path to the modified file. Use the
source_file path from the initial observation.
Returns:
Grading result. The episode ends after this call.
"""
if self.done:
raise ValueError("Episode already finished.")
result = self.client.step(FinancialAction(action_type="submit_file", content=path))
self.reward = result.reward
self.done = True
return result.observation.feedback
def submit_text_answer(self, answer: str) -> str:
"""Submit a text answer (QA tasks like 'How many plants?').
Args:
answer: The text answer to submit. May include numbers or
descriptive text — the grader extracts numbers and
matches keywords.
Returns:
Grading result. The episode ends after this call.
"""
if self.done:
raise ValueError("Episode already finished.")
result = self.client.step(FinancialAction(action_type="submit", content=answer))
self.reward = result.reward
self.done = True
return result.observation.feedback
# Smoke test the wrapper
print("Smoke testing the env wrapper on a single task...")
_smoke = OfficeDocumentEnv()
obs = _smoke.reset(task_id="finch_10")
print(f" reset OK: obs is {len(obs)} chars")
fb = _smoke.run_python_code("print('hello')")
print(f" run_python_code OK: reward={_smoke.reward:.3f}, done={_smoke.done}")
del _smoke
# %% [markdown]
# ## 5. Build the train dataset
#
# One row per train task. The `task_id` column is passed through to
# `reset()` via kwargs. TRL handles the rollout loop — we don't need to
# generate prompts; the user message is the system prompt only, the env's
# initial observation comes from `reset()`.
# %%
# This system prompt mirrors the one in inference.py that produced the
# Kimi-K2.5 / MiniMax baseline scores AND the SFT corpus. Markdown
# ```python``` blocks for code, plain SUBMIT_FILE: / SUBMIT_ANSWER: text
# directives for finalization. This is the format the SFT'd model already
# emits fluently from training - matching it removes the format mismatch
# that was causing 0 rewards under TRL's <tool_call>-XML expectation.
SYSTEM_PROMPT = """You are an expert at editing office documents (Excel, Word, PowerPoint) with Python.
You are working with a real .xlsx, .docx, or .pptx file. Pick the right
library based on the source_file path shown in the initial observation:
- .xlsx -> openpyxl (load_workbook / wb.save)
- .docx -> python-docx (Document / doc.save)
- .pptx -> python-pptx (Presentation / prs.save)
CRITICAL RULES:
1. Do NOT call reset(). Just write plain Python code in a ```python``` block.
2. Use the EXACT file path provided in the observation. Do not guess paths.
3. Each code block runs in a FRESH subprocess - you must re-import and re-open
the file every time. Variables do NOT persist between steps.
4. Use print() liberally to see data. Read the output carefully before your
next step.
5. You have at most 12 steps per episode. Be efficient: explore in step 1,
solve in step 2-3, submit.
6. **You MUST execute at least one code step before submitting.** The env
will reject SUBMIT_ANSWER and SUBMIT_FILE on step 1 - you need to read
or modify the file with code first. Submitting the source file unchanged
is never a correct solve.
RESPONSE FORMAT - use EXACTLY one of these three forms per response:
To run Python code:
```python
your code here
```
To submit a text answer (QA tasks like 'How many plants?'):
SUBMIT_ANSWER: your answer here
To submit a modified file (MODIFY tasks):
SUBMIT_FILE: /absolute/path/to/saved.<ext>
For MODIFY tasks: load with the right library, mutate, save to the SAME path,
then SUBMIT_FILE that path."""
train_ids = split_ids("train")
# Drop hand-curated tasks (task_*) to focus GRPO on the larger Round-2 pool;
# the SFT was already exposed to the hand-curated set via Kimi trajectories.
train_ids = [tid for tid in train_ids if not tid.startswith("task_")]
print(f"Train tasks for GRPO: {len(train_ids)}")
train_data = [
{"prompt": [{"role": "user", "content": SYSTEM_PROMPT}], "task_id": tid}
for tid in train_ids
]
train_ds = Dataset.from_list(train_data)
# %% [markdown]
# ## 6. Reward function — read env.reward after each episode
# %%
def env_reward(environments, **kwargs) -> List[float]:
"""TRL's environment_factory passes the list of OfficeDocumentEnv instances
after the multi-turn rollout ends. We read `env.reward` (set by the last
tool call before done=True, OR the per-step reward if the episode ran out
of steps)."""
return [float(env.reward) for env in environments]
# %% [markdown]
# ## 7. Load base model + SFT adapter (trainable continuation)
# %%
BASE_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct"
SFT_ADAPTER = "bpHigh/qwen3b-office-sft-kimi"
GRPO_HUB_ID = "bpHigh/qwen3b-office-grpo"
print(f"Loading tokenizer: {BASE_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.truncation_side = "left"
# TRL's environment_factory path runs add_response_schema(tokenizer) to parse
# tool calls from completions. Auto-detection fails for Qwen2.5-Coder's chat
# template (it only knows qwen3, qwen3.5, llama3, glm4, gptoss), so we attach
# qwen3_schema manually — it parses <tool_call>...</tool_call> blocks via
# regex and is template-agnostic, so it works fine for Qwen2.5 too.
from trl.chat_template_utils import qwen3_schema
tokenizer.response_schema = qwen3_schema
print(f"Loading base model: {BASE_MODEL}")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
print(f"Attaching SFT adapter (trainable): {SFT_ADAPTER}")
model = PeftModel.from_pretrained(base_model, SFT_ADAPTER, is_trainable=True)
model.print_trainable_parameters()
# %% [markdown]
# ## 8. GRPO config + trainer
# %%
config = GRPOConfig(
output_dir="/tmp/grpo_qwen3b_office",
num_train_epochs=1,
learning_rate=1e-5, # gentler than SFT's 2e-4
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # 4 prompts × 2 gens = 8 rollouts/step (fits 40GB w/ vLLM colocate)
num_generations=2,
max_completion_length=2048, # cumulative assistant tokens across all tool-call turns
vllm_max_model_length=10240, # prompt + completion + env-feedback growth across ≤12 turns
vllm_gpu_memory_utilization=0.22, # leave more VRAM for trainer forward pass
max_tool_calling_iterations=12, # cap multi-turn rollouts (env's MAX_STEPS=15 is the hard ceiling)
temperature=0.8,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_steps=1,
save_steps=20,
save_total_limit=2,
warmup_ratio=0.05,
use_vllm=True, # 2-3× faster generation
vllm_mode="colocate", # single GPU
chat_template_kwargs={"enable_thinking": False},
log_completions=True,
push_to_hub=True,
hub_model_id=GRPO_HUB_ID,
hub_strategy="end",
hub_private_repo=False,
report_to="trackio",
run_name="grpo_qwen3b_office_sft-kimi",
seed=42,
)
print("Creating GRPOTrainer with environment_factory=OfficeDocumentEnv...")
trainer = GRPOTrainer(
model=model,
args=config,
train_dataset=train_ds,
reward_funcs=env_reward,
environment_factory=OfficeDocumentEnv, # ← TRL handles the multi-turn loop + XML tool-call parsing
processing_class=tokenizer,
)
# %% [markdown]
# ## 9. Train
# %%
print("Starting GRPO training...")
train_start = time.time()
trainer.train()
train_dur = time.time() - train_start
print(f"✓ Training complete in {train_dur / 60:.1f} min")
trainer.save_model("/tmp/grpo_qwen3b_office")
tokenizer.save_pretrained("/tmp/grpo_qwen3b_office")
trainer.push_to_hub()
print(f"✓ Pushed to HF Hub: {GRPO_HUB_ID}")
# %% [markdown]
# ## 10. Evaluate on the 22-task held-out eval split
#
# Run greedy multi-step rollouts via the same env wrapper, return per-task
# scores in the same format as `runs/sft_eval_v2/*/results.json` so the
# comparison plot can pick them up directly.
# %%
@torch.inference_mode()
def run_eval_episode(task_id: str, max_steps: int = 15) -> dict:
"""Greedy rollout via the env wrapper; return final score + step count."""
env = OfficeDocumentEnv()
actions = []
obs = env.reset(task_id=task_id)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": obs},
]
final_reward = 0.0
n_steps = 0
for step in range(1, max_steps + 1):
n_steps = step
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=24000)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(
**inputs,
max_new_tokens=1500,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
# Naive tool-call parser for eval (we don't use TRL's loop here).
# The model should emit either:
# ```python ... ``` for code
# SUBMIT_FILE: <path> for file submit
# SUBMIT_ANSWER: <text> for text submit
# If the trained policy uses TRL tool-calling format instead, this
# naive parser will fail and the eval score will be low — see
# eval_lora.py for a more robust parser.
import re
if "SUBMIT_FILE:" in response:
path = response.split("SUBMIT_FILE:", 1)[1].strip().splitlines()[0].strip()
try:
env.submit_file(path)
except Exception:
pass
final_reward = env.reward
break
if "SUBMIT_ANSWER:" in response:
ans = response.split("SUBMIT_ANSWER:", 1)[1].strip()
try:
env.submit_text_answer(ans)
except Exception:
pass
final_reward = env.reward
break
m = re.search(r"```(?:python)?\s*\n(.*?)```", response, re.DOTALL)
code = m.group(1).strip() if m else response.strip()
try:
fb = env.run_python_code(code)
except ValueError:
final_reward = env.reward
break
actions.append("code")
messages.append({"role": "assistant", "content": response})
messages.append({"role": "user", "content": (
f"Code execution result (step {step}/{max_steps}):\n{fb}"
)})
if env.done:
final_reward = env.reward
break
return {
"task_id": task_id,
"family": TASKS[task_id].get("family", "?") if task_id in TASKS else "?",
"score": max(0.001, min(0.999, final_reward)),
"steps": n_steps,
}
eval_ids = split_ids("eval")
print(f"\nEvaluating GRPO model on {len(eval_ids)} held-out tasks...")
eval_start = time.time()
eval_results = []
for tid in eval_ids:
r = run_eval_episode(tid)
eval_results.append(r)
print(f" {r['task_id']:55s} ({r['family']:4s}) score={r['score']:.3f} steps={r['steps']}")
eval_dur = time.time() - eval_start
print(f"\n✓ Eval done in {eval_dur / 60:.1f} min")
# %% [markdown]
# ## 11. Save + summarize
# %%
avg_score = sum(r["score"] for r in eval_results) / len(eval_results)
success_rate = sum(1 for r in eval_results if r["score"] >= 0.5) / len(eval_results)
by_family = {}
for r in eval_results:
by_family.setdefault(r["family"], []).append(r["score"])
fam_summary = {
fam: {"n": len(scores), "avg": round(sum(scores) / len(scores), 4)}
for fam, scores in by_family.items()
}
total_dur_s = time.time() - RUN_START
output = {
"model": GRPO_HUB_ID,
"base": SFT_ADAPTER,
"n_tasks": len(eval_results),
"avg_score": round(avg_score, 4),
"success_rate": round(success_rate, 4),
"by_family": fam_summary,
"results": eval_results,
"wall_clock": {
"total_min": round(total_dur_s / 60, 1),
"training_min": round(train_dur / 60, 1),
"eval_min": round(eval_dur / 60, 1),
"estimated_cost_usd": round((total_dur_s / 3600) * 2.50, 2),
},
}
with open("/tmp/grpo_eval_results.json", "w") as f:
json.dump(output, f, indent=2)
print("\n" + "=" * 70)
print(f"GRPO training + eval complete")
print(f" model: {GRPO_HUB_ID}")
print(f" avg score: {avg_score:.3f}")
print(f" success rate: {success_rate:.0%}")
print(f" by family:")
for fam, info in fam_summary.items():
print(f" {fam}: avg={info['avg']:.3f} n={info['n']}")
print(f" wall-clock: {output['wall_clock']['total_min']:.1f} min")
print(f" est cost: ${output['wall_clock']['estimated_cost_usd']:.2f}")
print("=" * 70)
print(f"\nResults: /tmp/grpo_eval_results.json")
print(f"Adapter: https://huggingface.co/{GRPO_HUB_ID}")