# app.py import gradio as gr from concept_steerer import ConceptSteerer steerer = None def get_steerer(): global steerer if steerer is None: steerer = ConceptSteerer(model_name="unsloth/Llama-3.2-1B-Instruct") return steerer def create_concept(name, pos_examples, neg_examples): if not name.strip(): return "❌ Concept name is required." pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()] neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()] if not pos_list or not neg_list: return "❌ Provide at least one positive and one negative example." try: s = get_steerer() s.register_concept(name.strip(), pos_list, neg_list, layer=-2) return f"✅ Concept '{name}' registered!" except Exception as e: return f"❌ Error: {e}" def update_sliders(): s = get_steerer() concepts = s.get_concept_names() MAX = 10 updates = [] for i in range(MAX): if i < len(concepts): updates.append(gr.update(visible=True, label=concepts[i])) else: updates.append(gr.update(visible=False)) return updates def generate_text(prompt, *slider_vals): s = get_steerer() concepts = s.get_concept_names() steering = {} for name, val in zip(concepts, slider_vals[:len(concepts)]): if abs(val) > 1e-6: steering[name] = float(val) try: return s.generate(prompt, steering_config=steering, max_new_tokens=150) except Exception as e: return f"❌ Error: {e}" # Build UI with fixed 10 sliders (hidden by default) with gr.Blocks() as demo: gr.Markdown("# 🧠 LLM Concept Steering — Working Version") with gr.Tab("Create Concepts"): with gr.Row(): with gr.Column(): name_in = gr.Textbox(label="Concept Name") pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4) neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4) create_btn = gr.Button("Register Concept") with gr.Column(): status_out = gr.Textbox(label="Status", interactive=False) create_btn.click( create_concept, inputs=[name_in, pos_in, neg_in], outputs=status_out ) with gr.Tab("Generate"): prompt_in = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., Tell me a story.") # Pre-create 10 sliders (will be shown/hidden dynamically) sliders = [] for i in range(10): slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False) sliders.append(slider) gen_btn = gr.Button("Generate") output_out = gr.Textbox(label="Output", lines=8, interactive=False) gen_btn.click( generate_text, inputs=[prompt_in] + sliders, outputs=output_out ) # Update sliders when the app loads or when tab is viewed demo.load(update_sliders, inputs=None, outputs=sliders) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", share=False)