|
|
import streamlit as st |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Gemma 3 1B Chat", page_icon="π¬") |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
if "model" not in st.session_state: |
|
|
st.session_state.model = None |
|
|
if "tokenizer" not in st.session_state: |
|
|
st.session_state.tokenizer = None |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load the Gemma model and tokenizer""" |
|
|
model_name = "google/gemma-2-1b-it" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
with st.spinner("Loading Gemma 3 1B model..."): |
|
|
if st.session_state.model is None: |
|
|
st.session_state.model, st.session_state.tokenizer = load_model() |
|
|
|
|
|
|
|
|
st.title("π¬ Gemma 3 1B Chat") |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Type your message here..."): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
|
|
|
chat_prompt = st.session_state.tokenizer.apply_chat_template( |
|
|
[{"role": "user", "content": prompt}], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = st.session_state.tokenizer( |
|
|
chat_prompt, |
|
|
return_tensors="pt" |
|
|
).to(st.session_state.model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = st.session_state.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
|
|
|
response = st.session_state.tokenizer.decode( |
|
|
outputs[0][inputs.input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
st.markdown(response) |
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Settings") |
|
|
if st.button("Clear Chat"): |
|
|
st.session_state.messages = [] |
|
|
st.rerun() |
|
|
|
|
|
|