ZeroShirayuki commited on
Commit
8023b76
·
verified ·
1 Parent(s): 6a3d3b0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -15
app.py CHANGED
@@ -1,25 +1,42 @@
1
  # app.py
2
- from transformers import pipeline
3
  from fastapi import FastAPI, HTTPException
4
  import torch
 
 
 
5
  from typing import List, Dict
6
- import json
7
 
8
  app = FastAPI()
9
 
10
- # Load model and mappings
11
- checkpoint = torch.load('manga_recommender.pt')
12
- model = MangaRecommender(
13
- num_users=len(checkpoint['user_mapping']),
14
- num_items=len(checkpoint['manga_mapping'])
15
- )
16
- model.load_state_dict(checkpoint['model_state_dict'])
17
- user_mapping = checkpoint['user_mapping']
18
- manga_mapping = checkpoint['manga_mapping']
19
- reverse_manga_mapping = {v: k for k, v in manga_mapping.items()}
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.post("/predict")
22
  async def predict(user_id: str, top_k: int = 10):
 
 
 
23
  try:
24
  # Get user index
25
  user_idx = user_mapping.get(user_id)
@@ -43,10 +60,16 @@ async def predict(user_id: str, top_k: int = 10):
43
  "scores": scores
44
  }
45
  except Exception as e:
46
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
47
 
48
  @app.post("/update")
49
  async def update_model(ratings: List[Dict]):
 
 
 
50
  try:
51
  # Convert ratings to training format
52
  df = pd.DataFrame(ratings)
@@ -62,12 +85,14 @@ async def update_model(ratings: List[Dict]):
62
  criterion = nn.MSELoss()
63
 
64
  model.train()
 
65
  for user, item, rating in loader:
66
  optimizer.zero_grad()
67
  pred = model(user, item)
68
  loss = criterion(pred, rating)
69
  loss.backward()
70
  optimizer.step()
 
71
 
72
  # Save updated model
73
  torch.save({
@@ -76,6 +101,28 @@ async def update_model(ratings: List[Dict]):
76
  'manga_mapping': manga_mapping
77
  }, 'manga_recommender.pt')
78
 
79
- return {"message": "Model updated successfully"}
 
 
 
80
  except Exception as e:
81
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
 
2
  from fastapi import FastAPI, HTTPException
3
  import torch
4
+ from torch import nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import pandas as pd
7
  from typing import List, Dict
8
+ from train import MangaRecommender, MangaDataset # Import from train.py
9
 
10
  app = FastAPI()
11
 
12
+ try:
13
+ # Load model and mappings
14
+ checkpoint = torch.load('manga_recommender.pt')
15
+ model = MangaRecommender(
16
+ num_users=len(checkpoint['user_mapping']),
17
+ num_items=len(checkpoint['manga_mapping'])
18
+ )
19
+ model.load_state_dict(checkpoint['model_state_dict'])
20
+ user_mapping = checkpoint['user_mapping']
21
+ manga_mapping = checkpoint['manga_mapping']
22
+ reverse_manga_mapping = {v: k for k, v in manga_mapping.items()}
23
+ print("Model loaded successfully")
24
+ except Exception as e:
25
+ print(f"Error loading model: {e}")
26
+ model = None
27
+ user_mapping = {}
28
+ manga_mapping = {}
29
+ reverse_manga_mapping = {}
30
+
31
+ @app.get("/")
32
+ async def root():
33
+ return {"status": "running", "model_loaded": model is not None}
34
 
35
  @app.post("/predict")
36
  async def predict(user_id: str, top_k: int = 10):
37
+ if model is None:
38
+ raise HTTPException(status_code=500, detail="Model not loaded")
39
+
40
  try:
41
  # Get user index
42
  user_idx = user_mapping.get(user_id)
 
60
  "scores": scores
61
  }
62
  except Exception as e:
63
+ raise HTTPException(
64
+ status_code=500,
65
+ detail=f"Prediction error: {str(e)}"
66
+ )
67
 
68
  @app.post("/update")
69
  async def update_model(ratings: List[Dict]):
70
+ if model is None:
71
+ raise HTTPException(status_code=500, detail="Model not loaded")
72
+
73
  try:
74
  # Convert ratings to training format
75
  df = pd.DataFrame(ratings)
 
85
  criterion = nn.MSELoss()
86
 
87
  model.train()
88
+ total_loss = 0
89
  for user, item, rating in loader:
90
  optimizer.zero_grad()
91
  pred = model(user, item)
92
  loss = criterion(pred, rating)
93
  loss.backward()
94
  optimizer.step()
95
+ total_loss += loss.item()
96
 
97
  # Save updated model
98
  torch.save({
 
101
  'manga_mapping': manga_mapping
102
  }, 'manga_recommender.pt')
103
 
104
+ return {
105
+ "message": "Model updated successfully",
106
+ "average_loss": total_loss / len(loader)
107
+ }
108
  except Exception as e:
109
+ raise HTTPException(
110
+ status_code=500,
111
+ detail=f"Update error: {str(e)}"
112
+ )
113
+
114
+ @app.get("/model-info")
115
+ async def model_info():
116
+ if model is None:
117
+ raise HTTPException(status_code=500, detail="Model not loaded")
118
+
119
+ return {
120
+ "num_users": len(user_mapping),
121
+ "num_manga": len(manga_mapping),
122
+ "embedding_size": model.user_factors.embedding_dim
123
+ }
124
+
125
+ if __name__ == "__main__":
126
+ import uvicorn
127
+ uvicorn.run(app, host="0.0.0.0", port=8000)
128
+ raise HTTPException(status_code=500, detail=str(e))