feat:[FL-181][AI问答要支持function call]

实现AI问答的function call功能,支持调用系统接口进行查询。

改动内容:
1. 数据库扩展:
   - 在ai_chat_messages表增加tool_calls和tool_call_id字段
   - 创建数据库迁移文件

2. 模型和Schema更新:
   - AiChatMessage模型增加tool_calls(JSON)和tool_call_id字段
   - AiChatMessageSummary schema增加对应字段

3. Function Call实现:
   - 定义4个可调用函数:query_tower_models、query_lines、query_users、query_system_params
   - 实现_execute_function处理函数调用并返回格式化结果
   - 更新_call_openai_api支持tools参数

4. 消息流程更新:
   - 重构send_message支持完整的function call流程
   - 流程:用户消息 -> AI请求function call -> 执行函数 -> AI基于结果回复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
chengkai3
2026-06-22 23:58:58 +08:00
parent 3c500d1397
commit 7ef266e4a0
4 changed files with 323 additions and 13 deletions
+3 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
from sqlalchemy import DateTime, ForeignKey, Integer, JSON, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from ..core.database import Base
@@ -50,6 +50,8 @@ class AiChatMessage(Base):
)
role: Mapped[str] = mapped_column(String(16), index=True)
content: Mapped[str] = mapped_column(Text())
tool_calls: Mapped[dict | None] = mapped_column(JSON, nullable=True)
tool_call_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
conversation: Mapped[AiChatConversation] = relationship(
+2
View File
@@ -12,6 +12,8 @@ class AiChatMessageSummary(BaseModel):
conversation_id: int
role: str
content: str
tool_calls: dict | None = None
tool_call_id: str | None = None
created_at: datetime
+303 -5
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import json
import os
from typing import Any
@@ -26,6 +27,8 @@ def serialize_message(msg: AiChatMessage) -> AiChatMessageSummary:
conversation_id=msg.conversation_id,
role=msg.role,
content=msg.content,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
created_at=msg.created_at,
)
@@ -173,9 +176,49 @@ def send_message(
history.append(user_message)
try:
reply_content = _call_openai_api(db, history)
except Exception as e:
reply_content = f"抱歉,AI服务暂时不可用:{str(e)}"
response_data = _call_openai_api(db, history)
choice = response_data["choices"][0]
message = choice["message"]
finish_reason = choice.get("finish_reason")
if finish_reason == "tool_calls" and message.get("tool_calls"):
assistant_message_with_tool_calls = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=message.get("content") or "",
tool_calls=message["tool_calls"],
)
db.add(assistant_message_with_tool_calls)
db.commit()
db.refresh(assistant_message_with_tool_calls)
history.append(assistant_message_with_tool_calls)
for tool_call in message["tool_calls"]:
function_name = tool_call["function"]["name"]
function_args_str = tool_call["function"]["arguments"]
tool_call_id = tool_call["id"]
try:
function_args = json.loads(function_args_str)
except json.JSONDecodeError:
function_args = {}
function_result = _execute_function(db, function_name, function_args)
tool_message = AiChatMessage(
conversation_id=conversation_id,
role="tool",
content=function_result,
tool_call_id=tool_call_id,
)
db.add(tool_message)
db.commit()
db.refresh(tool_message)
history.append(tool_message)
final_response_data = _call_openai_api(db, history)
final_message = final_response_data["choices"][0]["message"]
reply_content = final_message.get("content", "")
assistant_message = AiChatMessage(
conversation_id=conversation_id,
@@ -188,6 +231,253 @@ def send_message(
return serialize_message(user_message), serialize_message(assistant_message)
else:
reply_content = message.get("content", "")
assistant_message = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=reply_content,
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
return serialize_message(user_message), serialize_message(assistant_message)
except Exception as e:
reply_content = f"抱歉,AI服务暂时不可用:{str(e)}"
assistant_message = AiChatMessage(
conversation_id=conversation_id,
role="assistant",
content=reply_content,
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
return serialize_message(user_message), serialize_message(assistant_message)
def _get_function_definitions() -> list[dict[str, Any]]:
"""Define available functions that the AI can call."""
return [
{
"type": "function",
"function": {
"name": "query_tower_models",
"description": "查询杆塔型号列表。可以根据关键字搜索杆塔型号名称。",
"parameters": {
"type": "object",
"properties": {
"keyword": {
"type": "string",
"description": "搜索关键字,用于匹配杆塔型号名称",
},
"enabled": {
"type": "boolean",
"description": "是否只查询启用的杆塔型号",
},
"limit": {
"type": "integer",
"description": "返回结果数量限制,默认10",
"default": 10,
},
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "query_lines",
"description": "查询线路列表。可以根据关键字搜索线路名称。",
"parameters": {
"type": "object",
"properties": {
"keyword": {
"type": "string",
"description": "搜索关键字,用于匹配线路名称",
},
"limit": {
"type": "integer",
"description": "返回结果数量限制,默认10",
"default": 10,
},
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "query_users",
"description": "查询系统用户列表。可以根据关键字搜索用户名称或邮箱。",
"parameters": {
"type": "object",
"properties": {
"keyword": {
"type": "string",
"description": "搜索关键字,用于匹配用户名称或邮箱",
},
"limit": {
"type": "integer",
"description": "返回结果数量限制,默认10",
"default": 10,
},
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "query_system_params",
"description": "查询系统参数配置。可以根据参数键或名称搜索。",
"parameters": {
"type": "object",
"properties": {
"keyword": {
"type": "string",
"description": "搜索关键字,用于匹配参数键或名称",
},
"category": {
"type": "string",
"description": "参数分类",
},
"limit": {
"type": "integer",
"description": "返回结果数量限制,默认10",
"default": 10,
},
},
"required": [],
},
},
},
]
def _execute_function(db: Session, function_name: str, arguments: dict[str, Any]) -> str:
"""Execute a function call and return the result as a string."""
try:
if function_name == "query_tower_models":
from ..services.tower_model_service import list_tower_models
keyword = arguments.get("keyword")
enabled = arguments.get("enabled")
limit = arguments.get("limit", 10)
result = list_tower_models(
db,
limit=limit,
offset=0,
keyword=keyword,
enabled=enabled,
)
if not result.items:
return "未找到匹配的杆塔型号。"
items_text = []
for item in result.items:
items_text.append(
f"- 型号: {item.name}, "
f"高度: {item.height}m, "
f"状态: {'启用' if item.enabled else '禁用'}"
)
return f"找到 {result.total} 个杆塔型号(显示前{len(result.items)}个):\n" + "\n".join(items_text)
elif function_name == "query_lines":
from ..services.line_service import list_lines
keyword = arguments.get("keyword")
limit = arguments.get("limit", 10)
result = list_lines(
db,
limit=limit,
offset=0,
keyword=keyword,
)
if not result.items:
return "未找到匹配的线路。"
items_text = []
for item in result.items:
items_text.append(
f"- 线路: {item.name}, "
f"电压等级: {item.voltage_level or '未设置'}, "
f"长度: {item.length or '未设置'}km"
)
return f"找到 {result.total} 条线路(显示前{len(result.items)}个):\n" + "\n".join(items_text)
elif function_name == "query_users":
from ..services.user_service import list_users
keyword = arguments.get("keyword")
limit = arguments.get("limit", 10)
result = list_users(
db,
limit=limit,
offset=0,
keyword=keyword,
)
if not result.items:
return "未找到匹配的用户。"
items_text = []
for item in result.items:
items_text.append(
f"- 用户ID: {item.id}, "
f"用户名: {item.username}, "
f"邮箱: {item.email or '未设置'}"
)
return f"找到 {result.total} 个用户(显示前{len(result.items)}个):\n" + "\n".join(items_text)
elif function_name == "query_system_params":
from ..services.system_param_service import list_system_params
keyword = arguments.get("keyword")
category = arguments.get("category")
limit = arguments.get("limit", 10)
result = list_system_params(
db,
limit=limit,
offset=0,
keyword=keyword,
category=category,
)
if not result.items:
return "未找到匹配的系统参数。"
items_text = []
for item in result.items:
items_text.append(
f"- 参数键: {item.param_key}, "
f"名称: {item.param_name}, "
f"值: {item.param_value or '未设置'}, "
f"状态: {item.status}"
)
return f"找到 {result.total} 个系统参数(显示前{len(result.items)}个):\n" + "\n".join(items_text)
else:
return f"未知的函数: {function_name}"
except Exception as e:
return f"执行函数 {function_name} 时出错: {str(e)}"
def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
api_key_param = get_system_param_by_key(db, "ai_chat.openai_api_key")
@@ -204,7 +494,14 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
if not api_key:
raise ValueError("AI模型配置缺失:请在系统参数中配置 ai_chat.openai_api_key")
messages = [{"role": msg.role, "content": msg.content} for msg in history]
messages = []
for msg in history:
msg_dict: dict[str, Any] = {"role": msg.role, "content": msg.content or ""}
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages.append(msg_dict)
headers = {
"Authorization": f"Bearer {api_key}",
@@ -214,6 +511,7 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"tools": _get_function_definitions(),
}
with httpx.Client(timeout=60.0) as client:
@@ -225,4 +523,4 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str:
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
return data
@@ -0,0 +1,8 @@
-- Add function calling support fields to ai_chat_messages table
ALTER TABLE ai_chat_messages
ADD COLUMN tool_calls JSON,
ADD COLUMN tool_call_id VARCHAR(64);
COMMENT ON COLUMN ai_chat_messages.tool_calls IS 'Stores function/tool calls made by the assistant';
COMMENT ON COLUMN ai_chat_messages.tool_call_id IS 'ID for tool/function result messages';