diff --git a/README.md b/README.md index b62a03c..75dd51f 100644 --- a/README.md +++ b/README.md @@ -75,10 +75,6 @@ - `POST /api/v1/users/{id}/password`(需要 `user.manage`,重置用户密码) - `DELETE /api/v1/users/{id}`(需要 `user.manage`,删除用户) - `POST /api/v1/users/{id}/roles`(需要 `user.manage`) -- `GET /api/v1/chat/sessions`(需要 `chat.use`) -- `POST /api/v1/chat/sessions`(需要 `chat.use`) -- `GET /api/v1/chat/sessions/{id}/messages`(需要 `chat.use`) -- `POST /api/v1/chat/sessions/{id}/messages`(需要 `chat.use`) 初始化管理员(可选): - 在 `.env` 设置 `INITIAL_ADMIN_EMAIL`、`INITIAL_ADMIN_USER_ID`、`INITIAL_ADMIN_USERNAME`、`INITIAL_ADMIN_PASSWORD` @@ -130,7 +126,4 @@ - `API_CORS_ORIGINS`:精确来源列表(逗号分隔),如 `https://admin.example.com,http://localhost:3000` - `API_CORS_ORIGIN_REGEX`:来源正则(可选),如 `https://.*\\.example\\.com` - 支持在 `API_CORS_ORIGINS` 中使用通配符(如 `https://*.example.com`)或 `*`(仅建议开发调试) -- AI 聊天依赖模型路由与 Provider Key: - - 路由优先级:`CAPABILITY: chat.default` -> `GLOBAL: __global__` - - 在 `deploy/dev-deploy/.env.dev` 配置 `LLM_PROVIDER_API_KEYS`(示例:`openai=sk-xxx`) - 默认镜像源已配置为 `docker.m.daocloud.io`,并默认使用 `pgvector` 镜像;如你网络环境可直连 Docker Hub,可在 `deploy/dev-deploy/.env` 中覆盖 `POSTGRES_IMAGE / PYTHON_BASE_IMAGE / NODE_BASE_IMAGE`。 diff --git a/api/app/core/config.py b/api/app/core/config.py index 17a888e..2d38d66 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -1,5 +1,4 @@ from functools import lru_cache -import json import re from typing import Literal from urllib.parse import quote_plus @@ -43,11 +42,6 @@ class Settings(BaseSettings): refresh_cookie_secure: bool = False refresh_cookie_samesite: Literal["lax", "strict", "none"] = "lax" - llm_provider_api_keys: str = "" - llm_request_timeout_seconds: int = 60 - chat_context_message_limit: int = 12 - chat_default_system_prompt: str = "You are a helpful assistant." - celery_broker_url: str | None = None celery_result_backend: str | None = None celery_timezone: str = "Asia/Shanghai" @@ -84,8 +78,6 @@ class Settings(BaseSettings): @field_validator( "access_token_expire_minutes", "refresh_token_expire_days", - "llm_request_timeout_seconds", - "chat_context_message_limit", "db_port", "scheduler_expire_interval_seconds", "flower_api_timeout_seconds", @@ -137,41 +129,6 @@ class Settings(BaseSettings): return None return "|".join(f"(?:{part})" for part in regex_parts) - @property - def llm_provider_key_map(self) -> dict[str, str]: - raw = self.llm_provider_api_keys.strip() - if not raw: - return {} - - if raw.startswith("{"): - try: - data = json.loads(raw) - except json.JSONDecodeError: - return {} - if not isinstance(data, dict): - return {} - normalized: dict[str, str] = {} - for provider, value in data.items(): - if not isinstance(provider, str) or not isinstance(value, str): - continue - provider_key = provider.strip().lower() - secret = value.strip() - if provider_key and secret: - normalized[provider_key] = secret - return normalized - - mapping: dict[str, str] = {} - for token in re.split(r"[,\n;]+", raw): - pair = token.strip() - if not pair or "=" not in pair: - continue - provider, value = pair.split("=", 1) - provider_key = provider.strip().lower() - secret = value.strip() - if provider_key and secret: - mapping[provider_key] = secret - return mapping - @property def resolved_database_url(self) -> str: explicit_database_url = (self.database_url or "").strip() diff --git a/api/app/core/database.py b/api/app/core/database.py index fe0cd10..aedf7f9 100644 --- a/api/app/core/database.py +++ b/api/app/core/database.py @@ -390,7 +390,6 @@ def init_db() -> None: line, line_tower, menu, - model_registry, object_group, question_bank, rbac, diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 9657951..bbaab79 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -4,7 +4,7 @@ Import all model modules during package initialization so SQLAlchemy can resolve string-based relationships regardless of route/service import order. """ -from . import atp_model, audit_log, auth_session, elevation, file_storage, fl_analysis, hot_search, lightning_event, lightning_sample, line, line_tower, menu, model_registry, object_group, question_bank, rbac, system_param, tower_model, tower_profile, user, worker_registry +from . import atp_model, audit_log, auth_session, elevation, file_storage, fl_analysis, hot_search, lightning_event, lightning_sample, line, line_tower, menu, object_group, question_bank, rbac, system_param, tower_model, tower_profile, user, worker_registry __all__ = [ "atp_model", @@ -19,7 +19,6 @@ __all__ = [ "line", "line_tower", "menu", - "model_registry", "object_group", "question_bank", "rbac", diff --git a/api/app/models/model_registry.py b/api/app/models/model_registry.py deleted file mode 100644 index e1cac0d..0000000 --- a/api/app/models/model_registry.py +++ /dev/null @@ -1,187 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from decimal import Decimal -from typing import TYPE_CHECKING, Any - -from sqlalchemy import ( - JSON, - Boolean, - DateTime, - ForeignKey, - Integer, - Numeric, - String, - Text, - UniqueConstraint, -) -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..core.database import Base -from .base import utcnow - -if TYPE_CHECKING: - from .user import User - - -class ModelRegistry(Base): - __tablename__ = "llm_models" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - code: Mapped[str] = mapped_column(String(64), unique=True, index=True) - name: Mapped[str] = mapped_column(String(128), index=True) - provider: Mapped[str] = mapped_column(String(64), index=True) - provider_model: Mapped[str] = mapped_column(String(128), index=True) - status: Mapped[str] = mapped_column(String(16), default="DRAFT", index=True) - capabilities: Mapped[list[str]] = mapped_column(JSON, default=list) - description: Mapped[str] = mapped_column(Text(), default="") - base_url: Mapped[str | None] = mapped_column(String(255)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - default=utcnow, - onupdate=utcnow, - ) - - route_rules: Mapped[list[ModelRouteRule]] = relationship( - "ModelRouteRule", - back_populates="target_model", - lazy="selectin", - primaryjoin="ModelRegistry.code == ModelRouteRule.target_model_code", - ) - api_keys: Mapped[list[ModelApiKey]] = relationship( - "ModelApiKey", - back_populates="model", - lazy="selectin", - cascade="all, delete-orphan", - order_by="ModelApiKey.version.desc()", - ) - health_checks: Mapped[list[ModelHealthCheck]] = relationship( - "ModelHealthCheck", - back_populates="model", - lazy="selectin", - cascade="all, delete-orphan", - order_by="ModelHealthCheck.created_at.desc()", - ) - test_runs: Mapped[list[ModelTestRun]] = relationship( - "ModelTestRun", - back_populates="model", - lazy="selectin", - cascade="all, delete-orphan", - order_by="ModelTestRun.created_at.desc()", - ) - - -class ModelRouteRule(Base): - __tablename__ = "model_route_rules" - __table_args__ = ( - UniqueConstraint("route_type", "route_key", name="uq_model_route_type_key"), - ) - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - route_type: Mapped[str] = mapped_column(String(16), index=True) - route_key: Mapped[str] = mapped_column(String(128), index=True) - target_model_code: Mapped[str] = mapped_column( - String(64), - ForeignKey("llm_models.code", ondelete="RESTRICT"), - index=True, - ) - priority: Mapped[int] = mapped_column(Integer, default=100, index=True) - enabled: Mapped[bool] = mapped_column(Boolean, default=True, index=True) - note: Mapped[str | None] = mapped_column(String(255)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - default=utcnow, - onupdate=utcnow, - ) - - target_model: Mapped[ModelRegistry] = relationship( - "ModelRegistry", - back_populates="route_rules", - lazy="selectin", - primaryjoin="ModelRouteRule.target_model_code == ModelRegistry.code", - ) - - -class ModelApiKey(Base): - __tablename__ = "model_api_keys" - __table_args__ = ( - UniqueConstraint("model_id", "version", name="uq_model_key_model_version"), - ) - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - model_id: Mapped[int] = mapped_column( - ForeignKey("llm_models.id", ondelete="CASCADE"), - index=True, - ) - version: Mapped[int] = mapped_column(Integer, index=True) - secret_hash: Mapped[str] = mapped_column(String(128)) - secret_masked: Mapped[str] = mapped_column(String(64)) - secret_fingerprint: Mapped[str] = mapped_column(String(32), index=True) - is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True) - rotation_note: Mapped[str | None] = mapped_column(String(255)) - created_by_user_id: Mapped[str | None] = mapped_column( - String(36), - ForeignKey("users.user_id", ondelete="SET NULL"), - index=True, - ) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow) - - model: Mapped[ModelRegistry] = relationship("ModelRegistry", back_populates="api_keys", lazy="selectin") - created_by: Mapped[User | None] = relationship("User", lazy="selectin") - - -class ModelHealthCheck(Base): - __tablename__ = "model_health_checks" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - model_id: Mapped[int] = mapped_column( - ForeignKey("llm_models.id", ondelete="CASCADE"), - index=True, - ) - status: Mapped[str] = mapped_column(String(16), index=True) - reason: Mapped[str] = mapped_column(String(255)) - latency_ms: Mapped[int | None] = mapped_column(Integer) - detail_json: Mapped[dict[str, Any] | None] = mapped_column(JSON) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow) - - model: Mapped[ModelRegistry] = relationship("ModelRegistry", back_populates="health_checks", lazy="selectin") - - -class ModelTestRun(Base): - __tablename__ = "model_test_runs" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - model_id: Mapped[int] = mapped_column( - ForeignKey("llm_models.id", ondelete="CASCADE"), - index=True, - ) - kind: Mapped[str] = mapped_column(String(32), default="SMOKE", index=True) - status: Mapped[str] = mapped_column(String(16), index=True) - input_tokens: Mapped[int] = mapped_column(Integer, default=0) - output_tokens: Mapped[int] = mapped_column(Integer, default=0) - latency_ms: Mapped[int | None] = mapped_column(Integer) - error_message: Mapped[str | None] = mapped_column(Text()) - created_by_user_id: Mapped[str | None] = mapped_column( - String(36), - ForeignKey("users.user_id", ondelete="SET NULL"), - index=True, - ) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow) - - model: Mapped[ModelRegistry] = relationship("ModelRegistry", back_populates="test_runs", lazy="selectin") - created_by: Mapped[User | None] = relationship("User", lazy="selectin") - - -class ModelUsageLog(Base): - __tablename__ = "model_usage_logs" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - model_code: Mapped[str] = mapped_column(String(64), index=True) - source: Mapped[str] = mapped_column(String(32), default="RUNTIME", index=True) - request_count: Mapped[int] = mapped_column(Integer, default=1) - success_count: Mapped[int] = mapped_column(Integer, default=1) - total_tokens: Mapped[int] = mapped_column(Integer, default=0) - total_cost_usd: Mapped[Decimal] = mapped_column(Numeric(12, 6), default=Decimal("0")) - recorded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, index=True) diff --git a/api/app/schemas/model_registry.py b/api/app/schemas/model_registry.py deleted file mode 100644 index f1cb431..0000000 --- a/api/app/schemas/model_registry.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import Literal - -from pydantic import BaseModel, Field - -ModelStatus = Literal["DRAFT", "ENABLED", "DISABLED", "DEPRECATED"] -ModelRouteType = Literal["GLOBAL", "CAPABILITY", "BUSINESS", "AGENT"] -ModelHealthStatus = Literal["HEALTHY", "DEGRADED", "UNHEALTHY"] -ModelTestStatus = Literal["PASSED", "FAILED"] - - -class ModelUsageSummary(BaseModel): - request_count: int = 0 - success_count: int = 0 - total_tokens: int = 0 - total_cost_usd: float = 0.0 - success_rate: float | None = None - - -class ModelTestSummary(BaseModel): - total_runs: int = 0 - passed_runs: int = 0 - failed_runs: int = 0 - pass_rate: float | None = None - - -class ModelRegistryPublic(BaseModel): - id: int - code: str - name: str - provider: str - provider_model: str - status: ModelStatus - capabilities: list[str] = Field(default_factory=list) - description: str = "" - base_url: str | None = None - active_key_masked: str | None = None - active_key_version: int | None = None - active_key_fingerprint: str | None = None - active_key_rotated_at: datetime | None = None - latest_health_status: ModelHealthStatus | None = None - latest_health_reason: str | None = None - latest_health_at: datetime | None = None - route_bindings_count: int = 0 - usage_7d: ModelUsageSummary = Field(default_factory=ModelUsageSummary) - tests_7d: ModelTestSummary = Field(default_factory=ModelTestSummary) - created_at: datetime - updated_at: datetime - - -class ModelListResponse(BaseModel): - items: list[ModelRegistryPublic] - total: int - - -class ModelCreateRequest(BaseModel): - code: str = Field(min_length=2, max_length=64, pattern=r"^[a-z0-9][a-z0-9._-]{1,63}$") - name: str = Field(min_length=2, max_length=128) - provider: str = Field(min_length=2, max_length=64) - provider_model: str = Field(min_length=1, max_length=128) - status: ModelStatus = "DRAFT" - capabilities: list[str] = Field(default_factory=list) - description: str = Field(default="", max_length=2000) - base_url: str | None = Field(default=None, max_length=255) - api_key: str | None = Field(default=None, min_length=8, max_length=1024) - - -class ModelUpdateRequest(BaseModel): - name: str | None = Field(default=None, min_length=2, max_length=128) - provider: str | None = Field(default=None, min_length=2, max_length=64) - provider_model: str | None = Field(default=None, min_length=1, max_length=128) - capabilities: list[str] | None = None - description: str | None = Field(default=None, max_length=2000) - base_url: str | None = Field(default=None, max_length=255) - - -class ModelTransitionRequest(BaseModel): - status: ModelStatus - note: str | None = Field(default=None, max_length=255) - - -class ModelRouteRulePublic(BaseModel): - id: int - route_type: ModelRouteType - route_key: str - target_model_code: str - priority: int - enabled: bool - note: str | None = None - created_at: datetime - updated_at: datetime - - -class ModelRouteRuleListResponse(BaseModel): - items: list[ModelRouteRulePublic] - total: int - - -class ModelRouteRuleCreateRequest(BaseModel): - route_type: ModelRouteType - route_key: str | None = Field(default=None, max_length=128) - target_model_code: str = Field(min_length=2, max_length=64) - priority: int = 100 - enabled: bool = True - note: str | None = Field(default=None, max_length=255) - - -class ModelRouteRuleUpdateRequest(BaseModel): - route_type: ModelRouteType | None = None - route_key: str | None = Field(default=None, max_length=128) - target_model_code: str | None = Field(default=None, min_length=2, max_length=64) - priority: int | None = None - enabled: bool | None = None - note: str | None = Field(default=None, max_length=255) - - -class ModelApiKeyPublic(BaseModel): - id: int - model_id: int - version: int - secret_masked: str - secret_fingerprint: str - is_active: bool - rotation_note: str | None = None - created_by_user_id: str | None = None - created_at: datetime - - -class ModelApiKeyListResponse(BaseModel): - items: list[ModelApiKeyPublic] - total: int - - -class ModelRotateKeyRequest(BaseModel): - api_key: str = Field(min_length=8, max_length=1024) - note: str | None = Field(default=None, max_length=255) - - -class ModelHealthCheckPublic(BaseModel): - id: int - model_id: int - status: ModelHealthStatus - reason: str - latency_ms: int | None = None - detail_json: dict | None = None - created_at: datetime - - -class ModelHealthCheckListResponse(BaseModel): - items: list[ModelHealthCheckPublic] - total: int - - -class ModelTestRunRequest(BaseModel): - kind: str = Field(default="SMOKE", min_length=2, max_length=32) - input_tokens: int = Field(default=0, ge=0) - output_tokens: int = Field(default=0, ge=0) - - -class ModelTestChatRequest(BaseModel): - message: str = Field(min_length=1, max_length=8000) - system_prompt: str | None = Field(default=None, max_length=4000) - - -class ModelTestRunPublic(BaseModel): - id: int - model_id: int - model_code: str - kind: str - status: ModelTestStatus - input_tokens: int - output_tokens: int - latency_ms: int | None = None - error_message: str | None = None - created_by_user_id: str | None = None - created_at: datetime - - -class ModelTestChatResponse(BaseModel): - model_id: int - model_code: str - provider: str - provider_model: str - reply: str | None = None - latency_ms: int | None = None - prompt_tokens: int | None = None - completion_tokens: int | None = None - total_tokens: int | None = None - test_status: ModelTestStatus - error_message: str | None = None - - -class ModelTestRunListResponse(BaseModel): - items: list[ModelTestRunPublic] - total: int - - -class ModelUsageIngestRequest(BaseModel): - model_code: str = Field(min_length=2, max_length=64) - source: str = Field(default="RUNTIME", min_length=2, max_length=32) - request_count: int = Field(default=1, ge=1) - success_count: int = Field(default=1, ge=0) - total_tokens: int = Field(default=0, ge=0) - total_cost_usd: float = Field(default=0.0, ge=0) - - -class ModelSummaryResponse(BaseModel): - total_models: int - status_counts: dict[str, int] - total_route_rules: int - route_type_counts: dict[str, int] - enabled_without_healthy_check: int - usage_7d: ModelUsageSummary - tests_7d: ModelTestSummary diff --git a/api/app/services/admin_service.py b/api/app/services/admin_service.py index a6afbd5..5f1c0a7 100644 --- a/api/app/services/admin_service.py +++ b/api/app/services/admin_service.py @@ -44,9 +44,7 @@ REMOVED_MENU_CODES = { "admin.schedule", "admin.mindmap", "admin.mermaid_mgr", - "admin.chat", "admin.api_tester", - "admin.models", "admin.orchestration", "admin.mdresolve", "admin.data_query", diff --git a/api/app/services/legacy_admin_rbac_service.py b/api/app/services/legacy_admin_rbac_service.py index c320273..8cece29 100644 --- a/api/app/services/legacy_admin_rbac_service.py +++ b/api/app/services/legacy_admin_rbac_service.py @@ -42,9 +42,7 @@ REMOVED_MENU_CODES = { "admin.schedule", "admin.mindmap", "admin.mermaid_mgr", - "admin.chat", "admin.api_tester", - "admin.models", "admin.orchestration", "admin.mdresolve", "admin.data_query", diff --git a/api/app/services/legacy_authz_service.py b/api/app/services/legacy_authz_service.py index 5783b4f..bef15b0 100644 --- a/api/app/services/legacy_authz_service.py +++ b/api/app/services/legacy_authz_service.py @@ -61,9 +61,7 @@ DISABLED_MENU_CODES: set[str] = { "admin.schedule", "admin.mindmap", "admin.mermaid_mgr", - "admin.chat", "admin.api_tester", - "admin.models", "admin.orchestration", "admin.mdresolve", "admin.data_query", diff --git a/api/app/services/llm_gateway.py b/api/app/services/llm_gateway.py deleted file mode 100644 index 3a9040b..0000000 --- a/api/app/services/llm_gateway.py +++ /dev/null @@ -1,252 +0,0 @@ -from __future__ import annotations - -import json -import time -from dataclasses import dataclass - -import httpx -from fastapi import HTTPException, status -from sqlalchemy import select -from sqlalchemy.orm import Session - -from ..core.config import get_settings -from ..models.model_registry import ModelApiKey, ModelRegistry, ModelRouteRule - -settings = get_settings() -CHAT_CAPABILITY_ROUTE_KEY = "chat.default" -GLOBAL_ROUTE_KEY = "__global__" -DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" - - -@dataclass -class LlmCompletionResult: - content: str - model_code: str - provider: str - provider_model: str - prompt_tokens: int | None - completion_tokens: int | None - total_tokens: int | None - latency_ms: int - - -def create_assistant_reply( - db: Session, - *, - user_message: str, - context_messages: list[tuple[str, str]], - system_prompt: str, -) -> LlmCompletionResult: - model = _resolve_chat_model(db) - return create_reply_with_model( - model=model, - user_message=user_message, - context_messages=context_messages, - system_prompt=system_prompt, - ) - - -def create_reply_with_model( - *, - model: ModelRegistry, - user_message: str, - context_messages: list[tuple[str, str]], - system_prompt: str, -) -> LlmCompletionResult: - provider_key = _resolve_provider_key(model.provider) - endpoint = _build_endpoint(model.base_url) - payload = { - "model": model.provider_model, - "messages": _build_messages( - system_prompt=system_prompt, - context_messages=context_messages, - user_message=user_message, - ), - } - - started = time.perf_counter() - try: - with httpx.Client(timeout=settings.llm_request_timeout_seconds) as client: - response = client.post( - endpoint, - headers={ - "Authorization": f"Bearer {provider_key}", - "Content-Type": "application/json", - }, - json=payload, - ) - except httpx.TimeoutException as exc: - raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="LLM request timeout") from exc - except httpx.HTTPError as exc: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM request failed: {exc.__class__.__name__}") from exc - - latency_ms = int((time.perf_counter() - started) * 1000) - if response.status_code >= 400: - detail = _extract_http_error_detail(response) - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response error: {detail}") - - body = response.json() - content = _extract_content(body) - if not content: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM returned empty content") - - usage = body.get("usage") if isinstance(body, dict) else None - prompt_tokens = _to_int(usage.get("prompt_tokens")) if isinstance(usage, dict) else None - completion_tokens = _to_int(usage.get("completion_tokens")) if isinstance(usage, dict) else None - total_tokens = _to_int(usage.get("total_tokens")) if isinstance(usage, dict) else None - - return LlmCompletionResult( - content=content, - model_code=model.code, - provider=model.provider, - provider_model=model.provider_model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - latency_ms=latency_ms, - ) - - -def _resolve_chat_model(db: Session) -> ModelRegistry: - capability_model = _resolve_model_from_route(db, route_type="CAPABILITY", route_key=CHAT_CAPABILITY_ROUTE_KEY) - if capability_model: - return capability_model - - global_model = _resolve_model_from_route(db, route_type="GLOBAL", route_key=GLOBAL_ROUTE_KEY) - if global_model: - return global_model - - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No enabled model route for chat (CAPABILITY:chat.default or GLOBAL)", - ) - - -def _resolve_model_from_route( - db: Session, - *, - route_type: str, - route_key: str, -) -> ModelRegistry | None: - rows = db.execute( - select(ModelRouteRule, ModelRegistry) - .join(ModelRegistry, ModelRouteRule.target_model_code == ModelRegistry.code) - .where( - ModelRouteRule.route_type == route_type, - ModelRouteRule.route_key == route_key, - ModelRouteRule.enabled.is_(True), - ModelRegistry.status == "ENABLED", - ) - .order_by(ModelRouteRule.priority.asc(), ModelRouteRule.id.asc()) - ).all() - if not rows: - return None - - for _, model in rows: - active_key_exists = db.scalar( - select(ModelApiKey.id).where( - ModelApiKey.model_id == model.id, - ModelApiKey.is_active.is_(True), - ) - ) - if active_key_exists is not None: - return model - return None - - -def _resolve_provider_key(provider: str) -> str: - key = settings.llm_provider_key_map.get(provider.strip().lower()) - if key: - return key - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Missing provider key for {provider}. Configure LLM_PROVIDER_API_KEYS.", - ) - - -def _build_messages( - *, - system_prompt: str, - context_messages: list[tuple[str, str]], - user_message: str, -) -> list[dict[str, str]]: - messages: list[dict[str, str]] = [] - normalized_system_prompt = system_prompt.strip() - if normalized_system_prompt: - messages.append({"role": "system", "content": normalized_system_prompt}) - - for role, content in context_messages: - if role not in {"user", "assistant"}: - continue - normalized_content = content.strip() - if not normalized_content: - continue - messages.append({"role": role, "content": normalized_content}) - - messages.append({"role": "user", "content": user_message.strip()}) - return messages - - -def _build_endpoint(base_url: str | None) -> str: - normalized = (base_url or "").strip().rstrip("/") - if not normalized: - return f"{DEFAULT_OPENAI_BASE_URL}/chat/completions" - if normalized.endswith("/chat/completions"): - return normalized - return f"{normalized}/chat/completions" - - -def _extract_content(body: object) -> str: - if not isinstance(body, dict): - return "" - - choices = body.get("choices") - if not isinstance(choices, list) or not choices: - return "" - first = choices[0] - if not isinstance(first, dict): - return "" - message = first.get("message") - if not isinstance(message, dict): - return "" - - content = message.get("content") - if isinstance(content, str): - return content.strip() - if isinstance(content, list): - texts: list[str] = [] - for item in content: - if isinstance(item, dict): - text = item.get("text") - if isinstance(text, str) and text.strip(): - texts.append(text.strip()) - return "\n".join(texts).strip() - return "" - - -def _extract_http_error_detail(response: httpx.Response) -> str: - try: - payload = response.json() - except json.JSONDecodeError: - return f"HTTP {response.status_code}" - if isinstance(payload, dict): - detail = payload.get("error") - if isinstance(detail, dict): - message = detail.get("message") - if isinstance(message, str) and message.strip(): - return message.strip() - message = payload.get("message") - if isinstance(message, str) and message.strip(): - return message.strip() - detail_field = payload.get("detail") - if isinstance(detail_field, str) and detail_field.strip(): - return detail_field.strip() - return f"HTTP {response.status_code}" - - -def _to_int(value: object) -> int | None: - if isinstance(value, int): - return value - if isinstance(value, str) and value.isdigit(): - return int(value) - return None diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py deleted file mode 100644 index 28bc107..0000000 --- a/api/app/services/model_service.py +++ /dev/null @@ -1,1254 +0,0 @@ -from __future__ import annotations - -import asyncio -import hashlib -import time -from datetime import timedelta -from decimal import Decimal, InvalidOperation - -from fastapi import HTTPException, status -from sqlalchemy import case, func, or_, select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from ..models.base import utcnow -from ..models.model_registry import ( - ModelApiKey, - ModelHealthCheck, - ModelRegistry, - ModelRouteRule, - ModelTestRun, - ModelUsageLog, -) -from ..models.user import User -from ..schemas.model_registry import ( - ModelApiKeyListResponse, - ModelApiKeyPublic, - ModelCreateRequest, - ModelHealthCheckListResponse, - ModelHealthCheckPublic, - ModelListResponse, - ModelRegistryPublic, - ModelRotateKeyRequest, - ModelRouteRuleCreateRequest, - ModelRouteRuleListResponse, - ModelRouteRulePublic, - ModelRouteRuleUpdateRequest, - ModelSummaryResponse, - ModelTestChatRequest, - ModelTestChatResponse, - ModelTestRunListResponse, - ModelTestRunPublic, - ModelTestRunRequest, - ModelTestSummary, - ModelTransitionRequest, - ModelUpdateRequest, - ModelUsageIngestRequest, - ModelUsageSummary, -) -from .llm_gateway import create_reply_with_model -from .push_service import publish_topic - -MODEL_TOPIC = "model.registry" -GLOBAL_ROUTE_KEY = "__global__" -VALID_STATUSES = ("DRAFT", "ENABLED", "DISABLED", "DEPRECATED") -VALID_ROUTE_TYPES = ("GLOBAL", "CAPABILITY", "BUSINESS", "AGENT") -MODEL_STATUS_TRANSITIONS: dict[str, set[str]] = { - "DRAFT": {"ENABLED", "DISABLED", "DEPRECATED"}, - "ENABLED": {"DISABLED", "DEPRECATED"}, - "DISABLED": {"ENABLED", "DEPRECATED"}, - "DEPRECATED": {"DISABLED"}, -} - - -def list_models(db: Session, *, status_filter: str | None, keyword: str | None) -> ModelListResponse: - stmt = select(ModelRegistry) - - if status_filter: - normalized_status = status_filter.strip().upper() - if normalized_status not in VALID_STATUSES: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid status filter: {status_filter}") - stmt = stmt.where(ModelRegistry.status == normalized_status) - - if keyword: - like = f"%{keyword.strip()}%" - stmt = stmt.where( - or_( - ModelRegistry.code.ilike(like), - ModelRegistry.name.ilike(like), - ModelRegistry.provider.ilike(like), - ModelRegistry.provider_model.ilike(like), - ) - ) - - models = db.execute(stmt.order_by(ModelRegistry.updated_at.desc(), ModelRegistry.id.desc())).scalars().all() - metrics = _collect_model_metrics(db, models) - return ModelListResponse( - items=[_serialize_model(model, metrics) for model in models], - total=len(models), - ) - - -def get_model_summary(db: Session) -> ModelSummaryResponse: - status_counts = {status_code: 0 for status_code in VALID_STATUSES} - for row in db.execute(select(ModelRegistry.status, func.count()).group_by(ModelRegistry.status)).all(): - status_counts[str(row[0])] = int(row[1]) - - route_type_counts = {route_type: 0 for route_type in VALID_ROUTE_TYPES} - for row in db.execute(select(ModelRouteRule.route_type, func.count()).group_by(ModelRouteRule.route_type)).all(): - route_type_counts[str(row[0])] = int(row[1]) - - usage_7d = _aggregate_usage_7d( - db, - model_codes=None, - ) - tests_7d = _aggregate_tests_7d( - db, - model_ids=None, - ) - - enabled_models = db.execute( - select(ModelRegistry.id).where(ModelRegistry.status == "ENABLED") - ).scalars().all() - enabled_without_healthy_check = _count_enabled_without_healthy(db, enabled_models) - - total_models = db.scalar(select(func.count()).select_from(ModelRegistry)) or 0 - total_route_rules = db.scalar(select(func.count()).select_from(ModelRouteRule)) or 0 - - return ModelSummaryResponse( - total_models=int(total_models), - status_counts=status_counts, - total_route_rules=int(total_route_rules), - route_type_counts=route_type_counts, - enabled_without_healthy_check=enabled_without_healthy_check, - usage_7d=usage_7d, - tests_7d=tests_7d, - ) - - -def get_model_detail(db: Session, model_id: int) -> ModelRegistryPublic: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - metrics = _collect_model_metrics(db, [model]) - return _serialize_model(model, metrics) - - -def create_model(db: Session, payload: ModelCreateRequest, *, actor: User) -> ModelRegistryPublic: - code = _normalize_code(payload.code) - existing = db.scalar(select(ModelRegistry.id).where(ModelRegistry.code == code)) - if existing is not None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Duplicate model code") - - model = ModelRegistry( - code=code, - name=payload.name.strip(), - provider=payload.provider.strip(), - provider_model=payload.provider_model.strip(), - status=payload.status, - capabilities=_normalize_capabilities(payload.capabilities), - description=payload.description.strip(), - base_url=_normalize_nullable_str(payload.base_url), - ) - - db.add(model) - db.flush() - - if payload.api_key: - _rotate_model_key_internal(db, model=model, raw_key=payload.api_key, actor_user_id=actor.id, note="initial") - - try: - db.commit() - except IntegrityError: - db.rollback() - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Model create failed due to duplicate or invalid data") - - saved = _get_model_by_id(db, model.id) - if not saved: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Model save failed") - - _publish_model_changed("created", model=saved) - return get_model_detail(db, saved.id) - - -def update_model(db: Session, model_id: int, payload: ModelUpdateRequest) -> ModelRegistryPublic: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - update_data = payload.model_dump(exclude_unset=True) - if not update_data: - return get_model_detail(db, model_id) - - if "name" in update_data: - model.name = str(update_data["name"]).strip() - if "provider" in update_data: - model.provider = str(update_data["provider"]).strip() - if "provider_model" in update_data: - model.provider_model = str(update_data["provider_model"]).strip() - if "capabilities" in update_data: - model.capabilities = _normalize_capabilities(update_data["capabilities"]) - if "description" in update_data: - model.description = str(update_data["description"] or "").strip() - if "base_url" in update_data: - model.base_url = _normalize_nullable_str(update_data["base_url"]) - - db.commit() - - saved = _get_model_by_id(db, model_id) - if not saved: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Model load failed") - - _publish_model_changed("updated", model=saved) - return get_model_detail(db, model_id) - - -def transition_model_status( - db: Session, - model_id: int, - payload: ModelTransitionRequest, - *, - actor: User, -) -> ModelRegistryPublic: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - current_status = model.status - target_status = payload.status - if current_status == target_status: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Status is unchanged") - - allowed = MODEL_STATUS_TRANSITIONS.get(current_status, set()) - if target_status not in allowed: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Transition not allowed: {current_status} -> {target_status}", - ) - - if target_status == "ENABLED": - active_key_exists = db.scalar( - select(ModelApiKey.id).where( - ModelApiKey.model_id == model.id, - ModelApiKey.is_active.is_(True), - ) - ) - if active_key_exists is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Active API key is required before enabling model") - - model.status = target_status - db.commit() - - saved = _get_model_by_id(db, model_id) - if not saved: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Model load failed") - - _publish_model_changed( - "status_transitioned", - model=saved, - extra_payload={ - "from_status": current_status, - "to_status": target_status, - "note": _normalize_nullable_str(payload.note), - "actor_user_id": actor.id, - }, - ) - return get_model_detail(db, model_id) - - -def delete_model(db: Session, model_id: int) -> None: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - if model.status == "ENABLED": - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Enabled model cannot be deleted; disable or deprecate first") - - bound_rules = db.execute( - select(ModelRouteRule) - .where(ModelRouteRule.target_model_code == model.code) - .order_by(ModelRouteRule.route_type.asc(), ModelRouteRule.route_key.asc()) - ).scalars().all() - if bound_rules: - refs = [f"{rule.route_type}:{rule.route_key}" for rule in bound_rules[:5]] - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Model is referenced by {len(bound_rules)} route rule(s): {', '.join(refs)}", - ) - - deleted_model_id = model.id - deleted_model_code = model.code - db.delete(model) - db.commit() - - _fire_and_forget( - publish_topic( - MODEL_TOPIC, - name="models.changed", - payload={"action": "deleted", "model_id": deleted_model_id, "model_code": deleted_model_code}, - requires_refetch=[], - dedupe_key=f"models:deleted:{deleted_model_id}", - ) - ) - - -def list_model_keys(db: Session, model_id: int) -> ModelApiKeyListResponse: - _require_model_exists(db, model_id) - keys = db.execute( - select(ModelApiKey) - .where(ModelApiKey.model_id == model_id) - .order_by(ModelApiKey.version.desc(), ModelApiKey.id.desc()) - ).scalars().all() - return ModelApiKeyListResponse(items=[_serialize_key(item) for item in keys], total=len(keys)) - - -def rotate_model_key( - db: Session, - model_id: int, - payload: ModelRotateKeyRequest, - *, - actor: User, -) -> ModelApiKeyPublic: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - rotated = _rotate_model_key_internal( - db, - model=model, - raw_key=payload.api_key, - actor_user_id=actor.id, - note=_normalize_nullable_str(payload.note), - ) - db.commit() - - _publish_model_changed( - "key_rotated", - model=model, - extra_payload={"key_version": rotated.version, "actor_user_id": actor.id}, - ) - return _serialize_key(rotated) - - -def run_model_health_check(db: Session, model_id: int) -> ModelHealthCheckPublic: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - start = time.perf_counter() - active_key = _get_active_key(db, model.id) - route_count = int( - db.scalar( - select(func.count()) - .select_from(ModelRouteRule) - .where( - ModelRouteRule.target_model_code == model.code, - ModelRouteRule.enabled.is_(True), - ) - ) - or 0 - ) - status_value, reason, detail = _evaluate_health(model=model, has_active_key=active_key is not None, route_count=route_count) - latency_ms = int((time.perf_counter() - start) * 1000) - - check = ModelHealthCheck( - model_id=model.id, - status=status_value, - reason=reason, - latency_ms=latency_ms, - detail_json=detail, - ) - db.add(check) - db.commit() - - _publish_model_changed( - "health_checked", - model=model, - extra_payload={"health_status": check.status, "health_reason": check.reason}, - ) - return _serialize_health_check(check) - - -def list_model_health_checks(db: Session, model_id: int, *, limit: int = 20) -> ModelHealthCheckListResponse: - _require_model_exists(db, model_id) - checks = db.execute( - select(ModelHealthCheck) - .where(ModelHealthCheck.model_id == model_id) - .order_by(ModelHealthCheck.id.desc()) - .limit(max(1, min(limit, 100))) - ).scalars().all() - return ModelHealthCheckListResponse( - items=[_serialize_health_check(item) for item in checks], - total=len(checks), - ) - - -def run_model_test( - db: Session, - model_id: int, - payload: ModelTestRunRequest, - *, - actor: User, -) -> ModelTestRunPublic: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - start = time.perf_counter() - active_key = _get_active_key(db, model.id) - - passed = model.status == "ENABLED" and active_key is not None - error_message = None - if not passed: - if model.status != "ENABLED": - error_message = f"Model status is {model.status}; expected ENABLED" - elif active_key is None: - error_message = "No active API key" - else: - error_message = "Unknown test failure" - - latency_ms = int((time.perf_counter() - start) * 1000) - status_value = "PASSED" if passed else "FAILED" - - test_run = ModelTestRun( - model_id=model.id, - kind=payload.kind.strip().upper(), - status=status_value, - input_tokens=payload.input_tokens, - output_tokens=payload.output_tokens, - latency_ms=latency_ms, - error_message=error_message, - created_by_user_id=actor.id, - ) - db.add(test_run) - - total_tokens = payload.input_tokens + payload.output_tokens - usage_log = ModelUsageLog( - model_code=model.code, - source="TEST", - request_count=1, - success_count=1 if passed else 0, - total_tokens=total_tokens, - total_cost_usd=Decimal("0"), - ) - db.add(usage_log) - - db.commit() - - _publish_model_changed( - "tested", - model=model, - extra_payload={ - "test_status": status_value, - "test_id": test_run.id, - "actor_user_id": actor.id, - }, - ) - return _serialize_test_run(test_run, model_code=model.code) - - -def run_model_test_chat( - db: Session, - model_id: int, - payload: ModelTestChatRequest, - *, - actor: User, -) -> ModelTestChatResponse: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - normalized_message = payload.message.strip() - normalized_system_prompt = (payload.system_prompt or "").strip() - if not normalized_message: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="message cannot be empty") - - active_key = _get_active_key(db, model.id) - - reply: str | None = None - prompt_tokens: int | None = None - completion_tokens: int | None = None - total_tokens: int | None = None - latency_ms: int | None = None - error_message: str | None = None - test_status = "FAILED" - - if model.status != "ENABLED": - error_message = f"Model status is {model.status}; expected ENABLED" - elif active_key is None: - error_message = "No active API key" - else: - started = time.perf_counter() - try: - llm_result = create_reply_with_model( - model=model, - user_message=normalized_message, - context_messages=[], - system_prompt=normalized_system_prompt, - ) - reply = llm_result.content - prompt_tokens = llm_result.prompt_tokens - completion_tokens = llm_result.completion_tokens - total_tokens = llm_result.total_tokens - latency_ms = llm_result.latency_ms - test_status = "PASSED" - except HTTPException as exc: - latency_ms = int((time.perf_counter() - started) * 1000) - error_message = str(exc.detail) - except Exception as exc: # pragma: no cover - defensive fallback - latency_ms = int((time.perf_counter() - started) * 1000) - error_message = str(exc) - - if prompt_tokens is None: - prompt_tokens = _estimate_text_tokens(normalized_message + ("\n" + normalized_system_prompt if normalized_system_prompt else "")) - if completion_tokens is None: - completion_tokens = _estimate_text_tokens(reply or "") - if total_tokens is None: - total_tokens = int(prompt_tokens or 0) + int(completion_tokens or 0) - - test_run = ModelTestRun( - model_id=model.id, - kind="CHAT", - status=test_status, - input_tokens=int(prompt_tokens or 0), - output_tokens=int(completion_tokens or 0), - latency_ms=latency_ms, - error_message=error_message, - created_by_user_id=actor.id, - ) - db.add(test_run) - - usage_log = ModelUsageLog( - model_code=model.code, - source="TEST_CHAT", - request_count=1, - success_count=1 if test_status == "PASSED" else 0, - total_tokens=int(total_tokens or 0), - total_cost_usd=Decimal("0"), - ) - db.add(usage_log) - - db.commit() - - _publish_model_changed( - "tested", - model=model, - extra_payload={ - "test_status": test_status, - "test_id": test_run.id, - "test_kind": "CHAT", - "actor_user_id": actor.id, - }, - ) - - return ModelTestChatResponse( - model_id=model.id, - model_code=model.code, - provider=model.provider, - provider_model=model.provider_model, - reply=reply, - latency_ms=latency_ms, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - test_status=test_status, - error_message=error_message, - ) - - -def list_model_tests(db: Session, model_id: int, *, limit: int = 20) -> ModelTestRunListResponse: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - - tests = db.execute( - select(ModelTestRun) - .where(ModelTestRun.model_id == model.id) - .order_by(ModelTestRun.id.desc()) - .limit(max(1, min(limit, 100))) - ).scalars().all() - return ModelTestRunListResponse( - items=[_serialize_test_run(item, model_code=model.code) for item in tests], - total=len(tests), - ) - - -def ingest_model_usage(db: Session, payload: ModelUsageIngestRequest) -> dict[str, bool]: - if payload.success_count > payload.request_count: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="success_count cannot be greater than request_count") - - model_code = _normalize_code(payload.model_code) - exists = db.scalar(select(ModelRegistry.id).where(ModelRegistry.code == model_code)) - if exists is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model code not found") - - total_cost = _to_decimal(payload.total_cost_usd) - log = ModelUsageLog( - model_code=model_code, - source=payload.source.strip().upper(), - request_count=payload.request_count, - success_count=payload.success_count, - total_tokens=payload.total_tokens, - total_cost_usd=total_cost, - ) - db.add(log) - db.commit() - - _fire_and_forget( - publish_topic( - MODEL_TOPIC, - name="models.usage_ingested", - payload={"model_code": model_code, "request_count": payload.request_count}, - requires_refetch=[], - dedupe_key=f"models:usage:{model_code}", - ) - ) - return {"success": True} - - -def list_route_rules(db: Session) -> ModelRouteRuleListResponse: - rules = db.execute( - select(ModelRouteRule) - .order_by(ModelRouteRule.route_type.asc(), ModelRouteRule.priority.asc(), ModelRouteRule.id.asc()) - ).scalars().all() - return ModelRouteRuleListResponse(items=[_serialize_route_rule(item) for item in rules], total=len(rules)) - - -def create_route_rule(db: Session, payload: ModelRouteRuleCreateRequest) -> ModelRouteRulePublic: - route_type = payload.route_type - route_key = _normalize_route_key(route_type, payload.route_key) - target_model_code = _normalize_code(payload.target_model_code) - - _require_target_model_routable(db, target_model_code) - - existing = db.scalar( - select(ModelRouteRule.id).where( - ModelRouteRule.route_type == route_type, - ModelRouteRule.route_key == route_key, - ) - ) - if existing is not None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Route rule already exists") - - rule = ModelRouteRule( - route_type=route_type, - route_key=route_key, - target_model_code=target_model_code, - priority=payload.priority, - enabled=payload.enabled, - note=_normalize_nullable_str(payload.note), - ) - db.add(rule) - - try: - db.commit() - except IntegrityError: - db.rollback() - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Route rule create failed") - - _publish_route_changed("created", rule=rule) - return _serialize_route_rule(rule) - - -def update_route_rule(db: Session, route_rule_id: int, payload: ModelRouteRuleUpdateRequest) -> ModelRouteRulePublic: - rule = db.execute(select(ModelRouteRule).where(ModelRouteRule.id == route_rule_id)).scalar_one_or_none() - if not rule: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Route rule not found") - - update_data = payload.model_dump(exclude_unset=True) - if not update_data: - return _serialize_route_rule(rule) - - next_route_type = update_data.get("route_type", rule.route_type) - next_route_key_raw = update_data.get("route_key", rule.route_key) - next_route_key = _normalize_route_key(next_route_type, next_route_key_raw) - - existing = db.scalar( - select(ModelRouteRule.id).where( - ModelRouteRule.route_type == next_route_type, - ModelRouteRule.route_key == next_route_key, - ModelRouteRule.id != rule.id, - ) - ) - if existing is not None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Route rule already exists") - - rule.route_type = next_route_type - rule.route_key = next_route_key - - if "target_model_code" in update_data: - target_model_code = _normalize_code(str(update_data["target_model_code"])) - _require_target_model_routable(db, target_model_code) - rule.target_model_code = target_model_code - - if "priority" in update_data: - rule.priority = int(update_data["priority"]) - if "enabled" in update_data: - rule.enabled = bool(update_data["enabled"]) - if "note" in update_data: - rule.note = _normalize_nullable_str(update_data["note"]) - - db.commit() - _publish_route_changed("updated", rule=rule) - return _serialize_route_rule(rule) - - -def delete_route_rule(db: Session, route_rule_id: int) -> dict[str, bool]: - rule = db.execute(select(ModelRouteRule).where(ModelRouteRule.id == route_rule_id)).scalar_one_or_none() - if not rule: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Route rule not found") - - deleted_rule_id = rule.id - deleted_route_type = rule.route_type - deleted_route_key = rule.route_key - db.delete(rule) - db.commit() - - _fire_and_forget( - publish_topic( - MODEL_TOPIC, - name="model_routes.changed", - payload={ - "action": "deleted", - "route_rule_id": deleted_rule_id, - "route_type": deleted_route_type, - "route_key": deleted_route_key, - }, - requires_refetch=[], - dedupe_key=f"model_routes:deleted:{deleted_rule_id}", - ) - ) - return {"success": True} - - -def _collect_model_metrics(db: Session, models: list[ModelRegistry]) -> dict[str, dict]: - if not models: - return { - "active_keys": {}, - "latest_health": {}, - "route_counts": {}, - "usage_7d": {}, - "tests_7d": {}, - } - - model_ids = [model.id for model in models] - model_codes = [model.code for model in models] - - active_key_map: dict[int, ModelApiKey] = {} - for key in db.execute( - select(ModelApiKey).where( - ModelApiKey.model_id.in_(model_ids), - ModelApiKey.is_active.is_(True), - ) - ).scalars().all(): - current = active_key_map.get(key.model_id) - if current is None or key.version > current.version: - active_key_map[key.model_id] = key - - latest_health_map: dict[int, ModelHealthCheck] = {} - latest_health_subq = ( - select( - ModelHealthCheck.model_id, - func.max(ModelHealthCheck.id).label("max_id"), - ) - .where(ModelHealthCheck.model_id.in_(model_ids)) - .group_by(ModelHealthCheck.model_id) - .subquery() - ) - for check in db.execute( - select(ModelHealthCheck).join( - latest_health_subq, - ModelHealthCheck.id == latest_health_subq.c.max_id, - ) - ).scalars().all(): - latest_health_map[check.model_id] = check - - route_counts = { - str(row[0]): int(row[1]) - for row in db.execute( - select(ModelRouteRule.target_model_code, func.count()) - .where(ModelRouteRule.target_model_code.in_(model_codes)) - .group_by(ModelRouteRule.target_model_code) - ).all() - } - - usage_7d_map = _aggregate_usage_7d(db, model_codes=model_codes, by_model=True) - tests_7d_map = _aggregate_tests_7d(db, model_ids=model_ids, by_model=True) - - return { - "active_keys": active_key_map, - "latest_health": latest_health_map, - "route_counts": route_counts, - "usage_7d": usage_7d_map, - "tests_7d": tests_7d_map, - } - - -def _serialize_model(model: ModelRegistry, metrics: dict[str, dict]) -> ModelRegistryPublic: - active_key = metrics["active_keys"].get(model.id) - latest_health = metrics["latest_health"].get(model.id) - usage_7d = metrics["usage_7d"].get(model.code, ModelUsageSummary()) - tests_7d = metrics["tests_7d"].get(model.id, ModelTestSummary()) - - return ModelRegistryPublic( - id=model.id, - code=model.code, - name=model.name, - provider=model.provider, - provider_model=model.provider_model, - status=model.status, - capabilities=list(model.capabilities or []), - description=model.description, - base_url=model.base_url, - active_key_masked=active_key.secret_masked if active_key else None, - active_key_version=active_key.version if active_key else None, - active_key_fingerprint=active_key.secret_fingerprint if active_key else None, - active_key_rotated_at=active_key.created_at if active_key else None, - latest_health_status=latest_health.status if latest_health else None, - latest_health_reason=latest_health.reason if latest_health else None, - latest_health_at=latest_health.created_at if latest_health else None, - route_bindings_count=int(metrics["route_counts"].get(model.code, 0)), - usage_7d=usage_7d, - tests_7d=tests_7d, - created_at=model.created_at, - updated_at=model.updated_at, - ) - - -def _aggregate_usage_7d( - db: Session, - *, - model_codes: list[str] | None, - by_model: bool = False, -) -> ModelUsageSummary | dict[str, ModelUsageSummary]: - since = utcnow() - timedelta(days=7) - if by_model: - if not model_codes: - return {} - - rows = db.execute( - select( - ModelUsageLog.model_code, - func.coalesce(func.sum(ModelUsageLog.request_count), 0), - func.coalesce(func.sum(ModelUsageLog.success_count), 0), - func.coalesce(func.sum(ModelUsageLog.total_tokens), 0), - func.coalesce(func.sum(ModelUsageLog.total_cost_usd), Decimal("0")), - ) - .where( - ModelUsageLog.model_code.in_(model_codes), - ModelUsageLog.recorded_at >= since, - ) - .group_by(ModelUsageLog.model_code) - ).all() - - result: dict[str, ModelUsageSummary] = {} - for row in rows: - model_code = str(row[0]) - request_count = int(row[1] or 0) - success_count = int(row[2] or 0) - total_tokens = int(row[3] or 0) - total_cost = float(row[4] or 0) - success_rate = round(success_count / request_count, 4) if request_count > 0 else None - result[model_code] = ModelUsageSummary( - request_count=request_count, - success_count=success_count, - total_tokens=total_tokens, - total_cost_usd=total_cost, - success_rate=success_rate, - ) - return result - - row = db.execute( - select( - func.coalesce(func.sum(ModelUsageLog.request_count), 0), - func.coalesce(func.sum(ModelUsageLog.success_count), 0), - func.coalesce(func.sum(ModelUsageLog.total_tokens), 0), - func.coalesce(func.sum(ModelUsageLog.total_cost_usd), Decimal("0")), - ).where(ModelUsageLog.recorded_at >= since) - ).one() - - request_count = int(row[0] or 0) - success_count = int(row[1] or 0) - total_tokens = int(row[2] or 0) - total_cost = float(row[3] or 0) - return ModelUsageSummary( - request_count=request_count, - success_count=success_count, - total_tokens=total_tokens, - total_cost_usd=total_cost, - success_rate=round(success_count / request_count, 4) if request_count > 0 else None, - ) - - -def _aggregate_tests_7d( - db: Session, - *, - model_ids: list[int] | None, - by_model: bool = False, -) -> ModelTestSummary | dict[int, ModelTestSummary]: - since = utcnow() - timedelta(days=7) - - passed_case = case((ModelTestRun.status == "PASSED", 1), else_=0) - failed_case = case((ModelTestRun.status == "FAILED", 1), else_=0) - - if by_model: - if not model_ids: - return {} - - rows = db.execute( - select( - ModelTestRun.model_id, - func.count(ModelTestRun.id), - func.coalesce(func.sum(passed_case), 0), - func.coalesce(func.sum(failed_case), 0), - ) - .where( - ModelTestRun.model_id.in_(model_ids), - ModelTestRun.created_at >= since, - ) - .group_by(ModelTestRun.model_id) - ).all() - - result: dict[int, ModelTestSummary] = {} - for row in rows: - model_id = int(row[0]) - total_runs = int(row[1] or 0) - passed_runs = int(row[2] or 0) - failed_runs = int(row[3] or 0) - result[model_id] = ModelTestSummary( - total_runs=total_runs, - passed_runs=passed_runs, - failed_runs=failed_runs, - pass_rate=round(passed_runs / total_runs, 4) if total_runs > 0 else None, - ) - return result - - row = db.execute( - select( - func.count(ModelTestRun.id), - func.coalesce(func.sum(passed_case), 0), - func.coalesce(func.sum(failed_case), 0), - ).where(ModelTestRun.created_at >= since) - ).one() - - total_runs = int(row[0] or 0) - passed_runs = int(row[1] or 0) - failed_runs = int(row[2] or 0) - return ModelTestSummary( - total_runs=total_runs, - passed_runs=passed_runs, - failed_runs=failed_runs, - pass_rate=round(passed_runs / total_runs, 4) if total_runs > 0 else None, - ) - - -def _count_enabled_without_healthy(db: Session, enabled_model_ids: list[int]) -> int: - if not enabled_model_ids: - return 0 - - latest_subq = ( - select( - ModelHealthCheck.model_id, - func.max(ModelHealthCheck.id).label("max_id"), - ) - .where(ModelHealthCheck.model_id.in_(enabled_model_ids)) - .group_by(ModelHealthCheck.model_id) - .subquery() - ) - - latest_checks = db.execute( - select(ModelHealthCheck).join(latest_subq, ModelHealthCheck.id == latest_subq.c.max_id) - ).scalars().all() - latest_map = {item.model_id: item for item in latest_checks} - - total = 0 - for model_id in enabled_model_ids: - latest = latest_map.get(model_id) - if not latest or latest.status != "HEALTHY": - total += 1 - return total - - -def _serialize_key(item: ModelApiKey) -> ModelApiKeyPublic: - return ModelApiKeyPublic( - id=item.id, - model_id=item.model_id, - version=item.version, - secret_masked=item.secret_masked, - secret_fingerprint=item.secret_fingerprint, - is_active=item.is_active, - rotation_note=item.rotation_note, - created_by_user_id=item.created_by_user_id, - created_at=item.created_at, - ) - - -def _serialize_health_check(item: ModelHealthCheck) -> ModelHealthCheckPublic: - return ModelHealthCheckPublic( - id=item.id, - model_id=item.model_id, - status=item.status, - reason=item.reason, - latency_ms=item.latency_ms, - detail_json=item.detail_json, - created_at=item.created_at, - ) - - -def _serialize_test_run(item: ModelTestRun, *, model_code: str) -> ModelTestRunPublic: - return ModelTestRunPublic( - id=item.id, - model_id=item.model_id, - model_code=model_code, - kind=item.kind, - status=item.status, - input_tokens=item.input_tokens, - output_tokens=item.output_tokens, - latency_ms=item.latency_ms, - error_message=item.error_message, - created_by_user_id=item.created_by_user_id, - created_at=item.created_at, - ) - - -def _serialize_route_rule(item: ModelRouteRule) -> ModelRouteRulePublic: - return ModelRouteRulePublic( - id=item.id, - route_type=item.route_type, - route_key=item.route_key, - target_model_code=item.target_model_code, - priority=item.priority, - enabled=item.enabled, - note=item.note, - created_at=item.created_at, - updated_at=item.updated_at, - ) - - -def _get_model_by_id(db: Session, model_id: int) -> ModelRegistry | None: - return db.execute(select(ModelRegistry).where(ModelRegistry.id == model_id)).scalar_one_or_none() - - -def _require_model_exists(db: Session, model_id: int) -> ModelRegistry: - model = _get_model_by_id(db, model_id) - if not model: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Model not found") - return model - - -def _require_target_model_routable(db: Session, model_code: str) -> ModelRegistry: - model = db.execute(select(ModelRegistry).where(ModelRegistry.code == model_code)).scalar_one_or_none() - if not model: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Target model code not found: {model_code}") - if model.status == "DEPRECATED": - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Deprecated model cannot be used in route rule") - return model - - -def _normalize_code(value: str) -> str: - return value.strip().lower() - - -def _normalize_capabilities(values: list[str]) -> list[str]: - normalized: list[str] = [] - seen: set[str] = set() - for value in values: - candidate = value.strip().lower() - if not candidate or candidate in seen: - continue - seen.add(candidate) - normalized.append(candidate) - normalized.sort() - return normalized - - -def _normalize_nullable_str(value: str | None) -> str | None: - if value is None: - return None - normalized = value.strip() - return normalized or None - - -def _normalize_route_key(route_type: str, route_key: str | None) -> str: - if route_type == "GLOBAL": - return GLOBAL_ROUTE_KEY - - normalized = (route_key or "").strip() - if not normalized: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"route_key is required for route_type={route_type}") - if normalized == GLOBAL_ROUTE_KEY: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"route_key {GLOBAL_ROUTE_KEY} is reserved") - return normalized - - -def _hash_secret(raw_key: str) -> str: - return hashlib.sha256(raw_key.encode("utf-8")).hexdigest() - - -def _mask_secret(raw_key: str) -> str: - value = raw_key.strip() - if len(value) <= 2: - return "**" - if len(value) <= 8: - return f"{value[0]}***{value[-1]}" - return f"{value[:4]}***{value[-4:]}" - - -def _fingerprint_secret(raw_key: str) -> str: - return _hash_secret(raw_key)[:12] - - -def _rotate_model_key_internal( - db: Session, - *, - model: ModelRegistry, - raw_key: str, - actor_user_id: str, - note: str | None, -) -> ModelApiKey: - value = raw_key.strip() - if not value: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="api_key cannot be empty") - - active_keys = db.execute( - select(ModelApiKey).where( - ModelApiKey.model_id == model.id, - ModelApiKey.is_active.is_(True), - ) - ).scalars().all() - for key in active_keys: - key.is_active = False - - current_max_version = db.scalar( - select(func.max(ModelApiKey.version)).where(ModelApiKey.model_id == model.id) - ) - next_version = int(current_max_version or 0) + 1 - - next_key = ModelApiKey( - model_id=model.id, - version=next_version, - secret_hash=_hash_secret(value), - secret_masked=_mask_secret(value), - secret_fingerprint=_fingerprint_secret(value), - is_active=True, - rotation_note=note, - created_by_user_id=actor_user_id, - ) - db.add(next_key) - db.flush() - return next_key - - -def _get_active_key(db: Session, model_id: int) -> ModelApiKey | None: - return db.execute( - select(ModelApiKey) - .where( - ModelApiKey.model_id == model_id, - ModelApiKey.is_active.is_(True), - ) - .order_by(ModelApiKey.version.desc(), ModelApiKey.id.desc()) - ).scalars().first() - - -def _evaluate_health(*, model: ModelRegistry, has_active_key: bool, route_count: int) -> tuple[str, str, dict[str, object]]: - if model.status != "ENABLED": - return ( - "UNHEALTHY", - f"Model status is {model.status}; expected ENABLED", - { - "model_status": model.status, - "has_active_key": has_active_key, - "route_count": route_count, - }, - ) - if not has_active_key: - return ( - "UNHEALTHY", - "No active API key", - { - "model_status": model.status, - "has_active_key": has_active_key, - "route_count": route_count, - }, - ) - if route_count == 0: - return ( - "DEGRADED", - "No enabled route rule bound to this model", - { - "model_status": model.status, - "has_active_key": has_active_key, - "route_count": route_count, - }, - ) - return ( - "HEALTHY", - "Model is enabled with active key and route bindings", - { - "model_status": model.status, - "has_active_key": has_active_key, - "route_count": route_count, - }, - ) - - -def _estimate_text_tokens(text: str) -> int: - value = text.strip() - if not value: - return 0 - # 粗估:1 token ≈ 4 chars,至少返回 1 - return max(1, (len(value) + 3) // 4) - - -def _to_decimal(value: float) -> Decimal: - try: - return Decimal(str(value)).quantize(Decimal("0.000001")) - except (InvalidOperation, ValueError): - return Decimal("0") - - -def _publish_model_changed(action: str, *, model: ModelRegistry, extra_payload: dict | None = None) -> None: - payload = { - "action": action, - "model_id": model.id, - "model_code": model.code, - "model_status": model.status, - } - if extra_payload: - payload.update(extra_payload) - _fire_and_forget( - publish_topic( - MODEL_TOPIC, - name="models.changed", - payload=payload, - requires_refetch=[], - dedupe_key=f"models:{action}:{model.id}", - ) - ) - - -def _publish_route_changed(action: str, *, rule: ModelRouteRule) -> None: - _fire_and_forget( - publish_topic( - MODEL_TOPIC, - name="model_routes.changed", - payload={ - "action": action, - "route_rule_id": rule.id, - "route_type": rule.route_type, - "route_key": rule.route_key, - "target_model_code": rule.target_model_code, - }, - requires_refetch=[], - dedupe_key=f"model_routes:{action}:{rule.id}", - ) - ) - - -def _fire_and_forget(coro: object) -> None: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return - loop.create_task(coro) diff --git a/api/tests/test_legacy_llm_cleanup.py b/api/tests/test_legacy_llm_cleanup.py new file mode 100644 index 0000000..2d60b33 --- /dev/null +++ b/api/tests/test_legacy_llm_cleanup.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import importlib.util +import os +import unittest + +os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:") +os.environ.setdefault("MINIO_ENABLED", "false") + +from api.app import models # noqa: F401 +from api.app.core.config import Settings +from api.app.core.database import Base +from api.app.services import admin_service, legacy_admin_rbac_service, legacy_authz_service + + +class LegacyLlmCleanupTest(unittest.TestCase): + def test_llm_registry_tables_removed_from_metadata(self) -> None: + removed_tables = { + "llm_models", + "model_route_rules", + "model_api_keys", + "model_health_checks", + "model_test_runs", + "model_usage_logs", + } + + self.assertNotIn("model_registry", models.__all__) + self.assertTrue(removed_tables.isdisjoint(Base.metadata.tables)) + + def test_llm_config_fields_removed(self) -> None: + removed_fields = { + "llm_provider_api_keys", + "llm_request_timeout_seconds", + "chat_context_message_limit", + "chat_default_system_prompt", + } + + self.assertTrue(removed_fields.isdisjoint(Settings.model_fields)) + + def test_legacy_menu_filters_no_longer_reference_chat_and_models(self) -> None: + for codes in ( + admin_service.REMOVED_MENU_CODES, + legacy_admin_rbac_service.REMOVED_MENU_CODES, + legacy_authz_service.DISABLED_MENU_CODES, + ): + self.assertNotIn("admin.chat", codes) + self.assertNotIn("admin.models", codes) + + def test_calendar_service_module_removed(self) -> None: + self.assertIsNone(importlib.util.find_spec("api.app.services.calendar_event_service")) + + +if __name__ == "__main__": + unittest.main()