anewspace / app.py
igardner's picture
Just fucking around, still
6b3f556
# app.py
import gradio as gr
from concept_steerer import ConceptSteerer
steerer = None
def get_steerer():
global steerer
if steerer is None:
steerer = ConceptSteerer(model_name="unsloth/Llama-3.2-1B-Instruct")
return steerer
def create_concept(name, pos_examples, neg_examples):
if not name.strip():
return "❌ Concept name is required."
pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()]
neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()]
if not pos_list or not neg_list:
return "❌ Provide at least one positive and one negative example."
try:
s = get_steerer()
s.register_concept(name.strip(), pos_list, neg_list, layer=-2)
return f"βœ… Concept '{name}' registered!"
except Exception as e:
return f"❌ Error: {e}"
def update_sliders():
s = get_steerer()
concepts = s.get_concept_names()
MAX = 10
updates = []
for i in range(MAX):
if i < len(concepts):
updates.append(gr.update(visible=True, label=concepts[i]))
else:
updates.append(gr.update(visible=False))
return updates
def generate_text(prompt, *slider_vals):
s = get_steerer()
concepts = s.get_concept_names()
steering = {}
for name, val in zip(concepts, slider_vals[:len(concepts)]):
if abs(val) > 1e-6:
steering[name] = float(val)
try:
return s.generate(prompt, steering_config=steering, max_new_tokens=150)
except Exception as e:
return f"❌ Error: {e}"
# Build UI with fixed 10 sliders (hidden by default)
with gr.Blocks() as demo:
gr.Markdown("# 🧠 LLM Concept Steering β€” Working Version")
with gr.Tab("Create Concepts"):
with gr.Row():
with gr.Column():
name_in = gr.Textbox(label="Concept Name")
pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4)
neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4)
create_btn = gr.Button("Register Concept")
with gr.Column():
status_out = gr.Textbox(label="Status", interactive=False)
create_btn.click(
create_concept,
inputs=[name_in, pos_in, neg_in],
outputs=status_out
)
with gr.Tab("Generate"):
prompt_in = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., Tell me a story.")
# Pre-create 10 sliders (will be shown/hidden dynamically)
sliders = []
for i in range(10):
slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False)
sliders.append(slider)
gen_btn = gr.Button("Generate")
output_out = gr.Textbox(label="Output", lines=8, interactive=False)
gen_btn.click(
generate_text,
inputs=[prompt_in] + sliders,
outputs=output_out
)
# Update sliders when the app loads or when tab is viewed
demo.load(update_sliders, inputs=None, outputs=sliders)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)