feat:[FL-184][AI问答改成流式响应]

Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
chengkai3
2026-06-24 13:48:33 +08:00
parent ae8a2cb9b6
commit 5bf92a9ded
7 changed files with 486 additions and 25 deletions
+31 -1
View File
@@ -1,4 +1,8 @@
import json
from collections.abc import Iterator
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from ...core.database import get_db
@@ -19,6 +23,7 @@ from ...services.ai_chat_service import (
list_conversations,
send_message,
serialize_conversation_detail,
stream_message,
update_conversation,
)
@@ -85,12 +90,37 @@ def delete_conversation_endpoint(
return {"success": True}
@router.post("/conversations/{conversation_id}/messages", response_model=AiChatMessageResponse)
def _encode_stream_events(events: Iterator[dict]) -> Iterator[str]:
for event in events:
yield json.dumps(event, ensure_ascii=False, default=str) + "\n"
@router.post("/conversations/{conversation_id}/messages")
def send_message_endpoint(
conversation_id: int,
payload: AiChatMessageSendRequest,
current_user: CurrentUser = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StreamingResponse:
events = stream_message(db, conversation_id, payload.content, user_id=current_user.user.id)
if events is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found")
return StreamingResponse(
_encode_stream_events(events),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
@router.post("/conversations/{conversation_id}/messages/sync", response_model=AiChatMessageResponse)
def send_message_sync_endpoint(
conversation_id: int,
payload: AiChatMessageSendRequest,
current_user: CurrentUser = Depends(get_current_user),
db: Session = Depends(get_db),
) -> AiChatMessageResponse:
result = send_message(db, conversation_id, payload.content, user_id=current_user.user.id)
if not result:
+216 -2
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Generator, Iterator
import json
import os
from typing import Any
@@ -258,6 +259,61 @@ def send_message(
return serialize_message(user_message), serialize_message(assistant_message)
def stream_message(
db: Session,
conversation_id: int,
content: str,
*,
user_id: str,
) -> Iterator[dict[str, Any]] | None:
conv = get_conversation_by_id(db, conversation_id, user_id)
if not conv:
return None
user_message = AiChatMessage(
conversation_id=conversation_id,
role="user",
content=content.strip(),
)
db.add(user_message)
db.commit()
db.refresh(user_message)
history = sorted(conv.messages, key=lambda m: m.created_at)
history.append(user_message)
def event_stream() -> Iterator[dict[str, Any]]:
yield {"type": "message", "message": serialize_message(user_message).model_dump(mode="json")}
try:
reply_content = yield from _stream_assistant_reply(db, conversation_id, history)
assistant_message = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=reply_content,
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")}
except Exception as e:
reply_content = f"抱歉,AI服务暂时不可用:{str(e)}"
assistant_message = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=reply_content,
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
yield {"type": "delta", "content": reply_content}
yield {"type": "done", "reply": serialize_message(assistant_message).model_dump(mode="json")}
return event_stream()
def _get_function_definitions() -> list[dict[str, Any]]:
"""Define available functions that the AI can call."""
return [
@@ -479,7 +535,12 @@ def _execute_function(db: Session, function_name: str, arguments: dict[str, Any]
return f"执行函数 {function_name} 时出错: {str(e)}"
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
def _build_openai_chat_request(
db: Session,
history: list[AiChatMessage],
*,
stream: bool = False,
) -> tuple[str, dict[str, str], dict[str, Any]]:
api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key")
model_param = get_system_param_by_key(db, "ai_chat.model")
base_url_param = get_system_param_by_key(db, "ai_chat.base_url")
@@ -513,10 +574,18 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
"messages": messages,
"tools": _get_function_definitions(),
}
if stream:
payload["stream"] = True
return f"{base_url}/chat/completions", headers, payload
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> dict[str, Any]:
url, headers, payload = _build_openai_chat_request(db, history)
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{base_url}/chat/completions",
url,
headers=headers,
json=payload,
)
@@ -524,3 +593,148 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
data = response.json()
return data
def _stream_openai_api(db: Session, history: list[AiChatMessage]) -> Iterator[dict[str, Any]]:
url, headers, payload = _build_openai_chat_request(db, history, stream=True)
with httpx.Client(timeout=60.0) as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
response.raise_for_status()
for raw_line in response.iter_lines():
if not raw_line:
continue
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
line = line.strip()
if line.startswith(":"):
continue
if line.startswith("data:"):
line = line.removeprefix("data:").strip()
if not line or line == "[DONE]":
continue
yield json.loads(line)
def _stream_assistant_reply(
db: Session,
conversation_id: int,
history: list[AiChatMessage],
) -> Generator[dict[str, Any], None, str]:
reply_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
finish_reason = yield from _consume_completion_stream(db, history, reply_parts, tool_calls)
normalized_tool_calls = _normalize_tool_calls(tool_calls)
if finish_reason == "tool_calls" and normalized_tool_calls:
assistant_message_with_tool_calls = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content="".join(reply_parts),
tool_calls=normalized_tool_calls,
)
db.add(assistant_message_with_tool_calls)
db.commit()
db.refresh(assistant_message_with_tool_calls)
history.append(assistant_message_with_tool_calls)
for tool_call in normalized_tool_calls:
function_name = tool_call["function"]["name"]
function_args_str = tool_call["function"].get("arguments") or "{}"
tool_call_id = tool_call["id"]
try:
function_args = json.loads(function_args_str)
except json.JSONDecodeError:
function_args = {}
function_result = _execute_function(db, function_name, function_args)
tool_message = AiChatMessage(
conversation_id=conversation_id,
role="tool",
content=function_result,
tool_call_id=tool_call_id,
)
db.add(tool_message)
db.commit()
db.refresh(tool_message)
history.append(tool_message)
reply_parts = []
yield from _consume_completion_stream(db, history, reply_parts, [])
return "".join(reply_parts)
def _consume_completion_stream(
db: Session,
history: list[AiChatMessage],
reply_parts: list[str],
tool_calls: list[dict[str, Any]],
) -> Generator[dict[str, Any], None, str | None]:
finish_reason: str | None = None
for chunk in _stream_openai_api(db, history):
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
finish_reason = choice.get("finish_reason") or finish_reason
delta = choice.get("delta") or {}
content_delta = delta.get("content")
if content_delta:
reply_parts.append(content_delta)
yield {"type": "delta", "content": content_delta}
for tool_call_delta in delta.get("tool_calls") or []:
_merge_tool_call_delta(tool_calls, tool_call_delta)
return finish_reason
def _merge_tool_call_delta(tool_calls: list[dict[str, Any]], delta: dict[str, Any]) -> None:
index = delta.get("index")
if not isinstance(index, int):
index = len(tool_calls)
while len(tool_calls) <= index:
tool_calls.append(
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
tool_call = tool_calls[index]
if delta.get("id"):
tool_call["id"] = delta["id"]
if delta.get("type"):
tool_call["type"] = delta["type"]
function_delta = delta.get("function") or {}
function = tool_call.setdefault("function", {"name": "", "arguments": ""})
if function_delta.get("name"):
function["name"] = f"{function.get('name', '')}{function_delta['name']}"
if function_delta.get("arguments"):
function["arguments"] = f"{function.get('arguments', '')}{function_delta['arguments']}"
def _normalize_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for index, tool_call in enumerate(tool_calls):
function = tool_call.get("function") or {}
function_name = function.get("name")
if not function_name:
continue
normalized.append(
{
"id": tool_call.get("id") or f"tool_call_{index}",
"type": tool_call.get("type") or "function",
"function": {
"name": function_name,
"arguments": function.get("arguments") or "{}",
},
}
)
return normalized
+82
View File
@@ -0,0 +1,82 @@
from __future__ import annotations
import os
import unittest
from typing import Any
from unittest.mock import patch
os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:")
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from api.app import models # noqa: F401
from api.app.core.database import Base
from api.app.models.ai_chat import AiChatConversation, AiChatMessage
from api.app.models.user import User
from api.app.services.ai_chat_service import stream_message
class AiChatStreamServiceTest(unittest.TestCase):
def setUp(self) -> None:
self.engine = create_engine(
"sqlite+pysqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
self.SessionLocal = sessionmaker(
bind=self.engine,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
Base.metadata.create_all(bind=self.engine)
self.session = self.SessionLocal()
self.session.add(
User(
id="user-1",
email="user@example.com",
username="tester",
password_hash="hash",
status="active",
)
)
self.session.add(AiChatConversation(id=1, title="测试对话", user_id="user-1"))
self.session.commit()
def tearDown(self) -> None:
self.session.close()
Base.metadata.drop_all(bind=self.engine)
self.engine.dispose()
def test_stream_message_yields_delta_and_persists_messages(self) -> None:
def fake_stream(_: Any, history: list[AiChatMessage]) -> list[dict[str, Any]]:
self.assertEqual(history[-1].content, "你好")
return [
{"choices": [{"delta": {"content": ""}, "finish_reason": None}]},
{"choices": [{"delta": {"content": ""}, "finish_reason": "stop"}]},
]
with patch("api.app.services.ai_chat_service._stream_openai_api", side_effect=fake_stream):
events = stream_message(self.session, 1, "你好", user_id="user-1")
self.assertIsNotNone(events)
event_list = list(events or [])
self.assertEqual([event["type"] for event in event_list], ["message", "delta", "delta", "done"])
self.assertEqual(event_list[1]["content"], "")
self.assertEqual(event_list[2]["content"], "")
self.assertEqual(event_list[3]["reply"]["content"], "你好")
stored_messages = (
self.session.query(AiChatMessage)
.filter(AiChatMessage.conversation_id == 1)
.order_by(AiChatMessage.id.asc())
.all()
)
self.assertEqual([message.role for message in stored_messages], ["user", "assistant"])
self.assertEqual([message.content for message in stored_messages], ["你好", "你好"])
if __name__ == "__main__":
unittest.main()