From 7ef266e4a061e9eff7614a07c6c3dbed91ce0afb Mon Sep 17 00:00:00 2001 From: chengkai3 Date: Mon, 22 Jun 2026 23:58:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:[FL-181][AI=E9=97=AE=E7=AD=94=E8=A6=81?= =?UTF-8?q?=E6=94=AF=E6=8C=81function=20call]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现AI问答的function call功能,支持调用系统接口进行查询。 改动内容: 1. 数据库扩展: - 在ai_chat_messages表增加tool_calls和tool_call_id字段 - 创建数据库迁移文件 2. 模型和Schema更新: - AiChatMessage模型增加tool_calls(JSON)和tool_call_id字段 - AiChatMessageSummary schema增加对应字段 3. Function Call实现: - 定义4个可调用函数:query_tower_models、query_lines、query_users、query_system_params - 实现_execute_function处理函数调用并返回格式化结果 - 更新_call_openai_api支持tools参数 4. 消息流程更新: - 重构send_message支持完整的function call流程 - 流程:用户消息 -> AI请求function call -> 执行函数 -> AI基于结果回复 Co-Authored-By: Claude Sonnet 4.6 Co-authored-by: multica-agent --- api/app/models/ai_chat.py | 4 +- api/app/schemas/ai_chat.py | 2 + api/app/services/ai_chat_service.py | 322 +++++++++++++++++- .../add_ai_chat_function_call_fields.sql | 8 + 4 files changed, 323 insertions(+), 13 deletions(-) create mode 100644 migrations/add_ai_chat_function_call_fields.sql diff --git a/api/app/models/ai_chat.py b/api/app/models/ai_chat.py index 879a4ed..dd06bfe 100644 --- a/api/app/models/ai_chat.py +++ b/api/app/models/ai_chat.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import DateTime, ForeignKey, Integer, String, Text +from sqlalchemy import DateTime, ForeignKey, Integer, JSON, String, Text from sqlalchemy.orm import Mapped, mapped_column, relationship from ..core.database import Base @@ -50,6 +50,8 @@ class AiChatMessage(Base): ) role: Mapped[str] = mapped_column(String(16), index=True) content: Mapped[str] = mapped_column(Text()) + tool_calls: Mapped[dict | None] = mapped_column(JSON, nullable=True) + tool_call_id: Mapped[str | None] = mapped_column(String(64), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow) conversation: Mapped[AiChatConversation] = relationship( diff --git a/api/app/schemas/ai_chat.py b/api/app/schemas/ai_chat.py index 89717c0..27dcd74 100644 --- a/api/app/schemas/ai_chat.py +++ b/api/app/schemas/ai_chat.py @@ -12,6 +12,8 @@ class AiChatMessageSummary(BaseModel): conversation_id: int role: str content: str + tool_calls: dict | None = None + tool_call_id: str | None = None created_at: datetime diff --git a/api/app/services/ai_chat_service.py b/api/app/services/ai_chat_service.py index 134a8b8..985f555 100644 --- a/api/app/services/ai_chat_service.py +++ b/api/app/services/ai_chat_service.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os from typing import Any @@ -26,6 +27,8 @@ def serialize_message(msg: AiChatMessage) -> AiChatMessageSummary: conversation_id=msg.conversation_id, role=msg.role, content=msg.content, + tool_calls=msg.tool_calls, + tool_call_id=msg.tool_call_id, created_at=msg.created_at, ) @@ -173,20 +176,307 @@ def send_message( history.append(user_message) try: - reply_content = _call_openai_api(db, history) + response_data = _call_openai_api(db, history) + choice = response_data["choices"][0] + message = choice["message"] + finish_reason = choice.get("finish_reason") + + if finish_reason == "tool_calls" and message.get("tool_calls"): + assistant_message_with_tool_calls = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content=message.get("content") or "", + tool_calls=message["tool_calls"], + ) + db.add(assistant_message_with_tool_calls) + db.commit() + db.refresh(assistant_message_with_tool_calls) + history.append(assistant_message_with_tool_calls) + + for tool_call in message["tool_calls"]: + function_name = tool_call["function"]["name"] + function_args_str = tool_call["function"]["arguments"] + tool_call_id = tool_call["id"] + + try: + function_args = json.loads(function_args_str) + except json.JSONDecodeError: + function_args = {} + + function_result = _execute_function(db, function_name, function_args) + + tool_message = AiChatMessage( + conversation_id=conversation_id, + role="tool", + content=function_result, + tool_call_id=tool_call_id, + ) + db.add(tool_message) + db.commit() + db.refresh(tool_message) + history.append(tool_message) + + final_response_data = _call_openai_api(db, history) + final_message = final_response_data["choices"][0]["message"] + reply_content = final_message.get("content", "") + + assistant_message = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content=reply_content, + ) + db.add(assistant_message) + db.commit() + db.refresh(assistant_message) + + return serialize_message(user_message), serialize_message(assistant_message) + + else: + reply_content = message.get("content", "") + assistant_message = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content=reply_content, + ) + db.add(assistant_message) + db.commit() + db.refresh(assistant_message) + + return serialize_message(user_message), serialize_message(assistant_message) + except Exception as e: reply_content = f"抱歉,AI服务暂时不可用:{str(e)}" + assistant_message = AiChatMessage( + conversation_id=conversation_id, + role="assistant", + content=reply_content, + ) + db.add(assistant_message) + db.commit() + db.refresh(assistant_message) - assistant_message = AiChatMessage( - conversation_id=conversation_id, - role="assistant", - content=reply_content, - ) - db.add(assistant_message) - db.commit() - db.refresh(assistant_message) + return serialize_message(user_message), serialize_message(assistant_message) - return serialize_message(user_message), serialize_message(assistant_message) + +def _get_function_definitions() -> list[dict[str, Any]]: + """Define available functions that the AI can call.""" + return [ + { + "type": "function", + "function": { + "name": "query_tower_models", + "description": "查询杆塔型号列表。可以根据关键字搜索杆塔型号名称。", + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "搜索关键字,用于匹配杆塔型号名称", + }, + "enabled": { + "type": "boolean", + "description": "是否只查询启用的杆塔型号", + }, + "limit": { + "type": "integer", + "description": "返回结果数量限制,默认10", + "default": 10, + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "query_lines", + "description": "查询线路列表。可以根据关键字搜索线路名称。", + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "搜索关键字,用于匹配线路名称", + }, + "limit": { + "type": "integer", + "description": "返回结果数量限制,默认10", + "default": 10, + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "query_users", + "description": "查询系统用户列表。可以根据关键字搜索用户名称或邮箱。", + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "搜索关键字,用于匹配用户名称或邮箱", + }, + "limit": { + "type": "integer", + "description": "返回结果数量限制,默认10", + "default": 10, + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "query_system_params", + "description": "查询系统参数配置。可以根据参数键或名称搜索。", + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "搜索关键字,用于匹配参数键或名称", + }, + "category": { + "type": "string", + "description": "参数分类", + }, + "limit": { + "type": "integer", + "description": "返回结果数量限制,默认10", + "default": 10, + }, + }, + "required": [], + }, + }, + }, + ] + + +def _execute_function(db: Session, function_name: str, arguments: dict[str, Any]) -> str: + """Execute a function call and return the result as a string.""" + try: + if function_name == "query_tower_models": + from ..services.tower_model_service import list_tower_models + + keyword = arguments.get("keyword") + enabled = arguments.get("enabled") + limit = arguments.get("limit", 10) + + result = list_tower_models( + db, + limit=limit, + offset=0, + keyword=keyword, + enabled=enabled, + ) + + if not result.items: + return "未找到匹配的杆塔型号。" + + items_text = [] + for item in result.items: + items_text.append( + f"- 型号: {item.name}, " + f"高度: {item.height}m, " + f"状态: {'启用' if item.enabled else '禁用'}" + ) + + return f"找到 {result.total} 个杆塔型号(显示前{len(result.items)}个):\n" + "\n".join(items_text) + + elif function_name == "query_lines": + from ..services.line_service import list_lines + + keyword = arguments.get("keyword") + limit = arguments.get("limit", 10) + + result = list_lines( + db, + limit=limit, + offset=0, + keyword=keyword, + ) + + if not result.items: + return "未找到匹配的线路。" + + items_text = [] + for item in result.items: + items_text.append( + f"- 线路: {item.name}, " + f"电压等级: {item.voltage_level or '未设置'}, " + f"长度: {item.length or '未设置'}km" + ) + + return f"找到 {result.total} 条线路(显示前{len(result.items)}个):\n" + "\n".join(items_text) + + elif function_name == "query_users": + from ..services.user_service import list_users + + keyword = arguments.get("keyword") + limit = arguments.get("limit", 10) + + result = list_users( + db, + limit=limit, + offset=0, + keyword=keyword, + ) + + if not result.items: + return "未找到匹配的用户。" + + items_text = [] + for item in result.items: + items_text.append( + f"- 用户ID: {item.id}, " + f"用户名: {item.username}, " + f"邮箱: {item.email or '未设置'}" + ) + + return f"找到 {result.total} 个用户(显示前{len(result.items)}个):\n" + "\n".join(items_text) + + elif function_name == "query_system_params": + from ..services.system_param_service import list_system_params + + keyword = arguments.get("keyword") + category = arguments.get("category") + limit = arguments.get("limit", 10) + + result = list_system_params( + db, + limit=limit, + offset=0, + keyword=keyword, + category=category, + ) + + if not result.items: + return "未找到匹配的系统参数。" + + items_text = [] + for item in result.items: + items_text.append( + f"- 参数键: {item.param_key}, " + f"名称: {item.param_name}, " + f"值: {item.param_value or '未设置'}, " + f"状态: {item.status}" + ) + + return f"找到 {result.total} 个系统参数(显示前{len(result.items)}个):\n" + "\n".join(items_text) + + else: + return f"未知的函数: {function_name}" + + except Exception as e: + return f"执行函数 {function_name} 时出错: {str(e)}" def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: @@ -204,7 +494,14 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: if not api_key: raise ValueError("AI模型配置缺失:请在系统参数中配置 ai_chat.openai_api_key") - messages = [{"role": msg.role, "content": msg.content} for msg in history] + messages = [] + for msg in history: + msg_dict: dict[str, Any] = {"role": msg.role, "content": msg.content or ""} + if msg.tool_calls: + msg_dict["tool_calls"] = msg.tool_calls + if msg.tool_call_id: + msg_dict["tool_call_id"] = msg.tool_call_id + messages.append(msg_dict) headers = { "Authorization": f"Bearer {api_key}", @@ -214,6 +511,7 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: payload: dict[str, Any] = { "model": model, "messages": messages, + "tools": _get_function_definitions(), } with httpx.Client(timeout=60.0) as client: @@ -225,4 +523,4 @@ def _call_openai_api(db: Session, history: list[AiChatMessage]) -> str: response.raise_for_status() data = response.json() - return data["choices"][0]["message"]["content"] + return data diff --git a/migrations/add_ai_chat_function_call_fields.sql b/migrations/add_ai_chat_function_call_fields.sql new file mode 100644 index 0000000..6768ae3 --- /dev/null +++ b/migrations/add_ai_chat_function_call_fields.sql @@ -0,0 +1,8 @@ +-- Add function calling support fields to ai_chat_messages table + +ALTER TABLE ai_chat_messages +ADD COLUMN tool_calls JSON, +ADD COLUMN tool_call_id VARCHAR(64); + +COMMENT ON COLUMN ai_chat_messages.tool_calls IS 'Stores function/tool calls made by the assistant'; +COMMENT ON COLUMN ai_chat_messages.tool_call_id IS 'ID for tool/function result messages';