from datetime import datetime from typing import TYPE_CHECKING from sqlalchemy import select, update from sqlalchemy.orm import Session from ..models.system_message import SystemMessage from ..schemas.system_message import SystemMessageCreateRequest if TYPE_CHECKING: from ..schemas.system_message import SystemMessagePublic def create_system_message( db: Session, payload: SystemMessageCreateRequest, ) -> SystemMessage: """创建系统消息""" message = SystemMessage( title=payload.title, content=payload.content, message_type=payload.message_type, target_user_id=payload.target_user_id, ) db.add(message) db.commit() db.refresh(message) return message def list_user_messages( db: Session, user_id: str, limit: int = 50, offset: int = 0, unread_only: bool = False, ) -> tuple[list[SystemMessage], int, int]: """获取用户的系统消息列表""" # 构建查询条件:全员广播或者特定用户 query = select(SystemMessage).where( (SystemMessage.target_user_id == user_id) | (SystemMessage.target_user_id.is_(None)) ) if unread_only: query = query.where(SystemMessage.is_read == False) # 获取总数 total_query = select(SystemMessage).where( (SystemMessage.target_user_id == user_id) | (SystemMessage.target_user_id.is_(None)) ) total = db.scalar(select(len(total_query.subquery().c.message_id))) # 获取未读数 unread_query = select(SystemMessage).where( (SystemMessage.target_user_id == user_id) | (SystemMessage.target_user_id.is_(None)), SystemMessage.is_read == False, ) unread_count = db.scalar(select(len(unread_query.subquery().c.message_id))) # 按创建时间倒序 query = query.order_by(SystemMessage.created_at.desc()) query = query.limit(limit).offset(offset) messages = list(db.scalars(query).all()) return messages, total or 0, unread_count or 0 def mark_messages_as_read( db: Session, user_id: str, message_ids: list[str], ) -> int: """标记消息为已读""" stmt = ( update(SystemMessage) .where( SystemMessage.id.in_(message_ids), (SystemMessage.target_user_id == user_id) | (SystemMessage.target_user_id.is_(None)), SystemMessage.is_read == False, ) .values(is_read=True, read_at=datetime.utcnow()) ) result = db.execute(stmt) db.commit() return result.rowcount or 0 def get_unread_count(db: Session, user_id: str) -> int: """获取用户未读消息数量""" query = select(SystemMessage).where( (SystemMessage.target_user_id == user_id) | (SystemMessage.target_user_id.is_(None)), SystemMessage.is_read == False, ) count = db.scalar(select(len(query.subquery().c.message_id))) return count or 0