Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import random | |
| import pandas as pd | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| import re | |
| import time | |
| # --------------------------- Configuration & Session State --------------------------- | |
| # Define maximum dimensions for the fortune image (in pixels) | |
| MAX_SIZE = (400, 400) | |
| # Initialize button click count in session state | |
| if "button_count_temp" not in st.session_state: | |
| st.session_state.button_count_temp = 0 | |
| # Set page configuration and title | |
| st.set_page_config(page_title="Fortune Stick Enquiry", layout="wide") | |
| st.title("Fortune Stick Enquiry") | |
| # Initialize session state variables for managing application state | |
| if "submitted_text" not in st.session_state: | |
| st.session_state.submitted_text = False | |
| if "fortune_number" not in st.session_state: | |
| st.session_state.fortune_number = None | |
| if "fortune_row" not in st.session_state: | |
| st.session_state.fortune_row = None | |
| if "error_message" not in st.session_state: | |
| st.session_state.error_message = "" | |
| if "cfu_explain_text" not in st.session_state: | |
| st.session_state.cfu_explain_text = "" | |
| if "stick_clicked" not in st.session_state: | |
| st.session_state.stick_clicked = False | |
| # Load fortune details from CSV file into session state | |
| if "fortune_data" not in st.session_state: | |
| try: | |
| st.session_state.fortune_data = pd.read_csv("/home/user/app/resources/detail.csv") | |
| except Exception as e: | |
| st.error(f"Error loading CSV: {e}") | |
| st.session_state.fortune_data = None | |
| # --------------------------- Model Functions --------------------------- | |
| # Function to load a fine-tuned classifier model and predict a label based on the question | |
| def load_finetuned_classifier_model(question): | |
| label_list = ["Geomancy", "Lost Property", "Personal Well-Being", "Future Prospect", "Traveling"] | |
| # Mapping to convert default "LABEL_x" outputs to meaningful labels | |
| mapping = {f"LABEL_{i}": label for i, label in enumerate(label_list)} | |
| pipe = pipeline("text-classification", model="tonyhui2234/CustomModel_classifier_model_10") | |
| prediction = pipe(question)[0]['label'] | |
| predicted_label = mapping.get(prediction, prediction) | |
| return predicted_label | |
| # Function to generate a detailed answer by combining the user's question and the fortune detail | |
| def generate_answer(question, fortune): | |
| # Start measuring runtime | |
| start_time = time.perf_counter() | |
| tokenizer = AutoTokenizer.from_pretrained("tonyhui2234/finetuned_model_text_gen") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("tonyhui2234/finetuned_model_text_gen", device_map="auto") | |
| input_text = "Question: " + question + " Fortune: " + fortune | |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=256, | |
| num_beams=4, | |
| early_stopping=True, | |
| repetition_penalty=2.0, | |
| no_repeat_ngram_size=3 | |
| ) | |
| answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Stop measuring runtime | |
| run_time = time.perf_counter() - start_time | |
| print(f"Runtime: {run_time:.4f} seconds") | |
| return answer | |
| # Function that combines analysis with regex to extract the related fortune detail and then generate an answer | |
| def analysis(row_detail, classifiy, question): | |
| # Use the classifier's output to match the corresponding detail in the fortune data | |
| pattern = re.compile(re.escape(classifiy) + r":\s*(.*?)(?:\.|$)", re.IGNORECASE) | |
| match = pattern.search(row_detail) | |
| if match: | |
| result = match.group(1) | |
| # Generate a custom answer based on the matched fortune detail and the user's question | |
| return generate_answer(question, result) | |
| else: | |
| return "Heaven's secret cannot be revealed." | |
| # Function to check if the input sentence is in English using a language detection model | |
| def check_sentence_is_english_model(question): | |
| pipe_english = pipeline("text-classification", model="eleldar/language-detection") | |
| return pipe_english(question)[0]['label'] == 'en' | |
| # Function to check if the input sentence is a question using a question vs. statement classifier | |
| def check_sentence_is_question_model(question): | |
| pipe_question = pipeline("text-classification", model="shahrukhx01/question-vs-statement-classifier") | |
| return pipe_question(question)[0]['label'] == 'LABEL_1' | |
| # --------------------------- Callback Functions --------------------------- | |
| # Callback for when the submit button is clicked | |
| def submit_text_callback(): | |
| question = st.session_state.get("user_sentence", "") | |
| # Clear any previous error message | |
| st.session_state.error_message = "" | |
| # Validate that the input is in English and is a question | |
| if not check_sentence_is_english_model(question): | |
| st.session_state.error_message = "Please enter in English!" | |
| st.session_state.button_count_temp = 0 | |
| return | |
| if not check_sentence_is_question_model(question): | |
| st.session_state.error_message = "This is not a question. Please enter again!" | |
| st.session_state.button_count_temp = 0 | |
| return | |
| # Require a second confirmation click to proceed | |
| if st.session_state.button_count_temp == 0: | |
| st.session_state.error_message = "Please take a moment to quietly reflect on your question in your mind, then click submit again!" | |
| st.session_state.button_count_temp = 1 | |
| return | |
| # If validations pass, set submission flag and reset click counter | |
| st.session_state.submitted_text = True | |
| st.session_state.button_count_temp = 0 | |
| # Randomly generate a fortune number between 1 and 100 | |
| st.session_state.fortune_number = random.randint(1, 100) | |
| # Retrieve corresponding fortune details from the CSV based on the generated number | |
| df = st.session_state.fortune_data | |
| row_detail = '' | |
| if df is not None: | |
| matching_row = df[df['CNumber'] == st.session_state.fortune_number] | |
| if not matching_row.empty: | |
| row = matching_row.iloc[0] | |
| row_detail = row.get("Detail", "No detail available.") | |
| st.session_state.fortune_row = { | |
| "Header": row.get("Header", "N/A"), | |
| "Luck": row.get("Luck", "N/A"), | |
| "Description": row.get("Description", "No description available."), | |
| "Detail": row_detail, | |
| "HeaderLink": row.get("link", None) | |
| } | |
| else: | |
| st.session_state.fortune_row = { | |
| "Header": "N/A", | |
| "Luck": "N/A", | |
| "Description": "No description available.", | |
| "Detail": "No detail available.", | |
| "HeaderLink": None | |
| } | |
| # Function to load and resize a local image file | |
| def load_and_resize_image(path, max_size=MAX_SIZE): | |
| try: | |
| img = Image.open(path) | |
| img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| return img | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| return None | |
| # Function to download an image from a URL and resize it | |
| def download_and_resize_image(url, max_size=MAX_SIZE): | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| image_bytes = BytesIO(response.content) | |
| img = Image.open(image_bytes) | |
| img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| return img | |
| except Exception as e: | |
| st.error(f"Error loading image from URL: {e}") | |
| return None | |
| # Callback for when the 'Cfu Explain' button is clicked | |
| def stick_enquiry_callback(): | |
| # Retrieve the user's question and ensure fortune data is available | |
| question = st.session_state.get("user_sentence", "") | |
| if not st.session_state.fortune_row: | |
| st.error("Fortune data is not available. Please submit your question first.") | |
| return | |
| row_detail = st.session_state.fortune_row.get("Detail", "No detail available.") | |
| # Classify the question to determine which fortune detail to use | |
| classifiy = load_finetuned_classifier_model(question) | |
| # Generate an explanation based on the classification and fortune detail | |
| cfu_explain = analysis(row_detail, classifiy, question) | |
| # Save the generated explanation for display | |
| st.session_state.cfu_explain_text = cfu_explain | |
| st.session_state.stick_clicked = True | |
| # --------------------------- Layout & Display --------------------------- | |
| # Define the main layout with two columns: left for user input and right for fortune display | |
| left_col, _, right_col = st.columns([3, 1, 5]) | |
| # ---- Left Column: User Input and Interaction ---- | |
| with left_col: | |
| left_top = st.container() | |
| left_bottom = st.container() | |
| # Top container: Question input and submission button | |
| with left_top: | |
| st.text_area("Enter your question in English", key="user_sentence", height=150) | |
| st.button("submit", key="submit_button", on_click=submit_text_callback) | |
| if st.session_state.error_message: | |
| st.error(st.session_state.error_message) | |
| # Bottom container: Button to trigger explanation and display the generated answer | |
| if st.session_state.submitted_text: | |
| with left_bottom: | |
| # Add spacing for better visual separation | |
| for _ in range(5): | |
| st.write("") | |
| col1, col2, col3 = st.columns(3) | |
| with col2: | |
| st.button("Cfu Explain", key="stick_button", on_click=stick_enquiry_callback) | |
| if st.session_state.stick_clicked: | |
| # Display the generated explanation text | |
| st.text_area(' ', value=st.session_state.cfu_explain_text, height=300, disabled=True) | |
| # ---- Right Column: Fortune Display (Image and Details) ---- | |
| with right_col: | |
| with st.container(): | |
| col_left, col_center, col_right = st.columns([1, 2, 1]) | |
| with col_center: | |
| # Display fortune image based on fortune data availability | |
| if st.session_state.submitted_text and st.session_state.fortune_row: | |
| header_link = st.session_state.fortune_row.get("HeaderLink") | |
| if header_link: | |
| img_from_url = download_and_resize_image(header_link) | |
| if img_from_url: | |
| st.image(img_from_url, use_container_width=False) | |
| else: | |
| img = load_and_resize_image("/home/user/app/resources/error.png") | |
| if img: | |
| st.image(img, use_container_width=False) | |
| else: | |
| img = load_and_resize_image("/home/user/app/resources/error.png") | |
| if img: | |
| st.image(img, use_container_width=False) | |
| else: | |
| img = load_and_resize_image("/home/user/app/resources/fortune.png") | |
| if img: | |
| st.image(img, caption="Your Fortune", use_container_width=False) | |
| with st.container(): | |
| # Display fortune details: Number, Luck, Description, and Detail | |
| if st.session_state.fortune_row: | |
| luck_text = st.session_state.fortune_row.get("Luck", "N/A") | |
| description_text = st.session_state.fortune_row.get("Description", "No description available.") | |
| detail_text = st.session_state.fortune_row.get("Detail", "No detail available.") | |
| summary = f""" | |
| <div style="font-size: 28px; font-weight: bold;"> | |
| Fortune stick number: {st.session_state.fortune_number}<br> | |
| Luck: {luck_text} | |
| </div> | |
| """ | |
| st.markdown(summary, unsafe_allow_html=True) | |
| st.text_area("Description", value=description_text, height=150, disabled=True) | |
| st.text_area("Detail", value=detail_text, height=150, disabled=True) |