feat: add CI/CD workflow and sync latest workspace changes

This commit is contained in:
chengkai3
2026-04-12 00:03:30 +08:00
parent add7517a1d
commit 0eb656aaf2
42 changed files with 2055 additions and 96 deletions
+1
View File
@@ -0,0 +1 @@
"""API router package."""
+13
View File
@@ -0,0 +1,13 @@
from fastapi import APIRouter
from .v1.auth import router as auth_router
from .v1.users import router as users_router
api_router = APIRouter(prefix="/api/v1")
api_router.include_router(auth_router)
api_router.include_router(users_router)
@api_router.get("/ping")
def ping() -> dict[str, str]:
return {"message": "pong"}
+1
View File
@@ -0,0 +1 @@
"""Versioned API routes."""
+128
View File
@@ -0,0 +1,128 @@
from fastapi import APIRouter, Depends, Request, Response
from sqlalchemy.orm import Session
from ...core.config import get_settings
from ...core.database import get_db
from ...core.dependencies import CurrentUser, get_current_user
from ...schemas.auth import AuthTokenResponse, LoginRequest, MessageResponse, RegisterRequest
from ...schemas.user import UserPublic
from ...services.auth_service import (
AuthResult,
login_user,
logout_user_session,
refresh_user_session,
register_user,
)
from ...services.user_service import serialize_user
settings = get_settings()
router = APIRouter(prefix="/auth", tags=["auth"])
def _client_ip(request: Request) -> str | None:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
if request.client:
return request.client.host
return None
def _set_refresh_cookie(response: Response, token: str) -> None:
response.set_cookie(
key=settings.refresh_cookie_name,
value=token,
httponly=True,
secure=settings.refresh_cookie_secure,
samesite=settings.refresh_cookie_samesite,
max_age=settings.refresh_token_expire_days * 24 * 60 * 60,
path="/api/v1/auth",
)
def _clear_refresh_cookie(response: Response) -> None:
response.delete_cookie(
key=settings.refresh_cookie_name,
path="/api/v1/auth",
httponly=True,
secure=settings.refresh_cookie_secure,
samesite=settings.refresh_cookie_samesite,
)
def _to_auth_response(result: AuthResult) -> AuthTokenResponse:
return AuthTokenResponse(
access_token=result.access_token,
expires_in=result.expires_in,
user=serialize_user(result.user),
)
@router.post("/register", response_model=AuthTokenResponse)
def register(
payload: RegisterRequest,
request: Request,
response: Response,
db: Session = Depends(get_db),
) -> AuthTokenResponse:
result = register_user(
db,
payload,
user_agent=request.headers.get("user-agent"),
ip_address=_client_ip(request),
)
_set_refresh_cookie(response, result.refresh_token)
return _to_auth_response(result)
@router.post("/login", response_model=AuthTokenResponse)
def login(
payload: LoginRequest,
request: Request,
response: Response,
db: Session = Depends(get_db),
) -> AuthTokenResponse:
result = login_user(
db,
payload,
user_agent=request.headers.get("user-agent"),
ip_address=_client_ip(request),
)
_set_refresh_cookie(response, result.refresh_token)
return _to_auth_response(result)
@router.post("/refresh", response_model=AuthTokenResponse)
def refresh(
request: Request,
response: Response,
db: Session = Depends(get_db),
) -> AuthTokenResponse:
result = refresh_user_session(
db,
request.cookies.get(settings.refresh_cookie_name),
user_agent=request.headers.get("user-agent"),
ip_address=_client_ip(request),
)
_set_refresh_cookie(response, result.refresh_token)
return _to_auth_response(result)
@router.post("/logout", response_model=MessageResponse)
def logout(
request: Request,
response: Response,
db: Session = Depends(get_db),
) -> MessageResponse:
logout_user_session(
db,
request.cookies.get(settings.refresh_cookie_name),
user_id=None,
)
_clear_refresh_cookie(response)
return MessageResponse(message="Logged out")
@router.get("/me", response_model=UserPublic)
def me(current_user: CurrentUser = Depends(get_current_user)) -> UserPublic:
return serialize_user(current_user.user)
+76
View File
@@ -0,0 +1,76 @@
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from ...core.database import get_db
from ...core.dependencies import CurrentUser, get_current_user, require_permission
from ...schemas.user import UserListResponse, UserPublic, UserRoleUpdateRequest, UserUpdateRequest
from ...services.user_service import (
get_user_by_id,
list_users,
serialize_user,
set_user_roles,
update_user,
)
router = APIRouter(prefix="/users", tags=["users"])
@router.get("", response_model=UserListResponse)
def list_all_users(
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
_: CurrentUser = Depends(require_permission("user.manage")),
db: Session = Depends(get_db),
) -> UserListResponse:
return list_users(db, limit=limit, offset=offset)
@router.get("/{user_id}", response_model=UserPublic)
def get_user_detail(
user_id: str,
current_user: CurrentUser = Depends(get_current_user),
db: Session = Depends(get_db),
) -> UserPublic:
can_manage = "admin" in current_user.role_codes or "user.manage" in current_user.permission_codes
if current_user.user.id != user_id and not can_manage:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions",
)
user = get_user_by_id(db, user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return serialize_user(user)
@router.patch("/{user_id}", response_model=UserPublic)
def update_user_profile(
user_id: str,
payload: UserUpdateRequest,
_: CurrentUser = Depends(require_permission("user.manage")),
db: Session = Depends(get_db),
) -> UserPublic:
updated = update_user(db, user_id, payload)
if not updated:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found or username exists",
)
return updated
@router.post("/{user_id}/roles", response_model=UserPublic)
def assign_roles(
user_id: str,
payload: UserRoleUpdateRequest,
_: CurrentUser = Depends(require_permission("user.manage")),
db: Session = Depends(get_db),
) -> UserPublic:
updated = set_user_roles(db, user_id, payload)
if not updated:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found or invalid roles",
)
return updated
+2 -29
View File
@@ -1,30 +1,3 @@
from functools import lru_cache
from .core.config import Settings, get_settings
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
api_name: str = "fquiz-api"
api_version: str = "0.1.0"
api_host: str = "0.0.0.0"
api_port: int = 8000
api_cors_origins: str = "http://localhost:3000,http://127.0.0.1:3000"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
)
@property
def cors_origins(self) -> list[str]:
return [
origin.strip()
for origin in self.api_cors_origins.split(",")
if origin.strip()
]
@lru_cache
def get_settings() -> Settings:
return Settings()
__all__ = ["Settings", "get_settings"]
+1
View File
@@ -0,0 +1 @@
"""Core infrastructure for config, database, and security."""
+56
View File
@@ -0,0 +1,56 @@
from functools import lru_cache
from typing import Literal
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
api_name: str = "fquiz-api"
api_version: str = "0.1.0"
api_host: str = "0.0.0.0"
api_port: int = 8000
api_cors_origins: str = "http://localhost:3000,http://127.0.0.1:3000"
database_url: str = "sqlite:///./fquiz.db"
jwt_secret_key: str = "change-this-in-production"
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 30
refresh_cookie_name: str = "refresh_token"
refresh_cookie_secure: bool = False
refresh_cookie_samesite: Literal["lax", "strict", "none"] = "lax"
initial_admin_email: str | None = None
initial_admin_username: str = "admin"
initial_admin_password: str | None = None
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
@field_validator("access_token_expire_minutes", "refresh_token_expire_days")
@classmethod
def validate_positive_numbers(cls, value: int) -> int:
if value <= 0:
msg = "Value must be greater than 0"
raise ValueError(msg)
return value
@property
def cors_origins(self) -> list[str]:
return [
origin.strip()
for origin in self.api_cors_origins.split(",")
if origin.strip()
]
@lru_cache
def get_settings() -> Settings:
return Settings()
+47
View File
@@ -0,0 +1,47 @@
from collections.abc import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from .config import get_settings
settings = get_settings()
connect_args: dict[str, bool] = {}
if settings.database_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
engine = create_engine(
settings.database_url,
pool_pre_ping=True,
connect_args=connect_args,
)
SessionLocal = sessionmaker(
bind=engine,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
class Base(DeclarativeBase):
pass
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db() -> None:
# Import models so metadata includes every table before create_all.
from ..models import audit_log, auth_session, rbac, user # noqa: F401
from ..services.seed_service import seed_defaults
Base.metadata.create_all(bind=engine)
with SessionLocal() as db:
seed_defaults(db)
+73
View File
@@ -0,0 +1,73 @@
from dataclasses import dataclass
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from ..models.rbac import Role
from ..models.user import User
from .database import get_db
from .security import decode_access_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@dataclass
class CurrentUser:
user: User
role_codes: set[str]
permission_codes: set[str]
def _load_user_with_rbac(db: Session, user_id: str) -> User | None:
stmt = (
select(User)
.options(joinedload(User.roles).joinedload(Role.permissions))
.where(User.id == user_id)
)
return db.execute(stmt).unique().scalar_one_or_none()
def _get_user_permissions(user: User) -> set[str]:
return {permission.code for role in user.roles for permission in role.permissions}
def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme),
) -> CurrentUser:
payload = decode_access_token(token)
user_id = str(payload["sub"])
user = _load_user_with_rbac(db, user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
if user.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User is disabled",
)
return CurrentUser(
user=user,
role_codes={role.code for role in user.roles},
permission_codes=_get_user_permissions(user),
)
def require_permission(permission_code: str):
def dependency(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if "admin" in current_user.role_codes:
return current_user
if permission_code not in current_user.permission_codes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing permission: {permission_code}",
)
return current_user
return dependency
+81
View File
@@ -0,0 +1,81 @@
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any
import jwt
from argon2 import PasswordHasher
from argon2.exceptions import InvalidHash, VerifyMismatchError
from fastapi import HTTPException, status
from .config import get_settings
password_hasher = PasswordHasher()
def hash_password(password: str) -> str:
return password_hasher.hash(password)
def verify_password(password: str, password_hash: str) -> bool:
try:
return password_hasher.verify(password_hash, password)
except (VerifyMismatchError, InvalidHash):
return False
def create_access_token(
*,
user_id: str,
role_codes: list[str],
permission_codes: list[str],
expires_minutes: int | None = None,
) -> tuple[str, int]:
settings = get_settings()
now = datetime.now(timezone.utc)
minutes = expires_minutes or settings.access_token_expire_minutes
expires_at = now + timedelta(minutes=minutes)
payload = {
"sub": user_id,
"roles": role_codes,
"permissions": permission_codes,
"iat": int(now.timestamp()),
"exp": int(expires_at.timestamp()),
}
token = jwt.encode(
payload,
settings.jwt_secret_key,
algorithm=settings.jwt_algorithm,
)
return token, int((expires_at - now).total_seconds())
def decode_access_token(token: str) -> dict[str, Any]:
settings = get_settings()
try:
payload = jwt.decode(
token,
settings.jwt_secret_key,
algorithms=[settings.jwt_algorithm],
)
except jwt.PyJWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired access token",
) from exc
if not payload.get("sub"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid access token payload",
)
return payload
def create_refresh_token() -> str:
return secrets.token_urlsafe(48)
def hash_token(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
+14 -8
View File
@@ -1,13 +1,25 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .config import get_settings
from .api.router import api_router
from .core.config import get_settings
from .core.database import init_db
settings = get_settings()
@asynccontextmanager
async def lifespan(_: FastAPI):
init_db()
yield
app = FastAPI(
title=settings.api_name,
version=settings.api_version,
lifespan=lifespan,
)
app.add_middleware(
@@ -27,10 +39,4 @@ def health() -> dict[str, str]:
"version": settings.api_version,
}
@app.get("/api/v1/ping")
def ping() -> dict[str, str]:
return {
"message": "pong",
"service": settings.api_name,
}
app.include_router(api_router)
+1
View File
@@ -0,0 +1 @@
"""Database models."""
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..core.database import Base
from .base import utcnow
if TYPE_CHECKING:
from .user import User
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
index=True,
)
action: Mapped[str] = mapped_column(String(128), index=True)
detail: Mapped[str | None] = mapped_column(Text())
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
user: Mapped[User | None] = relationship("User", back_populates="audit_logs")
+37
View File
@@ -0,0 +1,37 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from uuid import uuid4
from sqlalchemy import DateTime, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..core.database import Base
from .base import utcnow
if TYPE_CHECKING:
from .user import User
class AuthSession(Base):
__tablename__ = "auth_sessions"
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=lambda: str(uuid4()),
)
user_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
index=True,
)
refresh_token_hash: Mapped[str] = mapped_column(String(128), unique=True, index=True)
user_agent: Mapped[str | None] = mapped_column(String(512))
ip_address: Mapped[str | None] = mapped_column(String(64))
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), index=True)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
user: Mapped[User] = relationship("User", back_populates="sessions")
+5
View File
@@ -0,0 +1,5 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
return datetime.now(timezone.utc)
+74
View File
@@ -0,0 +1,74 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..core.database import Base
if TYPE_CHECKING:
from .user import User
class Role(Base):
__tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
code: Mapped[str] = mapped_column(String(64), unique=True, index=True)
name: Mapped[str] = mapped_column(String(128))
users: Mapped[list[User]] = relationship(
"User",
secondary="user_roles",
back_populates="roles",
lazy="selectin",
)
permissions: Mapped[list[Permission]] = relationship(
"Permission",
secondary="role_permissions",
back_populates="roles",
lazy="selectin",
)
class Permission(Base):
__tablename__ = "permissions"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
code: Mapped[str] = mapped_column(String(64), unique=True, index=True)
name: Mapped[str] = mapped_column(String(128))
roles: Mapped[list[Role]] = relationship(
"Role",
secondary="role_permissions",
back_populates="permissions",
lazy="selectin",
)
class UserRole(Base):
__tablename__ = "user_roles"
user_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
)
role_id: Mapped[int] = mapped_column(
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
)
class RolePermission(Base):
__tablename__ = "role_permissions"
role_id: Mapped[int] = mapped_column(
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permissions.id", ondelete="CASCADE"),
primary_key=True,
)
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from uuid import uuid4
from sqlalchemy import DateTime, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..core.database import Base
from .base import utcnow
if TYPE_CHECKING:
from .auth_session import AuthSession
from .audit_log import AuditLog
from .rbac import Role
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=lambda: str(uuid4()),
)
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
username: Mapped[str] = mapped_column(String(64), unique=True, index=True)
password_hash: Mapped[str] = mapped_column(String(255))
status: Mapped[str] = mapped_column(String(32), default="active", index=True)
last_login_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utcnow,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utcnow,
onupdate=utcnow,
)
roles: Mapped[list[Role]] = relationship(
"Role",
secondary="user_roles",
back_populates="users",
lazy="selectin",
)
sessions: Mapped[list[AuthSession]] = relationship(
"AuthSession",
back_populates="user",
lazy="selectin",
)
audit_logs: Mapped[list[AuditLog]] = relationship(
"AuditLog",
back_populates="user",
lazy="selectin",
)
+1
View File
@@ -0,0 +1 @@
"""Pydantic schemas."""
+25
View File
@@ -0,0 +1,25 @@
from pydantic import BaseModel, EmailStr, Field
from .user import UserPublic
class RegisterRequest(BaseModel):
email: EmailStr
username: str = Field(min_length=3, max_length=64)
password: str = Field(min_length=8, max_length=128)
class LoginRequest(BaseModel):
email: EmailStr
password: str = Field(min_length=8, max_length=128)
class AuthTokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
expires_in: int
user: UserPublic
class MessageResponse(BaseModel):
message: str
+29
View File
@@ -0,0 +1,29 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, EmailStr, Field
class UserPublic(BaseModel):
id: str
email: EmailStr
username: str
status: str
role_codes: list[str]
permission_codes: list[str]
created_at: datetime
last_login_at: datetime | None = None
class UserListResponse(BaseModel):
items: list[UserPublic]
total: int
class UserUpdateRequest(BaseModel):
username: str | None = Field(default=None, min_length=3, max_length=64)
status: Literal["active", "disabled"] | None = None
class UserRoleUpdateRequest(BaseModel):
role_codes: list[str] = Field(min_length=1)
+1
View File
@@ -0,0 +1 @@
"""Business services."""
+238
View File
@@ -0,0 +1,238 @@
from dataclasses import dataclass
from datetime import timedelta
from fastapi import HTTPException, status
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from ..models.audit_log import AuditLog
from ..models.auth_session import AuthSession
from ..models.base import utcnow
from ..models.rbac import Role
from ..models.user import User
from ..schemas.auth import LoginRequest, RegisterRequest
from .user_service import get_user_by_email
from ..core.config import get_settings
from ..core.security import (
create_access_token,
create_refresh_token,
hash_password,
hash_token,
verify_password,
)
settings = get_settings()
@dataclass
class AuthResult:
access_token: str
expires_in: int
refresh_token: str
user: User
def register_user(
db: Session,
payload: RegisterRequest,
*,
user_agent: str | None,
ip_address: str | None,
) -> AuthResult:
email = payload.email.lower()
duplicate = db.scalar(
select(User.id).where(or_(User.email == email, User.username == payload.username))
)
if duplicate:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Email or username already exists",
)
role = db.scalar(select(Role).where(Role.code == "user"))
if not role:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Default role not initialized",
)
user = User(
email=email,
username=payload.username,
password_hash=hash_password(payload.password),
status="active",
)
user.roles.append(role)
db.add(user)
db.commit()
db.add(AuditLog(user_id=user.id, action="auth.register", detail="User registered"))
db.commit()
return issue_auth_result_for_user(
db,
user_id=user.id,
user_agent=user_agent,
ip_address=ip_address,
action="auth.login_after_register",
)
def login_user(
db: Session,
payload: LoginRequest,
*,
user_agent: str | None,
ip_address: str | None,
) -> AuthResult:
user = get_user_by_email(db, payload.email.lower())
if not user or not verify_password(payload.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
if user.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User is disabled",
)
return issue_auth_result_for_user(
db,
user_id=user.id,
user_agent=user_agent,
ip_address=ip_address,
action="auth.login",
)
def refresh_user_session(
db: Session,
refresh_token: str | None,
*,
user_agent: str | None,
ip_address: str | None,
) -> AuthResult:
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing refresh token",
)
now = utcnow()
token_hash = hash_token(refresh_token)
session = db.scalar(
select(AuthSession).where(
and_(
AuthSession.refresh_token_hash == token_hash,
AuthSession.revoked_at.is_(None),
AuthSession.expires_at > now,
)
)
)
if not session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh session",
)
session.revoked_at = now
db.add(AuditLog(user_id=session.user_id, action="auth.refresh", detail="Session rotated"))
db.commit()
return issue_auth_result_for_user(
db,
user_id=session.user_id,
user_agent=user_agent,
ip_address=ip_address,
action="auth.refresh_issued",
)
def logout_user_session(db: Session, refresh_token: str | None, *, user_id: str | None) -> None:
if not refresh_token:
return
token_hash = hash_token(refresh_token)
now = utcnow()
session = db.scalar(
select(AuthSession).where(
and_(
AuthSession.refresh_token_hash == token_hash,
AuthSession.revoked_at.is_(None),
)
)
)
if not session:
return
if user_id and session.user_id != user_id:
return
session.revoked_at = now
db.add(AuditLog(user_id=session.user_id, action="auth.logout", detail="Session revoked"))
db.commit()
def issue_auth_result_for_user(
db: Session,
*,
user_id: str,
user_agent: str | None,
ip_address: str | None,
action: str,
) -> AuthResult:
user = get_user_by_id_with_rbac(db, user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
if user.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User is disabled",
)
refresh_token = create_refresh_token()
refresh_expires_at = utcnow() + timedelta(days=settings.refresh_token_expire_days)
db.add(
AuthSession(
user_id=user.id,
refresh_token_hash=hash_token(refresh_token),
user_agent=user_agent,
ip_address=ip_address,
expires_at=refresh_expires_at,
)
)
user.last_login_at = utcnow()
db.add(AuditLog(user_id=user.id, action=action, detail="Access token issued"))
db.commit()
user = get_user_by_id_with_rbac(db, user_id)
role_codes = sorted({role.code for role in user.roles})
permission_codes = sorted(
{permission.code for role in user.roles for permission in role.permissions}
)
access_token, expires_in = create_access_token(
user_id=user.id,
role_codes=role_codes,
permission_codes=permission_codes,
)
return AuthResult(
access_token=access_token,
expires_in=expires_in,
refresh_token=refresh_token,
user=user,
)
def get_user_by_id_with_rbac(db: Session, user_id: str) -> User | None:
from .user_service import get_user_by_id
return get_user_by_id(db, user_id)
+90
View File
@@ -0,0 +1,90 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from ..core.config import get_settings
from ..core.security import hash_password
from ..models.rbac import Permission, Role
from ..models.user import User
settings = get_settings()
DEFAULT_PERMISSIONS: dict[str, str] = {
"user.read": "Read user profile",
"user.write": "Update user profile",
"user.manage": "Manage all users and roles",
}
DEFAULT_ROLES: dict[str, dict[str, object]] = {
"admin": {
"name": "Administrator",
"permissions": ["user.read", "user.write", "user.manage"],
},
"user": {
"name": "User",
"permissions": ["user.read"],
},
}
def seed_defaults(db: Session) -> None:
permissions = _seed_permissions(db)
_seed_roles(db, permissions)
_seed_initial_admin(db)
db.commit()
def _seed_permissions(db: Session) -> dict[str, Permission]:
permission_map: dict[str, Permission] = {}
for code, name in DEFAULT_PERMISSIONS.items():
permission = db.scalar(select(Permission).where(Permission.code == code))
if not permission:
permission = Permission(code=code, name=name)
db.add(permission)
permission_map[code] = permission
db.flush()
# Refresh map with persisted entities.
for code in DEFAULT_PERMISSIONS:
permission = db.scalar(select(Permission).where(Permission.code == code))
if not permission:
msg = f"Permission not found after seeding: {code}"
raise RuntimeError(msg)
permission_map[code] = permission
return permission_map
def _seed_roles(db: Session, permission_map: dict[str, Permission]) -> None:
for code, role_info in DEFAULT_ROLES.items():
role = db.scalar(select(Role).where(Role.code == code))
if not role:
role = Role(code=code, name=str(role_info["name"]))
db.add(role)
db.flush()
role.permissions = [permission_map[p] for p in role_info["permissions"]]
db.flush()
def _seed_initial_admin(db: Session) -> None:
if not settings.initial_admin_email or not settings.initial_admin_password:
return
admin_role = db.scalar(select(Role).where(Role.code == "admin"))
if not admin_role:
return
admin_email = settings.initial_admin_email.lower()
user = db.scalar(select(User).where(User.email == admin_email))
if not user:
user = User(
email=admin_email,
username=settings.initial_admin_username,
password_hash=hash_password(settings.initial_admin_password),
status="active",
)
db.add(user)
db.flush()
role_codes = {role.code for role in user.roles}
if "admin" not in role_codes:
user.roles.append(admin_role)
+98
View File
@@ -0,0 +1,98 @@
from sqlalchemy import func, select
from sqlalchemy.orm import Session, joinedload
from ..models.rbac import Role
from ..models.user import User
from ..schemas.user import UserListResponse, UserPublic, UserRoleUpdateRequest, UserUpdateRequest
def _user_with_rbac_stmt():
return select(User).options(joinedload(User.roles).joinedload(Role.permissions))
def list_users(db: Session, *, limit: int, offset: int) -> UserListResponse:
total = db.scalar(select(func.count()).select_from(User)) or 0
stmt = (
_user_with_rbac_stmt()
.order_by(User.created_at.desc())
.offset(offset)
.limit(limit)
)
users = db.execute(stmt).unique().scalars().all()
return UserListResponse(items=[serialize_user(user) for user in users], total=total)
def get_user_by_id(db: Session, user_id: str) -> User | None:
stmt = _user_with_rbac_stmt().where(User.id == user_id)
return db.execute(stmt).unique().scalar_one_or_none()
def get_user_by_email(db: Session, email: str) -> User | None:
stmt = _user_with_rbac_stmt().where(User.email == email)
return db.execute(stmt).unique().scalar_one_or_none()
def update_user(
db: Session,
user_id: str,
payload: UserUpdateRequest,
) -> UserPublic | None:
user = get_user_by_id(db, user_id)
if not user:
return None
if payload.username and payload.username != user.username:
duplicate = db.scalar(
select(User.id).where(User.username == payload.username, User.id != user.id)
)
if duplicate:
return None
user.username = payload.username
if payload.status:
user.status = payload.status
db.commit()
updated = get_user_by_id(db, user_id)
return serialize_user(updated) if updated else None
def set_user_roles(
db: Session,
user_id: str,
payload: UserRoleUpdateRequest,
) -> UserPublic | None:
user = get_user_by_id(db, user_id)
if not user:
return None
role_codes = sorted(set(payload.role_codes))
roles = db.execute(select(Role).where(Role.code.in_(role_codes))).scalars().all()
if len(roles) != len(role_codes):
return None
user.roles = roles
db.commit()
updated = get_user_by_id(db, user_id)
return serialize_user(updated)
def serialize_user(user: User | None) -> UserPublic:
if user is None:
msg = "User is required"
raise ValueError(msg)
role_codes = sorted({role.code for role in user.roles})
permission_codes = sorted(
{permission.code for role in user.roles for permission in role.permissions}
)
return UserPublic(
id=user.id,
email=user.email,
username=user.username,
status=user.status,
role_codes=role_codes,
permission_codes=permission_codes,
created_at=user.created_at,
last_login_at=user.last_login_at,
)