Spaces:
Runtime error
Runtime error
fix queue (#6)
Browse files- fix queue (9431ee6c6359396371b5551bbcb2e678cfa4b060)
Co-authored-by: Radamés Ajna <radames@users.noreply.huggingface.co>
app.py
CHANGED
|
@@ -4,11 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Stopping
|
|
| 4 |
import time
|
| 5 |
import numpy as np
|
| 6 |
from torch.nn import functional as F
|
| 7 |
-
import os
|
| 8 |
-
auth_key = os.environ["HF_ACCESS_TOKEN"]
|
| 9 |
print(f"Starting to load the model to memory")
|
| 10 |
-
m = AutoModelForCausalLM.from_pretrained(
|
| 11 |
-
|
|
|
|
| 12 |
generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
|
| 13 |
print(f"Sucessfully loaded the model to the memory")
|
| 14 |
|
|
@@ -30,8 +31,10 @@ class StopOnTokens(StoppingCriteria):
|
|
| 30 |
|
| 31 |
def contrastive_generate(text, bad_text):
|
| 32 |
with torch.no_grad():
|
| 33 |
-
tokens = tok(text, return_tensors="pt")[
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
history = None
|
| 36 |
bad_history = None
|
| 37 |
curr_output = list()
|
|
@@ -39,7 +42,8 @@ def contrastive_generate(text, bad_text):
|
|
| 39 |
out = m(tokens, past_key_values=history, use_cache=True)
|
| 40 |
logits = out.logits
|
| 41 |
history = out.past_key_values
|
| 42 |
-
bad_out = m(bad_tokens, past_key_values=bad_history,
|
|
|
|
| 43 |
bad_logits = bad_out.logits
|
| 44 |
bad_history = bad_out.past_key_values
|
| 45 |
probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
|
|
@@ -60,39 +64,48 @@ def contrastive_generate(text, bad_text):
|
|
| 60 |
tokens.device)
|
| 61 |
return tok.decode(curr_output)
|
| 62 |
|
|
|
|
| 63 |
def generate(text, bad_text=None):
|
| 64 |
stop = StopOnTokens()
|
| 65 |
-
result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
|
|
|
|
| 66 |
return result[0]["generated_text"].replace(text, "")
|
| 67 |
|
| 68 |
|
| 69 |
def user(user_message, history):
|
| 70 |
-
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def bot(history, curr_system_message):
|
| 74 |
-
messages = curr_system_message +
|
|
|
|
|
|
|
| 75 |
output = generate(messages)
|
| 76 |
history[-1][1] = output
|
| 77 |
time.sleep(1)
|
| 78 |
-
return history
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
|
| 83 |
with gr.Blocks() as demo:
|
| 84 |
-
|
| 85 |
gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
|
| 86 |
gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
|
| 87 |
-
chatbot = gr.Chatbot(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
msg.submit(user, [msg,
|
| 94 |
-
bot, [chatbot, system_msg], chatbot
|
| 95 |
-
)
|
| 96 |
-
|
|
|
|
| 97 |
demo.queue(concurrency_count=5)
|
| 98 |
-
demo.launch()
|
|
|
|
| 4 |
import time
|
| 5 |
import numpy as np
|
| 6 |
from torch.nn import functional as F
|
| 7 |
+
import os
|
| 8 |
+
# auth_key = os.environ["HF_ACCESS_TOKEN"]
|
| 9 |
print(f"Starting to load the model to memory")
|
| 10 |
+
m = AutoModelForCausalLM.from_pretrained(
|
| 11 |
+
"stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
|
| 12 |
+
tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
|
| 13 |
generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
|
| 14 |
print(f"Sucessfully loaded the model to the memory")
|
| 15 |
|
|
|
|
| 31 |
|
| 32 |
def contrastive_generate(text, bad_text):
|
| 33 |
with torch.no_grad():
|
| 34 |
+
tokens = tok(text, return_tensors="pt")[
|
| 35 |
+
'input_ids'].cuda()[:, :4096-1024]
|
| 36 |
+
bad_tokens = tok(bad_text, return_tensors="pt")[
|
| 37 |
+
'input_ids'].cuda()[:, :4096-1024]
|
| 38 |
history = None
|
| 39 |
bad_history = None
|
| 40 |
curr_output = list()
|
|
|
|
| 42 |
out = m(tokens, past_key_values=history, use_cache=True)
|
| 43 |
logits = out.logits
|
| 44 |
history = out.past_key_values
|
| 45 |
+
bad_out = m(bad_tokens, past_key_values=bad_history,
|
| 46 |
+
use_cache=True)
|
| 47 |
bad_logits = bad_out.logits
|
| 48 |
bad_history = bad_out.past_key_values
|
| 49 |
probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
|
|
|
|
| 64 |
tokens.device)
|
| 65 |
return tok.decode(curr_output)
|
| 66 |
|
| 67 |
+
|
| 68 |
def generate(text, bad_text=None):
|
| 69 |
stop = StopOnTokens()
|
| 70 |
+
result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
|
| 71 |
+
temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
|
| 72 |
return result[0]["generated_text"].replace(text, "")
|
| 73 |
|
| 74 |
|
| 75 |
def user(user_message, history):
|
| 76 |
+
history = history + [[user_message, ""]]
|
| 77 |
+
return "", history, history
|
| 78 |
|
| 79 |
|
| 80 |
def bot(history, curr_system_message):
|
| 81 |
+
messages = curr_system_message + \
|
| 82 |
+
"".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
|
| 83 |
+
for item in history])
|
| 84 |
output = generate(messages)
|
| 85 |
history[-1][1] = output
|
| 86 |
time.sleep(1)
|
| 87 |
+
return history, history
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
with gr.Blocks() as demo:
|
| 91 |
+
history = gr.State([])
|
| 92 |
gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
|
| 93 |
gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
|
| 94 |
+
chatbot = gr.Chatbot().style(height=500)
|
| 95 |
+
with gr.Row():
|
| 96 |
+
with gr.Column(scale=0.70):
|
| 97 |
+
msg = gr.Textbox(label="", placeholder="Chat Message Box")
|
| 98 |
+
with gr.Column(scale=0.30, min_width=0):
|
| 99 |
+
with gr.Row():
|
| 100 |
+
submit = gr.Button("Submit")
|
| 101 |
+
clear = gr.Button("Clear")
|
| 102 |
+
system_msg = gr.Textbox(
|
| 103 |
+
start_message, label="System Message", interactive=False, visible=False)
|
| 104 |
|
| 105 |
+
msg.submit(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
|
| 106 |
+
fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
|
| 107 |
+
submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
|
| 108 |
+
fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
|
| 109 |
+
clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
|
| 110 |
demo.queue(concurrency_count=5)
|
| 111 |
+
demo.launch()
|