前端框架修改
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
@@ -6,13 +8,20 @@ from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
from .config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
connect_args: dict[str, bool] = {}
|
||||
if settings.database_url.startswith("sqlite"):
|
||||
database_url = settings.resolved_database_url
|
||||
|
||||
connect_args: dict[str, Any] = {}
|
||||
if database_url.startswith("sqlite"):
|
||||
connect_args["check_same_thread"] = False
|
||||
elif database_url.startswith("postgresql"):
|
||||
schema = settings.resolved_db_schema
|
||||
if schema:
|
||||
connect_args["options"] = f"-csearch_path={schema}"
|
||||
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
database_url,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
@@ -42,12 +51,17 @@ def init_db() -> None:
|
||||
from ..models import (
|
||||
audit_log,
|
||||
auth_session,
|
||||
calendar_event,
|
||||
chat,
|
||||
diary,
|
||||
file_storage,
|
||||
hot_search,
|
||||
life_countdown,
|
||||
menu,
|
||||
mermaid_diagram,
|
||||
mind_map,
|
||||
model_registry,
|
||||
object_group,
|
||||
question_bank,
|
||||
rbac,
|
||||
requirement,
|
||||
@@ -61,4 +75,20 @@ def init_db() -> None:
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
with SessionLocal() as db:
|
||||
seed_defaults(db)
|
||||
local_hosts = {"db", "localhost", "127.0.0.1", "::1"}
|
||||
database_url = (settings.database_url or "").strip().lower()
|
||||
database_url_targets_local = any(
|
||||
token in database_url for token in ("@db:", "@localhost:", "@127.0.0.1:", "@[::1]:")
|
||||
)
|
||||
should_seed_defaults = (
|
||||
settings.db_host.strip().lower() in local_hosts
|
||||
or database_url_targets_local
|
||||
)
|
||||
|
||||
if should_seed_defaults:
|
||||
seed_defaults(db)
|
||||
else:
|
||||
logger.info(
|
||||
"Skip seed defaults for non-local database target: host=%s",
|
||||
settings.db_host,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user