Mark-Lasfar commited on
Commit
dd3c39b
·
1 Parent(s): 01237fb

Update Model

Browse files
Files changed (4) hide show
  1. api/auth.py +86 -3
  2. api/database.py +64 -3
  3. api/models.py +10 -2
  4. main.py +5 -116
api/auth.py CHANGED
@@ -1,10 +1,15 @@
 
1
  from fastapi_users import FastAPIUsers
2
  from fastapi_users.authentication import CookieTransport, JWTStrategy, AuthenticationBackend
3
  from fastapi_users.db import SQLAlchemyUserDatabase
4
  from httpx_oauth.clients.google import GoogleOAuth2
5
  from httpx_oauth.clients.github import GitHubOAuth2
6
- from api.database import SessionLocal
7
- from api.models import User, OAuthAccount
 
 
 
 
8
  import os
9
  import logging
10
 
@@ -46,9 +51,87 @@ if not all([GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET, GITHUB_CLIENT_ID, GITHUB_CLI
46
  google_oauth_client = GoogleOAuth2(GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET)
47
  github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  fastapi_users = FastAPIUsers[User, int](
50
- lambda: SQLAlchemyUserDatabase(User, SessionLocal(), oauth_account_table=OAuthAccount),
51
  [auth_backend],
52
  )
53
 
54
  current_active_user = fastapi_users.current_user(active=True, optional=True)
 
 
 
 
 
 
 
 
 
 
1
+ # api/auth.py
2
  from fastapi_users import FastAPIUsers
3
  from fastapi_users.authentication import CookieTransport, JWTStrategy, AuthenticationBackend
4
  from fastapi_users.db import SQLAlchemyUserDatabase
5
  from httpx_oauth.clients.google import GoogleOAuth2
6
  from httpx_oauth.clients.github import GitHubOAuth2
7
+ from api.database import User, OAuthAccount, get_user_db
8
+ from fastapi_users.manager import BaseUserManager, IntegerIDMixin
9
+ from fastapi import Depends
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ from fastapi_users.models import UP
12
+ from typing import Optional
13
  import os
14
  import logging
15
 
 
51
  google_oauth_client = GoogleOAuth2(GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET)
52
  github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
53
 
54
+ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
55
+ reset_password_token_secret = SECRET
56
+ verification_token_secret = SECRET
57
+
58
+ async def oauth_callback(
59
+ self,
60
+ oauth_name: str,
61
+ access_token: str,
62
+ account_id: str,
63
+ account_email: str,
64
+ expires_at: Optional[int] = None,
65
+ refresh_token: Optional[str] = None,
66
+ request: Optional[Request] = None,
67
+ *,
68
+ associate_by_email: bool = False,
69
+ is_verified_by_default: bool = False,
70
+ ) -> UP:
71
+ oauth_account_dict = {
72
+ "oauth_name": oauth_name,
73
+ "access_token": access_token,
74
+ "account_id": account_id,
75
+ "account_email": account_email,
76
+ "expires_at": expires_at,
77
+ "refresh_token": refresh_token,
78
+ }
79
+ oauth_account = OAuthAccount(**oauth_account_dict)
80
+ existing_oauth_account = await self.user_db.get_by_oauth_account(oauth_name, account_id)
81
+ if existing_oauth_account is not None:
82
+ return await self.on_after_login(existing_oauth_account.user, request)
83
+
84
+ if associate_by_email:
85
+ user = await self.user_db.get_by_email(account_email)
86
+ if user is not None:
87
+ oauth_account.user_id = user.id
88
+ await self.user_db.add_oauth_account(oauth_account)
89
+ return await self.on_after_login(user, request)
90
+
91
+ user_dict = {
92
+ "email": account_email,
93
+ "hashed_password": self.password_helper.hash("dummy_password"),
94
+ "is_active": True,
95
+ "is_verified": is_verified_by_default,
96
+ }
97
+ user = await self.user_db.create(user_dict)
98
+ oauth_account.user_id = user.id
99
+ await self.user_db.add_oauth_account(oauth_account)
100
+ return await self.on_after_login(user, request)
101
+
102
+ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
103
+ yield UserManager(user_db)
104
+
105
+ from fastapi_users.router.oauth import get_oauth_router
106
+
107
+ google_oauth_router = get_oauth_router(
108
+ google_oauth_client,
109
+ auth_backend,
110
+ get_user_manager,
111
+ associate_by_email=True,
112
+ redirect_url="https://mgzon-mgzon-app.hf.space/auth/google/callback",
113
+ )
114
+
115
+ github_oauth_router = get_oauth_router(
116
+ github_oauth_client,
117
+ auth_backend,
118
+ get_user_manager,
119
+ associate_by_email=True,
120
+ redirect_url="https://mgzon-mgzon-app.hf.space/auth/github/callback",
121
+ )
122
+
123
  fastapi_users = FastAPIUsers[User, int](
124
+ get_user_db,
125
  [auth_backend],
126
  )
127
 
128
  current_active_user = fastapi_users.current_user(active=True, optional=True)
129
+
130
+ def get_auth_router(app: FastAPI):
131
+ app.include_router(google_oauth_router, prefix="/auth/google", tags=["auth"])
132
+ app.include_router(github_oauth_router, prefix="/auth/github", tags=["auth"])
133
+ app.include_router(fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"])
134
+ app.include_router(fastapi_users.get_register_router(UserRead, UserCreate), prefix="/auth", tags=["auth"])
135
+ app.include_router(fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"])
136
+ app.include_router(fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"])
137
+ app.include_router(fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", tags=["users"])
api/database.py CHANGED
@@ -1,13 +1,20 @@
1
  # api/database.py
2
  import os
 
3
  from sqlalchemy import create_engine
4
  from sqlalchemy.ext.declarative import declarative_base
5
- from sqlalchemy.orm import sessionmaker
 
 
 
 
 
 
6
 
7
  # جلب URL قاعدة البيانات من المتغيرات البيئية
8
  SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL")
9
  if not SQLALCHEMY_DATABASE_URL:
10
- raise ValueError("SQLALCHEMY_DATABASE_URL is not set in environment variables.")
11
 
12
  # إنشاء المحرك
13
  engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
@@ -18,10 +25,64 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
18
  # قاعدة أساسية للنماذج
19
  Base = declarative_base()
20
 
21
- # دالة للحصول على الجلسة
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def get_db():
23
  db = SessionLocal()
24
  try:
25
  yield db
26
  finally:
27
  db.close()
 
 
 
 
 
 
1
  # api/database.py
2
  import os
3
+ from sqlalchemy import Column, String, Integer, ForeignKey, DateTime, Boolean
4
  from sqlalchemy import create_engine
5
  from sqlalchemy.ext.declarative import declarative_base
6
+ from sqlalchemy.orm import sessionmaker, relationship
7
+ from sqlalchemy.sql import func
8
+ from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
9
+ from sqlalchemy.orm import Session
10
+ from typing import AsyncGenerator
11
+ from fastapi import Depends
12
+ from datetime import datetime
13
 
14
  # جلب URL قاعدة البيانات من المتغيرات البيئية
15
  SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL")
16
  if not SQLALCHEMY_DATABASE_URL:
17
+ raise ValueValue("SQLALCHEMY_DATABASE_URL is not set in environment variables.")
18
 
19
  # إنشاء المحرك
20
  engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
 
25
  # قاعدة أساسية للنماذج
26
  Base = declarative_base()
27
 
28
+ class OAuthAccount(Base):
29
+ __tablename__ = "oauth_account"
30
+
31
+ id = Column(Integer, primary_key=True, index=True)
32
+ user_id = Column(Integer, ForeignKey("user.id"), nullable=False)
33
+ oauth_name = Column(String, nullable=False)
34
+ access_token = Column(String, nullable=False)
35
+ expires_at = Column(Integer, nullable=True)
36
+ refresh_token = Column(String, nullable=True)
37
+ account_id = Column(String, index=True, nullable=False)
38
+ account_email = Column(String, nullable=False)
39
+
40
+ user = relationship("User", back_populates="oauth_accounts")
41
+
42
+ class User(SQLAlchemyBaseUserTable[int], Base):
43
+ __tablename__ = "user"
44
+
45
+ id = Column(Integer, primary_key=True, index=True)
46
+ email = Column(String, unique=True, index=True, nullable=False)
47
+ hashed_password = Column(String, nullable=False)
48
+ is_active = Column(Boolean, default=True)
49
+ is_superuser = Column(Boolean, default=False)
50
+ is_verified = Column(Boolean, default=False)
51
+ display_name = Column(String, nullable=True)
52
+ preferred_model = Column(String, nullable=True)
53
+ job_title = Column(String, nullable=True)
54
+ education = Column(String, nullable=True)
55
+ interests = Column(String, nullable=True)
56
+ additional_info = Column(String, nullable=True)
57
+ conversation_style = Column(String, nullable=True)
58
+ oauth_accounts = relationship("OAuthAccount", back_populates="user", cascade="all, delete-orphan")
59
+
60
+ class Conversation(Base):
61
+ __tablename__ = "conversation"
62
+
63
+ id = Column(Integer, primary_key=True, index=True)
64
+ conversation_id = Column(String, unique=True, index=True)
65
+ user_id = Column(Integer, ForeignKey("user.id"))
66
+ title = Column(String)
67
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
68
+ updated_at = Column(DateTime(timezone=True), onupdate=func.now())
69
+
70
+ class Message(Base):
71
+ __tablename__ = "message"
72
+
73
+ id = Column(Integer, primary_key=True, index=True)
74
+ conversation_id = Column(Integer, ForeignKey("conversation.id"))
75
+ role = Column(String)
76
+ content = Column(String)
77
+
78
  def get_db():
79
  db = SessionLocal()
80
  try:
81
  yield db
82
  finally:
83
  db.close()
84
+
85
+ Base.metadata.create_all(bind=engine)
86
+
87
+ async def get_user_db(session: Session = Depends(get_db)):
88
+ yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
api/models.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyBaseOAuthAccountTable
 
2
  from sqlalchemy import Column, Integer, String, Boolean, Text, ForeignKey, DateTime
3
  from sqlalchemy.orm import relationship
4
  from sqlalchemy.ext.declarative import declarative_base
@@ -11,9 +12,16 @@ import uuid
11
  Base = declarative_base()
12
 
13
  # جدول OAuth Accounts لتخزين بيانات تسجيل الدخول الخارجي
14
- class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base):
15
  __tablename__ = "oauth_accounts"
16
  id = Column(Integer, primary_key=True)
 
 
 
 
 
 
 
17
  user = relationship("User", back_populates="oauth_accounts")
18
 
19
  # نموذج المستخدم
 
1
+ # models.py
2
+ from fastapi_users.db import SQLAlchemyBaseUserTable
3
  from sqlalchemy import Column, Integer, String, Boolean, Text, ForeignKey, DateTime
4
  from sqlalchemy.orm import relationship
5
  from sqlalchemy.ext.declarative import declarative_base
 
12
  Base = declarative_base()
13
 
14
  # جدول OAuth Accounts لتخزين بيانات تسجيل الدخول الخارجي
15
+ class OAuthAccount(Base):
16
  __tablename__ = "oauth_accounts"
17
  id = Column(Integer, primary_key=True)
18
+ user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
19
+ oauth_name = Column(String, nullable=False)
20
+ access_token = Column(String, nullable=False)
21
+ expires_at = Column(Integer, nullable=True)
22
+ refresh_token = Column(String, nullable=True)
23
+ account_id = Column(String, index=True, nullable=False)
24
+ account_email = Column(String, nullable=False)
25
  user = relationship("User", back_populates="oauth_accounts")
26
 
27
  # نموذج المستخدم
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # SPDX-FileCopyrightText: Hadad <[email protected]>
2
  # SPDX-License-Identifier: Apache-2.0
3
 
@@ -12,9 +13,8 @@ from starlette.middleware.sessions import SessionMiddleware
12
  from fastapi.openapi.docs import get_swagger_ui_html
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from api.endpoints import router as api_router
15
- from api.auth import fastapi_users, auth_backend, current_active_user, google_oauth_client, github_oauth_client
16
  from api.database import get_db, engine, Base
17
- from api.models import User, UserRead, UserCreate, Conversation
18
  from motor.motor_asyncio import AsyncIOMotorClient
19
  from pydantic import BaseModel
20
  from typing import List
@@ -25,8 +25,8 @@ from sqlalchemy.orm import Session
25
  from pathlib import Path
26
  from hashlib import md5
27
  from datetime import datetime
28
- import re
29
  from httpx_oauth.exceptions import GetIdEmailError
 
30
 
31
  # Setup logging
32
  logging.basicConfig(level=logging.INFO)
@@ -79,6 +79,7 @@ CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 20))
79
  @asynccontextmanager
80
  async def lifespan(app: FastAPI):
81
  await setup_mongo_index()
 
82
  yield
83
 
84
  app = FastAPI(title="MGZon Chatbot API", lifespan=lifespan)
@@ -86,9 +87,6 @@ app = FastAPI(title="MGZon Chatbot API", lifespan=lifespan)
86
  # Add SessionMiddleware
87
  app.add_middleware(SessionMiddleware, secret_key=JWT_SECRET)
88
 
89
- # Create SQLAlchemy tables
90
- Base.metadata.create_all(bind=engine)
91
-
92
  # Mount static files
93
  os.makedirs("static", exist_ok=True)
94
  app.mount("/static", StaticFiles(directory="static"), name="static")
@@ -106,42 +104,8 @@ app.add_middleware(
106
  )
107
 
108
  # Include routers
109
- app.include_router(
110
- fastapi_users.get_auth_router(auth_backend),
111
- prefix="/auth/jwt",
112
- tags=["auth"],
113
- )
114
- app.include_router(
115
- fastapi_users.get_register_router(UserRead, UserCreate),
116
- prefix="/auth",
117
- tags=["auth"],
118
- )
119
- app.include_router(
120
- fastapi_users.get_users_router(UserRead, UserCreate),
121
- prefix="/users",
122
- tags=["users"],
123
- )
124
- app.include_router(
125
- fastapi_users.get_oauth_router(
126
- google_oauth_client,
127
- auth_backend,
128
- JWT_SECRET,
129
- redirect_url="https://mgzon-mgzon-app.hf.space/auth/google/callback"
130
- ),
131
- prefix="/auth/google",
132
- tags=["auth"],
133
- )
134
- app.include_router(
135
- fastapi_users.get_oauth_router(
136
- github_oauth_client,
137
- auth_backend,
138
- JWT_SECRET,
139
- redirect_url="https://mgzon-mgzon-app.hf.space/auth/github/callback"
140
- ),
141
- prefix="/auth/github",
142
- tags=["auth"],
143
- )
144
  app.include_router(api_router)
 
145
 
146
  # Debug routes endpoint
147
  @app.get("/debug/routes", response_class=PlainTextResponse)
@@ -153,81 +117,6 @@ async def debug_routes():
153
  routes.append(f"{methods} {path}")
154
  return "\n".join(sorted(routes))
155
 
156
- # OAuth callbacks
157
- @app.get("/auth/google/callback", response_class=RedirectResponse)
158
- async def google_oauth_callback(
159
- request: Request,
160
- code: str = Query(...),
161
- state: str = Query(...),
162
- db: Session = Depends(get_db)
163
- ):
164
- try:
165
- logger.info("Processing Google OAuth callback")
166
- # Exchange code for access token
167
- token_data = await google_oauth_client.get_access_token(code, "https://mgzon-mgzon-app.hf.space/auth/google/callback")
168
- logger.info(f"Google OAuth token received: {token_data}")
169
- # Get user info
170
- user_info = await google_oauth_client.get_id_email(token_data["access_token"])
171
- logger.info(f"Google user info: {user_info}")
172
- # Create or update user
173
- user = await fastapi_users.oauth_callback(
174
- oauth_name="google",
175
- access_token=token_data["access_token"],
176
- account_id=user_info["id"],
177
- account_email=user_info["email"],
178
- expires_at=token_data.get("expires_at"),
179
- refresh_token=token_data.get("refresh_token"),
180
- request=request,
181
- db=db
182
- )
183
- logger.info("Google OAuth user processed, creating session")
184
- # Create JWT token
185
- token = await auth_backend.get_login_response(user, request)
186
- logger.info("Google OAuth callback processed, redirecting to /chat")
187
- response = RedirectResponse(url="/chat", status_code=302)
188
- response.set_cookie("Authorization", f"Bearer {token.access_token}", httponly=True)
189
- return response
190
- except Exception as e:
191
- logger.error(f"Google OAuth callback error: {str(e)}")
192
- return RedirectResponse(url=f"/login?error=Google%20OAuth%20failed:%20{str(e)}", status_code=302)
193
-
194
- @app.get("/auth/github/callback", response_class=RedirectResponse)
195
- async def github_oauth_callback(
196
- request: Request,
197
- code: str = Query(...),
198
- state: str = Query(...),
199
- db: Session = Depends(get_db)
200
- ):
201
- try:
202
- logger.info("Processing GitHub OAuth callback")
203
- # Exchange code for access token
204
- token_data = await github_oauth_client.get_access_token(code, "https://mgzon-mgzon-app.hf.space/auth/github/callback")
205
- logger.info(f"GitHub OAuth token received: {token_data}")
206
- # Get user info
207
- user_info = await github_oauth_client.get_id_email(token_data["access_token"])
208
- logger.info(f"GitHub user info: {user_info}")
209
- # Create or update user
210
- user = await fastapi_users.oauth_callback(
211
- oauth_name="github",
212
- access_token=token_data["access_token"],
213
- account_id=user_info["id"],
214
- account_email=user_info["email"],
215
- expires_at=token_data.get("expires_at"),
216
- refresh_token=token_data.get("refresh_token"),
217
- request=request,
218
- db=db
219
- )
220
- logger.info("GitHub OAuth user processed, creating session")
221
- # Create JWT token
222
- token = await auth_backend.get_login_response(user, request)
223
- logger.info("GitHub OAuth callback processed, redirecting to /chat")
224
- response = RedirectResponse(url="/chat", status_code=302)
225
- response.set_cookie("Authorization", f"Bearer {token.access_token}", httponly=True)
226
- return response
227
- except Exception as e:
228
- logger.error(f"GitHub OAuth callback error: {str(e)}")
229
- return RedirectResponse(url=f"/login?error=GitHub%20OAuth%20failed:%20{str(e)}", status_code=302)
230
-
231
  # Custom middleware for 404 and 500 errors
232
  class NotFoundMiddleware(BaseHTTPMiddleware):
233
  async def dispatch(self, request: Request, call_next):
 
1
+ # main.py
2
  # SPDX-FileCopyrightText: Hadad <[email protected]>
3
  # SPDX-License-Identifier: Apache-2.0
4
 
 
13
  from fastapi.openapi.docs import get_swagger_ui_html
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from api.endpoints import router as api_router
16
+ from api.auth import fastapi_users, auth_backend, current_active_user, get_auth_router
17
  from api.database import get_db, engine, Base
 
18
  from motor.motor_asyncio import AsyncIOMotorClient
19
  from pydantic import BaseModel
20
  from typing import List
 
25
  from pathlib import Path
26
  from hashlib import md5
27
  from datetime import datetime
 
28
  from httpx_oauth.exceptions import GetIdEmailError
29
+ from api.models import UserRead, UserCreate, Conversation, UserUpdate
30
 
31
  # Setup logging
32
  logging.basicConfig(level=logging.INFO)
 
79
  @asynccontextmanager
80
  async def lifespan(app: FastAPI):
81
  await setup_mongo_index()
82
+ Base.metadata.create_all(bind=engine) # Create tables on startup
83
  yield
84
 
85
  app = FastAPI(title="MGZon Chatbot API", lifespan=lifespan)
 
87
  # Add SessionMiddleware
88
  app.add_middleware(SessionMiddleware, secret_key=JWT_SECRET)
89
 
 
 
 
90
  # Mount static files
91
  os.makedirs("static", exist_ok=True)
92
  app.mount("/static", StaticFiles(directory="static"), name="static")
 
104
  )
105
 
106
  # Include routers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  app.include_router(api_router)
108
+ get_auth_router(app) # Add OAuth and auth routers
109
 
110
  # Debug routes endpoint
111
  @app.get("/debug/routes", response_class=PlainTextResponse)
 
117
  routes.append(f"{methods} {path}")
118
  return "\n".join(sorted(routes))
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # Custom middleware for 404 and 500 errors
121
  class NotFoundMiddleware(BaseHTTPMiddleware):
122
  async def dispatch(self, request: Request, call_next):