|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import time |
|
|
import random |
|
|
import json |
|
|
import mysql.connector |
|
|
import os |
|
|
import csv |
|
|
import spaces |
|
|
import torch |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
from typing import Iterator |
|
|
from huggingface_hub import Repository, hf_hub_download |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
import mysql.connector |
|
|
import urllib.parse |
|
|
import urllib.request |
|
|
|
|
|
|
|
|
import atexit |
|
|
import os |
|
|
from huggingface_hub import HfApi, HfFolder |
|
|
|
|
|
|
|
|
import huggingface_hub |
|
|
from huggingface_hub import Repository |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
import sqlite3 |
|
|
import huggingface_hub |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import shutil |
|
|
import os |
|
|
import datetime |
|
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
|
|
|
|
|
|
|
DATASET_REPO_URL = "https://huggingface.co/datasets/botsi/trust-game-llama-2-chat-history" |
|
|
DATA_DIRECTORY = "data" |
|
|
DATA_FILENAME = "7B.csv" |
|
|
DATA_FILE = os.path.join("data", DATA_FILENAME) |
|
|
|
|
|
DB_PASSWORD = os.environ.get("DB_PASSWORD") |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
print("is none?", HF_TOKEN is None) |
|
|
print("hfh", huggingface_hub.__version__) |
|
|
|
|
|
repo = Repository( |
|
|
local_dir=DATA_DIRECTORY, clone_from=DATASET_REPO_URL |
|
|
) |
|
|
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
|
|
|
|
|
DESCRIPTION = """\ |
|
|
# This is your personal space to chat. |
|
|
You can ask anything: From discussing strategic game tactics to enjoying casual conversation. |
|
|
For example you could ask, what happened in the last round, what is your probability to win when you invest amount xy, what is my current balance etc. |
|
|
""" |
|
|
|
|
|
|
|
|
LICENSE = """ |
|
|
<p/> |
|
|
|
|
|
--- |
|
|
This demo is governed by the [original license](https://ai.meta.com/llama/license/) and [acceptable use policy](https://ai.meta.com/llama/use-policy/). |
|
|
The most recent copy of this policy can be found at ai.meta.com/llama/use-policy. |
|
|
""" |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model_id = "meta-llama/Llama-2-7b-chat-hf" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.use_default_system_prompt = False |
|
|
|
|
|
|
|
|
def fetch_personalized_data(externalID): |
|
|
try: |
|
|
|
|
|
with mysql.connector.connect( |
|
|
host="3.125.179.74", |
|
|
user="root", |
|
|
password=DB_PASSWORD, |
|
|
database="lionessdb" |
|
|
) as conn: |
|
|
|
|
|
with conn.cursor() as cursor: |
|
|
|
|
|
query = """ |
|
|
SELECT e5390g37899_core.playerNr, |
|
|
e5390g37899_core.groupNrStart, |
|
|
e5390g37899_core.subjectNr, |
|
|
e5390g37899_core.onPage, |
|
|
e5390g37899_core.role, |
|
|
e5390g37899_session.externalID, |
|
|
e5390g37899_decisions.initialCredit, |
|
|
e5390g37899_decisions.part, |
|
|
e5390g37899_decisions.transfer1, |
|
|
e5390g37899_decisions.tripledAmount1, |
|
|
e5390g37899_decisions.keptForSelf1, |
|
|
e5390g37899_decisions.returned1, |
|
|
e5390g37899_decisions.totalRound1, |
|
|
e5390g37899_decisions.transfer2, |
|
|
e5390g37899_decisions.tripledAmount2, |
|
|
e5390g37899_decisions.keptForSelf2, |
|
|
e5390g37899_decisions.returned2, |
|
|
e5390g37899_decisions.totalRound2, |
|
|
e5390g37899_decisions.transfer3, |
|
|
e5390g37899_decisions.tripledAmount3, |
|
|
e5390g37899_decisions.keptForSelf3, |
|
|
e5390g37899_decisions.returned3, |
|
|
e5390g37899_decisions.totalRound3, |
|
|
e5390g37899_decisions.transfer4, |
|
|
e5390g37899_decisions.tripledAmount4, |
|
|
e5390g37899_decisions.keptForSelf4, |
|
|
e5390g37899_decisions.returned4, |
|
|
e5390g37899_decisions.totalRound4, |
|
|
e5390g37899_decisions.transfer5, |
|
|
e5390g37899_decisions.tripledAmount5, |
|
|
e5390g37899_decisions.keptForSelf5, |
|
|
e5390g37899_decisions.returned5, |
|
|
e5390g37899_decisions.totalRound5, |
|
|
e5390g37899_decisions.transfer6, |
|
|
e5390g37899_decisions.tripledAmount6, |
|
|
e5390g37899_decisions.keptForSelf6 |
|
|
FROM e5390g37899_core |
|
|
JOIN e5390g37899_session ON |
|
|
e5390g37899_core.playerNr = e5390g37899_session.playerNr |
|
|
JOIN e5390g37899_decisions ON |
|
|
e5390g37899_core.playerNr = e5390g37899_decisions.playerNr |
|
|
WHERE e5390g37899_session.externalID = %s |
|
|
UNION ALL |
|
|
SELECT e5390g37899_core.playerNr, |
|
|
e5390g37899_core.groupNrStart, |
|
|
e5390g37899_core.subjectNr, |
|
|
e5390g37899_core.onPage, |
|
|
e5390g37899_core.role, |
|
|
e5390g37899_session.externalID, |
|
|
e5390g37899_decisions.initialCredit, |
|
|
e5390g37899_decisions.part, |
|
|
e5390g37899_decisions.transfer1, |
|
|
e5390g37899_decisions.tripledAmount1, |
|
|
e5390g37899_decisions.keptForSelf1, |
|
|
e5390g37899_decisions.returned1, |
|
|
e5390g37899_decisions.totalRound1, |
|
|
e5390g37899_decisions.transfer2, |
|
|
e5390g37899_decisions.tripledAmount2, |
|
|
e5390g37899_decisions.keptForSelf2, |
|
|
e5390g37899_decisions.returned2, |
|
|
e5390g37899_decisions.totalRound2, |
|
|
e5390g37899_decisions.transfer3, |
|
|
e5390g37899_decisions.tripledAmount3, |
|
|
e5390g37899_decisions.keptForSelf3, |
|
|
e5390g37899_decisions.returned3, |
|
|
e5390g37899_decisions.totalRound3, |
|
|
e5390g37899_decisions.transfer4, |
|
|
e5390g37899_decisions.tripledAmount4, |
|
|
e5390g37899_decisions.keptForSelf4, |
|
|
e5390g37899_decisions.returned4, |
|
|
e5390g37899_decisions.totalRound4, |
|
|
e5390g37899_decisions.transfer5, |
|
|
e5390g37899_decisions.tripledAmount5, |
|
|
e5390g37899_decisions.keptForSelf5, |
|
|
e5390g37899_decisions.returned5, |
|
|
e5390g37899_decisions.totalRound5, |
|
|
e5390g37899_decisions.transfer6, |
|
|
e5390g37899_decisions.tripledAmount6, |
|
|
e5390g37899_decisions.keptForSelf6 |
|
|
FROM e5390g37899_core |
|
|
JOIN e5390g37899_session ON |
|
|
e5390g37899_core.playerNr = e5390g37899_session.playerNr |
|
|
JOIN e5390g37899_decisions |
|
|
ON e5390g37899_core.playerNr = e5390g37899_decisions.playerNr |
|
|
WHERE e5390g37899_core.groupNrStart IN ( |
|
|
SELECT DISTINCT groupNrStart |
|
|
FROM e5390g37899_core |
|
|
JOIN e5390g37899_session |
|
|
ON e5390g37899_core.playerNr = e5390g37899_session.playerNr |
|
|
WHERE e5390g37899_session.externalID = %s |
|
|
) AND e5390g37899_session.externalID != %s |
|
|
""" |
|
|
cursor.execute(query,(externalID, externalID, externalID)) |
|
|
|
|
|
data = [{ |
|
|
'playerNr': row[0], |
|
|
'groupNrStart': row[1], |
|
|
'subjectNr': row[2], |
|
|
'onPage': row[3], |
|
|
'role': row[4], |
|
|
'externalID': row[5], |
|
|
'initialCredit': row[6], |
|
|
'part': row[7], |
|
|
'transfer1': row[8], |
|
|
'tripledAmount1': row[9], |
|
|
'keptForSelf1': row[10], |
|
|
'returned1': row[11], |
|
|
'totalRound1': row[12], |
|
|
'transfer2': row[13], |
|
|
'tripledAmount2': row[14], |
|
|
'keptForSelf2': row[15], |
|
|
'returned2': row[16], |
|
|
'totalRound2': row[17], |
|
|
'transfer3': row[18], |
|
|
'tripledAmount3': row[19], |
|
|
'keptForSelf3': row[20], |
|
|
'returned3': row[21], |
|
|
'totalRound3': row[22], |
|
|
'transfer4': row[23], |
|
|
'tripledAmount4': row[24], |
|
|
'keptForSelf4': row[25], |
|
|
'returned4': row[26], |
|
|
'totalRound4': row[27], |
|
|
'transfer5': row[28], |
|
|
'tripledAmount5': row[29], |
|
|
'keptForSelf5': row[30], |
|
|
'returned5': row[31], |
|
|
'totalRound5': row[32], |
|
|
'transfer6': row[33], |
|
|
'tripledAmount6': row[34], |
|
|
'keptForSelf6': row[35] |
|
|
} for row in cursor] |
|
|
print(data) |
|
|
return data |
|
|
except mysql.connector.Error as err: |
|
|
print(f"Error: {err}") |
|
|
return None |
|
|
|
|
|
def extract_variables(all_personalized_data, part): |
|
|
extracted_data = {} |
|
|
|
|
|
if part == "1": |
|
|
rounds = range(1, 4) |
|
|
elif part == "2": |
|
|
rounds = range(4, 7) |
|
|
else: |
|
|
print("No data for the particular part found") |
|
|
return None |
|
|
|
|
|
for data in all_personalized_data: |
|
|
role = map_role(str(data.get('role', 'unknown'))) |
|
|
player_data = {} |
|
|
for round_num in rounds: |
|
|
round_key = f'round{round_num - 3 if part == "2" else round_num}' |
|
|
player_data[round_key] = {} |
|
|
for var in ['transfer', 'tripledAmount', 'keptForSelf', 'returned', 'totalRound']: |
|
|
var_name = f'{var}{round_num}' |
|
|
if role == 'The Dealer' and var == 'tripledAmount': |
|
|
continue |
|
|
if role == 'The Investor' and var == 'keptForSelf': |
|
|
continue |
|
|
if data.get(var_name) is not None: |
|
|
player_data[round_key][var] = data[var_name] |
|
|
|
|
|
|
|
|
if role in extracted_data: |
|
|
extracted_data[role].update(player_data) |
|
|
else: |
|
|
extracted_data[role] = player_data |
|
|
|
|
|
return extracted_data |
|
|
|
|
|
|
|
|
def map_onPage(onPage): |
|
|
|
|
|
onPage_mapping_dict = { |
|
|
"stage411228.php": ("stage 6", "Round 1: Investor’s turn"), |
|
|
"stage411229.php": ("stage 7", "Round 1: Dealer’s turn"), |
|
|
"stage411230.php": ("stage 8", "Round 2: Investor’s turn"), |
|
|
"stage411231.php": ("stage 9", "Round 2: Investor’s turn"), |
|
|
"stage411232.php": ("stage 10", "Round 3: Investor’s turn"), |
|
|
"stage411233.php": ("stage 11", "Round 3: Dealer’s turn"), |
|
|
"stage411235.php": ("stage 13", "Round 1: Investor’s turn"), |
|
|
"stage411236.php": ("stage 14", "Round 1: Dealer’s turn"), |
|
|
"stage411237.php": ("stage 15", "Round 2: Investor’s turn"), |
|
|
"stage411238.php": ("stage 16", "Round 2: Investor’s turn"), |
|
|
"stage411239.php": ("stage 17", "Round 3: Investor’s turn"), |
|
|
"stage411240.php": ("stage 18", "Round 3: Dealer’s turn"), |
|
|
} |
|
|
|
|
|
if onPage in onPage_mapping_dict: |
|
|
onPage_filename, onPage_prompt = onPage_mapping_dict[onPage] |
|
|
else: |
|
|
|
|
|
onPage_filename, onPage_prompt = "unknown", "unknown" |
|
|
return onPage_filename, onPage_prompt |
|
|
|
|
|
def map_role(role): |
|
|
|
|
|
role_mapping_dict = { |
|
|
"1": "The Investor", |
|
|
"2": "The Dealer" |
|
|
} |
|
|
|
|
|
if role in role_mapping_dict: |
|
|
role_prompt = role_mapping_dict[role] |
|
|
else: |
|
|
|
|
|
role_prompt = "unknown" |
|
|
return role_prompt |
|
|
|
|
|
|
|
|
|
|
|
def get_default_system_prompt(extracted_data, onPage_prompt, role_prompt): |
|
|
|
|
|
|
|
|
BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n" |
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = f""" You are a smart game assistant for a Trust Game outside of this chat. |
|
|
Trust Game rules: Two players, The Investor and The Dealer, each play to maximize their own earnings. |
|
|
There are 3 rounds. Every round follows the same pattern. |
|
|
- First, each player gets a virtual starting credit of 10 coins. |
|
|
- Investor's turn: The Investor decides how much they want to investo into a shared pot. The shared pot is tripled automatically before the Dealer's turn. |
|
|
- Dealer's turn: The Dealer can keep and return as much of the tripled amount as they like. Their virtual starting credit remains untouched. |
|
|
Earnings from each round are not transferred to the next round. Each or your answers should be maximum 2 sentences long. |
|
|
Answer in a consistent style. If you are unsure about an answer, do not guess. |
|
|
Currently it is {role_prompt}’s turn so you are assisting {role_prompt}. Answer directly to the player. The currency is coins. |
|
|
The game is currently in {onPage_prompt}. |
|
|
This is what happened in the last rounds: {extracted_data}. |
|
|
""" |
|
|
print(DEFAULT_SYSTEM_PROMPT) |
|
|
return DEFAULT_SYSTEM_PROMPT |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct_input_prompt(chat_history, message, extracted_data, onPage_prompt, role_prompt): |
|
|
input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt(extracted_data, onPage_prompt, role_prompt)}\n<</SYS>>\n\n " |
|
|
for user, assistant in chat_history: |
|
|
input_prompt += f"{user} [/INST] {assistant} <s>[INST] " |
|
|
input_prompt += f"{message} [/INST] " |
|
|
return input_prompt |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate( |
|
|
request: gr.Request, |
|
|
message: str, |
|
|
chat_history: list[tuple[str, str]], |
|
|
|
|
|
max_new_tokens: int = 1024, |
|
|
temperature: float = 0.6, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.2, |
|
|
) -> Iterator[str]: |
|
|
|
|
|
conversation = [] |
|
|
|
|
|
|
|
|
params = request.query_params |
|
|
print('those are the query params') |
|
|
print(params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
externalID = params.get('PROLIFIC_PID') |
|
|
|
|
|
|
|
|
if externalID is not None: |
|
|
print("PROLIFIC_PID:", externalID) |
|
|
else: |
|
|
externalID = 'no_externalID' |
|
|
print("PROLIFIC_PID not found or has no value.") |
|
|
|
|
|
|
|
|
|
|
|
all_personalized_data = fetch_personalized_data(externalID) |
|
|
|
|
|
|
|
|
onPage = playerNr = groupNrStart = role = part = None |
|
|
|
|
|
|
|
|
if all_personalized_data: |
|
|
for entry in all_personalized_data: |
|
|
|
|
|
if entry['externalID'] == externalID: |
|
|
playerNr = entry.get('playerNr', "no_playerNr") |
|
|
groupNrStart = entry.get('groupNrStart', "no_groupNrStart") |
|
|
onPage = entry.get('onPage', "no_onPage") |
|
|
role = entry.get('role', "no_role") |
|
|
part = entry.get('part', "no_part") |
|
|
break |
|
|
|
|
|
|
|
|
print("onPage:", onPage) |
|
|
print("playerNr:", playerNr) |
|
|
print("groupNrStart:", groupNrStart) |
|
|
print("role:", role) |
|
|
print("part:", part) |
|
|
|
|
|
|
|
|
onPage_filename, onPage_prompt = map_onPage(onPage) |
|
|
print("onPage_filename:", onPage_filename) |
|
|
print("onPage_prompt:", onPage_prompt) |
|
|
|
|
|
|
|
|
role_prompt = map_role(str(role)) |
|
|
print("role_prompt:", role_prompt) |
|
|
|
|
|
extracted_data = extract_variables(all_personalized_data, part) |
|
|
print(extracted_data) |
|
|
|
|
|
|
|
|
input_prompt = construct_input_prompt(chat_history, message, extracted_data, onPage_prompt, role_prompt) |
|
|
|
|
|
|
|
|
if input_prompt: |
|
|
conversation.append({"role": "system", "content": input_prompt}) |
|
|
|
|
|
|
|
|
input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
for user, assistant in chat_history: |
|
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") |
|
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generate_kwargs = dict( |
|
|
{"input_ids": input_ids}, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
temperature=temperature, |
|
|
num_beams=1, |
|
|
repetition_penalty=repetition_penalty, |
|
|
) |
|
|
|
|
|
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
|
t.start() |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
for text in streamer: |
|
|
outputs.append(text) |
|
|
yield "".join(outputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
readable_sentence = ' '.join(filter(lambda x: x.strip(), outputs)) |
|
|
|
|
|
print(readable_sentence) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pd.set_option("display.max_colwidth", None) |
|
|
|
|
|
|
|
|
filename = f"{groupNrStart}_{playerNr}_{externalID}_{onPage_filename}_{DATA_FILENAME}" |
|
|
data_file = os.path.join(DATA_DIRECTORY, filename) |
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now() |
|
|
|
|
|
|
|
|
if os.path.exists(data_file): |
|
|
|
|
|
existing_data = pd.read_csv(data_file) |
|
|
else: |
|
|
|
|
|
existing_data = None |
|
|
|
|
|
|
|
|
turn_data = { |
|
|
"turn_id": len(existing_data) + 1 if existing_data is not None else 1, |
|
|
"question": message, |
|
|
"answer": readable_sentence, |
|
|
"timestamp": timestamp, |
|
|
} |
|
|
turn_df = pd.DataFrame([turn_data]) |
|
|
|
|
|
|
|
|
if existing_data is not None: |
|
|
updated_data = pd.concat([existing_data, turn_df], ignore_index=True) |
|
|
else: |
|
|
updated_data = turn_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
updated_data.to_csv(data_file, index=False, quoting=csv.QUOTE_ALL) |
|
|
|
|
|
print("Updating .csv") |
|
|
repo.push_to_hub(blocking=False, commit_message=f"Updating data at {timestamp}") |
|
|
|
|
|
css = """ |
|
|
share-button svelte-1lcyrx4 {visibility: hidden} |
|
|
""" |
|
|
chat_interface = gr.ChatInterface( |
|
|
fn=generate, |
|
|
retry_btn=None, |
|
|
clear_btn=None, |
|
|
undo_btn=None, |
|
|
chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = True, elem_id = 'chatbot'), |
|
|
) |
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
|
|
|
|
chat_interface.render() |
|
|
gr.Markdown(LICENSE) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=20).launch() |
|
|
|
|
|
demo.launch(share=True, debug=True) |
|
|
|