|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import time |
|
|
import random |
|
|
import json |
|
|
import mysql.connector |
|
|
import os |
|
|
import csv |
|
|
import spaces |
|
|
import torch |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
from typing import Iterator |
|
|
from huggingface_hub import Repository, hf_hub_download |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
import mysql.connector |
|
|
import urllib.parse |
|
|
import urllib.request |
|
|
|
|
|
|
|
|
import atexit |
|
|
import os |
|
|
from huggingface_hub import HfApi, HfFolder |
|
|
|
|
|
|
|
|
import huggingface_hub |
|
|
from huggingface_hub import Repository |
|
|
from datetime import datetime |
|
|
|
|
|
DATASET_REPO_URL = "https://huggingface.co/datasets/botsi/trust-game-llama-2-chat-history" |
|
|
DATA_FILENAME = "data.csv" |
|
|
DATA_FILE = os.path.join("data", DATA_FILENAME) |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
print("is none?", HF_TOKEN is None) |
|
|
print("hfh", huggingface_hub.__version__) |
|
|
|
|
|
repo = Repository( |
|
|
local_dir="data", clone_from=DATASET_REPO_URL |
|
|
) |
|
|
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
|
|
|
|
|
DESCRIPTION = """\ |
|
|
# This is your personal space to chat. |
|
|
You can ask anything: From discussing strategic game tactics to enjoying casual conversation. |
|
|
""" |
|
|
|
|
|
|
|
|
LICENSE = """ |
|
|
<p/> |
|
|
|
|
|
--- |
|
|
This demo is governed by the [original license](https://ai.meta.com/llama/license/) and [acceptable use policy](https://ai.meta.com/llama/use-policy/). |
|
|
The most recent copy of this policy can be found at ai.meta.com/llama/use-policy. |
|
|
""" |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model_id = "meta-llama/Llama-2-7b-chat-hf" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.use_default_system_prompt = False |
|
|
|
|
|
def fetch_personalized_data(session_index): |
|
|
try: |
|
|
|
|
|
with mysql.connector.connect( |
|
|
host="18.153.94.89", |
|
|
user="root", |
|
|
password="N12RXMKtKxRj", |
|
|
database="lionessdb" |
|
|
) as conn: |
|
|
|
|
|
with conn.cursor() as cursor: |
|
|
|
|
|
query = """ |
|
|
SELECT e5390g37504_core.playerNr, |
|
|
e5390g37504_core.groupNrStart, |
|
|
e5390g37504_core.subjectNr, |
|
|
e5390g37504_core.onPage, |
|
|
e5390g37504_decisions.session_index, |
|
|
e5390g37504_decisions.transfer1, |
|
|
e5390g37504_decisions.tripledAmount1, |
|
|
e5390g37504_decisions.keptForSelf1, |
|
|
e5390g37504_decisions.returned1, |
|
|
e5390g37504_decisions.newCreditRound2, |
|
|
e5390g37504_decisions.transfer2, |
|
|
e5390g37504_decisions.tripledAmount2, |
|
|
e5390g37504_decisions.keptForSelf2, |
|
|
e5390g37504_decisions.returned2, |
|
|
e5390g37504_decisions.results2rounds, |
|
|
e5390g37504_decisions.newCreditRound3, |
|
|
e5390g37504_decisions.transfer3, |
|
|
e5390g37504_decisions.tripledAmount3, |
|
|
e5390g37504_decisions.keptForSelf3, |
|
|
e5390g37504_decisions.returned3, |
|
|
e5390g37504_decisions.results3rounds |
|
|
FROM e5390g37504_core |
|
|
JOIN e5390g37504_decisions ON |
|
|
e5390g37504_core.playerNr = e5390g37504_decisions.playerNr |
|
|
WHERE e5390g37504_decisions.session_index = %s |
|
|
UNION ALL |
|
|
SELECT e5390g37504_core.playerNr, |
|
|
e5390g37504_core.groupNrStart, |
|
|
e5390g37504_core.subjectNr, |
|
|
e5390g37504_core.onPage, |
|
|
e5390g37504_decisions.session_index, |
|
|
e5390g37504_decisions.transfer1, |
|
|
e5390g37504_decisions.tripledAmount1, |
|
|
e5390g37504_decisions.keptForSelf1, |
|
|
e5390g37504_decisions.returned1, |
|
|
e5390g37504_decisions.newCreditRound2, |
|
|
e5390g37504_decisions.transfer2, |
|
|
e5390g37504_decisions.tripledAmount2, |
|
|
e5390g37504_decisions.keptForSelf2, |
|
|
e5390g37504_decisions.returned2, |
|
|
e5390g37504_decisions.results2rounds, |
|
|
e5390g37504_decisions.newCreditRound3, |
|
|
e5390g37504_decisions.transfer3, |
|
|
e5390g37504_decisions.tripledAmount3, |
|
|
e5390g37504_decisions.keptForSelf3, |
|
|
e5390g37504_decisions.returned3, |
|
|
e5390g37504_decisions.results3rounds |
|
|
FROM e5390g37504_core |
|
|
JOIN e5390g37504_decisions |
|
|
ON e5390g37504_core.playerNr = e5390g37504_decisions.playerNr |
|
|
WHERE e5390g37504_core.groupNrStart IN ( |
|
|
SELECT DISTINCT groupNrStart |
|
|
FROM e5390g37504_core |
|
|
JOIN e5390g37504_decisions |
|
|
ON e5390g37504_core.playerNr = e5390g37504_decisions.playerNr |
|
|
WHERE e5390g37504_decisions.session_index = %s |
|
|
) AND e5390g37504_decisions.session_index != %s |
|
|
""" |
|
|
cursor.execute(query,(session_index, session_index, session_index)) |
|
|
|
|
|
data = [{ |
|
|
'playerNr': row[0], |
|
|
'groupNrStart': row[1], |
|
|
'subjectNr': row[2], |
|
|
'onPage': row[3], |
|
|
'session_index': row[4], |
|
|
'transfer1': row[5], |
|
|
'tripledAmount1': row[6], |
|
|
'keptForSelf1': row[7], |
|
|
'returned1': row[8], |
|
|
'newCreditRound2': row[9], |
|
|
'transfer2': row[10], |
|
|
'tripledAmount2': row[11], |
|
|
'keptForSelf2': row[12], |
|
|
'returned2': row[13], |
|
|
'results2rounds': row[14], |
|
|
'newCreditRound3': row[15], |
|
|
'transfer3': row[16], |
|
|
'tripledAmount3': row[17], |
|
|
'keptForSelf3': row[18], |
|
|
'returned3': row[19], |
|
|
'results3rounds': row[20] |
|
|
} for row in cursor] |
|
|
print(data) |
|
|
return data |
|
|
except mysql.connector.Error as err: |
|
|
print(f"Error: {err}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def get_default_system_prompt(personalized_data): |
|
|
|
|
|
|
|
|
BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n" |
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = f""" You are an intelligent and fair game guide in a 2-player trust game. |
|
|
You are assisting players in making decisions to win. |
|
|
Answer in a consistent style. Each of your answers should be maximum 2 sentences long. |
|
|
The players are called The Investor and The Dealer and keep their role throughout the whole game. |
|
|
Both start with 10€ in round 1. The game consists of 3 rounds. In round 1, The Investor invests between 0€ and 10€. |
|
|
This amount is tripled automatically, and The Dealer can then distribute the tripled amount. After that, round 1 is over. |
|
|
Both go into the next round with their current asset: The Investor with 10€ minus what he invested plus what he received back from The Dealer. |
|
|
The Dealer with 10€ plus what he kept from the tripled amount. |
|
|
You will receive a JSON with information on who trusted whom with how much money after each round as context. |
|
|
Your goal is to guide players through the game, providing clear instructions and explanations. |
|
|
If any question or action seems unclear, explain it rather than providing inaccurate information. |
|
|
If you're unsure about an answer, it's better not to guess. |
|
|
|
|
|
Example JSON context after a round: {personalized_data} |
|
|
|
|
|
Few-shot training examples |
|
|
{BSYS} Give an overview of the trust game. {ESYS} |
|
|
{BSYS} Explain how trust amounts are calculated. {ESYS} |
|
|
{BSYS} What happens if a player doesn't trust in a round? {ESYS} |
|
|
""" |
|
|
print(DEFAULT_SYSTEM_PROMPT) |
|
|
return DEFAULT_SYSTEM_PROMPT |
|
|
|
|
|
|
|
|
|
|
|
def construct_input_prompt(chat_history, message, personalized_data): |
|
|
input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt(personalized_data)}\n<</SYS>>\n\n " |
|
|
for user, assistant in chat_history: |
|
|
input_prompt += f"{user} [/INST] {assistant} <s>[INST] " |
|
|
input_prompt += f"{message} [/INST] " |
|
|
return input_prompt |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate( |
|
|
request: gr.Request, |
|
|
message: str, |
|
|
chat_history: list[tuple[str, str]], |
|
|
|
|
|
max_new_tokens: int = 1024, |
|
|
temperature: float = 0.6, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.2, |
|
|
) -> Iterator[str]: |
|
|
|
|
|
conversation = [] |
|
|
|
|
|
|
|
|
params = request.query_params |
|
|
print('those are the query params') |
|
|
print(params) |
|
|
|
|
|
|
|
|
|
|
|
session_index = params.get('session_index') |
|
|
|
|
|
|
|
|
if session_index is not None: |
|
|
print("Session index:", session_index) |
|
|
else: |
|
|
print("Session index not found or has no value.") |
|
|
|
|
|
|
|
|
personalized_data = fetch_personalized_data(session_index) |
|
|
|
|
|
|
|
|
input_prompt = construct_input_prompt(chat_history, message, personalized_data) |
|
|
|
|
|
|
|
|
if input_prompt: |
|
|
conversation.append({"role": "system", "content": input_prompt}) |
|
|
|
|
|
|
|
|
input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
for user, assistant in chat_history: |
|
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") |
|
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generate_kwargs = dict( |
|
|
{"input_ids": input_ids}, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
temperature=temperature, |
|
|
num_beams=1, |
|
|
repetition_penalty=repetition_penalty, |
|
|
) |
|
|
|
|
|
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
|
t.start() |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
for text in streamer: |
|
|
outputs.append(text) |
|
|
yield "".join(outputs) |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
|
fn=generate, |
|
|
retry_btn=None, |
|
|
clear_btn=None, |
|
|
undo_btn=None, |
|
|
chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False), |
|
|
examples=[ |
|
|
["Can you explain the rules very briefly again?"], |
|
|
["How much should I invest in order to win?"], |
|
|
["What happened in the last round?"], |
|
|
["What is my probability to win if I do not share anything?"], |
|
|
], |
|
|
) |
|
|
|
|
|
with gr.Blocks(css="style.css", theme=gr.themes.Default(primary_hue=gr.themes.colors.emerald,secondary_hue=gr.themes.colors.indigo)) as demo: |
|
|
gr.Markdown(DESCRIPTION) |
|
|
chat_interface.render() |
|
|
gr.Markdown(LICENSE) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=20).launch() |
|
|
|
|
|
demo.launch(share=True, debug=True) |
|
|
|