Files
fquiz/api/app/services/ws_manager.py
T
2026-04-12 16:00:20 +08:00

220 lines
7.9 KiB
Python

from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from uuid import uuid4
from fastapi import WebSocket
from ..schemas.ws import WsEventEnvelope
from .topic_registry import get_auto_topics, validate_topic_subscription
@dataclass
class WsConnection:
websocket: WebSocket
connection_id: str
user_id: str
role_codes: set[str]
permission_codes: set[str]
subscribed_topics: set[str] = field(default_factory=set)
class WsConnectionManager:
def __init__(self) -> None:
self._connections: dict[str, WsConnection] = {}
self._topic_connections: dict[str, set[str]] = {}
self._lock = asyncio.Lock()
async def register(
self,
websocket: WebSocket,
*,
user_id: str,
role_codes: set[str],
permission_codes: set[str],
) -> WsConnection:
connection = WsConnection(
websocket=websocket,
connection_id=str(uuid4()),
user_id=user_id,
role_codes=set(role_codes),
permission_codes=set(permission_codes),
subscribed_topics=set(get_auto_topics()),
)
async with self._lock:
self._connections[connection.connection_id] = connection
for topic in connection.subscribed_topics:
self._topic_connections.setdefault(topic, set()).add(connection.connection_id)
return connection
async def unregister(self, connection_id: str) -> None:
async with self._lock:
self._remove_connection_locked(connection_id)
async def subscribe(self, connection_id: str, topics: list[str]) -> list[str]:
accepted: list[str] = []
async with self._lock:
connection = self._connections.get(connection_id)
if not connection:
return accepted
for topic in topics:
if topic in connection.subscribed_topics:
continue
connection.subscribed_topics.add(topic)
self._topic_connections.setdefault(topic, set()).add(connection_id)
accepted.append(topic)
return accepted
async def unsubscribe(self, connection_id: str, topics: list[str]) -> list[str]:
removed: list[str] = []
async with self._lock:
connection = self._connections.get(connection_id)
if not connection:
return removed
for topic in topics:
if topic not in connection.subscribed_topics or topic in get_auto_topics():
continue
connection.subscribed_topics.discard(topic)
subscribers = self._topic_connections.get(topic)
if subscribers:
subscribers.discard(connection_id)
if not subscribers:
self._topic_connections.pop(topic, None)
removed.append(topic)
return removed
async def refresh_user_authorization(
self,
user_id: str,
*,
role_codes: set[str],
permission_codes: set[str],
) -> None:
notifications: list[tuple[WsConnection, list[str]]] = []
async with self._lock:
connections = [
connection
for connection in self._connections.values()
if connection.user_id == user_id
]
for connection in connections:
connection.role_codes = set(role_codes)
connection.permission_codes = set(permission_codes)
removed_topics: list[str] = []
for topic in list(connection.subscribed_topics):
if topic in get_auto_topics():
continue
is_allowed, _ = validate_topic_subscription(
topic,
role_codes=connection.role_codes,
permission_codes=connection.permission_codes,
)
if is_allowed:
continue
connection.subscribed_topics.discard(topic)
subscribers = self._topic_connections.get(topic)
if subscribers:
subscribers.discard(connection.connection_id)
if not subscribers:
self._topic_connections.pop(topic, None)
removed_topics.append(topic)
if removed_topics:
notifications.append((connection, sorted(removed_topics)))
stale_ids: list[str] = []
for connection, removed_topics in notifications:
try:
await connection.websocket.send_json(
{
"type": "unsubscribed",
"topics": removed_topics,
"reason": "permission_changed",
}
)
except Exception:
stale_ids.append(connection.connection_id)
for connection_id in stale_ids:
await self.unregister(connection_id)
async def disconnect_user(
self,
user_id: str,
*,
code: int = 4403,
reason: str = "user_not_allowed",
) -> int:
connections_to_close: list[WsConnection] = []
async with self._lock:
for connection_id, connection in list(self._connections.items()):
if connection.user_id != user_id:
continue
connections_to_close.append(connection)
self._remove_connection_locked(connection_id)
for connection in connections_to_close:
try:
await connection.websocket.close(code=code, reason=reason)
except Exception:
continue
return len(connections_to_close)
async def publish(self, topic: str, event: WsEventEnvelope) -> None:
async with self._lock:
connection_ids = list(self._topic_connections.get(topic, set()))
connections = [
self._connections[connection_id]
for connection_id in connection_ids
if connection_id in self._connections
]
if not connections:
return
payload = {"type": "event", "event": event.model_dump(mode="json")}
stale_ids: list[str] = []
for connection in connections:
try:
await connection.websocket.send_json(payload)
except Exception:
stale_ids.append(connection.connection_id)
for connection_id in stale_ids:
await self.unregister(connection_id)
async def publish_to_user(self, user_id: str, event: WsEventEnvelope) -> None:
async with self._lock:
connections = [
connection
for connection in self._connections.values()
if connection.user_id == user_id
]
if not connections:
return
payload = {"type": "event", "event": event.model_dump(mode="json")}
stale_ids: list[str] = []
for connection in connections:
try:
await connection.websocket.send_json(payload)
except Exception:
stale_ids.append(connection.connection_id)
for connection_id in stale_ids:
await self.unregister(connection_id)
def _remove_connection_locked(self, connection_id: str) -> None:
connection = self._connections.pop(connection_id, None)
if not connection:
return
for topic in list(connection.subscribed_topics):
subscribers = self._topic_connections.get(topic)
if not subscribers:
continue
subscribers.discard(connection_id)
if not subscribers:
self._topic_connections.pop(topic, None)
ws_connection_manager = WsConnectionManager()