Files
fquiz/api/app/services/llm_gateway.py
T

253 lines
7.9 KiB
Python

from __future__ import annotations
import json
import time
from dataclasses import dataclass
import httpx
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.orm import Session
from ..core.config import get_settings
from ..models.model_registry import ModelApiKey, ModelRegistry, ModelRouteRule
settings = get_settings()
CHAT_CAPABILITY_ROUTE_KEY = "chat.default"
GLOBAL_ROUTE_KEY = "__global__"
DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
@dataclass
class LlmCompletionResult:
content: str
model_code: str
provider: str
provider_model: str
prompt_tokens: int | None
completion_tokens: int | None
total_tokens: int | None
latency_ms: int
def create_assistant_reply(
db: Session,
*,
user_message: str,
context_messages: list[tuple[str, str]],
system_prompt: str,
) -> LlmCompletionResult:
model = _resolve_chat_model(db)
return create_reply_with_model(
model=model,
user_message=user_message,
context_messages=context_messages,
system_prompt=system_prompt,
)
def create_reply_with_model(
*,
model: ModelRegistry,
user_message: str,
context_messages: list[tuple[str, str]],
system_prompt: str,
) -> LlmCompletionResult:
provider_key = _resolve_provider_key(model.provider)
endpoint = _build_endpoint(model.base_url)
payload = {
"model": model.provider_model,
"messages": _build_messages(
system_prompt=system_prompt,
context_messages=context_messages,
user_message=user_message,
),
}
started = time.perf_counter()
try:
with httpx.Client(timeout=settings.llm_request_timeout_seconds) as client:
response = client.post(
endpoint,
headers={
"Authorization": f"Bearer {provider_key}",
"Content-Type": "application/json",
},
json=payload,
)
except httpx.TimeoutException as exc:
raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="LLM request timeout") from exc
except httpx.HTTPError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM request failed: {exc.__class__.__name__}") from exc
latency_ms = int((time.perf_counter() - started) * 1000)
if response.status_code >= 400:
detail = _extract_http_error_detail(response)
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response error: {detail}")
body = response.json()
content = _extract_content(body)
if not content:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM returned empty content")
usage = body.get("usage") if isinstance(body, dict) else None
prompt_tokens = _to_int(usage.get("prompt_tokens")) if isinstance(usage, dict) else None
completion_tokens = _to_int(usage.get("completion_tokens")) if isinstance(usage, dict) else None
total_tokens = _to_int(usage.get("total_tokens")) if isinstance(usage, dict) else None
return LlmCompletionResult(
content=content,
model_code=model.code,
provider=model.provider,
provider_model=model.provider_model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
latency_ms=latency_ms,
)
def _resolve_chat_model(db: Session) -> ModelRegistry:
capability_model = _resolve_model_from_route(db, route_type="CAPABILITY", route_key=CHAT_CAPABILITY_ROUTE_KEY)
if capability_model:
return capability_model
global_model = _resolve_model_from_route(db, route_type="GLOBAL", route_key=GLOBAL_ROUTE_KEY)
if global_model:
return global_model
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No enabled model route for chat (CAPABILITY:chat.default or GLOBAL)",
)
def _resolve_model_from_route(
db: Session,
*,
route_type: str,
route_key: str,
) -> ModelRegistry | None:
rows = db.execute(
select(ModelRouteRule, ModelRegistry)
.join(ModelRegistry, ModelRouteRule.target_model_code == ModelRegistry.code)
.where(
ModelRouteRule.route_type == route_type,
ModelRouteRule.route_key == route_key,
ModelRouteRule.enabled.is_(True),
ModelRegistry.status == "ENABLED",
)
.order_by(ModelRouteRule.priority.asc(), ModelRouteRule.id.asc())
).all()
if not rows:
return None
for _, model in rows:
active_key_exists = db.scalar(
select(ModelApiKey.id).where(
ModelApiKey.model_id == model.id,
ModelApiKey.is_active.is_(True),
)
)
if active_key_exists is not None:
return model
return None
def _resolve_provider_key(provider: str) -> str:
key = settings.llm_provider_key_map.get(provider.strip().lower())
if key:
return key
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Missing provider key for {provider}. Configure LLM_PROVIDER_API_KEYS.",
)
def _build_messages(
*,
system_prompt: str,
context_messages: list[tuple[str, str]],
user_message: str,
) -> list[dict[str, str]]:
messages: list[dict[str, str]] = []
normalized_system_prompt = system_prompt.strip()
if normalized_system_prompt:
messages.append({"role": "system", "content": normalized_system_prompt})
for role, content in context_messages:
if role not in {"user", "assistant"}:
continue
normalized_content = content.strip()
if not normalized_content:
continue
messages.append({"role": role, "content": normalized_content})
messages.append({"role": "user", "content": user_message.strip()})
return messages
def _build_endpoint(base_url: str | None) -> str:
normalized = (base_url or "").strip().rstrip("/")
if not normalized:
return f"{DEFAULT_OPENAI_BASE_URL}/chat/completions"
if normalized.endswith("/chat/completions"):
return normalized
return f"{normalized}/chat/completions"
def _extract_content(body: object) -> str:
if not isinstance(body, dict):
return ""
choices = body.get("choices")
if not isinstance(choices, list) or not choices:
return ""
first = choices[0]
if not isinstance(first, dict):
return ""
message = first.get("message")
if not isinstance(message, dict):
return ""
content = message.get("content")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
texts: list[str] = []
for item in content:
if isinstance(item, dict):
text = item.get("text")
if isinstance(text, str) and text.strip():
texts.append(text.strip())
return "\n".join(texts).strip()
return ""
def _extract_http_error_detail(response: httpx.Response) -> str:
try:
payload = response.json()
except json.JSONDecodeError:
return f"HTTP {response.status_code}"
if isinstance(payload, dict):
detail = payload.get("error")
if isinstance(detail, dict):
message = detail.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
message = payload.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
detail_field = payload.get("detail")
if isinstance(detail_field, str) and detail_field.strip():
return detail_field.strip()
return f"HTTP {response.status_code}"
def _to_int(value: object) -> int | None:
if isinstance(value, int):
return value
if isinstance(value, str) and value.isdigit():
return int(value)
return None