bot / app.py
kushal1's picture
Upload 2 files
de67144 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Page config
st.set_page_config(page_title="Gemma 3 1B Chat", page_icon="πŸ’¬")
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "model" not in st.session_state:
st.session_state.model = None
if "tokenizer" not in st.session_state:
st.session_state.tokenizer = None
@st.cache_resource
def load_model():
"""Load the Gemma model and tokenizer"""
model_name = "google/gemma-2-1b-it" # Change to "google/gemma-3-1b-it" if available
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
return model, tokenizer
# Load model
with st.spinner("Loading Gemma 3 1B model..."):
if st.session_state.model is None:
st.session_state.model, st.session_state.tokenizer = load_model()
# Title
st.title("πŸ’¬ Gemma 3 1B Chat")
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Type your message here..."):
# Add user message
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Generate response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
# Format prompt for chat
chat_prompt = st.session_state.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = st.session_state.tokenizer(
chat_prompt,
return_tensors="pt"
).to(st.session_state.model.device)
# Generate
with torch.no_grad():
outputs = st.session_state.model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True
)
# Decode response
response = st.session_state.tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
st.markdown(response)
# Add assistant response
st.session_state.messages.append({"role": "assistant", "content": response})
# Sidebar with clear button
with st.sidebar:
st.header("Settings")
if st.button("Clear Chat"):
st.session_state.messages = []
st.rerun()