|
|
|
|
|
import gradio as gr |
|
|
from concept_steerer import ConceptSteerer |
|
|
|
|
|
steerer = None |
|
|
|
|
|
def get_steerer(): |
|
|
global steerer |
|
|
if steerer is None: |
|
|
steerer = ConceptSteerer(model_name="unsloth/Llama-3.2-1B-Instruct") |
|
|
return steerer |
|
|
|
|
|
def create_concept(name, pos_examples, neg_examples): |
|
|
if not name.strip(): |
|
|
return "β Concept name is required." |
|
|
pos_list = [p.strip() for p in pos_examples.strip().split('\n') if p.strip()] |
|
|
neg_list = [n.strip() for n in neg_examples.strip().split('\n') if n.strip()] |
|
|
if not pos_list or not neg_list: |
|
|
return "β Provide at least one positive and one negative example." |
|
|
try: |
|
|
s = get_steerer() |
|
|
s.register_concept(name.strip(), pos_list, neg_list, layer=-2) |
|
|
return f"β
Concept '{name}' registered!" |
|
|
except Exception as e: |
|
|
return f"β Error: {e}" |
|
|
|
|
|
def update_sliders(): |
|
|
s = get_steerer() |
|
|
concepts = s.get_concept_names() |
|
|
MAX = 10 |
|
|
updates = [] |
|
|
for i in range(MAX): |
|
|
if i < len(concepts): |
|
|
updates.append(gr.update(visible=True, label=concepts[i])) |
|
|
else: |
|
|
updates.append(gr.update(visible=False)) |
|
|
return updates |
|
|
|
|
|
def generate_text(prompt, *slider_vals): |
|
|
s = get_steerer() |
|
|
concepts = s.get_concept_names() |
|
|
steering = {} |
|
|
for name, val in zip(concepts, slider_vals[:len(concepts)]): |
|
|
if abs(val) > 1e-6: |
|
|
steering[name] = float(val) |
|
|
try: |
|
|
return s.generate(prompt, steering_config=steering, max_new_tokens=150) |
|
|
except Exception as e: |
|
|
return f"β Error: {e}" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# π§ LLM Concept Steering β Working Version") |
|
|
|
|
|
with gr.Tab("Create Concepts"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
name_in = gr.Textbox(label="Concept Name") |
|
|
pos_in = gr.Textbox(label="Positive Prompts (one per line)", lines=4) |
|
|
neg_in = gr.Textbox(label="Negative Prompts (one per line)", lines=4) |
|
|
create_btn = gr.Button("Register Concept") |
|
|
with gr.Column(): |
|
|
status_out = gr.Textbox(label="Status", interactive=False) |
|
|
create_btn.click( |
|
|
create_concept, |
|
|
inputs=[name_in, pos_in, neg_in], |
|
|
outputs=status_out |
|
|
) |
|
|
|
|
|
with gr.Tab("Generate"): |
|
|
prompt_in = gr.Textbox(label="Prompt", lines=2, placeholder="e.g., Tell me a story.") |
|
|
|
|
|
|
|
|
sliders = [] |
|
|
for i in range(10): |
|
|
slider = gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label=f"Concept {i+1}", visible=False) |
|
|
sliders.append(slider) |
|
|
|
|
|
gen_btn = gr.Button("Generate") |
|
|
output_out = gr.Textbox(label="Output", lines=8, interactive=False) |
|
|
|
|
|
gen_btn.click( |
|
|
generate_text, |
|
|
inputs=[prompt_in] + sliders, |
|
|
outputs=output_out |
|
|
) |
|
|
|
|
|
|
|
|
demo.load(update_sliders, inputs=None, outputs=sliders) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", share=False) |