Spaces:
Running
Running
File size: 4,435 Bytes
2540d7b be45289 2540d7b 359520d 95bd77c e96bd57 359520d 2540d7b be45289 a87369c bc400bc be45289 2540d7b 359520d 2540d7b 359520d 6afc72e 359520d 6afc72e 359520d 2540d7b 359520d 2540d7b 359520d 2540d7b bc400bc 2540d7b be45289 2540d7b df81c3f 6afc72e df81c3f 6afc72e 2540d7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
from huggingface_hub import InferenceClient
import os
# RAG imports
import os
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
client = InferenceClient(
model="meta-llama/Meta-Llama-3-8B-Instruct",
#provider="groq",
token=hf_token # <-- This is critical
)
# We'll load the existing FAISS index at the start
INDEX_FOLDER = "faiss_index"
_vectorstore = None
def load_vectorstore():
"""Loads FAISS index from local folder."""
global _vectorstore
if _vectorstore is None:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
_vectorstore = FAISS.load_local(INDEX_FOLDER, embeddings, allow_dangerous_deserialization=True)
return _vectorstore
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
Called on each user message. We'll do a retrieval step (RAG)
to get relevant context, then feed it into the system message
before calling the InferenceClient.
"""
# 1. Retrieve top documents from FAISS
vectorstore = load_vectorstore()
top_docs = vectorstore.similarity_search(message, k=3)
# Build context string from the docs
context_texts = []
for doc in top_docs:
context_texts.append(doc.page_content)
KnowledgeBase = "\n".join(context_texts)
# 2. Augment the original system message with retrieved context
augmented_system_message = system_message + "\n\n" + f"Relevant context:\n{KnowledgeBase}"
# 3. Convert (history) into messages
messages = [{"role": "system", "content": augmented_system_message }]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# Finally, add the new user message
messages.append({"role": "user", "content": message})
# 4. Stream from the InferenceClient
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
if not message.choices:
continue
token = message.choices[0].delta.content
if token is None:
continue
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly, knowledgeable assistant acting as Prakash Naikade."
"You have access to a rich set of documents and references collectively called KnowledgeBase, which you should call and treat as your current knowledge base. "
"Always use the facts, details, and stories from KnowledgeBase to ground your answers. "
"If a question goes beyond what KnowledgeBase covers, politely explain that you don’t have enough information to answer. "
"Remain friendly, empathetic, and helpful, providing clear, concise, and context-driven responses. "
"Stay consistent with any personal or professional details found in KnowledgeBase. "
"If KnowledgeBase lacks any relevant detail, avoid making up new information—be honest about the gap. "
"Your goal is to accurately represent Prakash Naikade: his background, expertise, and experiences, using only the data from KnowledgeBase to support your answers.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()
|