initial commit
This commit is contained in:
223
python_server/config.py
Normal file
223
python_server/config.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
python_server 的统一配置文件。
|
||||
|
||||
特点:
|
||||
- 使用 Python 而不是 YAML,方便在代码中集中列举所有可用模型,供前端读取。
|
||||
- 后端加载模型时,也从这里读取默认值,保证单一信息源。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, TypedDict, List
|
||||
|
||||
from model.Depth.zoe_loader import ZoeModelName
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 1. 深度模型枚举(给前端展示用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class DepthModelInfo(TypedDict):
|
||||
id: str # 唯一 ID,如 "zoedepth_n"
|
||||
family: str # 模型家族,如 "ZoeDepth"
|
||||
name: str # 展示名,如 "ZoeD_N (NYU+KITT)"
|
||||
description: str # 简短描述
|
||||
backend: str # 后端类型,如 "zoedepth", "depth_anything_v2", "midas", "dpt"
|
||||
|
||||
|
||||
DEPTH_MODELS: List[DepthModelInfo] = [
|
||||
# ZoeDepth 系列
|
||||
{
|
||||
"id": "zoedepth_n",
|
||||
"family": "ZoeDepth",
|
||||
"name": "ZoeD_N",
|
||||
"description": "ZoeDepth zero-shot 模型,适用室内/室外通用场景。",
|
||||
"backend": "zoedepth",
|
||||
},
|
||||
{
|
||||
"id": "zoedepth_k",
|
||||
"family": "ZoeDepth",
|
||||
"name": "ZoeD_K",
|
||||
"description": "ZoeDepth Kitti 专用版本,针对户外驾驶场景优化。",
|
||||
"backend": "zoedepth",
|
||||
},
|
||||
{
|
||||
"id": "zoedepth_nk",
|
||||
"family": "ZoeDepth",
|
||||
"name": "ZoeD_NK",
|
||||
"description": "ZoeDepth 双头版本(NYU+KITTI),综合室内/室外场景。",
|
||||
"backend": "zoedepth",
|
||||
},
|
||||
# 预留:Depth Anything v2
|
||||
{
|
||||
"id": "depth_anything_v2_s",
|
||||
"family": "Depth Anything V2",
|
||||
"name": "Depth Anything V2 Small",
|
||||
"description": "轻量级 Depth Anything V2 小模型。",
|
||||
"backend": "depth_anything_v2",
|
||||
},
|
||||
# 预留:MiDaS
|
||||
{
|
||||
"id": "midas_dpt_large",
|
||||
"family": "MiDaS",
|
||||
"name": "MiDaS DPT Large",
|
||||
"description": "MiDaS DPT-Large 高质量深度模型。",
|
||||
"backend": "midas",
|
||||
},
|
||||
# 预留:DPT
|
||||
{
|
||||
"id": "dpt_large",
|
||||
"family": "DPT",
|
||||
"name": "DPT Large",
|
||||
"description": "DPT Large 单目深度估计模型。",
|
||||
"backend": "dpt",
|
||||
},
|
||||
]
|
||||
|
||||
# -----------------------------
|
||||
# 1.2 补全模型枚举(给前端展示用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class InpaintModelInfo(TypedDict):
|
||||
id: str
|
||||
family: str
|
||||
name: str
|
||||
description: str
|
||||
backend: str # "sdxl_inpaint" | "controlnet"
|
||||
|
||||
|
||||
INPAINT_MODELS: List[InpaintModelInfo] = [
|
||||
{
|
||||
"id": "sdxl_inpaint",
|
||||
"family": "SDXL",
|
||||
"name": "SDXL Inpainting",
|
||||
"description": "基于 diffusers 的 SDXL 补全管线(需要 prompt + mask)。",
|
||||
"backend": "sdxl_inpaint",
|
||||
},
|
||||
{
|
||||
"id": "controlnet",
|
||||
"family": "ControlNet",
|
||||
"name": "ControlNet (placeholder)",
|
||||
"description": "ControlNet 补全/控制生成(当前统一封装暂未实现)。",
|
||||
"backend": "controlnet",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 1.3 动画模型枚举(给前端展示用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class AnimationModelInfo(TypedDict):
|
||||
id: str
|
||||
family: str
|
||||
name: str
|
||||
description: str
|
||||
backend: str # "animatediff"
|
||||
|
||||
|
||||
ANIMATION_MODELS: List[AnimationModelInfo] = [
|
||||
{
|
||||
"id": "animatediff",
|
||||
"family": "AnimateDiff",
|
||||
"name": "AnimateDiff (Text-to-Video)",
|
||||
"description": "基于 AnimateDiff 的文生动画能力,输出 GIF 动画。",
|
||||
"backend": "animatediff",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 2. 后端默认配置(给服务端用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthConfig:
|
||||
# 深度后端选择:前端不参与选择;只允许在后端配置中切换
|
||||
backend: Literal["zoedepth", "depth_anything_v2", "dpt", "midas"] = "zoedepth"
|
||||
# ZoeDepth 家族默认选择
|
||||
zoe_model: ZoeModelName = "ZoeD_N"
|
||||
# Depth Anything V2 默认 encoder
|
||||
da_v2_encoder: Literal["vits", "vitb", "vitl", "vitg"] = "vitl"
|
||||
# DPT 默认模型类型
|
||||
dpt_model_type: Literal["dpt_large", "dpt_hybrid"] = "dpt_large"
|
||||
# MiDaS 默认模型类型
|
||||
midas_model_type: Literal[
|
||||
"dpt_beit_large_512",
|
||||
"dpt_swin2_large_384",
|
||||
"dpt_swin2_tiny_256",
|
||||
"dpt_levit_224",
|
||||
] = "dpt_beit_large_512"
|
||||
# 统一的默认运行设备
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InpaintConfig:
|
||||
# 统一补全默认后端
|
||||
backend: Literal["sdxl_inpaint", "controlnet"] = "sdxl_inpaint"
|
||||
# SDXL Inpaint 的基础模型(可写 HuggingFace model id 或本地目录)
|
||||
sdxl_base_model: str = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
# ControlNet Inpaint 基础模型与 controlnet 权重
|
||||
controlnet_base_model: str = "runwayml/stable-diffusion-inpainting"
|
||||
controlnet_model: str = "lllyasviel/control_v11p_sd15_inpaint"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimationConfig:
|
||||
# 统一动画默认后端
|
||||
backend: Literal["animatediff"] = "animatediff"
|
||||
# AnimateDiff 根目录(相对 python_server/ 或绝对路径)
|
||||
animate_diff_root: str = "model/Animation/AnimateDiff"
|
||||
# 文生图基础模型(HuggingFace model id 或本地目录)
|
||||
pretrained_model_path: str = "runwayml/stable-diffusion-v1-5"
|
||||
# AnimateDiff 推理配置
|
||||
inference_config: str = "configs/inference/inference-v3.yaml"
|
||||
# 运动模块与个性化底模(为空则由脚本按默认处理)
|
||||
motion_module: str = "v3_sd15_mm.ckpt"
|
||||
dreambooth_model: str = "realisticVisionV60B1_v51VAE.safetensors"
|
||||
lora_model: str = ""
|
||||
lora_alpha: float = 0.8
|
||||
# 部分环境 xformers 兼容性差,可手动关闭
|
||||
without_xformers: bool = False
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
# 使用 default_factory 避免 dataclass 的可变默认值问题
|
||||
depth: DepthConfig = field(default_factory=DepthConfig)
|
||||
inpaint: InpaintConfig = field(default_factory=InpaintConfig)
|
||||
animation: AnimationConfig = field(default_factory=AnimationConfig)
|
||||
|
||||
|
||||
# 后端代码直接 import DEFAULT_CONFIG 即可
|
||||
DEFAULT_CONFIG = AppConfig()
|
||||
|
||||
|
||||
def list_depth_models() -> List[DepthModelInfo]:
|
||||
"""
|
||||
返回所有可用深度模型的元信息,方便前端通过 /models 等接口读取。
|
||||
"""
|
||||
return DEPTH_MODELS
|
||||
|
||||
|
||||
def list_inpaint_models() -> List[InpaintModelInfo]:
|
||||
"""
|
||||
返回所有可用补全模型的元信息,方便前端通过 /models 等接口读取。
|
||||
"""
|
||||
return INPAINT_MODELS
|
||||
|
||||
|
||||
def list_animation_models() -> List[AnimationModelInfo]:
|
||||
"""
|
||||
返回所有可用动画模型的元信息,方便前端通过 /models 等接口读取。
|
||||
"""
|
||||
return ANIMATION_MODELS
|
||||
|
||||
Reference in New Issue
Block a user