feat: add CI/CD workflow and sync latest workspace changes
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""API router package."""
|
||||
@@ -0,0 +1,13 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .v1.auth import router as auth_router
|
||||
from .v1.users import router as users_router
|
||||
|
||||
api_router = APIRouter(prefix="/api/v1")
|
||||
api_router.include_router(auth_router)
|
||||
api_router.include_router(users_router)
|
||||
|
||||
|
||||
@api_router.get("/ping")
|
||||
def ping() -> dict[str, str]:
|
||||
return {"message": "pong"}
|
||||
@@ -0,0 +1 @@
|
||||
"""Versioned API routes."""
|
||||
@@ -0,0 +1,128 @@
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...core.config import get_settings
|
||||
from ...core.database import get_db
|
||||
from ...core.dependencies import CurrentUser, get_current_user
|
||||
from ...schemas.auth import AuthTokenResponse, LoginRequest, MessageResponse, RegisterRequest
|
||||
from ...schemas.user import UserPublic
|
||||
from ...services.auth_service import (
|
||||
AuthResult,
|
||||
login_user,
|
||||
logout_user_session,
|
||||
refresh_user_session,
|
||||
register_user,
|
||||
)
|
||||
from ...services.user_service import serialize_user
|
||||
|
||||
settings = get_settings()
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str | None:
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return None
|
||||
|
||||
|
||||
def _set_refresh_cookie(response: Response, token: str) -> None:
|
||||
response.set_cookie(
|
||||
key=settings.refresh_cookie_name,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=settings.refresh_cookie_secure,
|
||||
samesite=settings.refresh_cookie_samesite,
|
||||
max_age=settings.refresh_token_expire_days * 24 * 60 * 60,
|
||||
path="/api/v1/auth",
|
||||
)
|
||||
|
||||
|
||||
def _clear_refresh_cookie(response: Response) -> None:
|
||||
response.delete_cookie(
|
||||
key=settings.refresh_cookie_name,
|
||||
path="/api/v1/auth",
|
||||
httponly=True,
|
||||
secure=settings.refresh_cookie_secure,
|
||||
samesite=settings.refresh_cookie_samesite,
|
||||
)
|
||||
|
||||
|
||||
def _to_auth_response(result: AuthResult) -> AuthTokenResponse:
|
||||
return AuthTokenResponse(
|
||||
access_token=result.access_token,
|
||||
expires_in=result.expires_in,
|
||||
user=serialize_user(result.user),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokenResponse)
|
||||
def register(
|
||||
payload: RegisterRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
) -> AuthTokenResponse:
|
||||
result = register_user(
|
||||
db,
|
||||
payload,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
ip_address=_client_ip(request),
|
||||
)
|
||||
_set_refresh_cookie(response, result.refresh_token)
|
||||
return _to_auth_response(result)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokenResponse)
|
||||
def login(
|
||||
payload: LoginRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
) -> AuthTokenResponse:
|
||||
result = login_user(
|
||||
db,
|
||||
payload,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
ip_address=_client_ip(request),
|
||||
)
|
||||
_set_refresh_cookie(response, result.refresh_token)
|
||||
return _to_auth_response(result)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokenResponse)
|
||||
def refresh(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
) -> AuthTokenResponse:
|
||||
result = refresh_user_session(
|
||||
db,
|
||||
request.cookies.get(settings.refresh_cookie_name),
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
ip_address=_client_ip(request),
|
||||
)
|
||||
_set_refresh_cookie(response, result.refresh_token)
|
||||
return _to_auth_response(result)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
db: Session = Depends(get_db),
|
||||
) -> MessageResponse:
|
||||
logout_user_session(
|
||||
db,
|
||||
request.cookies.get(settings.refresh_cookie_name),
|
||||
user_id=None,
|
||||
)
|
||||
_clear_refresh_cookie(response)
|
||||
return MessageResponse(message="Logged out")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserPublic)
|
||||
def me(current_user: CurrentUser = Depends(get_current_user)) -> UserPublic:
|
||||
return serialize_user(current_user.user)
|
||||
@@ -0,0 +1,76 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...core.database import get_db
|
||||
from ...core.dependencies import CurrentUser, get_current_user, require_permission
|
||||
from ...schemas.user import UserListResponse, UserPublic, UserRoleUpdateRequest, UserUpdateRequest
|
||||
from ...services.user_service import (
|
||||
get_user_by_id,
|
||||
list_users,
|
||||
serialize_user,
|
||||
set_user_roles,
|
||||
update_user,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("", response_model=UserListResponse)
|
||||
def list_all_users(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
_: CurrentUser = Depends(require_permission("user.manage")),
|
||||
db: Session = Depends(get_db),
|
||||
) -> UserListResponse:
|
||||
return list_users(db, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserPublic)
|
||||
def get_user_detail(
|
||||
user_id: str,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> UserPublic:
|
||||
can_manage = "admin" in current_user.role_codes or "user.manage" in current_user.permission_codes
|
||||
if current_user.user.id != user_id and not can_manage:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions",
|
||||
)
|
||||
|
||||
user = get_user_by_id(db, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
return serialize_user(user)
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserPublic)
|
||||
def update_user_profile(
|
||||
user_id: str,
|
||||
payload: UserUpdateRequest,
|
||||
_: CurrentUser = Depends(require_permission("user.manage")),
|
||||
db: Session = Depends(get_db),
|
||||
) -> UserPublic:
|
||||
updated = update_user(db, user_id, payload)
|
||||
if not updated:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found or username exists",
|
||||
)
|
||||
return updated
|
||||
|
||||
|
||||
@router.post("/{user_id}/roles", response_model=UserPublic)
|
||||
def assign_roles(
|
||||
user_id: str,
|
||||
payload: UserRoleUpdateRequest,
|
||||
_: CurrentUser = Depends(require_permission("user.manage")),
|
||||
db: Session = Depends(get_db),
|
||||
) -> UserPublic:
|
||||
updated = set_user_roles(db, user_id, payload)
|
||||
if not updated:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found or invalid roles",
|
||||
)
|
||||
return updated
|
||||
+2
-29
@@ -1,30 +1,3 @@
|
||||
from functools import lru_cache
|
||||
from .core.config import Settings, get_settings
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
api_name: str = "fquiz-api"
|
||||
api_version: str = "0.1.0"
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
api_cors_origins: str = "http://localhost:3000,http://127.0.0.1:3000"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
return [
|
||||
origin.strip()
|
||||
for origin in self.api_cors_origins.split(",")
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Core infrastructure for config, database, and security."""
|
||||
@@ -0,0 +1,56 @@
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
api_name: str = "fquiz-api"
|
||||
api_version: str = "0.1.0"
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
api_cors_origins: str = "http://localhost:3000,http://127.0.0.1:3000"
|
||||
|
||||
database_url: str = "sqlite:///./fquiz.db"
|
||||
|
||||
jwt_secret_key: str = "change-this-in-production"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 30
|
||||
|
||||
refresh_cookie_name: str = "refresh_token"
|
||||
refresh_cookie_secure: bool = False
|
||||
refresh_cookie_samesite: Literal["lax", "strict", "none"] = "lax"
|
||||
|
||||
initial_admin_email: str | None = None
|
||||
initial_admin_username: str = "admin"
|
||||
initial_admin_password: str | None = None
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@field_validator("access_token_expire_minutes", "refresh_token_expire_days")
|
||||
@classmethod
|
||||
def validate_positive_numbers(cls, value: int) -> int:
|
||||
if value <= 0:
|
||||
msg = "Value must be greater than 0"
|
||||
raise ValueError(msg)
|
||||
return value
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
return [
|
||||
origin.strip()
|
||||
for origin in self.api_cors_origins.split(",")
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
@@ -0,0 +1,47 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from .config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
connect_args: dict[str, bool] = {}
|
||||
if settings.database_url.startswith("sqlite"):
|
||||
connect_args["check_same_thread"] = False
|
||||
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
bind=engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
# Import models so metadata includes every table before create_all.
|
||||
from ..models import audit_log, auth_session, rbac, user # noqa: F401
|
||||
from ..services.seed_service import seed_defaults
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
with SessionLocal() as db:
|
||||
seed_defaults(db)
|
||||
@@ -0,0 +1,73 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from ..models.rbac import Role
|
||||
from ..models.user import User
|
||||
from .database import get_db
|
||||
from .security import decode_access_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurrentUser:
|
||||
user: User
|
||||
role_codes: set[str]
|
||||
permission_codes: set[str]
|
||||
|
||||
|
||||
def _load_user_with_rbac(db: Session, user_id: str) -> User | None:
|
||||
stmt = (
|
||||
select(User)
|
||||
.options(joinedload(User.roles).joinedload(Role.permissions))
|
||||
.where(User.id == user_id)
|
||||
)
|
||||
return db.execute(stmt).unique().scalar_one_or_none()
|
||||
|
||||
|
||||
def _get_user_permissions(user: User) -> set[str]:
|
||||
return {permission.code for role in user.roles for permission in role.permissions}
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme),
|
||||
) -> CurrentUser:
|
||||
payload = decode_access_token(token)
|
||||
user_id = str(payload["sub"])
|
||||
user = _load_user_with_rbac(db, user_id)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is disabled",
|
||||
)
|
||||
|
||||
return CurrentUser(
|
||||
user=user,
|
||||
role_codes={role.code for role in user.roles},
|
||||
permission_codes=_get_user_permissions(user),
|
||||
)
|
||||
|
||||
|
||||
def require_permission(permission_code: str):
|
||||
def dependency(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
if "admin" in current_user.role_codes:
|
||||
return current_user
|
||||
if permission_code not in current_user.permission_codes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing permission: {permission_code}",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return dependency
|
||||
@@ -0,0 +1,81 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import InvalidHash, VerifyMismatchError
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from .config import get_settings
|
||||
|
||||
password_hasher = PasswordHasher()
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return password_hasher.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
try:
|
||||
return password_hasher.verify(password_hash, password)
|
||||
except (VerifyMismatchError, InvalidHash):
|
||||
return False
|
||||
|
||||
|
||||
def create_access_token(
|
||||
*,
|
||||
user_id: str,
|
||||
role_codes: list[str],
|
||||
permission_codes: list[str],
|
||||
expires_minutes: int | None = None,
|
||||
) -> tuple[str, int]:
|
||||
settings = get_settings()
|
||||
now = datetime.now(timezone.utc)
|
||||
minutes = expires_minutes or settings.access_token_expire_minutes
|
||||
expires_at = now + timedelta(minutes=minutes)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"roles": role_codes,
|
||||
"permissions": permission_codes,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expires_at.timestamp()),
|
||||
}
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
settings.jwt_secret_key,
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
return token, int((expires_at - now).total_seconds())
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict[str, Any]:
|
||||
settings = get_settings()
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.jwt_secret_key,
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
except jwt.PyJWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired access token",
|
||||
) from exc
|
||||
|
||||
if not payload.get("sub"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid access token payload",
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def create_refresh_token() -> str:
|
||||
return secrets.token_urlsafe(48)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
+14
-8
@@ -1,13 +1,25 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .config import get_settings
|
||||
from .api.router import api_router
|
||||
from .core.config import get_settings
|
||||
from .core.database import init_db
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
init_db()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.api_name,
|
||||
version=settings.api_version,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -27,10 +39,4 @@ def health() -> dict[str, str]:
|
||||
"version": settings.api_version,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/ping")
|
||||
def ping() -> dict[str, str]:
|
||||
return {
|
||||
"message": "pong",
|
||||
"service": settings.api_name,
|
||||
}
|
||||
app.include_router(api_router)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Database models."""
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..core.database import Base
|
||||
from .base import utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
index=True,
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(128), index=True)
|
||||
detail: Mapped[str | None] = mapped_column(Text())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="audit_logs")
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..core.database import Base
|
||||
from .base import utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class AuthSession(Base):
|
||||
__tablename__ = "auth_sessions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid4()),
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
)
|
||||
refresh_token_hash: Mapped[str] = mapped_column(String(128), unique=True, index=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(String(512))
|
||||
ip_address: Mapped[str | None] = mapped_column(String(64))
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="sessions")
|
||||
@@ -0,0 +1,5 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..core.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class Role(Base):
|
||||
__tablename__ = "roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
code: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
|
||||
users: Mapped[list[User]] = relationship(
|
||||
"User",
|
||||
secondary="user_roles",
|
||||
back_populates="roles",
|
||||
lazy="selectin",
|
||||
)
|
||||
permissions: Mapped[list[Permission]] = relationship(
|
||||
"Permission",
|
||||
secondary="role_permissions",
|
||||
back_populates="roles",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
|
||||
class Permission(Base):
|
||||
__tablename__ = "permissions"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
code: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
|
||||
roles: Mapped[list[Role]] = relationship(
|
||||
"Role",
|
||||
secondary="role_permissions",
|
||||
back_populates="permissions",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
|
||||
class UserRole(Base):
|
||||
__tablename__ = "user_roles"
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
role_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
|
||||
class RolePermission(Base):
|
||||
__tablename__ = "role_permissions"
|
||||
|
||||
role_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
permission_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("permissions.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..core.database import Base
|
||||
from .base import utcnow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .auth_session import AuthSession
|
||||
from .audit_log import AuditLog
|
||||
from .rbac import Role
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid4()),
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255))
|
||||
status: Mapped[str] = mapped_column(String(32), default="active", index=True)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utcnow,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utcnow,
|
||||
onupdate=utcnow,
|
||||
)
|
||||
|
||||
roles: Mapped[list[Role]] = relationship(
|
||||
"Role",
|
||||
secondary="user_roles",
|
||||
back_populates="users",
|
||||
lazy="selectin",
|
||||
)
|
||||
sessions: Mapped[list[AuthSession]] = relationship(
|
||||
"AuthSession",
|
||||
back_populates="user",
|
||||
lazy="selectin",
|
||||
)
|
||||
audit_logs: Mapped[list[AuditLog]] = relationship(
|
||||
"AuditLog",
|
||||
back_populates="user",
|
||||
lazy="selectin",
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Pydantic schemas."""
|
||||
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from .user import UserPublic
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
username: str = Field(min_length=3, max_length=64)
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=128)
|
||||
|
||||
|
||||
class AuthTokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: UserPublic
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
@@ -0,0 +1,29 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class UserPublic(BaseModel):
|
||||
id: str
|
||||
email: EmailStr
|
||||
username: str
|
||||
status: str
|
||||
role_codes: list[str]
|
||||
permission_codes: list[str]
|
||||
created_at: datetime
|
||||
last_login_at: datetime | None = None
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
items: list[UserPublic]
|
||||
total: int
|
||||
|
||||
|
||||
class UserUpdateRequest(BaseModel):
|
||||
username: str | None = Field(default=None, min_length=3, max_length=64)
|
||||
status: Literal["active", "disabled"] | None = None
|
||||
|
||||
|
||||
class UserRoleUpdateRequest(BaseModel):
|
||||
role_codes: list[str] = Field(min_length=1)
|
||||
@@ -0,0 +1 @@
|
||||
"""Business services."""
|
||||
@@ -0,0 +1,238 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models.audit_log import AuditLog
|
||||
from ..models.auth_session import AuthSession
|
||||
from ..models.base import utcnow
|
||||
from ..models.rbac import Role
|
||||
from ..models.user import User
|
||||
from ..schemas.auth import LoginRequest, RegisterRequest
|
||||
from .user_service import get_user_by_email
|
||||
from ..core.config import get_settings
|
||||
from ..core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
hash_password,
|
||||
hash_token,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
access_token: str
|
||||
expires_in: int
|
||||
refresh_token: str
|
||||
user: User
|
||||
|
||||
|
||||
def register_user(
|
||||
db: Session,
|
||||
payload: RegisterRequest,
|
||||
*,
|
||||
user_agent: str | None,
|
||||
ip_address: str | None,
|
||||
) -> AuthResult:
|
||||
email = payload.email.lower()
|
||||
|
||||
duplicate = db.scalar(
|
||||
select(User.id).where(or_(User.email == email, User.username == payload.username))
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Email or username already exists",
|
||||
)
|
||||
|
||||
role = db.scalar(select(Role).where(Role.code == "user"))
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Default role not initialized",
|
||||
)
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
username=payload.username,
|
||||
password_hash=hash_password(payload.password),
|
||||
status="active",
|
||||
)
|
||||
user.roles.append(role)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
|
||||
db.add(AuditLog(user_id=user.id, action="auth.register", detail="User registered"))
|
||||
db.commit()
|
||||
|
||||
return issue_auth_result_for_user(
|
||||
db,
|
||||
user_id=user.id,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
action="auth.login_after_register",
|
||||
)
|
||||
|
||||
|
||||
def login_user(
|
||||
db: Session,
|
||||
payload: LoginRequest,
|
||||
*,
|
||||
user_agent: str | None,
|
||||
ip_address: str | None,
|
||||
) -> AuthResult:
|
||||
user = get_user_by_email(db, payload.email.lower())
|
||||
if not user or not verify_password(payload.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
if user.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is disabled",
|
||||
)
|
||||
|
||||
return issue_auth_result_for_user(
|
||||
db,
|
||||
user_id=user.id,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
action="auth.login",
|
||||
)
|
||||
|
||||
|
||||
def refresh_user_session(
|
||||
db: Session,
|
||||
refresh_token: str | None,
|
||||
*,
|
||||
user_agent: str | None,
|
||||
ip_address: str | None,
|
||||
) -> AuthResult:
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing refresh token",
|
||||
)
|
||||
|
||||
now = utcnow()
|
||||
token_hash = hash_token(refresh_token)
|
||||
session = db.scalar(
|
||||
select(AuthSession).where(
|
||||
and_(
|
||||
AuthSession.refresh_token_hash == token_hash,
|
||||
AuthSession.revoked_at.is_(None),
|
||||
AuthSession.expires_at > now,
|
||||
)
|
||||
)
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh session",
|
||||
)
|
||||
|
||||
session.revoked_at = now
|
||||
db.add(AuditLog(user_id=session.user_id, action="auth.refresh", detail="Session rotated"))
|
||||
db.commit()
|
||||
|
||||
return issue_auth_result_for_user(
|
||||
db,
|
||||
user_id=session.user_id,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
action="auth.refresh_issued",
|
||||
)
|
||||
|
||||
|
||||
def logout_user_session(db: Session, refresh_token: str | None, *, user_id: str | None) -> None:
|
||||
if not refresh_token:
|
||||
return
|
||||
|
||||
token_hash = hash_token(refresh_token)
|
||||
now = utcnow()
|
||||
session = db.scalar(
|
||||
select(AuthSession).where(
|
||||
and_(
|
||||
AuthSession.refresh_token_hash == token_hash,
|
||||
AuthSession.revoked_at.is_(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
if not session:
|
||||
return
|
||||
|
||||
if user_id and session.user_id != user_id:
|
||||
return
|
||||
|
||||
session.revoked_at = now
|
||||
db.add(AuditLog(user_id=session.user_id, action="auth.logout", detail="Session revoked"))
|
||||
db.commit()
|
||||
|
||||
|
||||
def issue_auth_result_for_user(
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
user_agent: str | None,
|
||||
ip_address: str | None,
|
||||
action: str,
|
||||
) -> AuthResult:
|
||||
user = get_user_by_id_with_rbac(db, user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is disabled",
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token()
|
||||
refresh_expires_at = utcnow() + timedelta(days=settings.refresh_token_expire_days)
|
||||
|
||||
db.add(
|
||||
AuthSession(
|
||||
user_id=user.id,
|
||||
refresh_token_hash=hash_token(refresh_token),
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
expires_at=refresh_expires_at,
|
||||
)
|
||||
)
|
||||
|
||||
user.last_login_at = utcnow()
|
||||
db.add(AuditLog(user_id=user.id, action=action, detail="Access token issued"))
|
||||
db.commit()
|
||||
|
||||
user = get_user_by_id_with_rbac(db, user_id)
|
||||
role_codes = sorted({role.code for role in user.roles})
|
||||
permission_codes = sorted(
|
||||
{permission.code for role in user.roles for permission in role.permissions}
|
||||
)
|
||||
access_token, expires_in = create_access_token(
|
||||
user_id=user.id,
|
||||
role_codes=role_codes,
|
||||
permission_codes=permission_codes,
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
access_token=access_token,
|
||||
expires_in=expires_in,
|
||||
refresh_token=refresh_token,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
def get_user_by_id_with_rbac(db: Session, user_id: str) -> User | None:
|
||||
from .user_service import get_user_by_id
|
||||
|
||||
return get_user_by_id(db, user_id)
|
||||
@@ -0,0 +1,90 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..core.config import get_settings
|
||||
from ..core.security import hash_password
|
||||
from ..models.rbac import Permission, Role
|
||||
from ..models.user import User
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
DEFAULT_PERMISSIONS: dict[str, str] = {
|
||||
"user.read": "Read user profile",
|
||||
"user.write": "Update user profile",
|
||||
"user.manage": "Manage all users and roles",
|
||||
}
|
||||
|
||||
DEFAULT_ROLES: dict[str, dict[str, object]] = {
|
||||
"admin": {
|
||||
"name": "Administrator",
|
||||
"permissions": ["user.read", "user.write", "user.manage"],
|
||||
},
|
||||
"user": {
|
||||
"name": "User",
|
||||
"permissions": ["user.read"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def seed_defaults(db: Session) -> None:
|
||||
permissions = _seed_permissions(db)
|
||||
_seed_roles(db, permissions)
|
||||
_seed_initial_admin(db)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _seed_permissions(db: Session) -> dict[str, Permission]:
|
||||
permission_map: dict[str, Permission] = {}
|
||||
for code, name in DEFAULT_PERMISSIONS.items():
|
||||
permission = db.scalar(select(Permission).where(Permission.code == code))
|
||||
if not permission:
|
||||
permission = Permission(code=code, name=name)
|
||||
db.add(permission)
|
||||
permission_map[code] = permission
|
||||
|
||||
db.flush()
|
||||
# Refresh map with persisted entities.
|
||||
for code in DEFAULT_PERMISSIONS:
|
||||
permission = db.scalar(select(Permission).where(Permission.code == code))
|
||||
if not permission:
|
||||
msg = f"Permission not found after seeding: {code}"
|
||||
raise RuntimeError(msg)
|
||||
permission_map[code] = permission
|
||||
return permission_map
|
||||
|
||||
|
||||
def _seed_roles(db: Session, permission_map: dict[str, Permission]) -> None:
|
||||
for code, role_info in DEFAULT_ROLES.items():
|
||||
role = db.scalar(select(Role).where(Role.code == code))
|
||||
if not role:
|
||||
role = Role(code=code, name=str(role_info["name"]))
|
||||
db.add(role)
|
||||
db.flush()
|
||||
|
||||
role.permissions = [permission_map[p] for p in role_info["permissions"]]
|
||||
db.flush()
|
||||
|
||||
|
||||
def _seed_initial_admin(db: Session) -> None:
|
||||
if not settings.initial_admin_email or not settings.initial_admin_password:
|
||||
return
|
||||
|
||||
admin_role = db.scalar(select(Role).where(Role.code == "admin"))
|
||||
if not admin_role:
|
||||
return
|
||||
|
||||
admin_email = settings.initial_admin_email.lower()
|
||||
user = db.scalar(select(User).where(User.email == admin_email))
|
||||
if not user:
|
||||
user = User(
|
||||
email=admin_email,
|
||||
username=settings.initial_admin_username,
|
||||
password_hash=hash_password(settings.initial_admin_password),
|
||||
status="active",
|
||||
)
|
||||
db.add(user)
|
||||
db.flush()
|
||||
|
||||
role_codes = {role.code for role in user.roles}
|
||||
if "admin" not in role_codes:
|
||||
user.roles.append(admin_role)
|
||||
@@ -0,0 +1,98 @@
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from ..models.rbac import Role
|
||||
from ..models.user import User
|
||||
from ..schemas.user import UserListResponse, UserPublic, UserRoleUpdateRequest, UserUpdateRequest
|
||||
|
||||
|
||||
def _user_with_rbac_stmt():
|
||||
return select(User).options(joinedload(User.roles).joinedload(Role.permissions))
|
||||
|
||||
|
||||
def list_users(db: Session, *, limit: int, offset: int) -> UserListResponse:
|
||||
total = db.scalar(select(func.count()).select_from(User)) or 0
|
||||
stmt = (
|
||||
_user_with_rbac_stmt()
|
||||
.order_by(User.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
users = db.execute(stmt).unique().scalars().all()
|
||||
return UserListResponse(items=[serialize_user(user) for user in users], total=total)
|
||||
|
||||
|
||||
def get_user_by_id(db: Session, user_id: str) -> User | None:
|
||||
stmt = _user_with_rbac_stmt().where(User.id == user_id)
|
||||
return db.execute(stmt).unique().scalar_one_or_none()
|
||||
|
||||
|
||||
def get_user_by_email(db: Session, email: str) -> User | None:
|
||||
stmt = _user_with_rbac_stmt().where(User.email == email)
|
||||
return db.execute(stmt).unique().scalar_one_or_none()
|
||||
|
||||
|
||||
def update_user(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
payload: UserUpdateRequest,
|
||||
) -> UserPublic | None:
|
||||
user = get_user_by_id(db, user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if payload.username and payload.username != user.username:
|
||||
duplicate = db.scalar(
|
||||
select(User.id).where(User.username == payload.username, User.id != user.id)
|
||||
)
|
||||
if duplicate:
|
||||
return None
|
||||
user.username = payload.username
|
||||
|
||||
if payload.status:
|
||||
user.status = payload.status
|
||||
|
||||
db.commit()
|
||||
updated = get_user_by_id(db, user_id)
|
||||
return serialize_user(updated) if updated else None
|
||||
|
||||
|
||||
def set_user_roles(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
payload: UserRoleUpdateRequest,
|
||||
) -> UserPublic | None:
|
||||
user = get_user_by_id(db, user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
role_codes = sorted(set(payload.role_codes))
|
||||
roles = db.execute(select(Role).where(Role.code.in_(role_codes))).scalars().all()
|
||||
if len(roles) != len(role_codes):
|
||||
return None
|
||||
|
||||
user.roles = roles
|
||||
db.commit()
|
||||
updated = get_user_by_id(db, user_id)
|
||||
return serialize_user(updated)
|
||||
|
||||
|
||||
def serialize_user(user: User | None) -> UserPublic:
|
||||
if user is None:
|
||||
msg = "User is required"
|
||||
raise ValueError(msg)
|
||||
|
||||
role_codes = sorted({role.code for role in user.roles})
|
||||
permission_codes = sorted(
|
||||
{permission.code for role in user.roles for permission in role.permissions}
|
||||
)
|
||||
return UserPublic(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
status=user.status,
|
||||
role_codes=role_codes,
|
||||
permission_codes=permission_codes,
|
||||
created_at=user.created_at,
|
||||
last_login_at=user.last_login_at,
|
||||
)
|
||||
Reference in New Issue
Block a user