Files
hfut-bishe/python_server/config.py
2026-04-07 20:55:30 +08:00

224 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
python_server 的统一配置文件。
特点:
- 使用 Python 而不是 YAML方便在代码中集中列举所有可用模型供前端读取。
- 后端加载模型时,也从这里读取默认值,保证单一信息源。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal, TypedDict, List
from model.Depth.zoe_loader import ZoeModelName
# -----------------------------
# 1. 深度模型枚举(给前端展示用)
# -----------------------------
class DepthModelInfo(TypedDict):
id: str # 唯一 ID如 "zoedepth_n"
family: str # 模型家族,如 "ZoeDepth"
name: str # 展示名,如 "ZoeD_N (NYU+KITT)"
description: str # 简短描述
backend: str # 后端类型,如 "zoedepth", "depth_anything_v2", "midas", "dpt"
DEPTH_MODELS: List[DepthModelInfo] = [
# ZoeDepth 系列
{
"id": "zoedepth_n",
"family": "ZoeDepth",
"name": "ZoeD_N",
"description": "ZoeDepth zero-shot 模型,适用室内/室外通用场景。",
"backend": "zoedepth",
},
{
"id": "zoedepth_k",
"family": "ZoeDepth",
"name": "ZoeD_K",
"description": "ZoeDepth Kitti 专用版本,针对户外驾驶场景优化。",
"backend": "zoedepth",
},
{
"id": "zoedepth_nk",
"family": "ZoeDepth",
"name": "ZoeD_NK",
"description": "ZoeDepth 双头版本NYU+KITTI综合室内/室外场景。",
"backend": "zoedepth",
},
# 预留Depth Anything v2
{
"id": "depth_anything_v2_s",
"family": "Depth Anything V2",
"name": "Depth Anything V2 Small",
"description": "轻量级 Depth Anything V2 小模型。",
"backend": "depth_anything_v2",
},
# 预留MiDaS
{
"id": "midas_dpt_large",
"family": "MiDaS",
"name": "MiDaS DPT Large",
"description": "MiDaS DPT-Large 高质量深度模型。",
"backend": "midas",
},
# 预留DPT
{
"id": "dpt_large",
"family": "DPT",
"name": "DPT Large",
"description": "DPT Large 单目深度估计模型。",
"backend": "dpt",
},
]
# -----------------------------
# 1.2 补全模型枚举(给前端展示用)
# -----------------------------
class InpaintModelInfo(TypedDict):
id: str
family: str
name: str
description: str
backend: str # "sdxl_inpaint" | "controlnet"
INPAINT_MODELS: List[InpaintModelInfo] = [
{
"id": "sdxl_inpaint",
"family": "SDXL",
"name": "SDXL Inpainting",
"description": "基于 diffusers 的 SDXL 补全管线(需要 prompt + mask",
"backend": "sdxl_inpaint",
},
{
"id": "controlnet",
"family": "ControlNet",
"name": "ControlNet (placeholder)",
"description": "ControlNet 补全/控制生成(当前统一封装暂未实现)。",
"backend": "controlnet",
},
]
# -----------------------------
# 1.3 动画模型枚举(给前端展示用)
# -----------------------------
class AnimationModelInfo(TypedDict):
id: str
family: str
name: str
description: str
backend: str # "animatediff"
ANIMATION_MODELS: List[AnimationModelInfo] = [
{
"id": "animatediff",
"family": "AnimateDiff",
"name": "AnimateDiff (Text-to-Video)",
"description": "基于 AnimateDiff 的文生动画能力,输出 GIF 动画。",
"backend": "animatediff",
},
]
# -----------------------------
# 2. 后端默认配置(给服务端用)
# -----------------------------
@dataclass
class DepthConfig:
# 深度后端选择:前端不参与选择;只允许在后端配置中切换
backend: Literal["zoedepth", "depth_anything_v2", "dpt", "midas"] = "zoedepth"
# ZoeDepth 家族默认选择
zoe_model: ZoeModelName = "ZoeD_N"
# Depth Anything V2 默认 encoder
da_v2_encoder: Literal["vits", "vitb", "vitl", "vitg"] = "vitl"
# DPT 默认模型类型
dpt_model_type: Literal["dpt_large", "dpt_hybrid"] = "dpt_large"
# MiDaS 默认模型类型
midas_model_type: Literal[
"dpt_beit_large_512",
"dpt_swin2_large_384",
"dpt_swin2_tiny_256",
"dpt_levit_224",
] = "dpt_beit_large_512"
# 统一的默认运行设备
device: str = "cuda"
@dataclass
class InpaintConfig:
# 统一补全默认后端
backend: Literal["sdxl_inpaint", "controlnet"] = "sdxl_inpaint"
# SDXL Inpaint 的基础模型(可写 HuggingFace model id 或本地目录)
sdxl_base_model: str = "stabilityai/stable-diffusion-xl-base-1.0"
# ControlNet Inpaint 基础模型与 controlnet 权重
controlnet_base_model: str = "runwayml/stable-diffusion-inpainting"
controlnet_model: str = "lllyasviel/control_v11p_sd15_inpaint"
device: str = "cuda"
@dataclass
class AnimationConfig:
# 统一动画默认后端
backend: Literal["animatediff"] = "animatediff"
# AnimateDiff 根目录(相对 python_server/ 或绝对路径)
animate_diff_root: str = "model/Animation/AnimateDiff"
# 文生图基础模型HuggingFace model id 或本地目录)
pretrained_model_path: str = "runwayml/stable-diffusion-v1-5"
# AnimateDiff 推理配置
inference_config: str = "configs/inference/inference-v3.yaml"
# 运动模块与个性化底模(为空则由脚本按默认处理)
motion_module: str = "v3_sd15_mm.ckpt"
dreambooth_model: str = "realisticVisionV60B1_v51VAE.safetensors"
lora_model: str = ""
lora_alpha: float = 0.8
# 部分环境 xformers 兼容性差,可手动关闭
without_xformers: bool = False
device: str = "cuda"
@dataclass
class AppConfig:
# 使用 default_factory 避免 dataclass 的可变默认值问题
depth: DepthConfig = field(default_factory=DepthConfig)
inpaint: InpaintConfig = field(default_factory=InpaintConfig)
animation: AnimationConfig = field(default_factory=AnimationConfig)
# 后端代码直接 import DEFAULT_CONFIG 即可
DEFAULT_CONFIG = AppConfig()
def list_depth_models() -> List[DepthModelInfo]:
"""
返回所有可用深度模型的元信息,方便前端通过 /models 等接口读取。
"""
return DEPTH_MODELS
def list_inpaint_models() -> List[InpaintModelInfo]:
"""
返回所有可用补全模型的元信息,方便前端通过 /models 等接口读取。
"""
return INPAINT_MODELS
def list_animation_models() -> List[AnimationModelInfo]:
"""
返回所有可用动画模型的元信息,方便前端通过 /models 等接口读取。
"""
return ANIMATION_MODELS