2026-04-12 00:03:30 +08:00
|
|
|
from collections.abc import Generator
|
2026-04-23 09:41:54 +08:00
|
|
|
import logging
|
|
|
|
|
from typing import Any
|
2026-04-12 00:03:30 +08:00
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine
|
|
|
|
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
|
|
|
|
|
|
|
|
|
from .config import get_settings
|
|
|
|
|
|
|
|
|
|
settings = get_settings()
|
2026-04-23 09:41:54 +08:00
|
|
|
logger = logging.getLogger(__name__)
|
2026-04-12 00:03:30 +08:00
|
|
|
|
2026-04-23 09:41:54 +08:00
|
|
|
database_url = settings.resolved_database_url
|
|
|
|
|
|
|
|
|
|
connect_args: dict[str, Any] = {}
|
|
|
|
|
if database_url.startswith("sqlite"):
|
2026-04-12 00:03:30 +08:00
|
|
|
connect_args["check_same_thread"] = False
|
2026-04-23 09:41:54 +08:00
|
|
|
elif database_url.startswith("postgresql"):
|
|
|
|
|
schema = settings.resolved_db_schema
|
|
|
|
|
if schema:
|
|
|
|
|
connect_args["options"] = f"-csearch_path={schema}"
|
2026-04-12 00:03:30 +08:00
|
|
|
|
|
|
|
|
engine = create_engine(
|
2026-04-23 09:41:54 +08:00
|
|
|
database_url,
|
2026-04-12 00:03:30 +08:00
|
|
|
pool_pre_ping=True,
|
|
|
|
|
connect_args=connect_args,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
SessionLocal = sessionmaker(
|
|
|
|
|
bind=engine,
|
|
|
|
|
autocommit=False,
|
|
|
|
|
autoflush=False,
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
yield db
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_db() -> None:
|
|
|
|
|
# Import models so metadata includes every table before create_all.
|
2026-04-17 21:55:27 +08:00
|
|
|
from ..models import (
|
2026-04-26 09:00:49 +08:00
|
|
|
atp_model,
|
2026-04-17 21:55:27 +08:00
|
|
|
audit_log,
|
|
|
|
|
auth_session,
|
2026-04-23 09:41:54 +08:00
|
|
|
calendar_event,
|
|
|
|
|
diary,
|
2026-04-17 21:55:27 +08:00
|
|
|
file_storage,
|
2026-04-19 07:48:34 +08:00
|
|
|
hot_search,
|
|
|
|
|
life_countdown,
|
2026-04-26 00:14:25 +08:00
|
|
|
lightning_event,
|
|
|
|
|
lightning_sample,
|
|
|
|
|
line,
|
|
|
|
|
line_tower,
|
2026-04-17 21:55:27 +08:00
|
|
|
menu,
|
|
|
|
|
model_registry,
|
2026-04-23 09:41:54 +08:00
|
|
|
object_group,
|
2026-04-19 07:48:34 +08:00
|
|
|
question_bank,
|
2026-04-17 21:55:27 +08:00
|
|
|
rbac,
|
|
|
|
|
requirement,
|
2026-04-19 07:48:34 +08:00
|
|
|
system_param,
|
2026-04-17 21:55:27 +08:00
|
|
|
todo,
|
|
|
|
|
user,
|
2026-04-19 07:48:34 +08:00
|
|
|
vocabulary_word,
|
2026-04-17 21:55:27 +08:00
|
|
|
) # noqa: F401
|
2026-04-12 00:03:30 +08:00
|
|
|
from ..services.seed_service import seed_defaults
|
|
|
|
|
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
with SessionLocal() as db:
|
2026-04-23 09:41:54 +08:00
|
|
|
local_hosts = {"db", "localhost", "127.0.0.1", "::1"}
|
|
|
|
|
database_url = (settings.database_url or "").strip().lower()
|
|
|
|
|
database_url_targets_local = any(
|
|
|
|
|
token in database_url for token in ("@db:", "@localhost:", "@127.0.0.1:", "@[::1]:")
|
|
|
|
|
)
|
|
|
|
|
should_seed_defaults = (
|
|
|
|
|
settings.db_host.strip().lower() in local_hosts
|
|
|
|
|
or database_url_targets_local
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if should_seed_defaults:
|
|
|
|
|
seed_defaults(db)
|
|
|
|
|
else:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Skip seed defaults for non-local database target: host=%s",
|
|
|
|
|
settings.db_host,
|
|
|
|
|
)
|