feat:[FL-181][AI问答要支持function call]
实现AI问答的function call功能,支持调用系统接口进行查询。 改动内容: 1. 数据库扩展: - 在ai_chat_messages表增加tool_calls和tool_call_id字段 - 创建数据库迁移文件 2. 模型和Schema更新: - AiChatMessage模型增加tool_calls(JSON)和tool_call_id字段 - AiChatMessageSummary schema增加对应字段 3. Function Call实现: - 定义4个可调用函数:query_tower_models、query_lines、query_users、query_system_params - 实现_execute_function处理函数调用并返回格式化结果 - 更新_call_openai_api支持tools参数 4. 消息流程更新: - 重构send_message支持完整的function call流程 - 流程:用户消息 -> AI请求function call -> 执行函数 -> AI基于结果回复 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from ..core.database import Base
|
||||
@@ -50,6 +50,8 @@ class AiChatMessage(Base):
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(16), index=True)
|
||||
content: Mapped[str] = mapped_column(Text())
|
||||
tool_calls: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
tool_call_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
|
||||
conversation: Mapped[AiChatConversation] = relationship(
|
||||
|
||||
@@ -12,6 +12,8 @@ class AiChatMessageSummary(BaseModel):
|
||||
conversation_id: int
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: dict | None = None
|
||||
tool_call_id: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@@ -26,6 +27,8 @@ def serialize_message(msg: AiChatMessage) -> AiChatMessageSummary:
|
||||
conversation_id=msg.conversation_id,
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
tool_calls=msg.tool_calls,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
|
||||
@@ -173,20 +176,307 @@ def send_message(
|
||||
history.append(user_message)
|
||||
|
||||
try:
|
||||
reply_content = _call_openai_api(db, history)
|
||||
response_data = _call_openai_api(db, history)
|
||||
choice = response_data["choices"][0]
|
||||
message = choice["message"]
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
if finish_reason == "tool_calls" and message.get("tool_calls"):
|
||||
assistant_message_with_tool_calls = AiChatMessage(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=message.get("content") or "",
|
||||
tool_calls=message["tool_calls"],
|
||||
)
|
||||
db.add(assistant_message_with_tool_calls)
|
||||
db.commit()
|
||||
db.refresh(assistant_message_with_tool_calls)
|
||||
history.append(assistant_message_with_tool_calls)
|
||||
|
||||
for tool_call in message["tool_calls"]:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args_str = tool_call["function"]["arguments"]
|
||||
tool_call_id = tool_call["id"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(function_args_str)
|
||||
except json.JSONDecodeError:
|
||||
function_args = {}
|
||||
|
||||
function_result = _execute_function(db, function_name, function_args)
|
||||
|
||||
tool_message = AiChatMessage(
|
||||
conversation_id=conversation_id,
|
||||
role="tool",
|
||||
content=function_result,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
db.add(tool_message)
|
||||
db.commit()
|
||||
db.refresh(tool_message)
|
||||
history.append(tool_message)
|
||||
|
||||
final_response_data = _call_openai_api(db, history)
|
||||
final_message = final_response_data["choices"][0]["message"]
|
||||
reply_content = final_message.get("content", "")
|
||||
|
||||
assistant_message = AiChatMessage(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=reply_content,
|
||||
)
|
||||
db.add(assistant_message)
|
||||
db.commit()
|
||||
db.refresh(assistant_message)
|
||||
|
||||
return serialize_message(user_message), serialize_message(assistant_message)
|
||||
|
||||
else:
|
||||
reply_content = message.get("content", "")
|
||||
assistant_message = AiChatMessage(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=reply_content,
|
||||
)
|
||||
db.add(assistant_message)
|
||||
db.commit()
|
||||
db.refresh(assistant_message)
|
||||
|
||||
return serialize_message(user_message), serialize_message(assistant_message)
|
||||
|
||||
except Exception as e:
|
||||
reply_content = f"抱歉,AI服务暂时不可用:{str(e)}"
|
||||
assistant_message = AiChatMessage(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=reply_content,
|
||||
)
|
||||
db.add(assistant_message)
|
||||
db.commit()
|
||||
db.refresh(assistant_message)
|
||||
|
||||
assistant_message = AiChatMessage(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=reply_content,
|
||||
)
|
||||
db.add(assistant_message)
|
||||
db.commit()
|
||||
db.refresh(assistant_message)
|
||||
return serialize_message(user_message), serialize_message(assistant_message)
|
||||
|
||||
return serialize_message(user_message), serialize_message(assistant_message)
|
||||
|
||||
def _get_function_definitions() -> list[dict[str, Any]]:
|
||||
"""Define available functions that the AI can call."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_tower_models",
|
||||
"description": "查询杆塔型号列表。可以根据关键字搜索杆塔型号名称。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键字,用于匹配杆塔型号名称",
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"description": "是否只查询启用的杆塔型号",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量限制,默认10",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_lines",
|
||||
"description": "查询线路列表。可以根据关键字搜索线路名称。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键字,用于匹配线路名称",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量限制,默认10",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_users",
|
||||
"description": "查询系统用户列表。可以根据关键字搜索用户名称或邮箱。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键字,用于匹配用户名称或邮箱",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量限制,默认10",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "query_system_params",
|
||||
"description": "查询系统参数配置。可以根据参数键或名称搜索。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "搜索关键字,用于匹配参数键或名称",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "参数分类",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量限制,默认10",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _execute_function(db: Session, function_name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Execute a function call and return the result as a string."""
|
||||
try:
|
||||
if function_name == "query_tower_models":
|
||||
from ..services.tower_model_service import list_tower_models
|
||||
|
||||
keyword = arguments.get("keyword")
|
||||
enabled = arguments.get("enabled")
|
||||
limit = arguments.get("limit", 10)
|
||||
|
||||
result = list_tower_models(
|
||||
db,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
keyword=keyword,
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
if not result.items:
|
||||
return "未找到匹配的杆塔型号。"
|
||||
|
||||
items_text = []
|
||||
for item in result.items:
|
||||
items_text.append(
|
||||
f"- 型号: {item.name}, "
|
||||
f"高度: {item.height}m, "
|
||||
f"状态: {'启用' if item.enabled else '禁用'}"
|
||||
)
|
||||
|
||||
return f"找到 {result.total} 个杆塔型号(显示前{len(result.items)}个):\n" + "\n".join(items_text)
|
||||
|
||||
elif function_name == "query_lines":
|
||||
from ..services.line_service import list_lines
|
||||
|
||||
keyword = arguments.get("keyword")
|
||||
limit = arguments.get("limit", 10)
|
||||
|
||||
result = list_lines(
|
||||
db,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
keyword=keyword,
|
||||
)
|
||||
|
||||
if not result.items:
|
||||
return "未找到匹配的线路。"
|
||||
|
||||
items_text = []
|
||||
for item in result.items:
|
||||
items_text.append(
|
||||
f"- 线路: {item.name}, "
|
||||
f"电压等级: {item.voltage_level or '未设置'}, "
|
||||
f"长度: {item.length or '未设置'}km"
|
||||
)
|
||||
|
||||
return f"找到 {result.total} 条线路(显示前{len(result.items)}个):\n" + "\n".join(items_text)
|
||||
|
||||
elif function_name == "query_users":
|
||||
from ..services.user_service import list_users
|
||||
|
||||
keyword = arguments.get("keyword")
|
||||
limit = arguments.get("limit", 10)
|
||||
|
||||
result = list_users(
|
||||
db,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
keyword=keyword,
|
||||
)
|
||||
|
||||
if not result.items:
|
||||
return "未找到匹配的用户。"
|
||||
|
||||
items_text = []
|
||||
for item in result.items:
|
||||
items_text.append(
|
||||
f"- 用户ID: {item.id}, "
|
||||
f"用户名: {item.username}, "
|
||||
f"邮箱: {item.email or '未设置'}"
|
||||
)
|
||||
|
||||
return f"找到 {result.total} 个用户(显示前{len(result.items)}个):\n" + "\n".join(items_text)
|
||||
|
||||
elif function_name == "query_system_params":
|
||||
from ..services.system_param_service import list_system_params
|
||||
|
||||
keyword = arguments.get("keyword")
|
||||
category = arguments.get("category")
|
||||
limit = arguments.get("limit", 10)
|
||||
|
||||
result = list_system_params(
|
||||
db,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
keyword=keyword,
|
||||
category=category,
|
||||
)
|
||||
|
||||
if not result.items:
|
||||
return "未找到匹配的系统参数。"
|
||||
|
||||
items_text = []
|
||||
for item in result.items:
|
||||
items_text.append(
|
||||
f"- 参数键: {item.param_key}, "
|
||||
f"名称: {item.param_name}, "
|
||||
f"值: {item.param_value or '未设置'}, "
|
||||
f"状态: {item.status}"
|
||||
)
|
||||
|
||||
return f"找到 {result.total} 个系统参数(显示前{len(result.items)}个):\n" + "\n".join(items_text)
|
||||
|
||||
else:
|
||||
return f"未知的函数: {function_name}"
|
||||
|
||||
except Exception as e:
|
||||
return f"执行函数 {function_name} 时出错: {str(e)}"
|
||||
|
||||
|
||||
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
||||
@@ -204,7 +494,14 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
||||
if not api_key:
|
||||
raise ValueError("AI模型配置缺失:请在系统参数中配置 ai_chat.openai_api_key")
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in history]
|
||||
messages = []
|
||||
for msg in history:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role, "content": msg.content or ""}
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages.append(msg_dict)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
@@ -214,6 +511,7 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": _get_function_definitions(),
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
@@ -225,4 +523,4 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return data["choices"][0]["message"]["content"]
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user