feat:[FL-184][AI问答改成流式响应]
Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
@@ -128,7 +128,7 @@
|
|||||||
- 聊天模型选择规则固定为:`CAPABILITY: chat.default` 优先,未命中时回退 `GLOBAL: __global__`。
|
- 聊天模型选择规则固定为:`CAPABILITY: chat.default` 优先,未命中时回退 `GLOBAL: __global__`。
|
||||||
- 仅允许 `ENABLED` 且具备激活密钥记录的模型参与路由命中;若不满足,接口返回 400。
|
- 仅允许 `ENABLED` 且具备激活密钥记录的模型参与路由命中;若不满足,接口返回 400。
|
||||||
- 运行时真实 Provider Key 不从数据库反解,统一从环境变量 `LLM_PROVIDER_API_KEYS` 注入(支持 `openai=sk-...` 或 JSON 字典字符串)。
|
- 运行时真实 Provider Key 不从数据库反解,统一从环境变量 `LLM_PROVIDER_API_KEYS` 注入(支持 `openai=sk-...` 或 JSON 字典字符串)。
|
||||||
- 一期模型调用采用非流式 OpenAI-compatible `POST /chat/completions`,后续如需流式再扩展 SSE/WS。
|
- AI 问答发送接口 `POST /api/v1/ai-chat/conversations/{id}/messages` 已改为 `application/x-ndjson` 流式响应,前端通过 `fetch` + `ReadableStream` 读取增量并实时刷新助手消息;如需旧同步行为,可用 `/messages/sync` 兼容接口。
|
||||||
|
|
||||||
## Celery 监控口径(2026-05-01)
|
## Celery 监控口径(2026-05-01)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ...core.database import get_db
|
from ...core.database import get_db
|
||||||
@@ -19,6 +23,7 @@ from ...services.ai_chat_service import (
|
|||||||
list_conversations,
|
list_conversations,
|
||||||
send_message,
|
send_message,
|
||||||
serialize_conversation_detail,
|
serialize_conversation_detail,
|
||||||
|
stream_message,
|
||||||
update_conversation,
|
update_conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,12 +90,37 @@ def delete_conversation_endpoint(
|
|||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/conversations/{conversation_id}/messages", response_model=AiChatMessageResponse)
|
def _encode_stream_events(events: Iterator[dict]) -> Iterator[str]:
|
||||||
|
for event in events:
|
||||||
|
yield json.dumps(event, ensure_ascii=False, default=str) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/conversations/{conversation_id}/messages")
|
||||||
def send_message_endpoint(
|
def send_message_endpoint(
|
||||||
conversation_id: int,
|
conversation_id: int,
|
||||||
payload: AiChatMessageSendRequest,
|
payload: AiChatMessageSendRequest,
|
||||||
current_user: CurrentUser = Depends(get_current_user),
|
current_user: CurrentUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
events = stream_message(db, conversation_id, payload.content, user_id=current_user.user.id)
|
||||||
|
if events is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
|
||||||
|
return StreamingResponse(
|
||||||
|
_encode_stream_events(events),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/conversations/{conversation_id}/messages/sync", response_model=AiChatMessageResponse)
|
||||||
|
def send_message_sync_endpoint(
|
||||||
|
conversation_id: int,
|
||||||
|
payload: AiChatMessageSendRequest,
|
||||||
|
current_user: CurrentUser = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
) -> AiChatMessageResponse:
|
) -> AiChatMessageResponse:
|
||||||
result = send_message(db, conversation_id, payload.content, user_id=current_user.user.id)
|
result = send_message(db, conversation_id, payload.content, user_id=current_user.user.id)
|
||||||
if not result:
|
if not result:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator, Iterator
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -258,6 +259,61 @@ def send_message(
|
|||||||
return serialize_message(user_message), serialize_message(assistant_message)
|
return serialize_message(user_message), serialize_message(assistant_message)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_message(
|
||||||
|
db: Session,
|
||||||
|
conversation_id: int,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
) -> Iterator[dict[str, Any]] | None:
|
||||||
|
conv = get_conversation_by_id(db, conversation_id, user_id)
|
||||||
|
if not conv:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_message = AiChatMessage(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=content.strip(),
|
||||||
|
)
|
||||||
|
db.add(user_message)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user_message)
|
||||||
|
|
||||||
|
history = sorted(conv.messages, key=lambda m: m.created_at)
|
||||||
|
history.append(user_message)
|
||||||
|
|
||||||
|
def event_stream() -> Iterator[dict[str, Any]]:
|
||||||
|
yield {"type": "message", "message": serialize_message(user_message).model_dump(mode="json")}
|
||||||
|
|
||||||
|
try:
|
||||||
|
reply_content = yield from _stream_assistant_reply(db, conversation_id, history)
|
||||||
|
assistant_message = AiChatMessage(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=reply_content,
|
||||||
|
)
|
||||||
|
db.add(assistant_message)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(assistant_message)
|
||||||
|
|
||||||
|
yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")}
|
||||||
|
except Exception as e:
|
||||||
|
reply_content = f"抱歉,AI服务暂时不可用:{str(e)}"
|
||||||
|
assistant_message = AiChatMessage(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=reply_content,
|
||||||
|
)
|
||||||
|
db.add(assistant_message)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(assistant_message)
|
||||||
|
|
||||||
|
yield {"type": "delta", "content": reply_content}
|
||||||
|
yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")}
|
||||||
|
|
||||||
|
return event_stream()
|
||||||
|
|
||||||
|
|
||||||
def _get_function_definitions() -> list[dict[str, Any]]:
|
def _get_function_definitions() -> list[dict[str, Any]]:
|
||||||
"""Define available functions that the AI can call."""
|
"""Define available functions that the AI can call."""
|
||||||
return [
|
return [
|
||||||
@@ -479,7 +535,12 @@ def _execute_function(db: Session, function_name: str, arguments: dict[str, Any]
|
|||||||
return f"执行函数 {function_name} 时出错: {str(e)}"
|
return f"执行函数 {function_name} 时出错: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
def _build_openai_chat_request(
|
||||||
|
db: Session,
|
||||||
|
history: list[AiChatMessage],
|
||||||
|
*,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> tuple[str, dict[str, str], dict[str, Any]]:
|
||||||
api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key")
|
api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key")
|
||||||
model_param = get_system_param_by_key(db, "ai_chat.model")
|
model_param = get_system_param_by_key(db, "ai_chat.model")
|
||||||
base_url_param = get_system_param_by_key(db, "ai_chat.base_url")
|
base_url_param = get_system_param_by_key(db, "ai_chat.base_url")
|
||||||
@@ -513,10 +574,18 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": _get_function_definitions(),
|
"tools": _get_function_definitions(),
|
||||||
}
|
}
|
||||||
|
if stream:
|
||||||
|
payload["stream"] = True
|
||||||
|
|
||||||
|
return f"{base_url}/chat/completions", headers, payload
|
||||||
|
|
||||||
|
|
||||||
|
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> dict[str, Any]:
|
||||||
|
url, headers, payload = _build_openai_chat_request(db, history)
|
||||||
|
|
||||||
with httpx.Client(timeout=60.0) as client:
|
with httpx.Client(timeout=60.0) as client:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"{base_url}/chat/completions",
|
url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
)
|
)
|
||||||
@@ -524,3 +593,148 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_openai_api(db: Session, history: list[AiChatMessage]) -> Iterator[dict[str, Any]]:
|
||||||
|
url, headers, payload = _build_openai_chat_request(db, history, stream=True)
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
for raw_line in response.iter_lines():
|
||||||
|
if not raw_line:
|
||||||
|
continue
|
||||||
|
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith(":"):
|
||||||
|
continue
|
||||||
|
if line.startswith("data:"):
|
||||||
|
line = line.removeprefix("data:").strip()
|
||||||
|
if not line or line == "[DONE]":
|
||||||
|
continue
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_assistant_reply(
|
||||||
|
db: Session,
|
||||||
|
conversation_id: int,
|
||||||
|
history: list[AiChatMessage],
|
||||||
|
) -> Generator[dict[str, Any], None, str]:
|
||||||
|
reply_parts: list[str] = []
|
||||||
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
finish_reason = yield from _consume_completion_stream(db, history, reply_parts, tool_calls)
|
||||||
|
normalized_tool_calls = _normalize_tool_calls(tool_calls)
|
||||||
|
if finish_reason == "tool_calls" and normalized_tool_calls:
|
||||||
|
assistant_message_with_tool_calls = AiChatMessage(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content="".join(reply_parts),
|
||||||
|
tool_calls=normalized_tool_calls,
|
||||||
|
)
|
||||||
|
db.add(assistant_message_with_tool_calls)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(assistant_message_with_tool_calls)
|
||||||
|
history.append(assistant_message_with_tool_calls)
|
||||||
|
|
||||||
|
for tool_call in normalized_tool_calls:
|
||||||
|
function_name = tool_call["function"]["name"]
|
||||||
|
function_args_str = tool_call["function"].get("arguments") or "{}"
|
||||||
|
tool_call_id = tool_call["id"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
function_args = json.loads(function_args_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
function_args = {}
|
||||||
|
|
||||||
|
function_result = _execute_function(db, function_name, function_args)
|
||||||
|
tool_message = AiChatMessage(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="tool",
|
||||||
|
content=function_result,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
db.add(tool_message)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(tool_message)
|
||||||
|
history.append(tool_message)
|
||||||
|
|
||||||
|
reply_parts = []
|
||||||
|
yield from _consume_completion_stream(db, history, reply_parts, [])
|
||||||
|
|
||||||
|
return "".join(reply_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _consume_completion_stream(
|
||||||
|
db: Session,
|
||||||
|
history: list[AiChatMessage],
|
||||||
|
reply_parts: list[str],
|
||||||
|
tool_calls: list[dict[str, Any]],
|
||||||
|
) -> Generator[dict[str, Any], None, str | None]:
|
||||||
|
finish_reason: str | None = None
|
||||||
|
for chunk in _stream_openai_api(db, history):
|
||||||
|
choices = chunk.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
finish_reason = choice.get("finish_reason") or finish_reason
|
||||||
|
delta = choice.get("delta") or {}
|
||||||
|
|
||||||
|
content_delta = delta.get("content")
|
||||||
|
if content_delta:
|
||||||
|
reply_parts.append(content_delta)
|
||||||
|
yield {"type": "delta", "content": content_delta}
|
||||||
|
|
||||||
|
for tool_call_delta in delta.get("tool_calls") or []:
|
||||||
|
_merge_tool_call_delta(tool_calls, tool_call_delta)
|
||||||
|
|
||||||
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_tool_call_delta(tool_calls: list[dict[str, Any]], delta: dict[str, Any]) -> None:
|
||||||
|
index = delta.get("index")
|
||||||
|
if not isinstance(index, int):
|
||||||
|
index = len(tool_calls)
|
||||||
|
|
||||||
|
while len(tool_calls) <= index:
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": "",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "", "arguments": ""},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call = tool_calls[index]
|
||||||
|
if delta.get("id"):
|
||||||
|
tool_call["id"] = delta["id"]
|
||||||
|
if delta.get("type"):
|
||||||
|
tool_call["type"] = delta["type"]
|
||||||
|
|
||||||
|
function_delta = delta.get("function") or {}
|
||||||
|
function = tool_call.setdefault("function", {"name": "", "arguments": ""})
|
||||||
|
if function_delta.get("name"):
|
||||||
|
function["name"] = f"{function.get('name', '')}{function_delta['name']}"
|
||||||
|
if function_delta.get("arguments"):
|
||||||
|
function["arguments"] = f"{function.get('arguments', '')}{function_delta['arguments']}"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
normalized: list[dict[str, Any]] = []
|
||||||
|
for index, tool_call in enumerate(tool_calls):
|
||||||
|
function = tool_call.get("function") or {}
|
||||||
|
function_name = function.get("name")
|
||||||
|
if not function_name:
|
||||||
|
continue
|
||||||
|
normalized.append(
|
||||||
|
{
|
||||||
|
"id": tool_call.get("id") or f"tool_call_{index}",
|
||||||
|
"type": tool_call.get("type") or "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": function.get("arguments") or "{}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return normalized
|
||||||
|
|||||||
@@ -0,0 +1,82 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:")
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from api.app import models # noqa: F401
|
||||||
|
from api.app.core.database import Base
|
||||||
|
from api.app.models.ai_chat import AiChatConversation, AiChatMessage
|
||||||
|
from api.app.models.user import User
|
||||||
|
from api.app.services.ai_chat_service import stream_message
|
||||||
|
|
||||||
|
|
||||||
|
class AiChatStreamServiceTest(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.engine = create_engine(
|
||||||
|
"sqlite+pysqlite://",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
self.SessionLocal = sessionmaker(
|
||||||
|
bind=self.engine,
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
Base.metadata.create_all(bind=self.engine)
|
||||||
|
self.session = self.SessionLocal()
|
||||||
|
self.session.add(
|
||||||
|
User(
|
||||||
|
id="user-1",
|
||||||
|
email="user@example.com",
|
||||||
|
username="tester",
|
||||||
|
password_hash="hash",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.session.add(AiChatConversation(id=1, title="测试对话", user_id="user-1"))
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.session.close()
|
||||||
|
Base.metadata.drop_all(bind=self.engine)
|
||||||
|
self.engine.dispose()
|
||||||
|
|
||||||
|
def test_stream_message_yields_delta_and_persists_messages(self) -> None:
|
||||||
|
def fake_stream(_: Any, history: list[AiChatMessage]) -> list[dict[str, Any]]:
|
||||||
|
self.assertEqual(history[-1].content, "你好")
|
||||||
|
return [
|
||||||
|
{"choices": [{"delta": {"content": "你"}, "finish_reason": None}]},
|
||||||
|
{"choices": [{"delta": {"content": "好"}, "finish_reason": "stop"}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("api.app.services.ai_chat_service._stream_openai_api", side_effect=fake_stream):
|
||||||
|
events = stream_message(self.session, 1, "你好", user_id="user-1")
|
||||||
|
self.assertIsNotNone(events)
|
||||||
|
event_list = list(events or [])
|
||||||
|
|
||||||
|
self.assertEqual([event["type"] for event in event_list], ["message", "delta", "delta", "done"])
|
||||||
|
self.assertEqual(event_list[1]["content"], "你")
|
||||||
|
self.assertEqual(event_list[2]["content"], "好")
|
||||||
|
self.assertEqual(event_list[3]["reply"]["content"], "你好")
|
||||||
|
|
||||||
|
stored_messages = (
|
||||||
|
self.session.query(AiChatMessage)
|
||||||
|
.filter(AiChatMessage.conversation_id == 1)
|
||||||
|
.order_by(AiChatMessage.id.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
self.assertEqual([message.role for message in stored_messages], ["user", "assistant"])
|
||||||
|
self.assertEqual([message.content for message in stored_messages], ["你好", "你好"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# Work Log - AI问答流式响应(FL-184)
|
||||||
|
|
||||||
|
- 背景:
|
||||||
|
- AI 问答原接口等待 OpenAI-compatible `/chat/completions` 完整返回后再落库并返回 JSON,前端发送后只能等待整段回复。
|
||||||
|
|
||||||
|
- 本次处理:
|
||||||
|
- `/api/v1/ai-chat/conversations/{id}/messages` 改为 `application/x-ndjson` 流式响应,依次输出用户消息、回复增量和最终助手消息。
|
||||||
|
- 保留 `/api/v1/ai-chat/conversations/{id}/messages/sync` 同步接口,便于旧调用方或排障回退。
|
||||||
|
- 前端 AI 问答页面改为读取 `ReadableStream`,发送后立即追加用户消息和临时助手消息,并随增量实时刷新内容。
|
||||||
|
- 补充 AI 聊天流式服务单测,覆盖增量事件输出和最终消息落库。
|
||||||
|
|
||||||
|
- 验证:
|
||||||
|
- 基线:`npm --workspace web exec tsc --noEmit --pretty false` 通过。
|
||||||
|
- 基线:`python3 -m pytest api/tests/test_system_param_service.py` 因系统 Python 3.7 不满足项目 Python >=3.10 要求失败;后改用 uv 创建 Python 3.10 venv。
|
||||||
|
- 修改后:`python3 -m py_compile api/app/services/ai_chat_service.py api/app/api/v1/ai_chat.py api/tests/test_ai_chat_stream_service.py` 通过。
|
||||||
|
- 修改后:`npm --workspace web exec tsc --noEmit --pretty false` 通过。
|
||||||
|
- 修改后:`npm --workspace web exec eslint src/app/admin/ai-chat/page.tsx src/types/ai-chat.ts` 通过。
|
||||||
|
- 修改后:`PYTHONPATH=. api/.venv/bin/pytest api/tests/test_ai_chat_stream_service.py api/tests/test_system_param_service.py` 通过,3 passed;存在既有 SQLAlchemy relationship warning。
|
||||||
|
|
||||||
|
- 风险与关注点:
|
||||||
|
- 主消息发送接口响应格式由 JSON 改为 NDJSON 流;当前前端已同步适配,后端保留 sync 兼容接口。
|
||||||
|
- 流式工具调用会先聚合工具调用参数并执行,再继续流式输出工具后的最终回复。
|
||||||
@@ -23,7 +23,7 @@ import {
|
|||||||
SendOutlined,
|
SendOutlined,
|
||||||
DeleteOutlined,
|
DeleteOutlined,
|
||||||
} from "@ant-design/icons";
|
} from "@ant-design/icons";
|
||||||
import { useCallback, useEffect, useRef, useState, type ComponentType, type RefAttributes } from "react";
|
import { useCallback, useEffect, useMemo, useRef, useState, type ComponentType, type RefAttributes } from "react";
|
||||||
|
|
||||||
import { useAuth } from "@/components/auth-provider";
|
import { useAuth } from "@/components/auth-provider";
|
||||||
import { useToastFeedback } from "@/hooks/use-toast-feedback";
|
import { useToastFeedback } from "@/hooks/use-toast-feedback";
|
||||||
@@ -31,13 +31,18 @@ import { readApiError, API_BASE_URL } from "@/lib/api";
|
|||||||
import type {
|
import type {
|
||||||
AiChatConversation,
|
AiChatConversation,
|
||||||
AiChatConversationListResponse,
|
AiChatConversationListResponse,
|
||||||
AiChatMessageResponse,
|
AiChatMessage,
|
||||||
} from "@/types/ai-chat";
|
} from "@/types/ai-chat";
|
||||||
|
|
||||||
const { TextArea } = Input;
|
const { TextArea } = Input;
|
||||||
const { Text } = Typography;
|
const { Text } = Typography;
|
||||||
const AntCard = Card as unknown as ComponentType<CardProps & RefAttributes<HTMLDivElement>>;
|
const AntCard = Card as unknown as ComponentType<CardProps & RefAttributes<HTMLDivElement>>;
|
||||||
|
|
||||||
|
type ChatStreamEvent =
|
||||||
|
| { type: "message"; message: AiChatMessage }
|
||||||
|
| { type: "delta"; content: string }
|
||||||
|
| { type: "done"; reply: AiChatMessage };
|
||||||
|
|
||||||
export default function AiChatPage() {
|
export default function AiChatPage() {
|
||||||
const { user, initializing, fetchWithAuth } = useAuth();
|
const { user, initializing, fetchWithAuth } = useAuth();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
@@ -48,6 +53,8 @@ export default function AiChatPage() {
|
|||||||
const [newConvTitle, setNewConvTitle] = useState("新对话");
|
const [newConvTitle, setNewConvTitle] = useState("新对话");
|
||||||
const [error, setError] = useState("");
|
const [error, setError] = useState("");
|
||||||
const [success, setSuccess] = useState("");
|
const [success, setSuccess] = useState("");
|
||||||
|
const [streamingMessageId, setStreamingMessageId] = useState<number | null>(null);
|
||||||
|
const [optimisticMessages, setOptimisticMessages] = useState<Record<number, AiChatMessage[]>>({});
|
||||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
useToastFeedback({
|
useToastFeedback({
|
||||||
@@ -134,34 +141,138 @@ export default function AiChatPage() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const sendMessageMutation = useMutation({
|
const currentMessages = useMemo(
|
||||||
mutationFn: async ({
|
() => (selectedConvId ? optimisticMessages[selectedConvId] ?? currentConv?.messages ?? [] : []),
|
||||||
convId,
|
[currentConv?.messages, optimisticMessages, selectedConvId],
|
||||||
content,
|
);
|
||||||
}: {
|
|
||||||
convId: number;
|
const updateOptimisticMessages = useCallback(
|
||||||
content: string;
|
(convId: number, updater: (messages: AiChatMessage[]) => AiChatMessage[]) => {
|
||||||
}) => {
|
setOptimisticMessages((current) => ({
|
||||||
|
...current,
|
||||||
|
[convId]: updater(current[convId] ?? currentConv?.messages ?? []),
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
[currentConv?.messages],
|
||||||
|
);
|
||||||
|
|
||||||
|
const readChatStream = useCallback(
|
||||||
|
async (convId: number, content: string) => {
|
||||||
|
const now = new Date().toISOString();
|
||||||
|
const assistantTempId = -Date.now();
|
||||||
|
const userTempId = assistantTempId - 1;
|
||||||
|
const userMessage: AiChatMessage = {
|
||||||
|
id: userTempId,
|
||||||
|
conversation_id: convId,
|
||||||
|
role: "user",
|
||||||
|
content,
|
||||||
|
created_at: now,
|
||||||
|
};
|
||||||
|
const assistantMessage: AiChatMessage = {
|
||||||
|
id: assistantTempId,
|
||||||
|
conversation_id: convId,
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
created_at: now,
|
||||||
|
};
|
||||||
|
|
||||||
|
setStreamingMessageId(assistantTempId);
|
||||||
|
updateOptimisticMessages(convId, (messages) => [...messages, userMessage, assistantMessage]);
|
||||||
|
|
||||||
const response = await fetchWithAuth(
|
const response = await fetchWithAuth(
|
||||||
`${API_BASE_URL}/api/v1/ai-chat/conversations/${convId}/messages`,
|
`${API_BASE_URL}/api/v1/ai-chat/conversations/${convId}/messages`,
|
||||||
{
|
{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({ content }),
|
body: JSON.stringify({ content }),
|
||||||
}
|
},
|
||||||
);
|
);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(await readApiError(response));
|
throw new Error(await readApiError(response));
|
||||||
}
|
}
|
||||||
return (await response.json()) as AiChatMessageResponse;
|
if (!response.body) {
|
||||||
|
throw new Error("浏览器未返回流式响应体");
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffer = "";
|
||||||
|
|
||||||
|
const applyStreamEvent = (event: ChatStreamEvent) => {
|
||||||
|
if (event.type === "message") {
|
||||||
|
updateOptimisticMessages(convId, (messages) =>
|
||||||
|
messages.map((msg) => (msg.id === userTempId ? event.message : msg)),
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event.type === "delta") {
|
||||||
|
updateOptimisticMessages(convId, (messages) =>
|
||||||
|
messages.map((msg) =>
|
||||||
|
msg.id === assistantTempId
|
||||||
|
? { ...msg, content: `${msg.content}${event.content}` }
|
||||||
|
: msg,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateOptimisticMessages(convId, (messages) =>
|
||||||
|
messages.map((msg) => (msg.id === assistantTempId ? event.reply : msg)),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const parseBufferedEvents = () => {
|
||||||
|
const lines = buffer.split("\n");
|
||||||
|
buffer = lines.pop() ?? "";
|
||||||
|
for (const line of lines) {
|
||||||
|
const trimmed = line.trim();
|
||||||
|
if (!trimmed) continue;
|
||||||
|
applyStreamEvent(JSON.parse(trimmed) as ChatStreamEvent);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
parseBufferedEvents();
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer += decoder.decode();
|
||||||
|
if (buffer.trim()) {
|
||||||
|
applyStreamEvent(JSON.parse(buffer.trim()) as ChatStreamEvent);
|
||||||
|
}
|
||||||
|
|
||||||
|
await queryClient.invalidateQueries({ queryKey: ["ai-chat-conversations"] });
|
||||||
|
await queryClient.invalidateQueries({ queryKey: ["ai-chat-conversation", convId] });
|
||||||
},
|
},
|
||||||
onSuccess: () => {
|
[fetchWithAuth, queryClient, updateOptimisticMessages],
|
||||||
queryClient.invalidateQueries({
|
);
|
||||||
queryKey: ["ai-chat-conversation", selectedConvId],
|
|
||||||
});
|
const sendMessageMutation = useMutation({
|
||||||
|
mutationFn: async ({ convId, content }: { convId: number; content: string }) => {
|
||||||
|
await readChatStream(convId, content);
|
||||||
|
},
|
||||||
|
onSuccess: (_, variables) => {
|
||||||
setMessageInput("");
|
setMessageInput("");
|
||||||
|
setStreamingMessageId(null);
|
||||||
|
setOptimisticMessages((current) => {
|
||||||
|
const next = { ...current };
|
||||||
|
delete next[variables.convId];
|
||||||
|
return next;
|
||||||
|
});
|
||||||
},
|
},
|
||||||
onError: (err: Error) => {
|
onError: (err: Error, variables) => {
|
||||||
|
setStreamingMessageId(null);
|
||||||
|
setOptimisticMessages((current) => {
|
||||||
|
const next = { ...current };
|
||||||
|
delete next[variables.convId];
|
||||||
|
return next;
|
||||||
|
});
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["ai-chat-conversation", variables.convId],
|
||||||
|
});
|
||||||
setError(`发送消息失败: ${err.message}`);
|
setError(`发送消息失败: ${err.message}`);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -176,7 +287,7 @@ export default function AiChatPage() {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||||
}, [currentConv?.messages]);
|
}, [currentMessages]);
|
||||||
|
|
||||||
if (initializing) {
|
if (initializing) {
|
||||||
return (
|
return (
|
||||||
@@ -324,7 +435,7 @@ export default function AiChatPage() {
|
|||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Space direction="vertical" style={{ width: "100%" }} size={16}>
|
<Space direction="vertical" style={{ width: "100%" }} size={16}>
|
||||||
{currentConv?.messages?.map((msg) => (
|
{currentMessages.map((msg) => (
|
||||||
<div
|
<div
|
||||||
key={msg.id}
|
key={msg.id}
|
||||||
style={{
|
style={{
|
||||||
@@ -381,7 +492,7 @@ export default function AiChatPage() {
|
|||||||
fontSize: 14,
|
fontSize: 14,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{msg.content}
|
{msg.content || (msg.id === streamingMessageId ? "正在回复..." : "")}
|
||||||
</div>
|
</div>
|
||||||
<Text
|
<Text
|
||||||
type="secondary"
|
type="secondary"
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
export interface AiChatMessage {
|
export interface AiChatMessage {
|
||||||
id: number;
|
id: number;
|
||||||
conversation_id: number;
|
conversation_id: number;
|
||||||
role: "user" | "assistant";
|
role: "user" | "assistant" | "tool";
|
||||||
content: string;
|
content: string;
|
||||||
|
tool_calls?: Record<string, unknown> | null;
|
||||||
|
tool_call_id?: string | null;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user