diff --git a/MEMORY.md b/MEMORY.md index 9b71d0f..e3bb489 100644 --- a/MEMORY.md +++ b/MEMORY.md @@ -128,7 +128,7 @@ - 聊天模型选择规则固定为:`CAPABILITY: chat.default` 优先,未命中时回退 `GLOBAL: __global__`。 - 仅允许 `ENABLED` 且具备激活密钥记录的模型参与路由命中;若不满足,接口返回 400。 - 运行时真实 Provider Key 不从数据库反解,统一从环境变量 `LLM_PROVIDER_API_KEYS` 注入(支持 `openai=sk-...` 或 JSON 字典字符串)。 -- 一期模型调用采用非流式 OpenAI-compatible `POST /chat/completions`,后续如需流式再扩展 SSE/WS。 +- AI 问答发送接口 `POST /api/v1/ai-chat/conversations/{id}/messages` 已改为 `application/x-ndjson` 流式响应,前端通过 `fetch` + `ReadableStream` 读取增量并实时刷新助手消息;如需旧同步行为,可用 `/messages/sync` 兼容接口。 ## Celery 监控口径(2026-05-01) diff --git a/api/app/api/v1/ai_chat.py b/api/app/api/v1/ai_chat.py index 45fddd7..6cb0bcf 100644 --- a/api/app/api/v1/ai_chat.py +++ b/api/app/api/v1/ai_chat.py @@ -1,4 +1,8 @@ +import json +from collections.abc import Iterator + from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from ...core.database import get_db @@ -19,6 +23,7 @@ from ...services.ai_chat_service import ( list_conversations, send_message, serialize_conversation_detail, + stream_message, update_conversation, ) @@ -85,12 +90,37 @@ def delete_conversation_endpoint( return {"success": True} -@router.post("/conversations/{conversation_id}/messages", response_model=AiChatMessageResponse) +def _encode_stream_events(events: Iterator[dict]) -> Iterator[str]: + for event in events: + yield json.dumps(event, ensure_ascii=False, default=str) + "\n" + + +@router.post("/conversations/{conversation_id}/messages") def send_message_endpoint( conversation_id: int, payload: AiChatMessageSendRequest, current_user: CurrentUser = Depends(get_current_user), db: Session = Depends(get_db), +) -> StreamingResponse: + events = stream_message(db, conversation_id, payload.content, user_id=current_user.user.id) + if events is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found") + return StreamingResponse( + _encode_stream_events(events), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + + +@router.post("/conversations/{conversation_id}/messages/sync", response_model=AiChatMessageResponse) +def send_message_sync_endpoint( + conversation_id: int, + payload: AiChatMessageSendRequest, + current_user: CurrentUser = Depends(get_current_user), + db: Session = Depends(get_db), ) -> AiChatMessageResponse: result = send_message(db, conversation_id, payload.content, user_id=current_user.user.id) if not result: diff --git a/api/app/services/ai_chat_service.py b/api/app/services/ai_chat_service.py index 985f555..372fd7c 100644 --- a/api/app/services/ai_chat_service.py +++ b/api/app/services/ai_chat_service.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Generator, Iterator import json import os from typing import Any @@ -258,6 +259,61 @@ def send_message( return serialize_message(user_message), serialize_message(assistant_message) +def stream_message( + db: Session, + conversation_id: int, + content: str, + *, + user_id: str, +) -> Iterator[dict[str, Any]] | None: + conv = get_conversation_by_id(db, conversation_id, user_id) + if not conv: + return None + + user_message = AiChatMessage( + conversation_id=conversation_id, + role="user", + content=content.strip(), + ) + db.add(user_message) + db.commit() + db.refresh(user_message) + + history = sorted(conv.messages, key=lambda m: m.created_at) + history.append(user_message) + + def event_stream() -> Iterator[dict[str, Any]]: + yield {"type": "message", "message": serialize_message(user_message).model_dump(mode="json")} + + try: + reply_content = yield from _stream_assistant_reply(db, conversation_id, history) + assistant_message = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content=reply_content, + ) + db.add(assistant_message) + db.commit() + db.refresh(assistant_message) + + yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")} + except Exception as e: + reply_content = f"抱歉,AI服务暂时不可用:{str(e)}" + assistant_message = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content=reply_content, + ) + db.add(assistant_message) + db.commit() + db.refresh(assistant_message) + + yield {"type": "delta", "content": reply_content} + yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")} + + return event_stream() + + def _get_function_definitions() -> list[dict[str, Any]]: """Define available functions that the AI can call.""" return [ @@ -479,7 +535,12 @@ def _execute_function(db: Session, function_name: str, arguments: dict[str, Any] return f"执行函数 {function_name} 时出错: {str(e)}" -def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: +def _build_openai_chat_request( + db: Session, + history: list[AiChatMessage], + *, + stream: bool = False, +) -> tuple[str, dict[str, str], dict[str, Any]]: api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key") model_param = get_system_param_by_key(db, "ai_chat.model") base_url_param = get_system_param_by_key(db, "ai_chat.base_url") @@ -513,10 +574,18 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: "messages": messages, "tools": _get_function_definitions(), } + if stream: + payload["stream"] = True + + return f"{base_url}/chat/completions", headers, payload + + +def _call_openai_api(db: Session, history: list[AiChatMessage]) -> dict[str, Any]: + url, headers, payload = _build_openai_chat_request(db, history) with httpx.Client(timeout=60.0) as client: response = client.post( - f"{base_url}/chat/completions", + url, headers=headers, json=payload, ) @@ -524,3 +593,148 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: data = response.json() return data + + +def _stream_openai_api(db: Session, history: list[AiChatMessage]) -> Iterator[dict[str, Any]]: + url, headers, payload = _build_openai_chat_request(db, history, stream=True) + + with httpx.Client(timeout=60.0) as client: + with client.stream("POST", url, headers=headers, json=payload) as response: + response.raise_for_status() + for raw_line in response.iter_lines(): + if not raw_line: + continue + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + line = line.strip() + if line.startswith(":"): + continue + if line.startswith("data:"): + line = line.removeprefix("data:").strip() + if not line or line == "[DONE]": + continue + yield json.loads(line) + + +def _stream_assistant_reply( + db: Session, + conversation_id: int, + history: list[AiChatMessage], +) -> Generator[dict[str, Any], None, str]: + reply_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + finish_reason = yield from _consume_completion_stream(db, history, reply_parts, tool_calls) + normalized_tool_calls = _normalize_tool_calls(tool_calls) + if finish_reason == "tool_calls" and normalized_tool_calls: + assistant_message_with_tool_calls = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content="".join(reply_parts), + tool_calls=normalized_tool_calls, + ) + db.add(assistant_message_with_tool_calls) + db.commit() + db.refresh(assistant_message_with_tool_calls) + history.append(assistant_message_with_tool_calls) + + for tool_call in normalized_tool_calls: + function_name = tool_call["function"]["name"] + function_args_str = tool_call["function"].get("arguments") or "{}" + tool_call_id = tool_call["id"] + + try: + function_args = json.loads(function_args_str) + except json.JSONDecodeError: + function_args = {} + + function_result = _execute_function(db, function_name, function_args) + tool_message = AiChatMessage( + conversation_id=conversation_id, + role="tool", + content=function_result, + tool_call_id=tool_call_id, + ) + db.add(tool_message) + db.commit() + db.refresh(tool_message) + history.append(tool_message) + + reply_parts = [] + yield from _consume_completion_stream(db, history, reply_parts, []) + + return "".join(reply_parts) + + +def _consume_completion_stream( + db: Session, + history: list[AiChatMessage], + reply_parts: list[str], + tool_calls: list[dict[str, Any]], +) -> Generator[dict[str, Any], None, str | None]: + finish_reason: str | None = None + for chunk in _stream_openai_api(db, history): + choices = chunk.get("choices") or [] + if not choices: + continue + + choice = choices[0] + finish_reason = choice.get("finish_reason") or finish_reason + delta = choice.get("delta") or {} + + content_delta = delta.get("content") + if content_delta: + reply_parts.append(content_delta) + yield {"type": "delta", "content": content_delta} + + for tool_call_delta in delta.get("tool_calls") or []: + _merge_tool_call_delta(tool_calls, tool_call_delta) + + return finish_reason + + +def _merge_tool_call_delta(tool_calls: list[dict[str, Any]], delta: dict[str, Any]) -> None: + index = delta.get("index") + if not isinstance(index, int): + index = len(tool_calls) + + while len(tool_calls) <= index: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + + tool_call = tool_calls[index] + if delta.get("id"): + tool_call["id"] = delta["id"] + if delta.get("type"): + tool_call["type"] = delta["type"] + + function_delta = delta.get("function") or {} + function = tool_call.setdefault("function", {"name": "", "arguments": ""}) + if function_delta.get("name"): + function["name"] = f"{function.get('name', '')}{function_delta['name']}" + if function_delta.get("arguments"): + function["arguments"] = f"{function.get('arguments', '')}{function_delta['arguments']}" + + +def _normalize_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for index, tool_call in enumerate(tool_calls): + function = tool_call.get("function") or {} + function_name = function.get("name") + if not function_name: + continue + normalized.append( + { + "id": tool_call.get("id") or f"tool_call_{index}", + "type": tool_call.get("type") or "function", + "function": { + "name": function_name, + "arguments": function.get("arguments") or "{}", + }, + } + ) + return normalized diff --git a/api/tests/test_ai_chat_stream_service.py b/api/tests/test_ai_chat_stream_service.py new file mode 100644 index 0000000..37eb152 --- /dev/null +++ b/api/tests/test_ai_chat_stream_service.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import os +import unittest +from typing import Any +from unittest.mock import patch + +os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:") + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from api.app import models # noqa: F401 +from api.app.core.database import Base +from api.app.models.ai_chat import AiChatConversation, AiChatMessage +from api.app.models.user import User +from api.app.services.ai_chat_service import stream_message + + +class AiChatStreamServiceTest(unittest.TestCase): + def setUp(self) -> None: + self.engine = create_engine( + "sqlite+pysqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + self.SessionLocal = sessionmaker( + bind=self.engine, + autocommit=False, + autoflush=False, + expire_on_commit=False, + ) + Base.metadata.create_all(bind=self.engine) + self.session = self.SessionLocal() + self.session.add( + User( + id="user-1", + email="user@example.com", + username="tester", + password_hash="hash", + status="active", + ) + ) + self.session.add(AiChatConversation(id=1, title="测试对话", user_id="user-1")) + self.session.commit() + + def tearDown(self) -> None: + self.session.close() + Base.metadata.drop_all(bind=self.engine) + self.engine.dispose() + + def test_stream_message_yields_delta_and_persists_messages(self) -> None: + def fake_stream(_: Any, history: list[AiChatMessage]) -> list[dict[str, Any]]: + self.assertEqual(history[-1].content, "你好") + return [ + {"choices": [{"delta": {"content": "你"}, "finish_reason": None}]}, + {"choices": [{"delta": {"content": "好"}, "finish_reason": "stop"}]}, + ] + + with patch("api.app.services.ai_chat_service._stream_openai_api", side_effect=fake_stream): + events = stream_message(self.session, 1, "你好", user_id="user-1") + self.assertIsNotNone(events) + event_list = list(events or []) + + self.assertEqual([event["type"] for event in event_list], ["message", "delta", "delta", "done"]) + self.assertEqual(event_list[1]["content"], "你") + self.assertEqual(event_list[2]["content"], "好") + self.assertEqual(event_list[3]["reply"]["content"], "你好") + + stored_messages = ( + self.session.query(AiChatMessage) + .filter(AiChatMessage.conversation_id == 1) + .order_by(AiChatMessage.id.asc()) + .all() + ) + self.assertEqual([message.role for message in stored_messages], ["user", "assistant"]) + self.assertEqual([message.content for message in stored_messages], ["你好", "你好"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/memory/2026-06-24.md b/memory/2026-06-24.md new file mode 100644 index 0000000..8817076 --- /dev/null +++ b/memory/2026-06-24.md @@ -0,0 +1,22 @@ +# Work Log - AI问答流式响应(FL-184) + +- 背景: + - AI 问答原接口等待 OpenAI-compatible `/chat/completions` 完整返回后再落库并返回 JSON,前端发送后只能等待整段回复。 + +- 本次处理: + - `/api/v1/ai-chat/conversations/{id}/messages` 改为 `application/x-ndjson` 流式响应,依次输出用户消息、回复增量和最终助手消息。 + - 保留 `/api/v1/ai-chat/conversations/{id}/messages/sync` 同步接口,便于旧调用方或排障回退。 + - 前端 AI 问答页面改为读取 `ReadableStream`,发送后立即追加用户消息和临时助手消息,并随增量实时刷新内容。 + - 补充 AI 聊天流式服务单测,覆盖增量事件输出和最终消息落库。 + +- 验证: + - 基线:`npm --workspace web exec tsc --noEmit --pretty false` 通过。 + - 基线:`python3 -m pytest api/tests/test_system_param_service.py` 因系统 Python 3.7 不满足项目 Python >=3.10 要求失败;后改用 uv 创建 Python 3.10 venv。 + - 修改后:`python3 -m py_compile api/app/services/ai_chat_service.py api/app/api/v1/ai_chat.py api/tests/test_ai_chat_stream_service.py` 通过。 + - 修改后:`npm --workspace web exec tsc --noEmit --pretty false` 通过。 + - 修改后:`npm --workspace web exec eslint src/app/admin/ai-chat/page.tsx src/types/ai-chat.ts` 通过。 + - 修改后:`PYTHONPATH=. api/.venv/bin/pytest api/tests/test_ai_chat_stream_service.py api/tests/test_system_param_service.py` 通过,3 passed;存在既有 SQLAlchemy relationship warning。 + +- 风险与关注点: + - 主消息发送接口响应格式由 JSON 改为 NDJSON 流;当前前端已同步适配,后端保留 sync 兼容接口。 + - 流式工具调用会先聚合工具调用参数并执行,再继续流式输出工具后的最终回复。 diff --git a/web/src/app/admin/ai-chat/page.tsx b/web/src/app/admin/ai-chat/page.tsx index 92bdfea..9029f96 100644 --- a/web/src/app/admin/ai-chat/page.tsx +++ b/web/src/app/admin/ai-chat/page.tsx @@ -23,7 +23,7 @@ import { SendOutlined, DeleteOutlined, } from "@ant-design/icons"; -import { useCallback, useEffect, useRef, useState, type ComponentType, type RefAttributes } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState, type ComponentType, type RefAttributes } from "react"; import { useAuth } from "@/components/auth-provider"; import { useToastFeedback } from "@/hooks/use-toast-feedback"; @@ -31,13 +31,18 @@ import { readApiError, API_BASE_URL } from "@/lib/api"; import type { AiChatConversation, AiChatConversationListResponse, - AiChatMessageResponse, + AiChatMessage, } from "@/types/ai-chat"; const { TextArea } = Input; const { Text } = Typography; const AntCard = Card as unknown as ComponentType>; +type ChatStreamEvent = + | { type: "message"; message: AiChatMessage } + | { type: "delta"; content: string } + | { type: "done"; reply: AiChatMessage }; + export default function AiChatPage() { const { user, initializing, fetchWithAuth } = useAuth(); const queryClient = useQueryClient(); @@ -48,6 +53,8 @@ export default function AiChatPage() { const [newConvTitle, setNewConvTitle] = useState("新对话"); const [error, setError] = useState(""); const [success, setSuccess] = useState(""); + const [streamingMessageId, setStreamingMessageId] = useState(null); + const [optimisticMessages, setOptimisticMessages] = useState>({}); const messagesEndRef = useRef(null); useToastFeedback({ @@ -134,34 +141,138 @@ export default function AiChatPage() { }, }); - const sendMessageMutation = useMutation({ - mutationFn: async ({ - convId, - content, - }: { - convId: number; - content: string; - }) => { + const currentMessages = useMemo( + () => (selectedConvId ? optimisticMessages[selectedConvId] ?? currentConv?.messages ?? [] : []), + [currentConv?.messages, optimisticMessages, selectedConvId], + ); + + const updateOptimisticMessages = useCallback( + (convId: number, updater: (messages: AiChatMessage[]) => AiChatMessage[]) => { + setOptimisticMessages((current) => ({ + ...current, + [convId]: updater(current[convId] ?? currentConv?.messages ?? []), + })); + }, + [currentConv?.messages], + ); + + const readChatStream = useCallback( + async (convId: number, content: string) => { + const now = new Date().toISOString(); + const assistantTempId = -Date.now(); + const userTempId = assistantTempId - 1; + const userMessage: AiChatMessage = { + id: userTempId, + conversation_id: convId, + role: "user", + content, + created_at: now, + }; + const assistantMessage: AiChatMessage = { + id: assistantTempId, + conversation_id: convId, + role: "assistant", + content: "", + created_at: now, + }; + + setStreamingMessageId(assistantTempId); + updateOptimisticMessages(convId, (messages) => [...messages, userMessage, assistantMessage]); + const response = await fetchWithAuth( `${API_BASE_URL}/api/v1/ai-chat/conversations/${convId}/messages`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ content }), - } + }, ); if (!response.ok) { throw new Error(await readApiError(response)); } - return (await response.json()) as AiChatMessageResponse; + if (!response.body) { + throw new Error("浏览器未返回流式响应体"); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + const applyStreamEvent = (event: ChatStreamEvent) => { + if (event.type === "message") { + updateOptimisticMessages(convId, (messages) => + messages.map((msg) => (msg.id === userTempId ? event.message : msg)), + ); + return; + } + + if (event.type === "delta") { + updateOptimisticMessages(convId, (messages) => + messages.map((msg) => + msg.id === assistantTempId + ? { ...msg, content: `${msg.content}${event.content}` } + : msg, + ), + ); + return; + } + + updateOptimisticMessages(convId, (messages) => + messages.map((msg) => (msg.id === assistantTempId ? event.reply : msg)), + ); + }; + + const parseBufferedEvents = () => { + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed) continue; + applyStreamEvent(JSON.parse(trimmed) as ChatStreamEvent); + } + }; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + parseBufferedEvents(); + } + + buffer += decoder.decode(); + if (buffer.trim()) { + applyStreamEvent(JSON.parse(buffer.trim()) as ChatStreamEvent); + } + + await queryClient.invalidateQueries({ queryKey: ["ai-chat-conversations"] }); + await queryClient.invalidateQueries({ queryKey: ["ai-chat-conversation", convId] }); }, - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: ["ai-chat-conversation", selectedConvId], - }); + [fetchWithAuth, queryClient, updateOptimisticMessages], + ); + + const sendMessageMutation = useMutation({ + mutationFn: async ({ convId, content }: { convId: number; content: string }) => { + await readChatStream(convId, content); + }, + onSuccess: (_, variables) => { setMessageInput(""); + setStreamingMessageId(null); + setOptimisticMessages((current) => { + const next = { ...current }; + delete next[variables.convId]; + return next; + }); }, - onError: (err: Error) => { + onError: (err: Error, variables) => { + setStreamingMessageId(null); + setOptimisticMessages((current) => { + const next = { ...current }; + delete next[variables.convId]; + return next; + }); + queryClient.invalidateQueries({ + queryKey: ["ai-chat-conversation", variables.convId], + }); setError(`发送消息失败: ${err.message}`); }, }); @@ -176,7 +287,7 @@ export default function AiChatPage() { useEffect(() => { messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); - }, [currentConv?.messages]); + }, [currentMessages]); if (initializing) { return ( @@ -324,7 +435,7 @@ export default function AiChatPage() { ) : ( - {currentConv?.messages?.map((msg) => ( + {currentMessages.map((msg) => (
- {msg.content} + {msg.content || (msg.id === streamingMessageId ? "正在回复..." : "")}
| null; + tool_call_id?: string | null; created_at: string; }