import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Page config st.set_page_config(page_title="Gemma 3 1B Chat", page_icon="💬") # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "model" not in st.session_state: st.session_state.model = None if "tokenizer" not in st.session_state: st.session_state.tokenizer = None @st.cache_resource def load_model(): """Load the Gemma model and tokenizer""" model_name = "google/gemma-2-1b-it" # Change to "google/gemma-3-1b-it" if available tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) return model, tokenizer # Load model with st.spinner("Loading Gemma 3 1B model..."): if st.session_state.model is None: st.session_state.model, st.session_state.tokenizer = load_model() # Title st.title("💬 Gemma 3 1B Chat") # Display chat history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Chat input if prompt := st.chat_input("Type your message here..."): # Add user message st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Generate response with st.chat_message("assistant"): with st.spinner("Thinking..."): # Format prompt for chat chat_prompt = st.session_state.tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True ) # Tokenize inputs = st.session_state.tokenizer( chat_prompt, return_tensors="pt" ).to(st.session_state.model.device) # Generate with torch.no_grad(): outputs = st.session_state.model.generate( **inputs, max_new_tokens=512, temperature=0.7, do_sample=True ) # Decode response response = st.session_state.tokenizer.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) st.markdown(response) # Add assistant response st.session_state.messages.append({"role": "assistant", "content": response}) # Sidebar with clear button with st.sidebar: st.header("Settings") if st.button("Clear Chat"): st.session_state.messages = [] st.rerun()