igenrate / fastapi_qlora_app.py
sujal7102003's picture
Upload folder using huggingface_hub
5df79c9 verified
import os
from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# βœ… Set Hugging Face cache & offload dirs
cache_dir = "/tmp/huggingface"
offload_dir = os.path.join(cache_dir, "offload")
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(offload_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
# βœ… FastAPI app setup
app = FastAPI()
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
def read_index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# βœ… Input schema
class TinyLlamaInput(BaseModel):
prompt: str
# βœ… Load only one model (QLoRA)
model_dir = "/content/lora-tinyllama-igenrate"
base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(
base_model,
device_map="auto",
torch_dtype=torch.float16,
cache_dir=cache_dir,
offload_folder=offload_dir
)
model = PeftModel.from_pretrained(base, model_dir)
model = model.merge_and_unload()
model.eval()
# βœ… Inference logic
def generate_response(prompt, tokenizer, model):
full_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded.split("### Response:")[-1].strip()
# πŸ”Ή Single endpoint for QLoRA
@app.post("/predict/qlora")
def predict_qlora(input_data: TinyLlamaInput):
answer = generate_response(input_data.prompt, tokenizer, model)
return {"model": "QLoRA - lora-tinyllama-igenrate", "response": answer}