File size: 14,249 Bytes
fd7d7b0
 
 
 
 
 
 
 
ddcde60
fd7d7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8ddc1d
 
 
d45c3c9
b8ddc1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7d7b0
b8ddc1d
 
 
 
 
 
 
d45c3c9
fd7d7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8477a4d
 
a01cd54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf7c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01cd54
 
 
8477a4d
cf7c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8477a4d
 
a01cd54
cf7c684
a01cd54
cf7c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01cd54
 
 
 
 
8477a4d
a01cd54
d45c3c9
5936ae6
 
 
5a0d4b3
5936ae6
 
d45c3c9
6dbcf2e
 
 
 
 
 
 
 
 
9c3db7a
 
 
 
f94a640
9c3db7a
 
 
 
 
 
 
 
 
fa072ac
 
 
9c3db7a
 
4692477
9c3db7a
 
 
589ac10
 
9c3db7a
 
 
 
 
 
 
 
 
 
a4c5e87
9c3db7a
 
 
a4c5e87
9c3db7a
 
 
a4c5e87
9c3db7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2772476
4f7271e
5a0d4b3
 
 
66d5bd2
a476d3b
e08cb4c
 
4692477
e08cb4c
 
 
 
a01cd54
 
 
 
e08cb4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3db7a
 
4692477
 
9c3db7a
 
 
 
d6466b7
589ac10
 
a4c5e87
589ac10
 
 
fd7d7b0
 
 
 
a4c5e87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
# Original code from https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat and https://huggingface.co/spaces/radames/gradio-chatbot-read-query-param 
import gradio as gr
import time
import random
import json
import mysql.connector
import os
import csv
from huggingface_hub import Repository, hf_hub_download
from datetime import datetime

import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator

# data_fetcher.py
import mysql.connector
import urllib.parse
import urllib.request


# Save chat history as JSON
import atexit
import os
from huggingface_hub import HfApi, HfFolder

'''# Define dataset repository URL and ID
DATASET_REPO_URL = "https://huggingface.co/datasets/botsi/trust-game-llama-2-7b-chat"
DATASET_REPO_ID = "botsi/trust-game-llama-2-7b-chat"

# Define data file information
DATA_FILENAME = "history_trust-game-llama-2-7b-chat.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)

# Get Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")

# Check if the data file exists
if os.path.exists(DATA_FILE):
    # Initialize Hugging Face API
    api = HfApi()

    # Upload file to dataset repository
    with open(DATA_FILE, "rb") as f:
        dataset_files = HfFolder.upload(
            folder_or_file=f,
            path_in_repo=DATA_FILENAME,
            repo_id=DATASET_REPO_ID,
            token=HF_TOKEN
        )

    # Print uploaded file information
    print("Uploaded file to dataset repository:")
    for file_info in dataset_files:
        print(f"File path in repo: {file_info.path}")
        print(f"File ID: {file_info.id}")
else:
    print("Data file does not exist.")
'''

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

DESCRIPTION = """\
# Llama-2 7B Chat
This is your personal space to chat. 
You can ask anything from strategic questions regarding the game or just chat as you like. 
"""
'''LICENSE = """
<p/>

---
As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
"""
'''

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    model_id = "meta-llama/Llama-2-7b-chat-hf"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False

import mysql.connector

