feat:[FL-184][AI问答改成流式响应]

Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
chengkai3
2026-06-24 13:48:33 +08:00
parent ae8a2cb9b6
commit 5bf92a9ded
7 changed files with 486 additions and 25 deletions
+1 -1
View File
@@ -128,7 +128,7 @@
- 聊天模型选择规则固定为:`CAPABILITY: chat.default` 优先,未命中时回退 `GLOBAL: __global__` - 聊天模型选择规则固定为:`CAPABILITY: chat.default` 优先,未命中时回退 `GLOBAL: __global__`
- 仅允许 `ENABLED` 且具备激活密钥记录的模型参与路由命中;若不满足,接口返回 400。 - 仅允许 `ENABLED` 且具备激活密钥记录的模型参与路由命中;若不满足,接口返回 400。
- 运行时真实 Provider Key 不从数据库反解,统一从环境变量 `LLM_PROVIDER_API_KEYS` 注入(支持 `openai=sk-...` 或 JSON 字典字符串)。 - 运行时真实 Provider Key 不从数据库反解,统一从环境变量 `LLM_PROVIDER_API_KEYS` 注入(支持 `openai=sk-...` 或 JSON 字典字符串)。
- 一期模型调用采用非流式 OpenAI-compatible `POST /chat/completions`,后续如需流式再扩展 SSE/WS - AI 问答发送接口 `POST /api/v1/ai-chat/conversations/{id}/messages` 已改为 `application/x-ndjson` 流式响应,前端通过 `fetch` + `ReadableStream` 读取增量并实时刷新助手消息;如需旧同步行为,可用 `/messages/sync` 兼容接口
## Celery 监控口径(2026-05-01 ## Celery 监控口径(2026-05-01
+31 -1
View File
@@ -1,4 +1,8 @@
import json
from collections.abc import Iterator
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ...core.database import get_db from ...core.database import get_db
@@ -19,6 +23,7 @@ from ...services.ai_chat_service import (
list_conversations, list_conversations,
send_message, send_message,
serialize_conversation_detail, serialize_conversation_detail,
stream_message,
update_conversation, update_conversation,
) )
@@ -85,12 +90,37 @@ def delete_conversation_endpoint(
return {"success": True} return {"success": True}
@router.post("/conversations/{conversation_id}/messages", response_model=AiChatMessageResponse) def _encode_stream_events(events: Iterator[dict]) -> Iterator[str]:
for event in events:
yield json.dumps(event, ensure_ascii=False, default=str) + "\n"
@router.post("/conversations/{conversation_id}/messages")
def send_message_endpoint( def send_message_endpoint(
conversation_id: int, conversation_id: int,
payload: AiChatMessageSendRequest, payload: AiChatMessageSendRequest,
current_user: CurrentUser = Depends(get_current_user), current_user: CurrentUser = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> StreamingResponse:
events = stream_message(db, conversation_id, payload.content, user_id=current_user.user.id)
if events is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
return StreamingResponse(
_encode_stream_events(events),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
@router.post("/conversations/{conversation_id}/messages/sync", response_model=AiChatMessageResponse)
def send_message_sync_endpoint(
conversation_id: int,
payload: AiChatMessageSendRequest,
current_user: CurrentUser = Depends(get_current_user),
db: Session = Depends(get_db),
) -> AiChatMessageResponse: ) -> AiChatMessageResponse:
result = send_message(db, conversation_id, payload.content, user_id=current_user.user.id) result = send_message(db, conversation_id, payload.content, user_id=current_user.user.id)
if not result: if not result:
+216 -2
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator, Iterator
import json import json
import os import os
from typing import Any from typing import Any
@@ -258,6 +259,61 @@ def send_message(
return serialize_message(user_message), serialize_message(assistant_message) return serialize_message(user_message), serialize_message(assistant_message)
def stream_message(
db: Session,
conversation_id: int,
content: str,
*,
user_id: str,
) -> Iterator[dict[str, Any]] | None:
conv = get_conversation_by_id(db, conversation_id, user_id)
if not conv:
return None
user_message = AiChatMessage(
conversation_id=conversation_id,
role="user",
content=content.strip(),
)
db.add(user_message)
db.commit()
db.refresh(user_message)
history = sorted(conv.messages, key=lambda m: m.created_at)
history.append(user_message)
def event_stream() -> Iterator[dict[str, Any]]:
yield {"type": "message", "message": serialize_message(user_message).model_dump(mode="json")}
try:
reply_content = yield from _stream_assistant_reply(db, conversation_id, history)
assistant_message = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=reply_content,
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")}
except Exception as e:
reply_content = f"抱歉,AI服务暂时不可用:{str(e)}"
assistant_message = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=reply_content,
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
yield {"type": "delta", "content": reply_content}
yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")}
return event_stream()
def _get_function_definitions() -> list[dict[str, Any]]: def _get_function_definitions() -> list[dict[str, Any]]:
"""Define available functions that the AI can call.""" """Define available functions that the AI can call."""
return [ return [
@@ -479,7 +535,12 @@ def _execute_function(db: Session, function_name: str, arguments: dict[str, Any]
return f"执行函数 {function_name} 时出错: {str(e)}" return f"执行函数 {function_name} 时出错: {str(e)}"
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: def _build_openai_chat_request(
db: Session,
history: list[AiChatMessage],
*,
stream: bool = False,
) -> tuple[str, dict[str, str], dict[str, Any]]:
api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key") api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key")
model_param = get_system_param_by_key(db, "ai_chat.model") model_param = get_system_param_by_key(db, "ai_chat.model")
base_url_param = get_system_param_by_key(db, "ai_chat.base_url") base_url_param = get_system_param_by_key(db, "ai_chat.base_url")
@@ -513,10 +574,18 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
"messages": messages, "messages": messages,
"tools": _get_function_definitions(), "tools": _get_function_definitions(),
} }
if stream:
payload["stream"] = True
return f"{base_url}/chat/completions", headers, payload
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> dict[str, Any]:
url, headers, payload = _build_openai_chat_request(db, history)
with httpx.Client(timeout=60.0) as client: with httpx.Client(timeout=60.0) as client:
response = client.post( response = client.post(
f"{base_url}/chat/completions", url,
headers=headers, headers=headers,
json=payload, json=payload,
) )
@@ -524,3 +593,148 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
data = response.json() data = response.json()
return data return data
def _stream_openai_api(db: Session, history: list[AiChatMessage]) -> Iterator[dict[str, Any]]:
url, headers, payload = _build_openai_chat_request(db, history, stream=True)
with httpx.Client(timeout=60.0) as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
response.raise_for_status()
for raw_line in response.iter_lines():
if not raw_line:
continue
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
line = line.strip()
if line.startswith(":"):
continue
if line.startswith("data:"):
line = line.removeprefix("data:").strip()
if not line or line == "[DONE]":
continue
yield json.loads(line)
def _stream_assistant_reply(
db: Session,
conversation_id: int,
history: list[AiChatMessage],
) -> Generator[dict[str, Any], None, str]:
reply_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
finish_reason = yield from _consume_completion_stream(db, history, reply_parts, tool_calls)
normalized_tool_calls = _normalize_tool_calls(tool_calls)
if finish_reason == "tool_calls" and normalized_tool_calls:
assistant_message_with_tool_calls = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content="".join(reply_parts),
tool_calls=normalized_tool_calls,
)
db.add(assistant_message_with_tool_calls)
db.commit()
db.refresh(assistant_message_with_tool_calls)
history.append(assistant_message_with_tool_calls)
for tool_call in normalized_tool_calls:
function_name = tool_call["function"]["name"]
function_args_str = tool_call["function"].get("arguments") or "{}"
tool_call_id = tool_call["id"]
try:
function_args = json.loads(function_args_str)
except json.JSONDecodeError:
function_args = {}
function_result = _execute_function(db, function_name, function_args)
tool_message = AiChatMessage(
conversation_id=conversation_id,
role="tool",
content=function_result,
tool_call_id=tool_call_id,
)
db.add(tool_message)
db.commit()
db.refresh(tool_message)
history.append(tool_message)
reply_parts = []
yield from _consume_completion_stream(db, history, reply_parts, [])
return "".join(reply_parts)
def _consume_completion_stream(
db: Session,
history: list[AiChatMessage],
reply_parts: list[str],
tool_calls: list[dict[str, Any]],
) -> Generator[dict[str, Any], None, str | None]:
finish_reason: str | None = None
for chunk in _stream_openai_api(db, history):
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
finish_reason = choice.get("finish_reason") or finish_reason
delta = choice.get("delta") or {}
content_delta = delta.get("content")
if content_delta:
reply_parts.append(content_delta)
yield {"type": "delta", "content": content_delta}
for tool_call_delta in delta.get("tool_calls") or []:
_merge_tool_call_delta(tool_calls, tool_call_delta)
return finish_reason
def _merge_tool_call_delta(tool_calls: list[dict[str, Any]], delta: dict[str, Any]) -> None:
index = delta.get("index")
if not isinstance(index, int):
index = len(tool_calls)
while len(tool_calls) <= index:
tool_calls.append(
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
tool_call = tool_calls[index]
if delta.get("id"):
tool_call["id"] = delta["id"]
if delta.get("type"):
tool_call["type"] = delta["type"]
function_delta = delta.get("function") or {}
function = tool_call.setdefault("function", {"name": "", "arguments": ""})
if function_delta.get("name"):
function["name"] = f"{function.get('name', '')}{function_delta['name']}"
if function_delta.get("arguments"):
function["arguments"] = f"{function.get('arguments', '')}{function_delta['arguments']}"
def _normalize_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for index, tool_call in enumerate(tool_calls):
function = tool_call.get("function") or {}
function_name = function.get("name")
if not function_name:
continue
normalized.append(
{
"id": tool_call.get("id") or f"tool_call_{index}",
"type": tool_call.get("type") or "function",
"function": {
"name": function_name,
"arguments": function.get("arguments") or "{}",
},
}
)
return normalized
+82
View File
@@ -0,0 +1,82 @@
from __future__ import annotations
import os
import unittest
from typing import Any
from unittest.mock import patch
os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:")
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from api.app import models # noqa: F401
from api.app.core.database import Base
from api.app.models.ai_chat import AiChatConversation, AiChatMessage
from api.app.models.user import User
from api.app.services.ai_chat_service import stream_message
class AiChatStreamServiceTest(unittest.TestCase):
def setUp(self) -> None:
self.engine = create_engine(
"sqlite+pysqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
self.SessionLocal = sessionmaker(
bind=self.engine,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
Base.metadata.create_all(bind=self.engine)
self.session = self.SessionLocal()
self.session.add(
User(
id="user-1",
email="user@example.com",
username="tester",
password_hash="hash",
status="active",
)
)
self.session.add(AiChatConversation(id=1, title="测试对话", user_id="user-1"))
self.session.commit()
def tearDown(self) -> None:
self.session.close()
Base.metadata.drop_all(bind=self.engine)
self.engine.dispose()
def test_stream_message_yields_delta_and_persists_messages(self) -> None:
def fake_stream(_: Any, history: list[AiChatMessage]) -> list[dict[str, Any]]:
self.assertEqual(history[-1].content, "你好")
return [
{"choices": [{"delta": {"content": ""}, "finish_reason": None}]},
{"choices": [{"delta": {"content": ""}, "finish_reason": "stop"}]},
]
with patch("api.app.services.ai_chat_service._stream_openai_api", side_effect=fake_stream):
events = stream_message(self.session, 1, "你好", user_id="user-1")
self.assertIsNotNone(events)
event_list = list(events or [])
self.assertEqual([event["type"] for event in event_list], ["message", "delta", "delta", "done"])
self.assertEqual(event_list[1]["content"], "")
self.assertEqual(event_list[2]["content"], "")
self.assertEqual(event_list[3]["reply"]["content"], "你好")
stored_messages = (
self.session.query(AiChatMessage)
.filter(AiChatMessage.conversation_id == 1)
.order_by(AiChatMessage.id.asc())
.all()
)
self.assertEqual([message.role for message in stored_messages], ["user", "assistant"])
self.assertEqual([message.content for message in stored_messages], ["你好", "你好"])
if __name__ == "__main__":
unittest.main()
+22
View File
@@ -0,0 +1,22 @@
# Work Log - AI问答流式响应(FL-184
- 背景:
- AI 问答原接口等待 OpenAI-compatible `/chat/completions` 完整返回后再落库并返回 JSON,前端发送后只能等待整段回复。
- 本次处理:
- `/api/v1/ai-chat/conversations/{id}/messages` 改为 `application/x-ndjson` 流式响应,依次输出用户消息、回复增量和最终助手消息。
- 保留 `/api/v1/ai-chat/conversations/{id}/messages/sync` 同步接口,便于旧调用方或排障回退。
- 前端 AI 问答页面改为读取 `ReadableStream`,发送后立即追加用户消息和临时助手消息,并随增量实时刷新内容。
- 补充 AI 聊天流式服务单测,覆盖增量事件输出和最终消息落库。
- 验证:
- 基线:`npm --workspace web exec tsc --noEmit --pretty false` 通过。
- 基线:`python3 -m pytest api/tests/test_system_param_service.py` 因系统 Python 3.7 不满足项目 Python >=3.10 要求失败;后改用 uv 创建 Python 3.10 venv。
- 修改后:`python3 -m py_compile api/app/services/ai_chat_service.py api/app/api/v1/ai_chat.py api/tests/test_ai_chat_stream_service.py` 通过。
- 修改后:`npm --workspace web exec tsc --noEmit --pretty false` 通过。
- 修改后:`npm --workspace web exec eslint src/app/admin/ai-chat/page.tsx src/types/ai-chat.ts` 通过。
- 修改后:`PYTHONPATH=. api/.venv/bin/pytest api/tests/test_ai_chat_stream_service.py api/tests/test_system_param_service.py` 通过,3 passed;存在既有 SQLAlchemy relationship warning。
- 风险与关注点:
- 主消息发送接口响应格式由 JSON 改为 NDJSON 流;当前前端已同步适配,后端保留 sync 兼容接口。
- 流式工具调用会先聚合工具调用参数并执行,再继续流式输出工具后的最终回复。
+131 -20
View File
@@ -23,7 +23,7 @@ import {
SendOutlined, SendOutlined,
DeleteOutlined, DeleteOutlined,
} from "@ant-design/icons"; } from "@ant-design/icons";
import { useCallback, useEffect, useRef, useState, type ComponentType, type RefAttributes } from "react"; import { useCallback, useEffect, useMemo, useRef, useState, type ComponentType, type RefAttributes } from "react";
import { useAuth } from "@/components/auth-provider"; import { useAuth } from "@/components/auth-provider";
import { useToastFeedback } from "@/hooks/use-toast-feedback"; import { useToastFeedback } from "@/hooks/use-toast-feedback";
@@ -31,13 +31,18 @@ import { readApiError, API_BASE_URL } from "@/lib/api";
import type { import type {
AiChatConversation, AiChatConversation,
AiChatConversationListResponse, AiChatConversationListResponse,
AiChatMessageResponse, AiChatMessage,
} from "@/types/ai-chat"; } from "@/types/ai-chat";
const { TextArea } = Input; const { TextArea } = Input;
const { Text } = Typography; const { Text } = Typography;
const AntCard = Card as unknown as ComponentType<CardProps & RefAttributes<HTMLDivElement>>; const AntCard = Card as unknown as ComponentType<CardProps & RefAttributes<HTMLDivElement>>;
type ChatStreamEvent =
| { type: "message"; message: AiChatMessage }
| { type: "delta"; content: string }
| { type: "done"; reply: AiChatMessage };
export default function AiChatPage() { export default function AiChatPage() {
const { user, initializing, fetchWithAuth } = useAuth(); const { user, initializing, fetchWithAuth } = useAuth();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
@@ -48,6 +53,8 @@ export default function AiChatPage() {
const [newConvTitle, setNewConvTitle] = useState("新对话"); const [newConvTitle, setNewConvTitle] = useState("新对话");
const [error, setError] = useState(""); const [error, setError] = useState("");
const [success, setSuccess] = useState(""); const [success, setSuccess] = useState("");
const [streamingMessageId, setStreamingMessageId] = useState<number | null>(null);
const [optimisticMessages, setOptimisticMessages] = useState<Record<number, AiChatMessage[]>>({});
const messagesEndRef = useRef<HTMLDivElement>(null); const messagesEndRef = useRef<HTMLDivElement>(null);
useToastFeedback({ useToastFeedback({
@@ -134,34 +141,138 @@ export default function AiChatPage() {
}, },
}); });
const sendMessageMutation = useMutation({ const currentMessages = useMemo(
mutationFn: async ({ () => (selectedConvId ? optimisticMessages[selectedConvId] ?? currentConv?.messages ?? [] : []),
convId, [currentConv?.messages, optimisticMessages, selectedConvId],
content, );
}: {
convId: number; const updateOptimisticMessages = useCallback(
content: string; (convId: number, updater: (messages: AiChatMessage[]) => AiChatMessage[]) => {
}) => { setOptimisticMessages((current) => ({
...current,
[convId]: updater(current[convId] ?? currentConv?.messages ?? []),
}));
},
[currentConv?.messages],
);
const readChatStream = useCallback(
async (convId: number, content: string) => {
const now = new Date().toISOString();
const assistantTempId = -Date.now();
const userTempId = assistantTempId - 1;
const userMessage: AiChatMessage = {
id: userTempId,
conversation_id: convId,
role: "user",
content,
created_at: now,
};
const assistantMessage: AiChatMessage = {
id: assistantTempId,
conversation_id: convId,
role: "assistant",
content: "",
created_at: now,
};
setStreamingMessageId(assistantTempId);
updateOptimisticMessages(convId, (messages) => [...messages, userMessage, assistantMessage]);
const response = await fetchWithAuth( const response = await fetchWithAuth(
`${API_BASE_URL}/api/v1/ai-chat/conversations/${convId}/messages`, `${API_BASE_URL}/api/v1/ai-chat/conversations/${convId}/messages`,
{ {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ content }), body: JSON.stringify({ content }),
} },
); );
if (!response.ok) { if (!response.ok) {
throw new Error(await readApiError(response)); throw new Error(await readApiError(response));
} }
return (await response.json()) as AiChatMessageResponse; if (!response.body) {
throw new Error("浏览器未返回流式响应体");
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
const applyStreamEvent = (event: ChatStreamEvent) => {
if (event.type === "message") {
updateOptimisticMessages(convId, (messages) =>
messages.map((msg) => (msg.id === userTempId ? event.message : msg)),
);
return;
}
if (event.type === "delta") {
updateOptimisticMessages(convId, (messages) =>
messages.map((msg) =>
msg.id === assistantTempId
? { ...msg, content: `${msg.content}${event.content}` }
: msg,
),
);
return;
}
updateOptimisticMessages(convId, (messages) =>
messages.map((msg) => (msg.id === assistantTempId ? event.reply : msg)),
);
};
const parseBufferedEvents = () => {
const lines = buffer.split("\n");
buffer = lines.pop() ?? "";
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
applyStreamEvent(JSON.parse(trimmed) as ChatStreamEvent);
}
};
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
parseBufferedEvents();
}
buffer += decoder.decode();
if (buffer.trim()) {
applyStreamEvent(JSON.parse(buffer.trim()) as ChatStreamEvent);
}
await queryClient.invalidateQueries({ queryKey: ["ai-chat-conversations"] });
await queryClient.invalidateQueries({ queryKey: ["ai-chat-conversation", convId] });
}, },
onSuccess: () => { [fetchWithAuth, queryClient, updateOptimisticMessages],
queryClient.invalidateQueries({ );
queryKey: ["ai-chat-conversation", selectedConvId],
}); const sendMessageMutation = useMutation({
mutationFn: async ({ convId, content }: { convId: number; content: string }) => {
await readChatStream(convId, content);
},
onSuccess: (_, variables) => {
setMessageInput(""); setMessageInput("");
setStreamingMessageId(null);
setOptimisticMessages((current) => {
const next = { ...current };
delete next[variables.convId];
return next;
});
}, },
onError: (err: Error) => { onError: (err: Error, variables) => {
setStreamingMessageId(null);
setOptimisticMessages((current) => {
const next = { ...current };
delete next[variables.convId];
return next;
});
queryClient.invalidateQueries({
queryKey: ["ai-chat-conversation", variables.convId],
});
setError(`发送消息失败: ${err.message}`); setError(`发送消息失败: ${err.message}`);
}, },
}); });
@@ -176,7 +287,7 @@ export default function AiChatPage() {
useEffect(() => { useEffect(() => {
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
}, [currentConv?.messages]); }, [currentMessages]);
if (initializing) { if (initializing) {
return ( return (
@@ -324,7 +435,7 @@ export default function AiChatPage() {
</div> </div>
) : ( ) : (
<Space direction="vertical" style={{ width: "100%" }} size={16}> <Space direction="vertical" style={{ width: "100%" }} size={16}>
{currentConv?.messages?.map((msg) => ( {currentMessages.map((msg) => (
<div <div
key={msg.id} key={msg.id}
style={{ style={{
@@ -381,7 +492,7 @@ export default function AiChatPage() {
fontSize: 14, fontSize: 14,
}} }}
> >
{msg.content} {msg.content || (msg.id === streamingMessageId ? "正在回复..." : "")}
</div> </div>
<Text <Text
type="secondary" type="secondary"
+3 -1
View File
@@ -1,8 +1,10 @@
export interface AiChatMessage { export interface AiChatMessage {
id: number; id: number;
conversation_id: number; conversation_id: number;
role: "user" | "assistant"; role: "user" | "assistant" | "tool";
content: string; content: string;
tool_calls?: Record<string, unknown> | null;
tool_call_id?: string | null;
created_at: string; created_at: string;
} }