AMDTritonAgent / gradio_app.py
vivienfanghua's picture
Update gradio_app.py
48bdc73 verified
import gradio as gr
import requests
import time
import os
from typing import Optional, Dict, Any, List, Tuple
class KernelBenchGradioApp:
def __init__(self, api_base_url: str = "http://localhost:8009"):
self.api_base_url = api_base_url
self.current_request_id: Optional[str] = None
def load_sample_files(self) -> List[str]:
try:
response = requests.get(f"{self.api_base_url}/api/samples")
if response.ok:
data = response.json()
samples = []
for sample in data.get("samples", []):
level = sample.get("level", "")
name = sample.get("name", "")
path = sample.get("path", "")
samples.append(f"[{level}] {name}")
return ["-- Select a sample to load --"] + samples
return ["-- Select a sample to load --"]
except Exception as e:
print(f"Error loading samples: {e}")
return ["-- Select a sample to load --"]
def load_sample_content(self, sample_selection: str) -> Tuple[str, str]:
"""Load sample content and return both content and problem name"""
if sample_selection == "-- Select a sample to load --":
return ("", "")
try:
if sample_selection.startswith("[level1]"):
level = "level1"
name = sample_selection[9:].strip()
elif sample_selection.startswith("[level2]"):
level = "level2"
name = sample_selection[9:].strip()
else:
return ("", "")
filename = name + ".py"
response = requests.get(f"{self.api_base_url}/api/samples/{level}/{filename}")
if response.ok:
data = response.json()
content = data.get("content", "")
# Return content and the problem name (without .py extension)
return (content, name)
return ("", "")
except Exception as e:
return (f"Error loading sample: {str(e)}", "")
def update_model_name(self, server_type: str) -> str:
model_map = {
"deepseek": "deepseek-coder",
"openai": "gpt-5",
"anthropic": "claude-sonnet-4.5",
"google": "gemini-1.5-flash-002",
"nim": "qwen/qwen3-coder-480b-a35b-instruct"
}
return model_map.get(server_type, "gpt-5")
def submit_generation(
self,
ref_arch_src: str,
backend: str,
server_type: str,
model_name: str,
gpu_arch: str,
max_tokens: int,
temperature: float,
custom_prompt: str,
problem_name: str,
max_retries: int,
progress=gr.Progress()
) -> Tuple[str, str, str, str]:
if not ref_arch_src or not ref_arch_src.strip():
return ("", "", "❌ Please provide reference architecture source code", "")
try:
request_data = {
"ref_arch_src": ref_arch_src,
"gpu_arch": [gpu_arch],
"backend": backend,
"model_name": model_name,
"server_type": server_type,
"max_tokens": int(max_tokens),
"temperature": float(temperature),
"custom_prompt": custom_prompt if custom_prompt and custom_prompt.strip() else None,
"problem_name": problem_name if problem_name and problem_name.strip() else None,
"max_retries": int(max_retries) if max_retries is not None else None
}
response = requests.post(
f"{self.api_base_url}/api/generate",
json=request_data
)
if not response.ok:
error_detail = response.json().get("detail", "Unknown error")
return ("", "", f"❌ Failed to submit request: {error_detail}", "")
result = response.json()
request_id = result.get("request_id")
self.current_request_id = request_id
progress(0, desc="Request submitted, waiting for processing...")
start_time = time.time()
max_wait_time = 900
while True:
if time.time() - start_time > max_wait_time:
return ("", "", f"⏱️ Request timed out (ID: {request_id[:8]}...)", "")
status_response = requests.get(f"{self.api_base_url}/api/status/{request_id}")
if not status_response.ok:
return ("", "", f"❌ Failed to check status for request {request_id[:8]}...", "")
status_data = status_response.json()
current_status = status_data.get("status")
if current_status == "pending":
progress(0.25, desc="⏳ Request queued...")
elif current_status == "processing":
elapsed = int(time.time() - start_time)
current_retry = status_data.get("current_retry", 0)
max_retries = status_data.get("max_retries", 0)
if current_retry > 0:
progress(0.5, desc=f"🔄 Retrying with reflection... (Attempt {current_retry + 1}/{max_retries + 1}, {elapsed}s elapsed)")
else:
progress(0.5, desc=f"🔄 Generating kernel... ({elapsed}s elapsed)")
elif current_status == "completed":
progress(1.0, desc="✅ Generation completed!")
generated_kernel = status_data.get("generated_kernel", "No kernel generated")
eval_result_str = status_data.get("eval_result", "")
eval_formatted = self.format_eval_results(eval_result_str)
current_retry = status_data.get("current_retry", 0)
retry_info = f"\n**Attempts:** {current_retry + 1}" if current_retry > 0 else ""
success_msg = f"✅ Generation completed successfully!{retry_info}\n**Request ID:** `{request_id[:8]}...`"
return (generated_kernel, eval_formatted, success_msg, request_id)
elif current_status == "failed":
error_msg = status_data.get("error_message", "Unknown error")
current_retry = status_data.get("current_retry", 0)
max_retries = status_data.get("max_retries", 0)
retry_info = f"\n**Total Attempts:** {current_retry + 1}/{max_retries + 1}" if max_retries > 0 else ""
return ("", "", f"❌ Generation failed: {error_msg}{retry_info}\n**Request ID:** `{request_id[:8]}...`", request_id)
time.sleep(2)
except Exception as e:
return ("", "", f"❌ Error: {str(e)}", "")
def format_eval_results(self, eval_result_str: str) -> str:
if not eval_result_str:
return "⚠️ No evaluation results available"
try:
result = self.parse_eval_string(eval_result_str)
sections = []
sections.append("### Evaluation Results\n")
compiled_status = "✅ Compiled" if result.get("compiled") else "❌ Failed to Compile"
correctness_status = "✅ Correct" if result.get("correctness") else "❌ Incorrect"
sections.append(f"**Status:** {compiled_status} | {correctness_status}\n")
metadata = result.get("metadata", {})
if metadata:
hardware = metadata.get("hardware", "Unknown")
device = metadata.get("device", "")
device_str = f" (Device {device})" if device else ""
sections.append(f"**Hardware:** {hardware}{device_str}")
correctness_trials = metadata.get("correctness_trials")
if correctness_trials:
sections.append(f"**Correctness Trials:** {correctness_trials}")
speedup = result.get("speedup")
if speedup and speedup > 0:
speedup_emoji = "🚀" if speedup > 1 else "⚠️" if speedup < 1 else "➖"
speedup_desc = "Faster than reference!" if speedup > 1 else "Slower than reference" if speedup < 1 else "Same as reference"
sections.append(f"\n**Speedup:** {speedup_emoji} **{speedup:.2f}x** - {speedup_desc}")
ref_runtime = result.get("ref_runtime")
runtime = result.get("runtime")
if ref_runtime is not None and runtime is not None:
sections.append("\n### Performance Comparison\n")
ref_stats = result.get("ref_runtime_stats", {})
runtime_stats = result.get("runtime_stats", {})
sections.append("| Model | Mean Runtime | Std Dev | Min | Max |")
sections.append("|-------|--------------|---------|-----|-----|")
sections.append(
f"| **Reference (PyTorch)** | {ref_runtime:.2f} ms | "
f"{ref_stats.get('std', 0):.4f} ms | "
f"{ref_stats.get('min', 0):.2f} ms | "
f"{ref_stats.get('max', 0):.2f} ms |"
)
sections.append(
f"| **Custom Kernel** | {runtime:.2f} ms | "
f"{runtime_stats.get('std', 0):.4f} ms | "
f"{runtime_stats.get('min', 0):.2f} ms | "
f"{runtime_stats.get('max', 0):.2f} ms |"
)
num_trials = runtime_stats.get("num_trials", "N/A")
sections.append(f"\n*Number of trials: {num_trials}*")
elif runtime is not None:
sections.append(f"\n**Runtime:** {runtime:.2f} ms")
runtime_stats = result.get("runtime_stats", {})
if runtime_stats:
sections.append("\n### Runtime Statistics\n")
sections.append("| Metric | Value |")
sections.append("|--------|-------|")
sections.append(f"| Mean | {runtime_stats.get('mean', 0):.2f} ms |")
sections.append(f"| Std Dev | {runtime_stats.get('std', 0):.4f} ms |")
sections.append(f"| Min | {runtime_stats.get('min', 0):.2f} ms |")
sections.append(f"| Max | {runtime_stats.get('max', 0):.2f} ms |")
sections.append(f"| Trials | {runtime_stats.get('num_trials', 'N/A')} |")
return "\n".join(sections)
except Exception as e:
return f"⚠️ Error formatting results:\n```\n{eval_result_str}\n```"
def parse_eval_string(self, s: str) -> Dict[str, Any]:
import re
import json
result = {
"compiled": None,
"correctness": None,
"metadata": {},
"runtime": None,
"runtime_stats": {},
"ref_runtime": None,
"ref_runtime_stats": {},
"speedup": None
}
try:
compiled_match = re.search(r'compiled=(True|False)', s)
if compiled_match:
result["compiled"] = compiled_match.group(1) == "True"
correctness_match = re.search(r'correctness=(True|False)', s)
if correctness_match:
result["correctness"] = correctness_match.group(1) == "True"
metadata_match = re.search(r'metadata=(\{[^}]*\})', s)
if metadata_match:
try:
metadata_str = metadata_match.group(1).replace("'", '"')
metadata_str = re.sub(r'\((\d+)\s*/\s*(\d+)\)', r'"(\1 / \2)"', metadata_str)
result["metadata"] = json.loads(metadata_str)
except:
pass
runtime_match = re.search(r'\bruntime=([\d.]+)(?=\s|$|runtime_stats|ref_)', s)
if runtime_match:
result["runtime"] = float(runtime_match.group(1))
ref_runtime_match = re.search(r'ref_runtime=([\d.]+)', s)
if ref_runtime_match:
result["ref_runtime"] = float(ref_runtime_match.group(1))
speedup_match = re.search(r'speedup=([\d.]+)', s)
if speedup_match:
result["speedup"] = float(speedup_match.group(1))
def extract_dict(pattern: str, start_str: str) -> Optional[Dict]:
start_idx = s.find(start_str)
if start_idx == -1:
return None
start_pos = start_idx + len(start_str)
brace_count = 0
end_pos = start_pos
for i in range(start_pos, len(s)):
if s[i] == '{':
brace_count += 1
elif s[i] == '}':
brace_count -= 1
if brace_count == 0:
end_pos = i + 1
break
if end_pos > start_pos:
dict_str = s[start_pos:end_pos]
dict_str = dict_str.replace("'", '"').replace("None", "null").replace("True", "true").replace("False", "false")
try:
return json.loads(dict_str)
except:
pass
return None
runtime_stats = extract_dict(r'runtime_stats=', 'runtime_stats=')
if runtime_stats:
result["runtime_stats"] = runtime_stats
ref_runtime_stats = extract_dict(r'ref_runtime_stats=', 'ref_runtime_stats=')
if ref_runtime_stats:
result["ref_runtime_stats"] = ref_runtime_stats
except Exception as e:
print(f"Error parsing eval string: {e}")
return result
def load_request_history(self, limit: int = None) -> Tuple[List[List[Any]], List[str]]:
try:
# Request all records by setting a very high limit
# The API will return all available requests
url = f"{self.api_base_url}/api/requests?limit=10000"
response = requests.get(url)
if not response.ok:
return ([], [])
data = response.json()
requests_list = data.get("requests", [])
if not requests_list:
return ([], [])
table_data = []
request_ids = []
for req in requests_list:
req_id = req.get("id", "")
request_ids.append(req_id)
req_id_short = req_id[:12] + "..."
status = req.get("status", "unknown")
backend = req.get("backend", "-")
model_name = req.get("model_name", "-")
problem_name = req.get("problem_name", "-")
created_at = req.get("created_at", "")
compiled = "-"
correctness = "-"
runtime = "-"
speedup = "-"
if req.get("eval_result") and status == "completed":
try:
eval_data = self.parse_eval_string(req.get("eval_result", ""))
if eval_data.get("compiled") is not None:
compiled = "✅" if eval_data["compiled"] else "❌"
if eval_data.get("correctness") is not None:
correctness = "✅" if eval_data["correctness"] else "❌"
if eval_data.get("runtime") is not None:
runtime = f"{eval_data['runtime']:.2f}"
if eval_data.get("speedup") is not None and eval_data["speedup"] > 0:
speedup = f"{eval_data['speedup']:.2f}x"
except:
pass
status_emoji = {
"pending": "⏳",
"processing": "🔄",
"completed": "✅",
"failed": "❌"
}.get(status, "❓")
try:
from datetime import datetime
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
created_str = dt.strftime("%m/%d %H:%M")
except:
created_str = created_at[:16] if len(created_at) > 16 else created_at
table_data.append([
req_id_short,
f"{status_emoji} {status}",
backend,
model_name,
problem_name,
compiled,
correctness,
runtime,
speedup,
created_str
])
return (table_data, request_ids)
except Exception as e:
print(f"Error loading history: {e}")
return ([], [])
def view_request_by_id(self, request_id: str) -> Tuple[str, str, str, str, str]:
if not request_id or not request_id.strip():
return ("", "", "", "", "⚠️ Please enter a request ID")
try:
response = requests.get(f"{self.api_base_url}/api/status/{request_id}")
if not response.ok:
return ("", "", "", "", f"❌ Request not found: {request_id}")
status_data = response.json()
status = status_data.get("status")
if status == "completed":
ref_code = status_data.get("ref_arch_src", "No reference code")
generated_kernel = status_data.get("generated_kernel", "No kernel generated")
eval_result = self.format_eval_results(status_data.get("eval_result", ""))
error_message = status_data.get("error_message", "")
msg = f"✅ Request `{request_id[:12]}...` loaded successfully"
return (ref_code, generated_kernel, eval_result, error_message, msg)
elif status == "failed":
# For failed requests, still load ref_code, generated_kernel (last attempt), and error
ref_code = status_data.get("ref_arch_src", "No reference code")
generated_kernel = status_data.get("generated_kernel", "No kernel generated (all attempts failed)")
error_msg = status_data.get("error_message", "Unknown error")
eval_result = self.format_eval_results(status_data.get("eval_result", "")) if status_data.get("eval_result") else ""
# Show retry information if available
current_retry = status_data.get("current_retry", 0)
max_retries = status_data.get("max_retries", 0)
retry_info = f"\n\n**Total Attempts:** {current_retry + 1}/{max_retries + 1}" if max_retries > 0 else ""
msg = f"❌ Request failed after all attempts{retry_info}"
return (ref_code, generated_kernel, eval_result, error_msg, msg)
else:
return ("", "", "", "", f"⏳ Request is still {status}")
except Exception as e:
return ("", "", "", "", f"❌ Error: {str(e)}")
def view_request_from_table(self, history_table_data: gr.SelectData, request_ids_state: List[str]) -> Tuple[str, str, str, str, str]:
try:
row_index = history_table_data.index[0]
if row_index < len(request_ids_state):
request_id = request_ids_state[row_index]
return self.view_request_by_id(request_id)
return ("", "", "", "", "⚠️ Invalid selection")
except Exception as e:
return ("", "", "", "", f"❌ Error: {str(e)}")
def create_interface(self) -> gr.Blocks:
with gr.Blocks(title="KernelAgent - GPU Kernel Generator", theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# 🚀 KernelAgent GPU Kernel Generator
**Powered by AMD** - Generate and evaluate optimized Triton kernels on AMD MI300x
***This work was inspired by KernelBench, and the LLM Backend uses qwen3-coder.
"""
)
with gr.Tabs():
with gr.Tab("Generate Kernel"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Configuration")
sample_dropdown = gr.Dropdown(
choices=self.load_sample_files(),
label="Load Sample Code",
value="-- Select a sample to load --",
interactive=True
)
ref_arch_src = gr.Textbox(
label="Reference Architecture Source Code",
placeholder="Paste your PyTorch reference implementation here...",
lines=12,
max_lines=20
)
custom_prompt_input = gr.Textbox(
label="Custom Prompt (Optional)",
placeholder="Add custom instructions to append to the generation prompt...",
lines=3,
max_lines=5
)
problem_name_state = gr.State("")
with gr.Row():
backend = gr.Dropdown(
choices=["triton"],
value="triton",
label="Backend"
)
server_type = gr.Dropdown(
choices=["nim"],
value="nim",
label="Model Provider"
)
with gr.Row():
model_name = gr.Textbox(
label="Model Name",
value="qwen/qwen3-coder-480b-a35b-instruct"
)
gpu_arch = gr.Dropdown(
choices=["MI300x"],
value="MI300x",
label="GPU Architecture"
)
with gr.Row():
max_tokens = gr.Number(
label="Max Tokens",
value=4096,
minimum=256,
maximum=8192
)
temperature = gr.Slider(
label="Temperature",
minimum=1.0,
maximum=1.0,
value=1.0,
step=0.1
)
with gr.Row():
max_retries = gr.Number(
label="Max Retries (Reflection)",
value=3,
minimum=0,
maximum=10,
info="Number of times to retry generation with error feedback"
)
generate_btn = gr.Button("🚀 Generate Kernel", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Results")
status_msg = gr.Markdown("")
generated_kernel = gr.Code(
label="Generated Kernel",
language="cpp",
lines=15
)
eval_results = gr.Markdown(
label="Evaluation Results"
)
request_id_state = gr.State("")
sample_dropdown.change(
fn=self.load_sample_content,
inputs=[sample_dropdown],
outputs=[ref_arch_src, problem_name_state]
)
server_type.change(
fn=self.update_model_name,
inputs=[server_type],
outputs=[model_name]
)
generate_btn.click(
fn=self.submit_generation,
inputs=[
ref_arch_src, backend, server_type, model_name,
gpu_arch, max_tokens, temperature, custom_prompt_input, problem_name_state, max_retries
],
outputs=[generated_kernel, eval_results, status_msg, request_id_state]
)
with gr.Tab("Request History"):
gr.Markdown("### All Generation Requests")
gr.Markdown("💡 **Tip:** Click on any row in the table to view its details below")
request_ids_state = gr.State([])
with gr.Row():
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
history_table = gr.Dataframe(
headers=["ID", "Status", "Backend", "Model", "Problem", "Compiled", "Correct", "Runtime", "Speedup", "Created"],
datatype=["str", "str", "str", "str", "str", "str", "str", "str", "str", "str"],
value=[],
interactive=False,
wrap=True
)
gr.Markdown("---")
gr.Markdown("### Request Details")
view_status_msg = gr.Markdown("")
with gr.Tabs():
with gr.Tab("Reference Code"):
view_ref_code = gr.Code(
label="Reference Architecture Source Code",
language="python",
lines=12
)
with gr.Tab("Generated Kernel"):
view_kernel_output = gr.Code(
label="Generated Kernel Code",
language="cpp",
lines=12
)
with gr.Tab("Evaluation Results"):
view_eval_output = gr.Markdown(
label="Evaluation Results"
)
with gr.Tab("Error Message"):
view_error_message = gr.Textbox(
label="Error Details",
lines=12,
max_lines=100,
interactive=False,
placeholder="No error message available"
)
def refresh_history():
table_data, req_ids = self.load_request_history()
return table_data, req_ids
refresh_btn.click(
fn=refresh_history,
inputs=[],
outputs=[history_table, request_ids_state]
)
history_table.select(
fn=self.view_request_from_table,
inputs=[request_ids_state],
outputs=[view_ref_code, view_kernel_output, view_eval_output, view_error_message, view_status_msg]
)
app.load(
fn=refresh_history,
inputs=[],
outputs=[history_table, request_ids_state]
)
# gr.Markdown(
# """
# ---
# **Note:** This interface provides the same functionality as the web frontend.
# Backend API must be running at `http://localhost:8009` for this to work.
# """
# )
return app
def create_gradio_app(api_base_url: str = "http://localhost:8009") -> gr.Blocks:
app_instance = KernelBenchGradioApp(api_base_url=api_base_url)
return app_instance.create_interface()
if __name__ == "__main__":
api_url = os.environ.get("API_BASE_URL", "http://134.199.134.28/proxy/8009")
app = create_gradio_app(api_base_url=api_url)
app.queue().launch()