def fetch_personalized_data(session_index):
    try:
        # Connect to the database
        with mysql.connector.connect(
            host="18.153.94.89",
            user="root",
            password="N12RXMKtKxRj",
            database="lionessdb"
        ) as conn:
            # Create a cursor object
            with conn.cursor() as cursor:
                # Query to fetch relevant data from both tables based on session_index
                query = """
                    SELECT e5390g37096_core.playerNr, 
                           e5390g37096_core.groupNrStart, 
                           e5390g37096_core.subjectNr, 
                           e5390g37096_core.onPage, 
                           e5390g37096_decisions.session_index, 
                           e5390g37096_decisions.transfer1,
                           e5390g37096_decisions.tripledAmount1,
                           e5390g37096_decisions.keptForSelf1,
                           e5390g37096_decisions.returned1,
                           e5390g37096_decisions.newCreditRound2,
                           e5390g37096_decisions.transfer2,
                           e5390g37096_decisions.tripledAmount2,
                           e5390g37096_decisions.keptForSelf2,
                           e5390g37096_decisions.returned2,
                           e5390g37096_decisions.results2rounds, 
                           e5390g37096_decisions.newCreditRound3, 
                           e5390g37096_decisions.transfer3, 
                           e5390g37096_decisions.tripledAmount3, 
                           e5390g37096_decisions.keptForSelf3, 
                           e5390g37096_decisions.returned3, 
                           e5390g37096_decisions.results3rounds
                    FROM e5390g37096_core
                    JOIN e5390g37096_decisions ON 
                        e5390g37096_core.playerNr = e5390g37096_decisions.playerNr
                    WHERE e5390g37096_decisions.session_index = %s
                    UNION ALL
                    SELECT e5390g37096_core.playerNr, 
                           e5390g37096_core.groupNrStart, 
                           e5390g37096_core.subjectNr, 
                           e5390g37096_core.onPage, 
                           e5390g37096_decisions.session_index, 
                           e5390g37096_decisions.transfer1,
                           e5390g37096_decisions.tripledAmount1,
                           e5390g37096_decisions.keptForSelf1,
                           e5390g37096_decisions.returned1,
                           e5390g37096_decisions.newCreditRound2,
                           e5390g37096_decisions.transfer2,
                           e5390g37096_decisions.tripledAmount2,
                           e5390g37096_decisions.keptForSelf2,
                           e5390g37096_decisions.returned2,
                           e5390g37096_decisions.results2rounds, 
                           e5390g37096_decisions.newCreditRound3, 
                           e5390g37096_decisions.transfer3, 
                           e5390g37096_decisions.tripledAmount3, 
                           e5390g37096_decisions.keptForSelf3, 
                           e5390g37096_decisions.returned3, 
                           e5390g37096_decisions.results3rounds
                    FROM e5390g37096_core
                    JOIN e5390g37096_decisions 
                        ON e5390g37096_core.playerNr = e5390g37096_decisions.playerNr
                    WHERE e5390g37096_core.groupNrStart IN (
                        SELECT DISTINCT groupNrStart
                        FROM e5390g37096_core
                        JOIN e5390g37096_decisions 
                            ON e5390g37096_core.playerNr = e5390g37096_decisions.playerNr
                        WHERE e5390g37096_decisions.session_index = %s
                    ) AND e5390g37096_decisions.session_index != %s
                """
                cursor.execute(query, (session_index, session_index, session_index))
                # Fetch data row by row
                data = [{
                    'playerNr': row[0], 
                    'groupNrStart': row[1], 
                    'subjectNr': row[2], 
                    'onPage': row[3],
                    'session_index': row[4],
                    'transfer1': row[5],
                    'tripledAmount1': row[6],
                    'keptForSelf1': row[7],
                    'returned1': row[8],
                    'newCreditRound2': row[9],
                    'transfer2': row[10],
                    'tripledAmount2': row[11],
                    'keptForSelf2': row[12],
                    'returned2': row[13],
                    'results2rounds': row[14],
                    'newCreditRound3': row[15],
                    'transfer3': row[16],
                    'tripledAmount3': row[17],
                    'keptForSelf3': row[18],
                    'returned3': row[19],
                    'results3rounds': row[20]
                } for row in cursor]
                print(data)
                return data
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return None


js = """
    function() {
        const params = new URLSearchParams(window.location.search);
        const url_params = Object.fromEntries(params);
        return url_params;
        }
    """

