[fix/feat]:[FL-39][移除 seed 启动自动调用,改为管理员手动接口触发]
Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
@@ -14,6 +14,7 @@ from ...schemas.admin import (
|
||||
RoleListResponse,
|
||||
RoleMenuUpdateRequest,
|
||||
RolePublic,
|
||||
SeedDefaultsResponse,
|
||||
RoleUpdateRequest,
|
||||
)
|
||||
from ...services.admin_service import (
|
||||
@@ -35,6 +36,7 @@ from ...services.legacy_admin_rbac_service import (
|
||||
update_menu,
|
||||
update_role,
|
||||
)
|
||||
from ...services.seed_service import seed_defaults
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
@@ -135,6 +137,16 @@ def get_audit_logs(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/system/seed-defaults", response_model=SeedDefaultsResponse)
|
||||
def seed_defaults_endpoint(
|
||||
force: bool = Query(default=False),
|
||||
_: CurrentUser = Depends(require_permission("menu.manage")),
|
||||
db: Session = Depends(get_db),
|
||||
) -> SeedDefaultsResponse:
|
||||
result = seed_defaults(db, force=force)
|
||||
return SeedDefaultsResponse.model_validate(result.to_response())
|
||||
|
||||
|
||||
@router.get("/menus", response_model=MenuListResponse)
|
||||
def get_menus(
|
||||
_: CurrentUser = Depends(require_any_permission("menu.read", "menu.manage")),
|
||||
|
||||
@@ -403,8 +403,6 @@ def init_db() -> None:
|
||||
user,
|
||||
worker_registry,
|
||||
) # noqa: F401
|
||||
from ..services.seed_service import seed_defaults
|
||||
|
||||
_ensure_user_pk_column_compatibility()
|
||||
_ensure_user_timestamp_column_compatibility()
|
||||
_ensure_user_audit_column_compatibility()
|
||||
@@ -412,21 +410,3 @@ def init_db() -> None:
|
||||
_ensure_tower_model_column_compatibility()
|
||||
_ensure_tower_profile_column_compatibility()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
with SessionLocal() as db:
|
||||
local_hosts = {"db", "localhost", "127.0.0.1", "::1"}
|
||||
database_url = (settings.database_url or "").strip().lower()
|
||||
database_url_targets_local = any(
|
||||
token in database_url for token in ("@db:", "@localhost:", "@127.0.0.1:", "@[::1]:")
|
||||
)
|
||||
should_seed_defaults = (
|
||||
settings.db_host.strip().lower() in local_hosts
|
||||
or database_url_targets_local
|
||||
)
|
||||
|
||||
if should_seed_defaults:
|
||||
seed_defaults(db)
|
||||
else:
|
||||
logger.info(
|
||||
"Skip seed defaults for non-local database target: host=%s",
|
||||
settings.db_host,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -111,4 +112,20 @@ class AuditLogListResponse(BaseModel):
|
||||
offset: int
|
||||
|
||||
|
||||
class SeedCategorySummary(BaseModel):
|
||||
created: int = 0
|
||||
updated: int = 0
|
||||
linked: int = 0
|
||||
unchanged: int = 0
|
||||
overwritten: int = 0
|
||||
|
||||
|
||||
class SeedDefaultsResponse(BaseModel):
|
||||
success: bool
|
||||
force: bool
|
||||
mode: Literal["missing_only", "force_overwrite"]
|
||||
overwrote_existing: bool
|
||||
summary: dict[str, SeedCategorySummary] = Field(default_factory=dict)
|
||||
|
||||
|
||||
MenuTreeItem.model_rebuild()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,6 +14,73 @@ from .tower_model_service import seed_tower_models_from_legacy
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
SEED_RESULT_CATEGORIES = (
|
||||
"permissions",
|
||||
"roles",
|
||||
"role_permissions",
|
||||
"menus",
|
||||
"role_menus",
|
||||
"file_storage_backends",
|
||||
"file_storage_mounts",
|
||||
"admin_users",
|
||||
"legacy_tower_models",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SeedCategorySummary:
|
||||
created: int = 0
|
||||
updated: int = 0
|
||||
linked: int = 0
|
||||
unchanged: int = 0
|
||||
overwritten: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, int]:
|
||||
return {
|
||||
"created": self.created,
|
||||
"updated": self.updated,
|
||||
"linked": self.linked,
|
||||
"unchanged": self.unchanged,
|
||||
"overwritten": self.overwritten,
|
||||
}
|
||||
|
||||
|
||||
def _build_seed_summary() -> dict[str, SeedCategorySummary]:
|
||||
return {
|
||||
category: SeedCategorySummary()
|
||||
for category in SEED_RESULT_CATEGORIES
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SeedDefaultsResult:
|
||||
force: bool
|
||||
summary: dict[str, SeedCategorySummary] = field(default_factory=_build_seed_summary)
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def mode(self) -> str:
|
||||
return "force_overwrite" if self.force else "missing_only"
|
||||
|
||||
@property
|
||||
def overwrote_existing(self) -> bool:
|
||||
return any(item.overwritten > 0 for item in self.summary.values())
|
||||
|
||||
def to_response(self) -> dict[str, object]:
|
||||
return {
|
||||
"success": self.success,
|
||||
"force": self.force,
|
||||
"mode": self.mode,
|
||||
"overwrote_existing": self.overwrote_existing,
|
||||
"summary": {
|
||||
category: item.to_dict()
|
||||
for category, item in self.summary.items()
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_PERMISSIONS: dict[str, str] = {
|
||||
"user.read": "Read user profile",
|
||||
"user.write": "Update user profile",
|
||||
@@ -359,24 +428,55 @@ def _default_file_storage_mounts() -> list[dict[str, object]]:
|
||||
]
|
||||
|
||||
|
||||
def seed_defaults(db: Session) -> None:
|
||||
permissions = _seed_permissions(db)
|
||||
roles = _seed_roles(db, permissions)
|
||||
menus = _seed_menus(db)
|
||||
_seed_role_menus(db, roles, menus)
|
||||
_seed_file_storage(db)
|
||||
_seed_initial_admin(db)
|
||||
db.commit()
|
||||
_seed_legacy_tower_models_if_empty(db)
|
||||
def seed_defaults(db: Session, *, force: bool = False) -> SeedDefaultsResult:
|
||||
result = SeedDefaultsResult(force=force)
|
||||
|
||||
try:
|
||||
permissions = _seed_permissions(db, result=result, force=force)
|
||||
roles = _seed_roles(db, result=result, force=force)
|
||||
_seed_role_permissions(db, roles, permissions, result=result, force=force)
|
||||
menus = _seed_menus(db, result=result, force=force)
|
||||
_seed_role_menus(db, roles, menus, result=result, force=force)
|
||||
_seed_file_storage(db, result=result, force=force)
|
||||
_seed_initial_admin(db, roles, result=result)
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
legacy_seed_result = _seed_legacy_tower_models_if_empty(db)
|
||||
if legacy_seed_result is not None:
|
||||
legacy_summary = result.summary["legacy_tower_models"]
|
||||
legacy_summary.created += legacy_seed_result.imported_models
|
||||
legacy_summary.updated += legacy_seed_result.updated_models
|
||||
legacy_summary.unchanged += legacy_seed_result.skipped_models
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _seed_permissions(db: Session) -> dict[str, Permission]:
|
||||
def _seed_permissions(
|
||||
db: Session,
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
force: bool,
|
||||
) -> dict[str, Permission]:
|
||||
category = result.summary["permissions"]
|
||||
permission_map: dict[str, Permission] = {}
|
||||
for code, name in DEFAULT_PERMISSIONS.items():
|
||||
permission = db.scalar(select(Permission).where(Permission.code == code))
|
||||
if not permission:
|
||||
permission = Permission(code=code, name=name)
|
||||
db.add(permission)
|
||||
category.created += 1
|
||||
elif force:
|
||||
if permission.name != name:
|
||||
permission.name = name
|
||||
category.updated += 1
|
||||
category.overwritten += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
permission_map[code] = permission
|
||||
|
||||
db.flush()
|
||||
@@ -389,23 +489,88 @@ def _seed_permissions(db: Session) -> dict[str, Permission]:
|
||||
return permission_map
|
||||
|
||||
|
||||
def _seed_roles(db: Session, permission_map: dict[str, Permission]) -> dict[str, Role]:
|
||||
def _seed_roles(
|
||||
db: Session,
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
force: bool,
|
||||
) -> dict[str, Role]:
|
||||
category = result.summary["roles"]
|
||||
role_map: dict[str, Role] = {}
|
||||
for code, role_info in DEFAULT_ROLES.items():
|
||||
role_name = str(role_info["name"])
|
||||
role = db.scalar(select(Role).where(Role.code == code))
|
||||
if not role:
|
||||
role = Role(code=code, name=str(role_info["name"]))
|
||||
role = Role(code=code, name=role_name)
|
||||
db.add(role)
|
||||
db.flush()
|
||||
category.created += 1
|
||||
elif force:
|
||||
if role.name != role_name:
|
||||
role.name = role_name
|
||||
category.updated += 1
|
||||
category.overwritten += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
|
||||
role.permissions = [permission_map[p] for p in role_info["permissions"]]
|
||||
role_map[code] = role
|
||||
db.flush()
|
||||
return role_map
|
||||
|
||||
|
||||
def _seed_menus(db: Session) -> dict[str, Menu]:
|
||||
def _seed_role_permissions(
|
||||
db: Session,
|
||||
role_map: dict[str, Role],
|
||||
permission_map: dict[str, Permission],
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
force: bool,
|
||||
) -> None:
|
||||
category = result.summary["role_permissions"]
|
||||
for role_code, role_info in DEFAULT_ROLES.items():
|
||||
role = role_map.get(role_code)
|
||||
if not role:
|
||||
continue
|
||||
|
||||
default_permissions = [permission_map[code] for code in role_info["permissions"]]
|
||||
existing_codes = {permission.code for permission in role.permissions}
|
||||
desired_codes = {permission.code for permission in default_permissions}
|
||||
|
||||
if force:
|
||||
if existing_codes != desired_codes:
|
||||
role.permissions = default_permissions
|
||||
category.updated += 1
|
||||
if existing_codes:
|
||||
category.overwritten += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
continue
|
||||
|
||||
missing_permissions = [
|
||||
permission
|
||||
for permission in default_permissions
|
||||
if permission.code not in existing_codes
|
||||
]
|
||||
if missing_permissions:
|
||||
role.permissions = [*role.permissions, *missing_permissions]
|
||||
category.linked += len(missing_permissions)
|
||||
else:
|
||||
category.unchanged += 1
|
||||
|
||||
db.flush()
|
||||
|
||||
|
||||
def _seed_menus(
|
||||
db: Session,
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
force: bool,
|
||||
) -> dict[str, Menu]:
|
||||
category = result.summary["menus"]
|
||||
menu_map: dict[str, Menu] = {}
|
||||
created_codes: set[str] = set()
|
||||
|
||||
for menu_info in DEFAULT_MENUS:
|
||||
code = str(menu_info["code"])
|
||||
@@ -414,41 +579,128 @@ def _seed_menus(db: Session) -> dict[str, Menu]:
|
||||
menu = Menu(code=code, name=str(menu_info["name"]))
|
||||
db.add(menu)
|
||||
db.flush()
|
||||
created_codes.add(code)
|
||||
category.created += 1
|
||||
menu_map[code] = menu
|
||||
|
||||
for menu_info in DEFAULT_MENUS:
|
||||
code = str(menu_info["code"])
|
||||
parent_code = menu_info["parent_code"]
|
||||
menu = menu_map[code]
|
||||
menu.name = str(menu_info["name"])
|
||||
menu.path = menu_info["path"] if isinstance(menu_info["path"], str) else None
|
||||
menu.icon = menu_info["icon"] if isinstance(menu_info["icon"], str) else None
|
||||
menu.parent_id = menu_map[str(parent_code)].id if parent_code else None
|
||||
menu.type = str(menu_info["type"])
|
||||
menu.sort_order = int(menu_info["sort_order"])
|
||||
menu.status = str(menu_info["status"])
|
||||
menu.visible = bool(menu_info["visible"])
|
||||
menu.cacheable = bool(menu_info["cacheable"])
|
||||
menu.permission_code = (
|
||||
str(menu_info["permission_code"])
|
||||
if menu_info.get("permission_code") is not None
|
||||
else None
|
||||
)
|
||||
if code in created_codes:
|
||||
_apply_menu_defaults(menu, menu_info=menu_info, menu_map=menu_map)
|
||||
continue
|
||||
|
||||
if force:
|
||||
if _apply_menu_defaults(menu, menu_info=menu_info, menu_map=menu_map):
|
||||
category.updated += 1
|
||||
category.overwritten += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
|
||||
db.flush()
|
||||
return menu_map
|
||||
|
||||
|
||||
def _seed_role_menus(db: Session, role_map: dict[str, Role], menu_map: dict[str, Menu]) -> None:
|
||||
def _apply_menu_defaults(
|
||||
menu: Menu,
|
||||
*,
|
||||
menu_info: dict[str, object],
|
||||
menu_map: dict[str, Menu],
|
||||
) -> bool:
|
||||
parent_code = menu_info["parent_code"]
|
||||
desired_parent_id = menu_map[str(parent_code)].id if parent_code else None
|
||||
desired_name = str(menu_info["name"])
|
||||
desired_path = menu_info["path"] if isinstance(menu_info["path"], str) else None
|
||||
desired_icon = menu_info["icon"] if isinstance(menu_info["icon"], str) else None
|
||||
desired_type = str(menu_info["type"])
|
||||
desired_sort_order = int(menu_info["sort_order"])
|
||||
desired_status = str(menu_info["status"])
|
||||
desired_visible = bool(menu_info["visible"])
|
||||
desired_cacheable = bool(menu_info["cacheable"])
|
||||
desired_permission_code = (
|
||||
str(menu_info["permission_code"])
|
||||
if menu_info.get("permission_code") is not None
|
||||
else None
|
||||
)
|
||||
|
||||
changed = (
|
||||
menu.name != desired_name
|
||||
or menu.path != desired_path
|
||||
or menu.icon != desired_icon
|
||||
or menu.parent_id != desired_parent_id
|
||||
or menu.type != desired_type
|
||||
or menu.sort_order != desired_sort_order
|
||||
or menu.status != desired_status
|
||||
or menu.visible is not desired_visible
|
||||
or menu.cacheable is not desired_cacheable
|
||||
or menu.permission_code != desired_permission_code
|
||||
)
|
||||
|
||||
menu.name = desired_name
|
||||
menu.path = desired_path
|
||||
menu.icon = desired_icon
|
||||
menu.parent_id = desired_parent_id
|
||||
menu.type = desired_type
|
||||
menu.sort_order = desired_sort_order
|
||||
menu.status = desired_status
|
||||
menu.visible = desired_visible
|
||||
menu.cacheable = desired_cacheable
|
||||
menu.permission_code = desired_permission_code
|
||||
return changed
|
||||
|
||||
|
||||
def _seed_role_menus(
|
||||
db: Session,
|
||||
role_map: dict[str, Role],
|
||||
menu_map: dict[str, Menu],
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
force: bool,
|
||||
) -> None:
|
||||
category = result.summary["role_menus"]
|
||||
for role_code, menu_codes in ROLE_MENU_BINDINGS.items():
|
||||
role = role_map.get(role_code)
|
||||
if not role:
|
||||
continue
|
||||
role.menus = [menu_map[menu_code] for menu_code in menu_codes if menu_code in menu_map]
|
||||
|
||||
desired_menus = [menu_map[menu_code] for menu_code in menu_codes if menu_code in menu_map]
|
||||
existing_codes = {menu.code for menu in role.menus}
|
||||
desired_codes = {menu.code for menu in desired_menus}
|
||||
|
||||
if force:
|
||||
if existing_codes != desired_codes:
|
||||
role.menus = desired_menus
|
||||
category.updated += 1
|
||||
if existing_codes:
|
||||
category.overwritten += 1
|
||||
else:
|
||||
category.unchanged += 1
|
||||
continue
|
||||
|
||||
missing_menus = [
|
||||
menu
|
||||
for menu in desired_menus
|
||||
if menu.code not in existing_codes
|
||||
]
|
||||
if missing_menus:
|
||||
role.menus = [*role.menus, *missing_menus]
|
||||
category.linked += len(missing_menus)
|
||||
else:
|
||||
category.unchanged += 1
|
||||
|
||||
db.flush()
|
||||
|
||||
|
||||
def _seed_file_storage(db: Session) -> None:
|
||||
def _seed_file_storage(
|
||||
db: Session,
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
force: bool,
|
||||
) -> None:
|
||||
backend_category = result.summary["file_storage_backends"]
|
||||
mount_category = result.summary["file_storage_mounts"]
|
||||
backend_map: dict[str, FileStorageBackend] = {}
|
||||
|
||||
for backend_info in _default_file_storage_backends():
|
||||
@@ -468,12 +720,25 @@ def _seed_file_storage(db: Session) -> None:
|
||||
)
|
||||
db.add(backend)
|
||||
db.flush()
|
||||
backend_category.created += 1
|
||||
else:
|
||||
backend.name = str(backend_info["name"])
|
||||
backend.driver_type = str(backend_info["driver_type"])
|
||||
backend.status = str(backend_info["status"])
|
||||
backend.is_default = bool(backend_info["is_default"])
|
||||
backend.config_json = normalized_config
|
||||
if force:
|
||||
changed = _apply_backend_defaults(
|
||||
backend,
|
||||
name=str(backend_info["name"]),
|
||||
driver_type=str(backend_info["driver_type"]),
|
||||
status=str(backend_info["status"]),
|
||||
is_default=bool(backend_info["is_default"]),
|
||||
config_json=normalized_config,
|
||||
)
|
||||
if changed:
|
||||
backend_category.updated += 1
|
||||
backend_category.overwritten += 1
|
||||
else:
|
||||
backend_category.unchanged += 1
|
||||
else:
|
||||
backend_category.unchanged += 1
|
||||
|
||||
backend_map[code] = backend
|
||||
|
||||
for mount_info in _default_file_storage_mounts():
|
||||
@@ -495,21 +760,86 @@ def _seed_file_storage(db: Session) -> None:
|
||||
)
|
||||
db.add(mount)
|
||||
db.flush()
|
||||
mount_category.created += 1
|
||||
continue
|
||||
|
||||
mount.name = str(mount_info["name"])
|
||||
mount.backend_id = backend.id
|
||||
mount.mount_path = str(mount_info["mount_path"])
|
||||
mount.root_path = str(mount_info["root_path"])
|
||||
if mount_info.get("is_enabled") is not None:
|
||||
mount.is_enabled = bool(mount_info["is_enabled"])
|
||||
if force:
|
||||
changed = _apply_mount_defaults(
|
||||
mount,
|
||||
backend_id=backend.id,
|
||||
name=str(mount_info["name"]),
|
||||
mount_path=str(mount_info["mount_path"]),
|
||||
root_path=str(mount_info["root_path"]),
|
||||
is_enabled=bool(mount_info["is_enabled"]),
|
||||
)
|
||||
if changed:
|
||||
mount_category.updated += 1
|
||||
mount_category.overwritten += 1
|
||||
else:
|
||||
mount_category.unchanged += 1
|
||||
else:
|
||||
mount_category.unchanged += 1
|
||||
|
||||
|
||||
def _seed_initial_admin(db: Session) -> None:
|
||||
def _apply_backend_defaults(
|
||||
backend: FileStorageBackend,
|
||||
*,
|
||||
name: str,
|
||||
driver_type: str,
|
||||
status: str,
|
||||
is_default: bool,
|
||||
config_json: dict[str, object],
|
||||
) -> bool:
|
||||
changed = (
|
||||
backend.name != name
|
||||
or backend.driver_type != driver_type
|
||||
or backend.status != status
|
||||
or backend.is_default is not is_default
|
||||
or backend.config_json != config_json
|
||||
)
|
||||
backend.name = name
|
||||
backend.driver_type = driver_type
|
||||
backend.status = status
|
||||
backend.is_default = is_default
|
||||
backend.config_json = config_json
|
||||
return changed
|
||||
|
||||
|
||||
def _apply_mount_defaults(
|
||||
mount: FileStorageMount,
|
||||
*,
|
||||
backend_id: int,
|
||||
name: str,
|
||||
mount_path: str,
|
||||
root_path: str,
|
||||
is_enabled: bool,
|
||||
) -> bool:
|
||||
changed = (
|
||||
mount.name != name
|
||||
or mount.backend_id != backend_id
|
||||
or mount.mount_path != mount_path
|
||||
or mount.root_path != root_path
|
||||
or mount.is_enabled is not is_enabled
|
||||
)
|
||||
mount.name = name
|
||||
mount.backend_id = backend_id
|
||||
mount.mount_path = mount_path
|
||||
mount.root_path = root_path
|
||||
mount.is_enabled = is_enabled
|
||||
return changed
|
||||
|
||||
|
||||
def _seed_initial_admin(
|
||||
db: Session,
|
||||
role_map: dict[str, Role],
|
||||
*,
|
||||
result: SeedDefaultsResult,
|
||||
) -> None:
|
||||
category = result.summary["admin_users"]
|
||||
if not settings.initial_admin_email or not settings.initial_admin_password:
|
||||
return
|
||||
|
||||
admin_role = db.scalar(select(Role).where(Role.code == "admin"))
|
||||
admin_role = role_map.get("admin")
|
||||
if not admin_role:
|
||||
return
|
||||
|
||||
@@ -519,6 +849,7 @@ def _seed_initial_admin(db: Session) -> None:
|
||||
|
||||
admin_email = settings.initial_admin_email.lower()
|
||||
user = db.scalar(select(User).where((User.id == admin_user_id) | (User.email == admin_email)))
|
||||
created = False
|
||||
if not user:
|
||||
user = User(
|
||||
id=admin_user_id,
|
||||
@@ -529,16 +860,21 @@ def _seed_initial_admin(db: Session) -> None:
|
||||
)
|
||||
db.add(user)
|
||||
db.flush()
|
||||
category.created += 1
|
||||
created = True
|
||||
|
||||
role_codes = {role.code for role in user.roles}
|
||||
if "admin" not in role_codes:
|
||||
user.roles.append(admin_role)
|
||||
category.linked += 1
|
||||
elif not created:
|
||||
category.unchanged += 1
|
||||
|
||||
|
||||
def _seed_legacy_tower_models_if_empty(db: Session) -> None:
|
||||
def _seed_legacy_tower_models_if_empty(db: Session):
|
||||
existing_count = int(db.scalar(select(func.count()).select_from(TowerModel)) or 0)
|
||||
if existing_count > 0:
|
||||
return
|
||||
return None
|
||||
|
||||
actor = db.scalar(select(User).where(User.id == settings.initial_admin_user_id))
|
||||
if actor is None:
|
||||
@@ -546,13 +882,14 @@ def _seed_legacy_tower_models_if_empty(db: Session) -> None:
|
||||
if actor is None:
|
||||
actor = db.scalar(select(User).order_by(User.created_at.asc()))
|
||||
if actor is None:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
seed_tower_models_from_legacy(
|
||||
return seed_tower_models_from_legacy(
|
||||
db,
|
||||
actor=actor,
|
||||
overwrite_existing=False,
|
||||
)
|
||||
except Exception:
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite+pysqlite:///:memory:")
|
||||
os.environ.setdefault("MINIO_ENABLED", "false")
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from api.app import models # noqa: F401
|
||||
from api.app.api.v1.admin import router as admin_router
|
||||
from api.app.core.database import Base, get_db, init_db
|
||||
from api.app.core.dependencies import CurrentUser, get_current_user
|
||||
from api.app.models.menu import Menu
|
||||
from api.app.models.user import User
|
||||
from api.app.services.seed_service import DEFAULT_MENUS, SeedDefaultsResult, seed_defaults
|
||||
|
||||
|
||||
DEFAULT_MENU_BY_CODE = {
|
||||
str(menu["code"]): menu
|
||||
for menu in DEFAULT_MENUS
|
||||
}
|
||||
|
||||
|
||||
class DatabaseFixtureTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.engine = create_engine(
|
||||
"sqlite+pysqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
self.SessionLocal = sessionmaker(
|
||||
bind=self.engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
Base.metadata.drop_all(bind=self.engine)
|
||||
self.engine.dispose()
|
||||
|
||||
|
||||
class InitDbContractTest(unittest.TestCase):
|
||||
def test_init_db_does_not_open_a_seed_session(self) -> None:
|
||||
with patch("api.app.core.database._ensure_user_pk_column_compatibility"), patch(
|
||||
"api.app.core.database._ensure_user_timestamp_column_compatibility"
|
||||
), patch("api.app.core.database._ensure_user_audit_column_compatibility"), patch(
|
||||
"api.app.core.database._ensure_elevation_dataset_column_compatibility"
|
||||
), patch("api.app.core.database._ensure_tower_model_column_compatibility"), patch(
|
||||
"api.app.core.database._ensure_tower_profile_column_compatibility"
|
||||
), patch("api.app.core.database.Base.metadata.create_all") as create_all, patch(
|
||||
"api.app.core.database.SessionLocal"
|
||||
) as session_local:
|
||||
init_db()
|
||||
|
||||
create_all.assert_called_once()
|
||||
session_local.assert_not_called()
|
||||
|
||||
|
||||
class SeedDefaultsServiceTest(DatabaseFixtureTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.session = self.SessionLocal()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.session.close()
|
||||
super().tearDown()
|
||||
|
||||
def _run_seed(self, *, force: bool = False) -> SeedDefaultsResult:
|
||||
with patch(
|
||||
"api.app.services.seed_service._seed_legacy_tower_models_if_empty",
|
||||
return_value=None,
|
||||
):
|
||||
return seed_defaults(self.session, force=force)
|
||||
|
||||
def _load_menu(self, code: str) -> Menu:
|
||||
menu = self.session.scalar(select(Menu).where(Menu.code == code))
|
||||
self.assertIsNotNone(menu)
|
||||
return menu
|
||||
|
||||
def test_seed_defaults_preserves_existing_menu_changes_by_default(self) -> None:
|
||||
self._run_seed()
|
||||
|
||||
menu = self._load_menu("admin.menus")
|
||||
menu.name = "自定义菜单"
|
||||
menu.sort_order = 999
|
||||
self.session.commit()
|
||||
|
||||
result = self._run_seed()
|
||||
self.session.expire_all()
|
||||
|
||||
refreshed = self._load_menu("admin.menus")
|
||||
self.assertEqual(refreshed.name, "自定义菜单")
|
||||
self.assertEqual(refreshed.sort_order, 999)
|
||||
self.assertEqual(result.mode, "missing_only")
|
||||
self.assertFalse(result.overwrote_existing)
|
||||
|
||||
def test_seed_defaults_force_restores_default_menu_fields(self) -> None:
|
||||
self._run_seed()
|
||||
|
||||
menu = self._load_menu("admin.menus")
|
||||
menu.name = "自定义菜单"
|
||||
menu.sort_order = 999
|
||||
self.session.commit()
|
||||
|
||||
result = self._run_seed(force=True)
|
||||
self.session.expire_all()
|
||||
|
||||
refreshed = self._load_menu("admin.menus")
|
||||
default_menu = DEFAULT_MENU_BY_CODE["admin.menus"]
|
||||
self.assertEqual(refreshed.name, str(default_menu["name"]))
|
||||
self.assertEqual(refreshed.sort_order, int(default_menu["sort_order"]))
|
||||
self.assertEqual(result.mode, "force_overwrite")
|
||||
self.assertTrue(result.overwrote_existing)
|
||||
self.assertGreater(result.summary["menus"].overwritten, 0)
|
||||
|
||||
|
||||
class SeedDefaultsEndpointTest(DatabaseFixtureTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.app = FastAPI()
|
||||
self.app.include_router(admin_router, prefix="/api/v1")
|
||||
|
||||
def override_get_db():
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
current_user = CurrentUser(
|
||||
user=User(
|
||||
id="admin",
|
||||
email="admin@example.com",
|
||||
username="admin",
|
||||
password_hash="secret",
|
||||
status="ENABLED",
|
||||
),
|
||||
role_codes={"admin"},
|
||||
permission_codes={"menu.manage"},
|
||||
)
|
||||
self.app.dependency_overrides[get_db] = override_get_db
|
||||
self.app.dependency_overrides[get_current_user] = lambda: current_user
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.client.close()
|
||||
self.app.dependency_overrides.clear()
|
||||
super().tearDown()
|
||||
|
||||
def test_seed_defaults_endpoint_triggers_seed_service(self) -> None:
|
||||
result = SeedDefaultsResult(force=True)
|
||||
result.summary["menus"].updated = 1
|
||||
result.summary["menus"].overwritten = 1
|
||||
|
||||
with patch("api.app.api.v1.admin.seed_defaults", return_value=result) as seed_defaults_mock:
|
||||
response = self.client.post("/api/v1/admin/system/seed-defaults?force=true")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
seed_defaults_mock.assert_called_once()
|
||||
self.assertTrue(seed_defaults_mock.call_args.kwargs["force"])
|
||||
|
||||
payload = response.json()
|
||||
self.assertTrue(payload["success"])
|
||||
self.assertEqual(payload["mode"], "force_overwrite")
|
||||
self.assertTrue(payload["overwrote_existing"])
|
||||
self.assertEqual(payload["summary"]["menus"]["updated"], 1)
|
||||
self.assertEqual(payload["summary"]["menus"]["overwritten"], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user