| from fastapi import FastAPI, HTTPException, Header |
| from pydantic import BaseModel |
| import numpy as np |
| from tensorflow.keras.models import load_model |
| from fastapi.middleware.cors import CORSMiddleware |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| model = load_model("model_stresss.h5") |
| labels = ['Tidak Stress', 'Sedikit Stress', 'Normal', 'Stress', 'Sangat Stress'] |
|
|
| |
| cache_path = "/tmp/huggingface" |
|
|
| model_dir = "Chipan/indobert-emotion" |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, cache_dir=cache_path) |
| model_bert = AutoModelForSequenceClassification.from_pretrained(model_dir, cache_dir=cache_path) |
| model_bert.eval() |
|
|
| |
| label_map = {0: "Bersyukur", 1: "Marah", 2: "Sedih", 3: "Senang", 4: "Stress"} |
|
|
| |
| app = FastAPI() |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| class CheckInData(BaseModel): |
| mood: float |
| sleep: float |
| anxiety: float |
| exercise: float |
| support: float |
|
|
| class TextInput(BaseModel): |
| text: str |
|
|
| @app.post("/predict") |
| def predict(data: CheckInData, authorization: str = Header(None)): |
| try: |
| raw = np.array([[data.mood, data.sleep, data.anxiety, data.exercise, data.support]]) |
| prediction = model.predict(raw) |
| idx = int(np.argmax(prediction)) |
| return { |
| "predicted_index": idx, |
| "predicted_label": labels[idx], |
| "raw_prediction": prediction.tolist() |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") |
|
|
| @app.post("/analyze") |
| def analyze_emotion(input: TextInput): |
| try: |
| inputs = tokenizer(input.text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
| with torch.no_grad(): |
| logits = model_bert(**inputs).logits |
| probs = F.softmax(logits, dim=1) |
| idx = int(torch.argmax(probs)) |
| return { |
| "emotion": label_map.get(idx, "unknown"), |
| "confidence": round(probs[0, idx].item(), 4) |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Emotion analysis error: {str(e)}") |
|
|
| @app.get("/") |
| def root(): |
| return {"status": "ok"} |
|
|