def get_window_url_params():
    return """
        function() {
            const params = new URLSearchParams(window.location.search);
            const url_params = Object.fromEntries(params);
            return url_params;
            }
        """

## trust-game-llama-2-7b-chat
# app.py 
@spaces.GPU
def generate(
    request: gr.Request, 
    message: str,
    chat_history: list[tuple[str, str]],
    # system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]: # Change return type hint to Iterator[str]

    params = request.query_params
    print(params)

    # Construct the input prompt using the functions from the system_prompt_config module
    input_prompt = construct_input_prompt(chat_history, message)

    # Use the global variable to store the chat history
    # global global_chat_history
    
    conversation = []

    # Move the condition here after the assignment
    if input_prompt:
        conversation.append({"role": "system", "content": input_prompt})

    # Convert input prompt to tensor
    input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)

    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    
    input_ids = input_ids.to(model.device)

    # Set up the TextIteratorStreamer
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    
    # Set up the generation arguments
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )

    # Start the model generation thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Yield generated text chunks
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)
        
    #gr.Markdown(DESCRIPTION)
    #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")

chat_interface = gr.ChatInterface(
fn=generate,
theme="soft",
retry_btn=None,
clear_btn=None,
undo_btn=None,
chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False), 
examples=[
    ["Can you explain the rules very briefly again?"],
    ["How much should I invest in order to win?"],
    ["What happened in the last round?"],
    ["What is my probability to win if I do not share anything?"],
],
)

with gr.Blocks(js = js, css="style.css") as demo:
    #url_params = gr.JSON({}, visible=False, label="URL Params")
    #session_index = get_session_index(url_params)
    session_index = 'eb3636167d3a63fbeee32934610e5b2f'
    personalized_data = fetch_personalized_data(session_index)
        
    ## trust-game-llama-2-7b-chat
    # app.py 
    def get_default_system_prompt():
        #BOS, EOS = "<s>", "</s>" 
        #BINST, EINST = "[INST]", "[/INST]"
        BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

        DEFAULT_SYSTEM_PROMPT = f""" You are an intelligent and fair game guide in a 2-player trust game.
        You are assisting players in making decisions to win. 
        Answer in a consistent style. Each of your answers should be maximum 2 sentences long. 
        The players are called The Investor and The Dealer and keep their role throughout the whole game. 
        Both start with 10€ in round 1. The game consists of 3 rounds. In round 1, The Investor invests between 0€ and 10€. 
        This amount is tripled automatically, and The Dealer can then distribute the tripled amount. After that, round 1 is over. 
        Both go into the next round with their current asset: The Investor with 10€ minus what he invested plus what he received back from The Dealer.
        The Dealer with 10€ plus what he kept from the tripled amount. 
        You will receive a JSON with information on who trusted whom with how much money after each round as context.
        Your goal is to guide players through the game, providing clear instructions and explanations. 
        If any question or action seems unclear, explain it rather than providing inaccurate information. 
        If you're unsure about an answer, it's better not to guess.

        Example JSON context after a round: {personalized_data}

        Few-shot training examples
        {BSYS} Give an overview of the trust game. {ESYS}
        {BSYS} Explain how trust amounts are calculated. {ESYS}
        {BSYS} What happens if a player doesn't trust in a round? {ESYS}
        """
        print(DEFAULT_SYSTEM_PROMPT)
        return DEFAULT_SYSTEM_PROMPT
    
    ## trust-game-llama-2-7b-chat
    # app.py
    def construct_input_prompt(chat_history, message):
        input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt()}\n<</SYS>>\n\n "
        for user, assistant in chat_history:
            input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
        input_prompt += f"{message} [/INST] "
        return input_prompt

    chat_interface.render()
    #gr.Markdown(LICENSE)

if __name__ == "__main__":
    #demo.queue(max_size=20).launch()
    demo.queue(max_size=20)
    demo.launch(share=True, debug=True)

# Register the function to be called when the program exits
# atexit.register(save_chat_history)