initial commit
This commit is contained in:
1
python_server/.gitignore
vendored
Normal file
1
python_server/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
outputs/
|
||||
1
python_server/__init__.py
Normal file
1
python_server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
223
python_server/config.py
Normal file
223
python_server/config.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
python_server 的统一配置文件。
|
||||
|
||||
特点:
|
||||
- 使用 Python 而不是 YAML,方便在代码中集中列举所有可用模型,供前端读取。
|
||||
- 后端加载模型时,也从这里读取默认值,保证单一信息源。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, TypedDict, List
|
||||
|
||||
from model.Depth.zoe_loader import ZoeModelName
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 1. 深度模型枚举(给前端展示用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class DepthModelInfo(TypedDict):
|
||||
id: str # 唯一 ID,如 "zoedepth_n"
|
||||
family: str # 模型家族,如 "ZoeDepth"
|
||||
name: str # 展示名,如 "ZoeD_N (NYU+KITT)"
|
||||
description: str # 简短描述
|
||||
backend: str # 后端类型,如 "zoedepth", "depth_anything_v2", "midas", "dpt"
|
||||
|
||||
|
||||
DEPTH_MODELS: List[DepthModelInfo] = [
|
||||
# ZoeDepth 系列
|
||||
{
|
||||
"id": "zoedepth_n",
|
||||
"family": "ZoeDepth",
|
||||
"name": "ZoeD_N",
|
||||
"description": "ZoeDepth zero-shot 模型,适用室内/室外通用场景。",
|
||||
"backend": "zoedepth",
|
||||
},
|
||||
{
|
||||
"id": "zoedepth_k",
|
||||
"family": "ZoeDepth",
|
||||
"name": "ZoeD_K",
|
||||
"description": "ZoeDepth Kitti 专用版本,针对户外驾驶场景优化。",
|
||||
"backend": "zoedepth",
|
||||
},
|
||||
{
|
||||
"id": "zoedepth_nk",
|
||||
"family": "ZoeDepth",
|
||||
"name": "ZoeD_NK",
|
||||
"description": "ZoeDepth 双头版本(NYU+KITTI),综合室内/室外场景。",
|
||||
"backend": "zoedepth",
|
||||
},
|
||||
# 预留:Depth Anything v2
|
||||
{
|
||||
"id": "depth_anything_v2_s",
|
||||
"family": "Depth Anything V2",
|
||||
"name": "Depth Anything V2 Small",
|
||||
"description": "轻量级 Depth Anything V2 小模型。",
|
||||
"backend": "depth_anything_v2",
|
||||
},
|
||||
# 预留:MiDaS
|
||||
{
|
||||
"id": "midas_dpt_large",
|
||||
"family": "MiDaS",
|
||||
"name": "MiDaS DPT Large",
|
||||
"description": "MiDaS DPT-Large 高质量深度模型。",
|
||||
"backend": "midas",
|
||||
},
|
||||
# 预留:DPT
|
||||
{
|
||||
"id": "dpt_large",
|
||||
"family": "DPT",
|
||||
"name": "DPT Large",
|
||||
"description": "DPT Large 单目深度估计模型。",
|
||||
"backend": "dpt",
|
||||
},
|
||||
]
|
||||
|
||||
# -----------------------------
|
||||
# 1.2 补全模型枚举(给前端展示用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class InpaintModelInfo(TypedDict):
|
||||
id: str
|
||||
family: str
|
||||
name: str
|
||||
description: str
|
||||
backend: str # "sdxl_inpaint" | "controlnet"
|
||||
|
||||
|
||||
INPAINT_MODELS: List[InpaintModelInfo] = [
|
||||
{
|
||||
"id": "sdxl_inpaint",
|
||||
"family": "SDXL",
|
||||
"name": "SDXL Inpainting",
|
||||
"description": "基于 diffusers 的 SDXL 补全管线(需要 prompt + mask)。",
|
||||
"backend": "sdxl_inpaint",
|
||||
},
|
||||
{
|
||||
"id": "controlnet",
|
||||
"family": "ControlNet",
|
||||
"name": "ControlNet (placeholder)",
|
||||
"description": "ControlNet 补全/控制生成(当前统一封装暂未实现)。",
|
||||
"backend": "controlnet",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 1.3 动画模型枚举(给前端展示用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class AnimationModelInfo(TypedDict):
|
||||
id: str
|
||||
family: str
|
||||
name: str
|
||||
description: str
|
||||
backend: str # "animatediff"
|
||||
|
||||
|
||||
ANIMATION_MODELS: List[AnimationModelInfo] = [
|
||||
{
|
||||
"id": "animatediff",
|
||||
"family": "AnimateDiff",
|
||||
"name": "AnimateDiff (Text-to-Video)",
|
||||
"description": "基于 AnimateDiff 的文生动画能力,输出 GIF 动画。",
|
||||
"backend": "animatediff",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 2. 后端默认配置(给服务端用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthConfig:
|
||||
# 深度后端选择:前端不参与选择;只允许在后端配置中切换
|
||||
backend: Literal["zoedepth", "depth_anything_v2", "dpt", "midas"] = "zoedepth"
|
||||
# ZoeDepth 家族默认选择
|
||||
zoe_model: ZoeModelName = "ZoeD_N"
|
||||
# Depth Anything V2 默认 encoder
|
||||
da_v2_encoder: Literal["vits", "vitb", "vitl", "vitg"] = "vitl"
|
||||
# DPT 默认模型类型
|
||||
dpt_model_type: Literal["dpt_large", "dpt_hybrid"] = "dpt_large"
|
||||
# MiDaS 默认模型类型
|
||||
midas_model_type: Literal[
|
||||
"dpt_beit_large_512",
|
||||
"dpt_swin2_large_384",
|
||||
"dpt_swin2_tiny_256",
|
||||
"dpt_levit_224",
|
||||
] = "dpt_beit_large_512"
|
||||
# 统一的默认运行设备
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InpaintConfig:
|
||||
# 统一补全默认后端
|
||||
backend: Literal["sdxl_inpaint", "controlnet"] = "sdxl_inpaint"
|
||||
# SDXL Inpaint 的基础模型(可写 HuggingFace model id 或本地目录)
|
||||
sdxl_base_model: str = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
# ControlNet Inpaint 基础模型与 controlnet 权重
|
||||
controlnet_base_model: str = "runwayml/stable-diffusion-inpainting"
|
||||
controlnet_model: str = "lllyasviel/control_v11p_sd15_inpaint"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimationConfig:
|
||||
# 统一动画默认后端
|
||||
backend: Literal["animatediff"] = "animatediff"
|
||||
# AnimateDiff 根目录(相对 python_server/ 或绝对路径)
|
||||
animate_diff_root: str = "model/Animation/AnimateDiff"
|
||||
# 文生图基础模型(HuggingFace model id 或本地目录)
|
||||
pretrained_model_path: str = "runwayml/stable-diffusion-v1-5"
|
||||
# AnimateDiff 推理配置
|
||||
inference_config: str = "configs/inference/inference-v3.yaml"
|
||||
# 运动模块与个性化底模(为空则由脚本按默认处理)
|
||||
motion_module: str = "v3_sd15_mm.ckpt"
|
||||
dreambooth_model: str = "realisticVisionV60B1_v51VAE.safetensors"
|
||||
lora_model: str = ""
|
||||
lora_alpha: float = 0.8
|
||||
# 部分环境 xformers 兼容性差,可手动关闭
|
||||
without_xformers: bool = False
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
# 使用 default_factory 避免 dataclass 的可变默认值问题
|
||||
depth: DepthConfig = field(default_factory=DepthConfig)
|
||||
inpaint: InpaintConfig = field(default_factory=InpaintConfig)
|
||||
animation: AnimationConfig = field(default_factory=AnimationConfig)
|
||||
|
||||
|
||||
# 后端代码直接 import DEFAULT_CONFIG 即可
|
||||
DEFAULT_CONFIG = AppConfig()
|
||||
|
||||
|
||||
def list_depth_models() -> List[DepthModelInfo]:
|
||||
"""
|
||||
返回所有可用深度模型的元信息,方便前端通过 /models 等接口读取。
|
||||
"""
|
||||
return DEPTH_MODELS
|
||||
|
||||
|
||||
def list_inpaint_models() -> List[InpaintModelInfo]:
|
||||
"""
|
||||
返回所有可用补全模型的元信息,方便前端通过 /models 等接口读取。
|
||||
"""
|
||||
return INPAINT_MODELS
|
||||
|
||||
|
||||
def list_animation_models() -> List[AnimationModelInfo]:
|
||||
"""
|
||||
返回所有可用动画模型的元信息,方便前端通过 /models 等接口读取。
|
||||
"""
|
||||
return ANIMATION_MODELS
|
||||
|
||||
153
python_server/config_loader.py
Normal file
153
python_server/config_loader.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
兼容层:从 Python 配置模块中构造 zoe_loader 需要的 ZoeConfig。
|
||||
|
||||
后端其它代码尽量只依赖这里的函数,而不直接依赖 config.py 的具体结构,
|
||||
便于以后扩展。
|
||||
"""
|
||||
|
||||
from model.Depth.zoe_loader import ZoeConfig
|
||||
from model.Depth.depth_anything_v2_loader import DepthAnythingV2Config
|
||||
from model.Depth.dpt_loader import DPTConfig
|
||||
from model.Depth.midas_loader import MiDaSConfig
|
||||
from config import AppConfig, DEFAULT_CONFIG
|
||||
|
||||
|
||||
def load_app_config() -> AppConfig:
|
||||
"""
|
||||
当前直接返回 DEFAULT_CONFIG。
|
||||
如未来需要支持多环境 / 覆盖配置,可以在这里加逻辑。
|
||||
"""
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
|
||||
def build_zoe_config_from_app(app_cfg: AppConfig | None = None) -> ZoeConfig:
|
||||
"""
|
||||
将 AppConfig.depth 映射为 ZoeConfig,供 zoe_loader 使用。
|
||||
如果未显式传入 app_cfg,则使用全局 DEFAULT_CONFIG。
|
||||
"""
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
|
||||
return ZoeConfig(
|
||||
model=app_cfg.depth.zoe_model,
|
||||
device=app_cfg.depth.device,
|
||||
)
|
||||
|
||||
|
||||
def build_depth_anything_v2_config_from_app(
|
||||
app_cfg: AppConfig | None = None,
|
||||
) -> DepthAnythingV2Config:
|
||||
"""
|
||||
将 AppConfig.depth 映射为 DepthAnythingV2Config。
|
||||
"""
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
|
||||
return DepthAnythingV2Config(
|
||||
encoder=app_cfg.depth.da_v2_encoder,
|
||||
device=app_cfg.depth.device,
|
||||
)
|
||||
|
||||
|
||||
def build_dpt_config_from_app(app_cfg: AppConfig | None = None) -> DPTConfig:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return DPTConfig(
|
||||
model_type=app_cfg.depth.dpt_model_type,
|
||||
device=app_cfg.depth.device,
|
||||
)
|
||||
|
||||
|
||||
def build_midas_config_from_app(app_cfg: AppConfig | None = None) -> MiDaSConfig:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return MiDaSConfig(
|
||||
model_type=app_cfg.depth.midas_model_type,
|
||||
device=app_cfg.depth.device,
|
||||
)
|
||||
|
||||
|
||||
def get_depth_backend_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.depth.backend
|
||||
|
||||
|
||||
def get_inpaint_backend_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.inpaint.backend
|
||||
|
||||
|
||||
def get_sdxl_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.inpaint.sdxl_base_model
|
||||
|
||||
|
||||
def get_controlnet_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.inpaint.controlnet_base_model
|
||||
|
||||
|
||||
def get_controlnet_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.inpaint.controlnet_model
|
||||
|
||||
|
||||
def get_animation_backend_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.backend
|
||||
|
||||
|
||||
def get_animatediff_root_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.animate_diff_root
|
||||
|
||||
|
||||
def get_animatediff_pretrained_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.pretrained_model_path
|
||||
|
||||
|
||||
def get_animatediff_inference_config_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.inference_config
|
||||
|
||||
|
||||
def get_animatediff_motion_module_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.motion_module
|
||||
|
||||
|
||||
def get_animatediff_dreambooth_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.dreambooth_model
|
||||
|
||||
|
||||
def get_animatediff_lora_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.lora_model
|
||||
|
||||
|
||||
def get_animatediff_lora_alpha_from_app(app_cfg: AppConfig | None = None) -> float:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.lora_alpha
|
||||
|
||||
|
||||
def get_animatediff_without_xformers_from_app(app_cfg: AppConfig | None = None) -> bool:
|
||||
if app_cfg is None:
|
||||
app_cfg = load_app_config()
|
||||
return app_cfg.animation.without_xformers
|
||||
|
||||
|
||||
1
python_server/model/Animation/AnimateDiff
Submodule
1
python_server/model/Animation/AnimateDiff
Submodule
Submodule python_server/model/Animation/AnimateDiff added at e92bd5671b
12
python_server/model/Animation/__init__.py
Normal file
12
python_server/model/Animation/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .animation_loader import (
|
||||
AnimationBackend,
|
||||
UnifiedAnimationConfig,
|
||||
build_animation_predictor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnimationBackend",
|
||||
"UnifiedAnimationConfig",
|
||||
"build_animation_predictor",
|
||||
]
|
||||
|
||||
268
python_server/model/Animation/animation_loader.py
Normal file
268
python_server/model/Animation/animation_loader.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Unified animation model loading entry.
|
||||
|
||||
Current support:
|
||||
- AnimateDiff (script-based invocation)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
get_animatediff_root_from_app,
|
||||
get_animatediff_pretrained_model_from_app,
|
||||
get_animatediff_inference_config_from_app,
|
||||
get_animatediff_motion_module_from_app,
|
||||
get_animatediff_dreambooth_model_from_app,
|
||||
get_animatediff_lora_model_from_app,
|
||||
get_animatediff_lora_alpha_from_app,
|
||||
get_animatediff_without_xformers_from_app,
|
||||
)
|
||||
|
||||
|
||||
class AnimationBackend(str, Enum):
|
||||
ANIMATEDIFF = "animatediff"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedAnimationConfig:
|
||||
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
|
||||
# Optional overrides. If None, values come from app config.
|
||||
animate_diff_root: str | None = None
|
||||
pretrained_model_path: str | None = None
|
||||
inference_config: str | None = None
|
||||
motion_module: str | None = None
|
||||
dreambooth_model: str | None = None
|
||||
lora_model: str | None = None
|
||||
lora_alpha: float | None = None
|
||||
without_xformers: bool | None = None
|
||||
controlnet_path: str | None = None
|
||||
controlnet_config: str | None = None
|
||||
|
||||
|
||||
def _yaml_string(value: str) -> str:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _resolve_root(root_cfg: str) -> Path:
|
||||
root = Path(root_cfg)
|
||||
if not root.is_absolute():
|
||||
root = Path(__file__).resolve().parents[2] / root_cfg
|
||||
return root.resolve()
|
||||
|
||||
|
||||
def _make_animatediff_predictor(
|
||||
cfg: UnifiedAnimationConfig,
|
||||
) -> Callable[..., Path]:
|
||||
app_cfg = load_app_config()
|
||||
|
||||
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
|
||||
script_path = root / "scripts" / "animate.py"
|
||||
samples_dir = root / "samples"
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not script_path.is_file():
|
||||
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
|
||||
|
||||
pretrained_model_path = (
|
||||
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
|
||||
)
|
||||
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
|
||||
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
|
||||
dreambooth_model = cfg.dreambooth_model
|
||||
if dreambooth_model is None:
|
||||
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
|
||||
lora_model = cfg.lora_model
|
||||
if lora_model is None:
|
||||
lora_model = get_animatediff_lora_model_from_app(app_cfg)
|
||||
lora_alpha = cfg.lora_alpha
|
||||
if lora_alpha is None:
|
||||
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
|
||||
without_xformers = cfg.without_xformers
|
||||
if without_xformers is None:
|
||||
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
|
||||
|
||||
def _predict(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 25,
|
||||
guidance_scale: float = 8.0,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
video_length: int = 16,
|
||||
seed: int = -1,
|
||||
control_image_path: str | None = None,
|
||||
output_format: str = "gif",
|
||||
) -> Path:
|
||||
if output_format not in {"gif", "png_sequence"}:
|
||||
raise ValueError("output_format must be 'gif' or 'png_sequence'")
|
||||
prompt_value = prompt.strip()
|
||||
if not prompt_value:
|
||||
raise ValueError("prompt must not be empty")
|
||||
|
||||
negative_prompt_value = negative_prompt or ""
|
||||
|
||||
motion_module_line = (
|
||||
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
|
||||
)
|
||||
dreambooth_line = (
|
||||
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
|
||||
)
|
||||
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
|
||||
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
|
||||
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
|
||||
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
|
||||
control_image_line = ""
|
||||
if control_image_path:
|
||||
control_image = Path(control_image_path).expanduser().resolve()
|
||||
if not control_image.is_file():
|
||||
raise FileNotFoundError(f"control_image_path not found: {control_image}")
|
||||
control_image_line = (
|
||||
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
|
||||
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
|
||||
" controlnet_images:\n"
|
||||
f' - {_yaml_string(str(control_image))}\n'
|
||||
" controlnet_image_indexs:\n"
|
||||
" - 0\n"
|
||||
)
|
||||
|
||||
config_text = (
|
||||
"- prompt:\n"
|
||||
f" - {_yaml_string(prompt_value)}\n"
|
||||
" n_prompt:\n"
|
||||
f" - {_yaml_string(negative_prompt_value)}\n"
|
||||
f" steps: {int(num_inference_steps)}\n"
|
||||
f" guidance_scale: {float(guidance_scale)}\n"
|
||||
f" W: {int(width)}\n"
|
||||
f" H: {int(height)}\n"
|
||||
f" L: {int(video_length)}\n"
|
||||
" seed:\n"
|
||||
f" - {int(seed)}\n"
|
||||
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
|
||||
)
|
||||
|
||||
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
|
||||
cfg_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
prefix="animatediff_cfg_",
|
||||
dir=str(root),
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
cfg_file_path = Path(cfg_file.name)
|
||||
try:
|
||||
cfg_file.write(config_text)
|
||||
cfg_file.flush()
|
||||
cfg_file.close()
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--pretrained-model-path",
|
||||
pretrained_model_path,
|
||||
"--inference-config",
|
||||
inference_config,
|
||||
"--config",
|
||||
str(cfg_file_path),
|
||||
"--L",
|
||||
str(int(video_length)),
|
||||
"--W",
|
||||
str(int(width)),
|
||||
"--H",
|
||||
str(int(height)),
|
||||
]
|
||||
if without_xformers:
|
||||
cmd.append("--without-xformers")
|
||||
if output_format == "png_sequence":
|
||||
cmd.append("--save-png-sequence")
|
||||
|
||||
env = dict(os.environ)
|
||||
existing_pythonpath = env.get("PYTHONPATH", "")
|
||||
root_pythonpath = str(root)
|
||||
env["PYTHONPATH"] = (
|
||||
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
|
||||
)
|
||||
|
||||
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(
|
||||
command,
|
||||
cwd=str(root),
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
proc = _run_once(cmd)
|
||||
except subprocess.CalledProcessError as first_error:
|
||||
stderr_text = first_error.stderr or ""
|
||||
should_retry_without_xformers = (
|
||||
not without_xformers
|
||||
and "--without-xformers" not in cmd
|
||||
and (
|
||||
"memory_efficient_attention" in stderr_text
|
||||
or "AcceleratorError" in stderr_text
|
||||
or "invalid configuration argument" in stderr_text
|
||||
)
|
||||
)
|
||||
if not should_retry_without_xformers:
|
||||
raise
|
||||
|
||||
retry_cmd = [*cmd, "--without-xformers"]
|
||||
proc = _run_once(retry_cmd)
|
||||
_ = proc
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
"AnimateDiff inference failed.\n"
|
||||
f"stdout:\n{e.stdout}\n"
|
||||
f"stderr:\n{e.stderr}"
|
||||
) from e
|
||||
finally:
|
||||
try:
|
||||
cfg_file_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
|
||||
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
|
||||
if not candidates:
|
||||
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
|
||||
if not candidates:
|
||||
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
|
||||
|
||||
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
|
||||
if output_format == "png_sequence":
|
||||
frames_root = latest / "sample_frames"
|
||||
if not frames_root.is_dir():
|
||||
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
|
||||
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if not frame_dirs:
|
||||
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
|
||||
return frame_dirs[0].resolve()
|
||||
return (latest / "sample.gif").resolve()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_animation_predictor(
|
||||
cfg: UnifiedAnimationConfig | None = None,
|
||||
) -> tuple[Callable[..., Path], AnimationBackend]:
|
||||
cfg = cfg or UnifiedAnimationConfig()
|
||||
|
||||
if cfg.backend == AnimationBackend.ANIMATEDIFF:
|
||||
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
|
||||
|
||||
raise ValueError(f"Unsupported animation backend: {cfg.backend}")
|
||||
|
||||
1
python_server/model/Depth/DPT
Submodule
1
python_server/model/Depth/DPT
Submodule
Submodule python_server/model/Depth/DPT added at cd3fe90bb4
1
python_server/model/Depth/Depth-Anything-V2
Submodule
1
python_server/model/Depth/Depth-Anything-V2
Submodule
Submodule python_server/model/Depth/Depth-Anything-V2 added at e5a2732d3e
1
python_server/model/Depth/MiDaS
Submodule
1
python_server/model/Depth/MiDaS
Submodule
Submodule python_server/model/Depth/MiDaS added at 454597711a
1
python_server/model/Depth/ZoeDepth
Submodule
1
python_server/model/Depth/ZoeDepth
Submodule
Submodule python_server/model/Depth/ZoeDepth added at d87f17b2f5
1
python_server/model/Depth/__init__.py
Normal file
1
python_server/model/Depth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
|
||||
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
|
||||
if _DA_REPO_ROOT.is_dir():
|
||||
da_path = str(_DA_REPO_ROOT)
|
||||
if da_path not in sys.path:
|
||||
sys.path.insert(0, da_path)
|
||||
|
||||
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
|
||||
|
||||
|
||||
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthAnythingV2Config:
|
||||
"""
|
||||
Depth Anything V2 模型选择配置。
|
||||
|
||||
encoder: "vits" | "vitb" | "vitl" | "vitg"
|
||||
device: "cuda" | "cpu"
|
||||
input_size: 推理时的输入分辨率(短边),参考官方 demo,默认 518。
|
||||
"""
|
||||
|
||||
encoder: EncoderName = "vitl"
|
||||
device: str = "cuda"
|
||||
input_size: int = 518
|
||||
|
||||
|
||||
_MODEL_CONFIGS = {
|
||||
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
||||
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
||||
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
||||
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
|
||||
}
|
||||
|
||||
|
||||
_DA_V2_WEIGHTS_URLS = {
|
||||
# 官方权重托管在 HuggingFace:
|
||||
# - Small -> vits
|
||||
# - Base -> vitb
|
||||
# - Large -> vitl
|
||||
# - Giant -> vitg
|
||||
# 如需替换为国内镜像,可直接修改这些 URL。
|
||||
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
|
||||
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
|
||||
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
|
||||
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DA_V2_WEIGHTS_URLS.get(encoder)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 encoder='{encoder}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
|
||||
print("\n权重下载完成。")
|
||||
|
||||
|
||||
def load_depth_anything_v2_from_config(
|
||||
cfg: DepthAnythingV2Config,
|
||||
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
|
||||
"""
|
||||
根据配置加载 Depth Anything V2 模型与对应配置。
|
||||
|
||||
说明:
|
||||
- 权重文件路径遵循官方命名约定:
|
||||
checkpoints/depth_anything_v2_{encoder}.pth
|
||||
例如:depth_anything_v2_vitl.pth
|
||||
- 请确保上述权重文件已下载到
|
||||
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
|
||||
"""
|
||||
if cfg.encoder not in _MODEL_CONFIGS:
|
||||
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
|
||||
|
||||
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
|
||||
_download_if_missing(cfg.encoder, ckpt_path)
|
||||
|
||||
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
|
||||
|
||||
state_dict = torch.load(str(ckpt_path), map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if cfg.device.startswith("cuda") and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
model = model.to(device).eval()
|
||||
cfg = DepthAnythingV2Config(
|
||||
encoder=cfg.encoder,
|
||||
device=device,
|
||||
input_size=cfg.input_size,
|
||||
)
|
||||
return model, cfg
|
||||
|
||||
|
||||
def infer_depth_anything_v2(
|
||||
model: DepthAnythingV2,
|
||||
image_bgr: np.ndarray,
|
||||
input_size: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
|
||||
"""
|
||||
depth = model.infer_image(image_bgr, input_size)
|
||||
depth = np.asarray(depth, dtype="float32")
|
||||
return depth
|
||||
|
||||
148
python_server/model/Depth/depth_loader.py
Normal file
148
python_server/model/Depth/depth_loader.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的深度模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- ZoeDepth(三种:ZoeD_N / ZoeD_K / ZoeD_NK)
|
||||
- Depth Anything V2(四种 encoder:vits / vitb / vitl / vitg)
|
||||
|
||||
未来如果要加 MiDaS / DPT,只需要在这里再接一层即可。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
build_zoe_config_from_app,
|
||||
build_depth_anything_v2_config_from_app,
|
||||
build_dpt_config_from_app,
|
||||
build_midas_config_from_app,
|
||||
)
|
||||
from .zoe_loader import load_zoe_from_config
|
||||
from .depth_anything_v2_loader import (
|
||||
load_depth_anything_v2_from_config,
|
||||
infer_depth_anything_v2,
|
||||
)
|
||||
from .dpt_loader import load_dpt_from_config, infer_dpt
|
||||
from .midas_loader import load_midas_from_config, infer_midas
|
||||
|
||||
|
||||
class DepthBackend(str, Enum):
|
||||
"""统一的深度模型后端类型。"""
|
||||
|
||||
ZOEDEPTH = "zoedepth"
|
||||
DEPTH_ANYTHING_V2 = "depth_anything_v2"
|
||||
DPT = "dpt"
|
||||
MIDAS = "midas"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDepthConfig:
|
||||
"""
|
||||
统一深度配置。
|
||||
|
||||
backend: 使用哪个后端
|
||||
device: 强制设备(可选),不填则使用 config.py 中的设置
|
||||
"""
|
||||
|
||||
backend: DepthBackend = DepthBackend.ZOEDEPTH
|
||||
device: str | None = None
|
||||
|
||||
|
||||
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
zoe_cfg = build_zoe_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
zoe_cfg.device = device_override
|
||||
|
||||
model, _ = load_zoe_from_config(zoe_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
|
||||
return np.asarray(depth, dtype="float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
da_cfg.device = device_override
|
||||
|
||||
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
# Depth Anything V2 的 infer_image 接收 BGR uint8
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
|
||||
bgr = rgb[:, :, ::-1]
|
||||
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
dpt_cfg = build_dpt_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
dpt_cfg.device = device_override
|
||||
|
||||
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
midas_cfg = build_midas_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
midas_cfg.device = device_override
|
||||
|
||||
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
||||
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_depth_predictor(
|
||||
cfg: UnifiedDepthConfig | None = None,
|
||||
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
|
||||
"""
|
||||
统一构建深度预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
|
||||
- 实际使用的 backend 类型
|
||||
"""
|
||||
cfg = cfg or UnifiedDepthConfig()
|
||||
|
||||
if cfg.backend == DepthBackend.ZOEDEPTH:
|
||||
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
|
||||
|
||||
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
|
||||
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
|
||||
|
||||
if cfg.backend == DepthBackend.DPT:
|
||||
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
|
||||
|
||||
if cfg.backend == DepthBackend.MIDAS:
|
||||
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
|
||||
|
||||
raise ValueError(f"不支持的深度后端: {cfg.backend}")
|
||||
|
||||
156
python_server/model/Depth/dpt_loader.py
Normal file
156
python_server/model/Depth/dpt_loader.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
|
||||
if _DPT_REPO_ROOT.is_dir():
|
||||
dpt_path = str(_DPT_REPO_ROOT)
|
||||
if dpt_path not in sys.path:
|
||||
sys.path.insert(0, dpt_path)
|
||||
|
||||
from dpt.models import DPTDepthModel # type: ignore[import]
|
||||
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
|
||||
from torchvision.transforms import Compose
|
||||
import cv2
|
||||
|
||||
|
||||
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPTConfig:
|
||||
model_type: DPTModelType = "dpt_large"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_DPT_WEIGHTS_URLS = {
|
||||
# 官方 DPT 模型权重托管在:
|
||||
# https://github.com/isl-org/DPT#models
|
||||
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DPT_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 DPT 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nDPT 权重下载完成。")
|
||||
|
||||
|
||||
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
|
||||
"""
|
||||
加载 DPT 模型与对应的预处理 transform。
|
||||
"""
|
||||
ckpt_name = {
|
||||
"dpt_large": "dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
||||
}[cfg.model_type]
|
||||
|
||||
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
if cfg.model_type == "dpt_large":
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
else:
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model.to(device).eval()
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method="minimal",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
cfg = DPTConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform
|
||||
|
||||
|
||||
def infer_dpt(
|
||||
model: DPTDepthModel,
|
||||
transform: Compose,
|
||||
image_bgr: np.ndarray,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
sample = transform({"image": img})["image"]
|
||||
|
||||
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = model(input_batch)
|
||||
prediction = (
|
||||
torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=img.shape[:2],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
return prediction.astype("float32")
|
||||
|
||||
127
python_server/model/Depth/midas_loader.py
Normal file
127
python_server/model/Depth/midas_loader.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
|
||||
if _MIDAS_REPO_ROOT.is_dir():
|
||||
midas_path = str(_MIDAS_REPO_ROOT)
|
||||
if midas_path not in sys.path:
|
||||
sys.path.insert(0, midas_path)
|
||||
|
||||
from midas.model_loader import load_model, default_models # type: ignore[import]
|
||||
import utils # from MiDaS repo
|
||||
|
||||
|
||||
MiDaSModelType = Literal[
|
||||
"dpt_beit_large_512",
|
||||
"dpt_swin2_large_384",
|
||||
"dpt_swin2_tiny_256",
|
||||
"dpt_levit_224",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiDaSConfig:
|
||||
model_type: MiDaSModelType = "dpt_beit_large_512"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_MIDAS_WEIGHTS_URLS = {
|
||||
# 官方权重参见 MiDaS 仓库 README
|
||||
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
|
||||
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
|
||||
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
|
||||
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _MIDAS_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nMiDaS 权重下载完成。")
|
||||
|
||||
|
||||
def load_midas_from_config(
|
||||
cfg: MiDaSConfig,
|
||||
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
|
||||
"""
|
||||
加载 MiDaS 模型与对应 transform。
|
||||
返回: model, cfg, transform, net_w, net_h
|
||||
"""
|
||||
# default_models 中给了默认权重路径名
|
||||
model_info = default_models[cfg.model_type]
|
||||
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model, transform, net_w, net_h = load_model(
|
||||
device=torch.device(device),
|
||||
model_path=str(ckpt_path),
|
||||
model_type=cfg.model_type,
|
||||
optimize=False,
|
||||
height=None,
|
||||
square=False,
|
||||
)
|
||||
|
||||
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform, net_w, net_h
|
||||
|
||||
|
||||
def infer_midas(
|
||||
model: torch.nn.Module,
|
||||
transform: callable,
|
||||
image_rgb: np.ndarray,
|
||||
net_w: int,
|
||||
net_h: int,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
image = transform({"image": image_rgb})["image"]
|
||||
prediction = utils.process(
|
||||
torch.device(device),
|
||||
model,
|
||||
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
|
||||
image=image,
|
||||
input_size=(net_w, net_h),
|
||||
target_size=image_rgb.shape[1::-1],
|
||||
optimize=False,
|
||||
use_camera=False,
|
||||
)
|
||||
return np.asarray(prediction, dtype="float32").squeeze()
|
||||
|
||||
74
python_server/model/Depth/zoe_loader.py
Normal file
74
python_server/model/Depth/zoe_loader.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
|
||||
# 这样其内部的 `import zoedepth...` 才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
|
||||
if _ZOE_REPO_ROOT.is_dir():
|
||||
zoe_path = str(_ZOE_REPO_ROOT)
|
||||
if zoe_path not in sys.path:
|
||||
sys.path.insert(0, zoe_path)
|
||||
|
||||
from zoedepth.models.builder import build_model
|
||||
from zoedepth.utils.config import get_config
|
||||
|
||||
|
||||
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZoeConfig:
|
||||
"""
|
||||
ZoeDepth 模型选择配置。
|
||||
|
||||
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model: ZoeModelName = "ZoeD_N"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
|
||||
"""
|
||||
手动加载 ZoeDepth 三种模型之一:
|
||||
- "ZoeD_N"
|
||||
- "ZoeD_K"
|
||||
- "ZoeD_NK"
|
||||
"""
|
||||
if name == "ZoeD_N":
|
||||
conf = get_config("zoedepth", "infer")
|
||||
elif name == "ZoeD_K":
|
||||
conf = get_config("zoedepth", "infer", config_version="kitti")
|
||||
elif name == "ZoeD_NK":
|
||||
conf = get_config("zoedepth_nk", "infer")
|
||||
else:
|
||||
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
|
||||
|
||||
model = build_model(conf)
|
||||
|
||||
if device.startswith("cuda") and torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
else:
|
||||
model = model.to("cpu")
|
||||
|
||||
model.eval()
|
||||
return model, conf
|
||||
|
||||
|
||||
def load_zoe_from_config(config: ZoeConfig):
|
||||
"""
|
||||
根据 ZoeConfig 加载模型。
|
||||
|
||||
示例:
|
||||
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
|
||||
model, conf = load_zoe_from_config(cfg)
|
||||
"""
|
||||
return load_zoe_from_name(config.model, config.device)
|
||||
|
||||
1
python_server/model/Inpaint/ControlNet
Submodule
1
python_server/model/Inpaint/ControlNet
Submodule
Submodule python_server/model/Inpaint/ControlNet added at ed85cd1e25
1
python_server/model/Inpaint/__init__.py
Normal file
1
python_server/model/Inpaint/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的补全(Inpaint)模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- SDXL Inpaint(diffusers AutoPipelineForInpainting)
|
||||
- ControlNet(占位,暂未统一封装)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
get_sdxl_base_model_from_app,
|
||||
get_controlnet_base_model_from_app,
|
||||
get_controlnet_model_from_app,
|
||||
)
|
||||
|
||||
|
||||
class InpaintBackend(str, Enum):
|
||||
SDXL_INPAINT = "sdxl_inpaint"
|
||||
CONTROLNET = "controlnet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedInpaintConfig:
|
||||
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
|
||||
device: str | None = None
|
||||
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
|
||||
sdxl_base_model: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDrawConfig:
|
||||
"""
|
||||
统一绘图配置:
|
||||
- 纯文生图:image=None
|
||||
- 图生图(模仿输入图):image=某张参考图
|
||||
"""
|
||||
device: str | None = None
|
||||
sdxl_base_model: str | None = None
|
||||
|
||||
|
||||
def _resolve_device_and_dtype(device: str | None):
|
||||
import torch
|
||||
|
||||
app_cfg = load_app_config()
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def _enable_memory_opts(pipe, device: str) -> None:
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
|
||||
run_w, run_h = orig_w, orig_h
|
||||
if max_side > 0 and max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
return run_w, run_h
|
||||
|
||||
|
||||
def _make_sdxl_inpaint_predictor(
|
||||
cfg: UnifiedInpaintConfig,
|
||||
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
|
||||
"""
|
||||
返回补全函数:
|
||||
- 输入:image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
|
||||
- 输出:PIL RGB 结果图
|
||||
"""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
|
||||
|
||||
app_cfg = load_app_config()
|
||||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||
|
||||
device = cfg.device
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16" if device == "cuda" else None,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
|
||||
|
||||
# 省显存设置(尽量不改变输出语义)
|
||||
# 注意:CPU offload 会明显变慢,但能显著降低显存占用。
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _predict(
|
||||
image: Image.Image,
|
||||
mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.8,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
max_side: int = 1024,
|
||||
) -> Image.Image:
|
||||
image = image.convert("RGB")
|
||||
# diffusers 要求 mask 为单通道,白色区域为需要重绘
|
||||
mask = mask.convert("L")
|
||||
|
||||
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM,
|
||||
# 推理时将图像按比例缩放到不超过 max_side(默认 1024)并对齐到 8 的倍数。
|
||||
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = orig_w, orig_h
|
||||
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
if run_w <= 0:
|
||||
run_w = 8
|
||||
if run_h <= 0:
|
||||
run_h = 8
|
||||
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||
else:
|
||||
image_run = image
|
||||
mask_run = mask
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
mask_image=mask_run,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
out = out.convert("RGB")
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
|
||||
import torch
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
|
||||
app_cfg = load_app_config()
|
||||
|
||||
device = _.device
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
base_model = get_controlnet_base_model_from_app(app_cfg)
|
||||
controlnet_id = get_controlnet_model_from_app(app_cfg)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
|
||||
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pipe.to(device)
|
||||
else:
|
||||
pipe.to(device)
|
||||
|
||||
def _predict(
|
||||
image: Image.Image,
|
||||
mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.8,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
max_side: int = 768,
|
||||
) -> Image.Image:
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
image = image.convert("RGB")
|
||||
mask = mask.convert("L")
|
||||
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = orig_w, orig_h
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||
else:
|
||||
image_run = image
|
||||
mask_run = mask
|
||||
|
||||
# control image:使用 canny 边缘作为约束(最通用)
|
||||
rgb = np.array(image_run, dtype=np.uint8)
|
||||
edges = cv2.Canny(rgb, 100, 200)
|
||||
edges3 = np.stack([edges, edges, edges], axis=-1)
|
||||
control_image = Image.fromarray(edges3)
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
mask_image=mask_run,
|
||||
control_image=control_image,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
|
||||
out = out.convert("RGB")
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_inpaint_predictor(
|
||||
cfg: UnifiedInpaintConfig | None = None,
|
||||
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
|
||||
"""
|
||||
统一构建补全预测函数。
|
||||
"""
|
||||
cfg = cfg or UnifiedInpaintConfig()
|
||||
|
||||
if cfg.backend == InpaintBackend.SDXL_INPAINT:
|
||||
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
|
||||
|
||||
if cfg.backend == InpaintBackend.CONTROLNET:
|
||||
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
|
||||
|
||||
raise ValueError(f"不支持的补全后端: {cfg.backend}")
|
||||
|
||||
|
||||
def build_draw_predictor(
|
||||
cfg: UnifiedDrawConfig | None = None,
|
||||
) -> Callable[..., Image.Image]:
|
||||
"""
|
||||
构建统一绘图函数:
|
||||
- 文生图:draw(prompt, image=None, ...)
|
||||
- 图生图:draw(prompt, image=ref_image, strength=0.55, ...)
|
||||
"""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
|
||||
cfg = cfg or UnifiedDrawConfig()
|
||||
app_cfg = load_app_config()
|
||||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
|
||||
|
||||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16" if device == "cuda" else None,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
|
||||
|
||||
_enable_memory_opts(pipe_t2i, device)
|
||||
_enable_memory_opts(pipe_i2i, device)
|
||||
|
||||
def _draw(
|
||||
prompt: str,
|
||||
image: Image.Image | None = None,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.55,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
max_side: int = 1024,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if image is None:
|
||||
run_w, run_h = _align_size(width, height, max_side=max_side)
|
||||
out = pipe_t2i(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
return out.convert("RGB")
|
||||
|
||||
image = image.convert("RGB")
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
else:
|
||||
image_run = image
|
||||
|
||||
out = pipe_i2i(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0].convert("RGB")
|
||||
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _draw
|
||||
|
||||
1
python_server/model/Inpaint/sdxl-inpaint
Submodule
1
python_server/model/Inpaint/sdxl-inpaint
Submodule
Submodule python_server/model/Inpaint/sdxl-inpaint added at 29867f540b
1
python_server/model/Seg/Mask2Former
Submodule
1
python_server/model/Seg/Mask2Former
Submodule
Submodule python_server/model/Seg/Mask2Former added at 9b0651c6c1
1
python_server/model/Seg/__init__.py
Normal file
1
python_server/model/Seg/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
59
python_server/model/Seg/mask2former_loader.py
Normal file
59
python_server/model/Seg/mask2former_loader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mask2FormerHFConfig:
|
||||
"""
|
||||
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
|
||||
|
||||
model_id: HuggingFace 模型 id(默认 ADE20K semantic)
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def build_mask2former_hf_predictor(
|
||||
cfg: Mask2FormerHFConfig | None = None,
|
||||
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
|
||||
"""
|
||||
返回 predictor(image_rgb_uint8) -> label_map(int32)。
|
||||
"""
|
||||
cfg = cfg or Mask2FormerHFConfig()
|
||||
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
|
||||
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
|
||||
model.to(device).eval()
|
||||
|
||||
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb_u8 = image_rgb.astype("uint8")
|
||||
else:
|
||||
image_rgb_u8 = image_rgb
|
||||
|
||||
pil = Image.fromarray(image_rgb_u8, mode="RGB")
|
||||
inputs = processor(images=pil, return_tensors="pt").to(device)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# post-process to original size
|
||||
target_sizes = [(pil.height, pil.width)]
|
||||
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
|
||||
return seg.detach().to("cpu").numpy().astype("int32")
|
||||
|
||||
return _predict, cfg
|
||||
|
||||
168
python_server/model/Seg/seg_loader.py
Normal file
168
python_server/model/Seg/seg_loader.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的分割模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- SAM (segment-anything)
|
||||
- Mask2Former(使用 HuggingFace transformers 的语义分割实现)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class SegBackend(str, Enum):
|
||||
SAM = "sam"
|
||||
MASK2FORMER = "mask2former"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedSegConfig:
|
||||
backend: SegBackend = SegBackend.SAM
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# SAM (Segment Anything)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _ensure_sam_on_path() -> Path:
|
||||
sam_root = _THIS_DIR / "segment-anything"
|
||||
if not sam_root.is_dir():
|
||||
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
|
||||
sam_path = str(sam_root)
|
||||
if sam_path not in sys.path:
|
||||
sys.path.insert(0, sam_path)
|
||||
return sam_root
|
||||
|
||||
|
||||
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
|
||||
import requests
|
||||
|
||||
ckpt_dir = sam_root / "checkpoints"
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
|
||||
|
||||
if ckpt_path.is_file():
|
||||
return ckpt_path
|
||||
|
||||
url = (
|
||||
"https://dl.fbaipublicfiles.com/segment_anything/"
|
||||
"sam_vit_h_4b8939.pth"
|
||||
)
|
||||
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print(
|
||||
"\r[{}{}] {:.1f}%".format(
|
||||
"#" * done,
|
||||
"." * (50 - done),
|
||||
downloaded * 100 / total,
|
||||
),
|
||||
end="",
|
||||
)
|
||||
print("\nSAM 权重下载完成。")
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
返回一个分割函数:
|
||||
- 输入:RGB uint8 图像 (H, W, 3)
|
||||
- 输出:语义标签图 (H, W),每个目标一个 int id(从 1 开始)
|
||||
"""
|
||||
sam_root = _ensure_sam_on_path()
|
||||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||
|
||||
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
sam = sam_model_registry["vit_h"](
|
||||
checkpoint=str(ckpt_path),
|
||||
).to(device)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb_u8 = image_rgb.astype("uint8")
|
||||
else:
|
||||
image_rgb_u8 = image_rgb
|
||||
|
||||
masks = mask_generator.generate(image_rgb_u8)
|
||||
h, w, _ = image_rgb_u8.shape
|
||||
label_map = np.zeros((h, w), dtype="int32")
|
||||
|
||||
for idx, m in enumerate(masks, start=1):
|
||||
seg = m.get("segmentation")
|
||||
if seg is None:
|
||||
continue
|
||||
label_map[seg.astype(bool)] = idx
|
||||
|
||||
return label_map
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Mask2Former (占位)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||
from .mask2former_loader import build_mask2former_hf_predictor
|
||||
|
||||
predictor, _ = build_mask2former_hf_predictor()
|
||||
return predictor
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 统一构建函数
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def build_seg_predictor(
|
||||
cfg: UnifiedSegConfig | None = None,
|
||||
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
|
||||
"""
|
||||
统一构建分割预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
|
||||
- 实际使用的 backend
|
||||
"""
|
||||
cfg = cfg or UnifiedSegConfig()
|
||||
|
||||
if cfg.backend == SegBackend.SAM:
|
||||
return _make_sam_predictor(), SegBackend.SAM
|
||||
|
||||
if cfg.backend == SegBackend.MASK2FORMER:
|
||||
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
|
||||
|
||||
raise ValueError(f"不支持的分割后端: {cfg.backend}")
|
||||
|
||||
1
python_server/model/Seg/segment-anything
Submodule
1
python_server/model/Seg/segment-anything
Submodule
Submodule python_server/model/Seg/segment-anything added at dca509fe79
1
python_server/model/__init__.py
Normal file
1
python_server/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
407
python_server/server.py
Normal file
407
python_server/server.py
Normal file
@@ -0,0 +1,407 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from config_loader import load_app_config, get_depth_backend_from_app
|
||||
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
|
||||
|
||||
from model.Seg.seg_loader import UnifiedSegConfig, SegBackend, build_seg_predictor
|
||||
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
|
||||
from model.Animation.animation_loader import (
|
||||
UnifiedAnimationConfig,
|
||||
AnimationBackend,
|
||||
build_animation_predictor,
|
||||
)
|
||||
|
||||
|
||||
APP_ROOT = Path(__file__).resolve().parent
|
||||
OUTPUT_DIR = APP_ROOT / "outputs"
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app = FastAPI(title="HFUT Model Server", version="0.1.0")
|
||||
|
||||
|
||||
class ImageInput(BaseModel):
|
||||
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64(不含 data: 前缀)")
|
||||
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||||
|
||||
|
||||
class DepthRequest(ImageInput):
|
||||
pass
|
||||
|
||||
|
||||
class SegmentRequest(ImageInput):
|
||||
pass
|
||||
|
||||
|
||||
class InpaintRequest(ImageInput):
|
||||
prompt: Optional[str] = Field("", description="补全 prompt")
|
||||
strength: float = Field(0.8, ge=0.0, le=1.0)
|
||||
negative_prompt: Optional[str] = Field("", description="负向 prompt")
|
||||
# 可选 mask(白色区域为重绘)
|
||||
mask_b64: Optional[str] = Field(None, description="mask PNG base64(可选)")
|
||||
# 推理缩放上限(避免 OOM)
|
||||
max_side: int = Field(1024, ge=128, le=2048)
|
||||
|
||||
|
||||
class AnimateRequest(BaseModel):
|
||||
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||||
prompt: str = Field(..., description="文本提示词")
|
||||
negative_prompt: Optional[str] = Field("", description="负向提示词")
|
||||
num_inference_steps: int = Field(25, ge=1, le=200)
|
||||
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
|
||||
width: int = Field(512, ge=128, le=2048)
|
||||
height: int = Field(512, ge=128, le=2048)
|
||||
video_length: int = Field(16, ge=1, le=128)
|
||||
seed: int = Field(-1, description="-1 表示随机种子")
|
||||
|
||||
|
||||
def _b64_to_pil_image(b64: str) -> Image.Image:
|
||||
raw = base64.b64decode(b64)
|
||||
return Image.open(io.BytesIO(raw))
|
||||
|
||||
|
||||
def _pil_image_to_png_b64(img: Image.Image) -> str:
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def _depth_to_png16_b64(depth: np.ndarray) -> str:
|
||||
depth = np.asarray(depth, dtype=np.float32)
|
||||
dmin = float(depth.min())
|
||||
dmax = float(depth.max())
|
||||
if dmax > dmin:
|
||||
norm = (depth - dmin) / (dmax - dmin)
|
||||
else:
|
||||
norm = np.zeros_like(depth, dtype=np.float32)
|
||||
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
|
||||
img = Image.fromarray(u16, mode="I;16")
|
||||
return _pil_image_to_png_b64(img)
|
||||
|
||||
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
|
||||
depth = np.asarray(depth, dtype=np.float32)
|
||||
dmin = float(depth.min())
|
||||
dmax = float(depth.max())
|
||||
if dmax > dmin:
|
||||
norm = (depth - dmin) / (dmax - dmin)
|
||||
else:
|
||||
norm = np.zeros_like(depth, dtype=np.float32)
|
||||
# 前后端约定:最远=255,最近=0(8-bit)
|
||||
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
|
||||
img = Image.fromarray(u8, mode="L")
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _default_half_mask(img: Image.Image) -> Image.Image:
|
||||
w, h = img.size
|
||||
mask = Image.new("L", (w, h), 0)
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle([w // 2, 0, w, h], fill=255)
|
||||
return mask
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# /models(给前端/GUI 使用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
def get_models() -> Dict[str, Any]:
|
||||
"""
|
||||
返回一个兼容 Qt 前端的 schema:
|
||||
{
|
||||
"models": {
|
||||
"depth": { "key": { "name": "..."} ... },
|
||||
"segment": { ... },
|
||||
"inpaint": { "key": { "name": "...", "params": [...] } ... }
|
||||
}
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"models": {
|
||||
"depth": {
|
||||
# 兼容旧配置默认值
|
||||
"midas": {"name": "MiDaS (default)"},
|
||||
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
|
||||
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
|
||||
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
|
||||
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
|
||||
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
|
||||
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
|
||||
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
|
||||
"dpt_large": {"name": "DPT (large)"},
|
||||
"dpt_hybrid": {"name": "DPT (hybrid)"},
|
||||
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
|
||||
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
|
||||
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
|
||||
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
|
||||
},
|
||||
"segment": {
|
||||
"sam": {"name": "SAM (vit_h)"},
|
||||
# 兼容旧配置默认值
|
||||
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
|
||||
"mask2former": {"name": "Mask2Former (not implemented)"},
|
||||
},
|
||||
"inpaint": {
|
||||
# 兼容旧配置默认值:copy 表示不做补全
|
||||
"copy": {"name": "Copy (no-op)", "params": []},
|
||||
"sdxl_inpaint": {
|
||||
"name": "SDXL Inpaint",
|
||||
"params": [
|
||||
{"id": "prompt", "label": "提示词", "optional": True},
|
||||
],
|
||||
},
|
||||
"controlnet": {
|
||||
"name": "ControlNet Inpaint (canny)",
|
||||
"params": [
|
||||
{"id": "prompt", "label": "提示词", "optional": True},
|
||||
],
|
||||
},
|
||||
},
|
||||
"animation": {
|
||||
"animatediff": {
|
||||
"name": "AnimateDiff (Text-to-Video)",
|
||||
"params": [
|
||||
{"id": "prompt", "label": "提示词", "optional": False},
|
||||
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
|
||||
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
|
||||
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
|
||||
{"id": "width", "label": "宽度", "optional": True},
|
||||
{"id": "height", "label": "高度", "optional": True},
|
||||
{"id": "video_length", "label": "帧数", "optional": True},
|
||||
{"id": "seed", "label": "随机种子", "optional": True},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Depth
|
||||
# -----------------------------
|
||||
|
||||
|
||||
_depth_predictor = None
|
||||
_depth_backend: DepthBackend | None = None
|
||||
|
||||
|
||||
def _ensure_depth_predictor() -> None:
|
||||
global _depth_predictor, _depth_backend
|
||||
if _depth_predictor is not None and _depth_backend is not None:
|
||||
return
|
||||
|
||||
app_cfg = load_app_config()
|
||||
backend_str = get_depth_backend_from_app(app_cfg)
|
||||
try:
|
||||
backend = DepthBackend(backend_str)
|
||||
except Exception as e:
|
||||
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
|
||||
|
||||
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
|
||||
|
||||
|
||||
@app.post("/depth")
|
||||
def depth(req: DepthRequest):
|
||||
"""
|
||||
计算深度并直接返回二进制 PNG(16-bit 灰度)。
|
||||
|
||||
约束:
|
||||
- 前端不传/不选模型;模型选择写死在后端 config.py
|
||||
- 成功:HTTP 200 + Content-Type: image/png
|
||||
- 失败:HTTP 500,detail 为错误信息
|
||||
"""
|
||||
try:
|
||||
_ensure_depth_predictor()
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
depth_arr = _depth_predictor(pil) # type: ignore[misc]
|
||||
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
|
||||
return Response(content=png_bytes, media_type="image/png")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Segment
|
||||
# -----------------------------
|
||||
|
||||
|
||||
_seg_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_seg_predictor(model_name: str):
|
||||
if model_name in _seg_cache:
|
||||
return _seg_cache[model_name]
|
||||
|
||||
# 兼容旧默认 key
|
||||
if model_name == "mask2former_debug":
|
||||
model_name = "sam"
|
||||
|
||||
if model_name == "sam":
|
||||
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
|
||||
_seg_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
if model_name == "mask2former":
|
||||
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
|
||||
_seg_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
raise ValueError(f"未知 segment model_name: {model_name}")
|
||||
|
||||
|
||||
@app.post("/segment")
|
||||
def segment(req: SegmentRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
model_name = req.model_name or "sam"
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
rgb = np.array(pil, dtype=np.uint8)
|
||||
|
||||
predictor = _get_seg_predictor(model_name)
|
||||
label_map = predictor(rgb).astype(np.int32)
|
||||
|
||||
out_dir = OUTPUT_DIR / "segment"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"{model_name}_label.png"
|
||||
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
|
||||
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
|
||||
|
||||
return {"success": True, "label_path": str(out_path)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Inpaint
|
||||
# -----------------------------
|
||||
|
||||
|
||||
_inpaint_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_inpaint_predictor(model_name: str):
|
||||
if model_name in _inpaint_cache:
|
||||
return _inpaint_cache[model_name]
|
||||
|
||||
if model_name == "copy":
|
||||
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
|
||||
return image.convert("RGB")
|
||||
_inpaint_cache[model_name] = _copy
|
||||
return _copy
|
||||
|
||||
if model_name == "sdxl_inpaint":
|
||||
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
|
||||
_inpaint_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
if model_name == "controlnet":
|
||||
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
|
||||
_inpaint_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
raise ValueError(f"未知 inpaint model_name: {model_name}")
|
||||
|
||||
|
||||
@app.post("/inpaint")
|
||||
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
model_name = req.model_name or "sdxl_inpaint"
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
|
||||
if req.mask_b64:
|
||||
mask = _b64_to_pil_image(req.mask_b64).convert("L")
|
||||
else:
|
||||
mask = _default_half_mask(pil)
|
||||
|
||||
predictor = _get_inpaint_predictor(model_name)
|
||||
out = predictor(
|
||||
pil,
|
||||
mask,
|
||||
req.prompt or "",
|
||||
req.negative_prompt or "",
|
||||
strength=req.strength,
|
||||
max_side=req.max_side,
|
||||
)
|
||||
|
||||
out_dir = OUTPUT_DIR / "inpaint"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"{model_name}_inpaint.png"
|
||||
out.save(out_path)
|
||||
|
||||
return {"success": True, "output_path": str(out_path)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
_animation_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_animation_predictor(model_name: str):
|
||||
if model_name in _animation_cache:
|
||||
return _animation_cache[model_name]
|
||||
|
||||
if model_name == "animatediff":
|
||||
pred, _ = build_animation_predictor(
|
||||
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
|
||||
)
|
||||
_animation_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
raise ValueError(f"未知 animation model_name: {model_name}")
|
||||
|
||||
|
||||
@app.post("/animate")
|
||||
def animate(req: AnimateRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
model_name = req.model_name or "animatediff"
|
||||
predictor = _get_animation_predictor(model_name)
|
||||
|
||||
result_path = predictor(
|
||||
prompt=req.prompt,
|
||||
negative_prompt=req.negative_prompt or "",
|
||||
num_inference_steps=req.num_inference_steps,
|
||||
guidance_scale=req.guidance_scale,
|
||||
width=req.width,
|
||||
height=req.height,
|
||||
video_length=req.video_length,
|
||||
seed=req.seed,
|
||||
)
|
||||
|
||||
out_dir = OUTPUT_DIR / "animation"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = out_dir / f"{model_name}_{ts}.gif"
|
||||
shutil.copy2(result_path, out_path)
|
||||
|
||||
return {"success": True, "output_path": str(out_path)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> Dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
77
python_server/test_animation.py
Normal file
77
python_server/test_animation.py
Normal file
@@ -0,0 +1,77 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
from model.Animation.animation_loader import (
|
||||
build_animation_predictor,
|
||||
UnifiedAnimationConfig,
|
||||
AnimationBackend,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 配置区(按需修改)
|
||||
# -----------------------------
|
||||
OUTPUT_DIR = "outputs/test_animation"
|
||||
ANIMATION_BACKEND = AnimationBackend.ANIMATEDIFF
|
||||
OUTPUT_FORMAT = "png_sequence" # "gif" | "png_sequence"
|
||||
|
||||
PROMPT = "a cinematic mountain landscape, camera slowly pans left"
|
||||
NEGATIVE_PROMPT = "blurry, low quality"
|
||||
NUM_INFERENCE_STEPS = 25
|
||||
GUIDANCE_SCALE = 8.0
|
||||
WIDTH = 512
|
||||
HEIGHT = 512
|
||||
VIDEO_LENGTH = 16
|
||||
SEED = -1
|
||||
CONTROL_IMAGE_PATH = "path/to/your_image.png"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
out_dir = base_dir / OUTPUT_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
predictor, used_backend = build_animation_predictor(
|
||||
UnifiedAnimationConfig(backend=ANIMATION_BACKEND)
|
||||
)
|
||||
|
||||
if CONTROL_IMAGE_PATH.strip() in {"", "path/to/your_image.png"}:
|
||||
raise ValueError("请先设置 CONTROL_IMAGE_PATH 为你的输入图片路径(png/jpg)。")
|
||||
|
||||
control_image = (base_dir / CONTROL_IMAGE_PATH).resolve()
|
||||
if not control_image.is_file():
|
||||
raise FileNotFoundError(f"control image not found: {control_image}")
|
||||
|
||||
result_path = predictor(
|
||||
prompt=PROMPT,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
num_inference_steps=NUM_INFERENCE_STEPS,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
video_length=VIDEO_LENGTH,
|
||||
seed=SEED,
|
||||
control_image_path=str(control_image),
|
||||
output_format=OUTPUT_FORMAT,
|
||||
)
|
||||
|
||||
source = Path(result_path)
|
||||
if OUTPUT_FORMAT == "png_sequence":
|
||||
out_seq_dir = out_dir / f"{used_backend.value}_frames"
|
||||
if out_seq_dir.exists():
|
||||
shutil.rmtree(out_seq_dir)
|
||||
shutil.copytree(source, out_seq_dir)
|
||||
print(f"[Animation] backend={used_backend.value}, saved={out_seq_dir}")
|
||||
return
|
||||
|
||||
out_path = out_dir / f"{used_backend.value}.gif"
|
||||
out_path.write_bytes(source.read_bytes())
|
||||
print(f"[Animation] backend={used_backend.value}, saved={out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
101
python_server/test_depth.py
Normal file
101
python_server/test_depth.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
# 解除大图限制
|
||||
Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
from model.Depth.depth_loader import build_depth_predictor, UnifiedDepthConfig, DepthBackend
|
||||
|
||||
# ================= 配置区 =================
|
||||
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
|
||||
OUTPUT_DIR = "outputs/test_depth_v4"
|
||||
DEPTH_BACKEND = DepthBackend.DEPTH_ANYTHING_V2
|
||||
|
||||
# 边缘捕捉参数
|
||||
# 增大这个值会让边缘更细,减小会让边缘更粗(捕捉更多微弱信息)
|
||||
EDGE_TOP_PERCENTILE = 96.0 # 选取梯度最强的前 7.0% 的像素
|
||||
# 局部增强的灵敏度,建议在 2.0 - 5.0 之间
|
||||
CLAHE_CLIP_LIMIT = 3.0
|
||||
# 形态学核大小
|
||||
MORPH_SIZE = 5
|
||||
# =========================================
|
||||
|
||||
def _extract_robust_edges(depth_norm: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
通过局部增强和 Sobel 梯度提取闭合边缘
|
||||
"""
|
||||
# 1. 转换为 8 位灰度
|
||||
depth_u8 = (depth_norm * 255).astype(np.uint8)
|
||||
|
||||
# 2. 【核心步骤】CLAHE 局部自适应对比度增强
|
||||
# 这会强行放大古画中细微的建筑/人物深度差
|
||||
clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=(16, 16))
|
||||
enhanced_depth = clahe.apply(depth_u8)
|
||||
|
||||
# 3. 高斯模糊:减少数字化噪声
|
||||
blurred = cv2.GaussianBlur(enhanced_depth, (5, 5), 0)
|
||||
|
||||
# 4. Sobel 算子计算梯度强度
|
||||
grad_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
|
||||
grad_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)
|
||||
grad_mag = np.sqrt(grad_x**2 + grad_y**2)
|
||||
|
||||
# 5. 【统计学阈值】不再死守固定数值,而是选 Top X%
|
||||
threshold = np.percentile(grad_mag, EDGE_TOP_PERCENTILE)
|
||||
binary_edges = (grad_mag >= threshold).astype(np.uint8) * 255
|
||||
|
||||
# 6. 形态学闭合:桥接裂缝,让线条连起来
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (MORPH_SIZE, MORPH_SIZE))
|
||||
closed = cv2.morphologyEx(binary_edges, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
# 7. 再次轻微膨胀,为 SAM 提供更好的引导范围
|
||||
final_mask = cv2.dilate(closed, kernel, iterations=1)
|
||||
|
||||
return final_mask
|
||||
|
||||
def main() -> None:
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
img_path = Path(INPUT_IMAGE)
|
||||
out_dir = base_dir / OUTPUT_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. 初始化深度模型
|
||||
predictor, used_backend = build_depth_predictor(UnifiedDepthConfig(backend=DEPTH_BACKEND))
|
||||
|
||||
# 2. 加载图像
|
||||
print(f"[Loading] 正在处理: {img_path.name}")
|
||||
img_pil = Image.open(img_path).convert("RGB")
|
||||
w_orig, h_orig = img_pil.size
|
||||
|
||||
# 3. 深度预测
|
||||
print(f"[Depth] 正在进行深度估计 (Large Image)...")
|
||||
depth = np.asarray(predictor(img_pil), dtype=np.float32).squeeze()
|
||||
|
||||
# 4. 归一化与保存
|
||||
dmin, dmax = depth.min(), depth.max()
|
||||
depth_norm = (depth - dmin) / (dmax - dmin + 1e-8)
|
||||
depth_u16 = (depth_norm * 65535.0).astype(np.uint16)
|
||||
Image.fromarray(depth_u16).save(out_dir / f"{img_path.stem}.depth.png")
|
||||
|
||||
# 5. 提取强鲁棒性边缘
|
||||
print(f"[Edge] 正在应用 CLAHE + Sobel 增强算法提取边缘...")
|
||||
edge_mask = _extract_robust_edges(depth_norm)
|
||||
|
||||
# 6. 导出
|
||||
mask_path = out_dir / f"{img_path.stem}.edge_mask_robust.png"
|
||||
Image.fromarray(edge_mask).save(mask_path)
|
||||
|
||||
edge_ratio = float((edge_mask > 0).sum()) / float(edge_mask.size)
|
||||
print("-" * 30)
|
||||
print(f"提取完成!")
|
||||
print(f"边缘密度: {edge_ratio:.2%} (目标通常应在 1% ~ 8% 之间)")
|
||||
print(f"如果 Mask 依然太黑,请调低 EDGE_TOP_PERCENTILE (如 90.0)")
|
||||
print(f"如果 Mask 太乱,请调高 EDGE_TOP_PERCENTILE (如 96.0)")
|
||||
print("-" * 30)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
284
python_server/test_inpaint.py
Normal file
284
python_server/test_inpaint.py
Normal file
@@ -0,0 +1,284 @@
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from model.Inpaint.inpaint_loader import (
|
||||
build_inpaint_predictor,
|
||||
UnifiedInpaintConfig,
|
||||
InpaintBackend,
|
||||
build_draw_predictor,
|
||||
UnifiedDrawConfig,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 配置区(按需修改)
|
||||
# -----------------------------
|
||||
# 任务模式:
|
||||
# - "inpaint": 补全
|
||||
# - "draw": 绘图(文生图 / 图生图)
|
||||
TASK_MODE = "draw"
|
||||
|
||||
# 指向你的输入图像,例如 image_0.png
|
||||
INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png"
|
||||
INPUT_MASK = "" # 为空表示不使用手工 mask
|
||||
OUTPUT_DIR = "outputs/test_inpaint_v2"
|
||||
MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK)
|
||||
USE_AUTO_MASK = True # True: 自动从图像推断补全区域
|
||||
AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时,接近黑色像素阈值
|
||||
AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖
|
||||
FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask
|
||||
|
||||
# 透明输出控制
|
||||
PRESERVE_TRANSPARENCY = True # 保持输出透明
|
||||
EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明
|
||||
NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景
|
||||
|
||||
INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果
|
||||
|
||||
# -----------------------------
|
||||
# 关键 Prompt 修改:只生成树
|
||||
# -----------------------------
|
||||
# 细化 PROMPT,强调树的种类、叶子和风格,使其与原图融合
|
||||
PROMPT = (
|
||||
"A highly detailed traditional Chinese ink brush painting. "
|
||||
"Restore and complete the existing trees naturally. "
|
||||
"Extend the complex tree branches and dense, varied green and teal leaves. "
|
||||
"Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. "
|
||||
"Focus solely on vegetation. "
|
||||
"Maintain the light beige background color and texture."
|
||||
)
|
||||
|
||||
# 使用 NEGATIVE_PROMPT 显式排除不想要的内容
|
||||
NEGATIVE_PROMPT = (
|
||||
"buildings, architecture, houses, pavilions, temples, windows, doors, "
|
||||
"people, figures, persons, characters, figures, clothing, faces, hands, "
|
||||
"text, writing, characters, words, letters, signatures, seals, stamps, "
|
||||
"calligraphy, objects, artifacts, boxes, baskets, tools, "
|
||||
"extra branches crossing unnaturally, bad composition, watermark, signature"
|
||||
)
|
||||
|
||||
STRENGTH = 0.8 # 保持较高强度以进行生成
|
||||
GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt
|
||||
NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节
|
||||
MAX_SIDE = 1024
|
||||
CONTROLNET_SCALE = 1.0
|
||||
|
||||
# -----------------------------
|
||||
# 绘图(draw)参数
|
||||
# -----------------------------
|
||||
# DRAW_INPUT_IMAGE 为空时:文生图
|
||||
# DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘)
|
||||
DRAW_INPUT_IMAGE = ""
|
||||
DRAW_PROMPT = """
|
||||
Chinese ink wash painting, vast snowy river under a pale sky,
|
||||
a small lonely boat at the horizon where water meets sky,
|
||||
an old fisherman wearing a straw hat sits at the stern, fishing quietly,
|
||||
gentle snowfall, misty atmosphere, distant mountains barely visible,
|
||||
minimalist composition, large empty space, soft brush strokes,
|
||||
calm, cold, and silent mood, poetic and serene
|
||||
"""
|
||||
|
||||
DRAW_NEGATIVE_PROMPT = """
|
||||
blurry, low quality, many people, bright colors, modern elements, crowded, noisy
|
||||
"""
|
||||
DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘
|
||||
DRAW_GUIDANCE_SCALE = 9
|
||||
DRAW_STEPS = 64
|
||||
DRAW_WIDTH = 2560
|
||||
DRAW_HEIGHT = 1440
|
||||
DRAW_MAX_SIDE = 2560
|
||||
|
||||
|
||||
def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray:
|
||||
out = mask.astype(bool)
|
||||
for _ in range(max(0, iterations)):
|
||||
p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False)
|
||||
out = (
|
||||
p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:]
|
||||
| p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:]
|
||||
| p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:]
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image:
|
||||
w, h = img_size
|
||||
x1, y1, x2, y2 = MASK_RECT
|
||||
x1 = max(0, min(x1, w - 1))
|
||||
x2 = max(0, min(x2, w - 1))
|
||||
y1 = max(0, min(y1, h - 1))
|
||||
y2 = max(0, min(y2, h - 1))
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
mask = Image.new("L", (w, h), 0)
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle([x1, y1, x2, y2], fill=255)
|
||||
return mask
|
||||
|
||||
|
||||
def _auto_mask_from_image(img_path: Path) -> Image.Image:
|
||||
"""
|
||||
自动推断缺失区域:
|
||||
1) 若输入带 alpha,透明区域作为 mask
|
||||
2) 否则将“接近黑色”区域作为候选缺失区域
|
||||
"""
|
||||
raw = Image.open(img_path)
|
||||
arr = np.asarray(raw)
|
||||
|
||||
if arr.ndim == 3 and arr.shape[2] == 4:
|
||||
alpha = arr[:, :, 3]
|
||||
mask_bool = alpha < 250
|
||||
else:
|
||||
rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8)
|
||||
# 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失
|
||||
dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1)
|
||||
mask_bool = dark
|
||||
|
||||
mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER)
|
||||
return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L")
|
||||
|
||||
|
||||
def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
生成输出 alpha:
|
||||
- 输入若有 alpha,优先沿用
|
||||
- 输入若无 alpha,则把接近黑色区域视作透明背景
|
||||
"""
|
||||
raw = Image.open(img_path)
|
||||
arr = np.asarray(raw)
|
||||
if arr.ndim == 3 and arr.shape[2] == 4:
|
||||
return arr[:, :, 3].astype(np.uint8)
|
||||
|
||||
rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8)
|
||||
non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1)
|
||||
return (non_black.astype(np.uint8) * 255)
|
||||
|
||||
|
||||
def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image:
|
||||
if INPUT_MASK:
|
||||
raw_mask_path = Path(INPUT_MASK)
|
||||
mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path)
|
||||
if mask_path.is_file():
|
||||
return Image.open(mask_path).convert("L")
|
||||
|
||||
if USE_AUTO_MASK:
|
||||
auto = _auto_mask_from_image(img_path)
|
||||
auto_arr = np.asarray(auto, dtype=np.uint8)
|
||||
if (auto_arr > 0).any():
|
||||
return auto
|
||||
if not FALLBACK_TO_RECT_MASK:
|
||||
raise ValueError("自动mask为空,请检查输入图像是否存在透明/黑色缺失区域。")
|
||||
|
||||
return _make_rect_mask(img_rgb.size)
|
||||
|
||||
|
||||
def _resolve_path(base_dir: Path, p: str) -> Path:
|
||||
raw = Path(p)
|
||||
return raw if raw.is_absolute() else (base_dir / raw)
|
||||
|
||||
|
||||
def run_inpaint_test(base_dir: Path, out_dir: Path) -> None:
|
||||
img_path = _resolve_path(base_dir, INPUT_IMAGE)
|
||||
if not img_path.is_file():
|
||||
raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}")
|
||||
|
||||
predictor, used_backend = build_inpaint_predictor(
|
||||
UnifiedInpaintConfig(backend=INPAINT_BACKEND)
|
||||
)
|
||||
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
mask = _load_or_make_mask(base_dir, img_path, img)
|
||||
mask_out = out_dir / f"{img_path.stem}.mask_used.png"
|
||||
mask.save(mask_out)
|
||||
|
||||
kwargs = dict(
|
||||
strength=STRENGTH,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
num_inference_steps=NUM_INFERENCE_STEPS,
|
||||
max_side=MAX_SIDE,
|
||||
)
|
||||
if used_backend == InpaintBackend.CONTROLNET:
|
||||
kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE
|
||||
|
||||
out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs)
|
||||
out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png"
|
||||
if PRESERVE_TRANSPARENCY:
|
||||
alpha = _build_alpha_from_input(img_path, img)
|
||||
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||
if EXPAND_ALPHA_WITH_MASK:
|
||||
alpha = np.maximum(alpha, mask_u8)
|
||||
|
||||
out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8)
|
||||
out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1)
|
||||
Image.fromarray(out_rgba, mode="RGBA").save(out_path)
|
||||
else:
|
||||
out.save(out_path)
|
||||
|
||||
ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1])
|
||||
print(f"[Inpaint] backend={used_backend.value}, saved={out_path}")
|
||||
print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}")
|
||||
|
||||
|
||||
def run_draw_test(base_dir: Path, out_dir: Path) -> None:
|
||||
"""
|
||||
绘图测试:
|
||||
- 文生图:DRAW_INPUT_IMAGE=""
|
||||
- 图生图:DRAW_INPUT_IMAGE 指向参考图
|
||||
"""
|
||||
draw_predictor = build_draw_predictor(UnifiedDrawConfig())
|
||||
ref_image: Image.Image | None = None
|
||||
mode = "text2img"
|
||||
|
||||
if DRAW_INPUT_IMAGE:
|
||||
ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE)
|
||||
if not ref_path.is_file():
|
||||
raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}")
|
||||
ref_image = Image.open(ref_path).convert("RGB")
|
||||
mode = "img2img"
|
||||
|
||||
out = draw_predictor(
|
||||
prompt=DRAW_PROMPT,
|
||||
image=ref_image,
|
||||
negative_prompt=DRAW_NEGATIVE_PROMPT,
|
||||
strength=DRAW_STRENGTH,
|
||||
guidance_scale=DRAW_GUIDANCE_SCALE,
|
||||
num_inference_steps=DRAW_STEPS,
|
||||
width=DRAW_WIDTH,
|
||||
height=DRAW_HEIGHT,
|
||||
max_side=DRAW_MAX_SIDE,
|
||||
)
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = out_dir / f"draw_{mode}_{ts}.png"
|
||||
out.save(out_path)
|
||||
print(f"[Draw] mode={mode}, saved={out_path}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
out_dir = base_dir / OUTPUT_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if TASK_MODE == "inpaint":
|
||||
run_inpaint_test(base_dir, out_dir)
|
||||
return
|
||||
if TASK_MODE == "draw":
|
||||
run_draw_test(base_dir, out_dir)
|
||||
return
|
||||
|
||||
raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
163
python_server/test_seg.py
Normal file
163
python_server/test_seg.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw
|
||||
from scipy.ndimage import label as nd_label
|
||||
|
||||
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
|
||||
|
||||
# ================= 配置区 =================
|
||||
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
|
||||
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
|
||||
OUTPUT_DIR = "outputs/test_seg_v2"
|
||||
SEG_BACKEND = SegBackend.SAM
|
||||
|
||||
# 目标筛选参数
|
||||
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
|
||||
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
|
||||
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
|
||||
|
||||
# 视觉效果
|
||||
MASK_ALPHA = 0.4
|
||||
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
|
||||
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
|
||||
SAVE_OBJECT_PNG = True
|
||||
# =========================================
|
||||
|
||||
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
||||
h, w = arr.shape[:2]
|
||||
if max_side <= 0 or max(h, w) <= max_side:
|
||||
return arr, (h, w), (h, w)
|
||||
scale = float(max_side) / float(max(h, w))
|
||||
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
|
||||
resample = Image.NEAREST if is_mask else Image.BICUBIC
|
||||
pil = Image.fromarray(arr)
|
||||
out = pil.resize((run_w, run_h), resample=resample)
|
||||
return np.asarray(out), (h, w), (run_h, run_w)
|
||||
|
||||
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
|
||||
"""
|
||||
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
|
||||
"""
|
||||
# 确保是布尔类型
|
||||
mask_bool = edge_mask > 127
|
||||
# 连通域标记
|
||||
labeled_array, num_features = nd_label(mask_bool)
|
||||
|
||||
prompts = []
|
||||
component_info = []
|
||||
for i in range(1, num_features + 1):
|
||||
coords = np.argwhere(labeled_array == i)
|
||||
area = len(coords)
|
||||
if area < 100: continue # 过滤噪声
|
||||
# 取几何中心作为引导点
|
||||
center_y, center_x = np.median(coords, axis=0).astype(int)
|
||||
component_info.append(((center_x, center_y), area))
|
||||
|
||||
# 按面积排序,优先处理大面积线条覆盖的物体
|
||||
component_info.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for pt, _ in component_info[:max_components]:
|
||||
prompts.append({
|
||||
"point_coords": np.array([pt], dtype=np.float32),
|
||||
"point_labels": np.array([1], dtype=np.int32)
|
||||
})
|
||||
return prompts
|
||||
|
||||
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
|
||||
obj_dir = out_dir / f"{stem}.objects"
|
||||
obj_dir.mkdir(parents=True, exist_ok=True)
|
||||
for idx, lb in enumerate(targets, start=1):
|
||||
mask = (label_map == lb)
|
||||
ys, xs = np.where(mask)
|
||||
if len(ys) == 0: continue
|
||||
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
|
||||
rgb_crop = rgb[y1:y2+1, x1:x2+1]
|
||||
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
|
||||
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
|
||||
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
|
||||
|
||||
def main():
|
||||
# 1. 加载资源
|
||||
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
|
||||
out_dir = Path(__file__).parent / OUTPUT_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
|
||||
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
|
||||
|
||||
# 2. 准备推理尺寸 (缩放以节省显存)
|
||||
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
|
||||
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
|
||||
|
||||
# 3. 初始化 SAM(直接使用 SamPredictor,避免 wrapper 接口不支持 point/bbox prompt)
|
||||
import torch
|
||||
|
||||
_ensure_sam_on_path()
|
||||
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
|
||||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
|
||||
|
||||
# 4. 提取引导点
|
||||
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
|
||||
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
|
||||
|
||||
# 5. 执行推理
|
||||
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
|
||||
predictor.set_image(rgb_run)
|
||||
|
||||
for idx, p in enumerate(prompts, start=1):
|
||||
# multimask_output=True 可以获得更稳定的结果
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=p["point_coords"],
|
||||
point_labels=p["point_labels"],
|
||||
multimask_output=True
|
||||
)
|
||||
# 挑选得分最高的 mask
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# 只有面积足够才保留
|
||||
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
|
||||
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
|
||||
final_label_map_run[best_mask > 0] = idx
|
||||
|
||||
# 6. 后处理与映射回原图
|
||||
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
|
||||
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
|
||||
|
||||
# 7. 导出与可视化
|
||||
unique_labels = [l for l in np.unique(label_map) if l > 0]
|
||||
|
||||
# 绘制可视化图
|
||||
marked_img = rgb_orig.copy()
|
||||
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
|
||||
|
||||
# 混合颜色显示
|
||||
overlay = rgb_orig.astype(np.float32)
|
||||
for lb in unique_labels:
|
||||
m = (label_map == lb)
|
||||
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
|
||||
# 简单边缘处理
|
||||
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
|
||||
|
||||
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
|
||||
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
|
||||
|
||||
if SAVE_OBJECT_PNG:
|
||||
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
|
||||
|
||||
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user