import os from fastapi import FastAPI, Request from pydantic import BaseModel from fastapi.templating import Jinja2Templates from fastapi.responses import HTMLResponse from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch # ✅ Set Hugging Face cache & offload dirs cache_dir = "/tmp/huggingface" offload_dir = os.path.join(cache_dir, "offload") os.makedirs(cache_dir, exist_ok=True) os.makedirs(offload_dir, exist_ok=True) os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir # ✅ FastAPI app setup app = FastAPI() templates = Jinja2Templates(directory="templates") @app.get("/", response_class=HTMLResponse) def read_index(request: Request): return templates.TemplateResponse("index.html", {"request": request}) # ✅ Input schema class TinyLlamaInput(BaseModel): prompt: str # ✅ Load only one model (QLoRA) model_dir = "lora-tinyllama-igenrate" base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" try: tokenizer = AutoTokenizer.from_pretrained(model_dir) except: tokenizer = AutoTokenizer.from_pretrained(base_model) tokenizer.pad_token = tokenizer.eos_token base = AutoModelForCausalLM.from_pretrained( base_model, device_map="auto", torch_dtype=torch.float16, cache_dir=cache_dir, offload_folder=offload_dir ) model = PeftModel.from_pretrained(base, model_dir) model = model.merge_and_unload() model.eval() # ✅ Inference logic def generate_response(prompt, tokenizer, model): full_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n" inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=150, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) decoded = tokenizer.decode(output[0], skip_special_tokens=True) return decoded.split("### Response:")[-1].strip() # 🔹 Single endpoint for QLoRA @app.post("/predict/qlora") def predict_qlora(input_data: TinyLlamaInput): answer = generate_response(input_data.prompt, tokenizer, model) return {"model": "QLoRA - lora-tinyllama-igenrate", "response": answer}