initial commit

This commit is contained in:
2026-04-07 20:55:30 +08:00
commit 81d1fb7856
84 changed files with 11929 additions and 0 deletions

1
python_server/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
outputs/

View File

@@ -0,0 +1 @@

223
python_server/config.py Normal file
View File

@@ -0,0 +1,223 @@
"""
python_server 的统一配置文件。
特点:
- 使用 Python 而不是 YAML方便在代码中集中列举所有可用模型供前端读取。
- 后端加载模型时,也从这里读取默认值,保证单一信息源。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal, TypedDict, List
from model.Depth.zoe_loader import ZoeModelName
# -----------------------------
# 1. 深度模型枚举(给前端展示用)
# -----------------------------
class DepthModelInfo(TypedDict):
id: str # 唯一 ID如 "zoedepth_n"
family: str # 模型家族,如 "ZoeDepth"
name: str # 展示名,如 "ZoeD_N (NYU+KITT)"
description: str # 简短描述
backend: str # 后端类型,如 "zoedepth", "depth_anything_v2", "midas", "dpt"
DEPTH_MODELS: List[DepthModelInfo] = [
# ZoeDepth 系列
{
"id": "zoedepth_n",
"family": "ZoeDepth",
"name": "ZoeD_N",
"description": "ZoeDepth zero-shot 模型,适用室内/室外通用场景。",
"backend": "zoedepth",
},
{
"id": "zoedepth_k",
"family": "ZoeDepth",
"name": "ZoeD_K",
"description": "ZoeDepth Kitti 专用版本,针对户外驾驶场景优化。",
"backend": "zoedepth",
},
{
"id": "zoedepth_nk",
"family": "ZoeDepth",
"name": "ZoeD_NK",
"description": "ZoeDepth 双头版本NYU+KITTI综合室内/室外场景。",
"backend": "zoedepth",
},
# 预留Depth Anything v2
{
"id": "depth_anything_v2_s",
"family": "Depth Anything V2",
"name": "Depth Anything V2 Small",
"description": "轻量级 Depth Anything V2 小模型。",
"backend": "depth_anything_v2",
},
# 预留MiDaS
{
"id": "midas_dpt_large",
"family": "MiDaS",
"name": "MiDaS DPT Large",
"description": "MiDaS DPT-Large 高质量深度模型。",
"backend": "midas",
},
# 预留DPT
{
"id": "dpt_large",
"family": "DPT",
"name": "DPT Large",
"description": "DPT Large 单目深度估计模型。",
"backend": "dpt",
},
]
# -----------------------------
# 1.2 补全模型枚举(给前端展示用)
# -----------------------------
class InpaintModelInfo(TypedDict):
id: str
family: str
name: str
description: str
backend: str # "sdxl_inpaint" | "controlnet"
INPAINT_MODELS: List[InpaintModelInfo] = [
{
"id": "sdxl_inpaint",
"family": "SDXL",
"name": "SDXL Inpainting",
"description": "基于 diffusers 的 SDXL 补全管线(需要 prompt + mask",
"backend": "sdxl_inpaint",
},
{
"id": "controlnet",
"family": "ControlNet",
"name": "ControlNet (placeholder)",
"description": "ControlNet 补全/控制生成(当前统一封装暂未实现)。",
"backend": "controlnet",
},
]
# -----------------------------
# 1.3 动画模型枚举(给前端展示用)
# -----------------------------
class AnimationModelInfo(TypedDict):
id: str
family: str
name: str
description: str
backend: str # "animatediff"
ANIMATION_MODELS: List[AnimationModelInfo] = [
{
"id": "animatediff",
"family": "AnimateDiff",
"name": "AnimateDiff (Text-to-Video)",
"description": "基于 AnimateDiff 的文生动画能力,输出 GIF 动画。",
"backend": "animatediff",
},
]
# -----------------------------
# 2. 后端默认配置(给服务端用)
# -----------------------------
@dataclass
class DepthConfig:
# 深度后端选择:前端不参与选择;只允许在后端配置中切换
backend: Literal["zoedepth", "depth_anything_v2", "dpt", "midas"] = "zoedepth"
# ZoeDepth 家族默认选择
zoe_model: ZoeModelName = "ZoeD_N"
# Depth Anything V2 默认 encoder
da_v2_encoder: Literal["vits", "vitb", "vitl", "vitg"] = "vitl"
# DPT 默认模型类型
dpt_model_type: Literal["dpt_large", "dpt_hybrid"] = "dpt_large"
# MiDaS 默认模型类型
midas_model_type: Literal[
"dpt_beit_large_512",
"dpt_swin2_large_384",
"dpt_swin2_tiny_256",
"dpt_levit_224",
] = "dpt_beit_large_512"
# 统一的默认运行设备
device: str = "cuda"
@dataclass
class InpaintConfig:
# 统一补全默认后端
backend: Literal["sdxl_inpaint", "controlnet"] = "sdxl_inpaint"
# SDXL Inpaint 的基础模型(可写 HuggingFace model id 或本地目录)
sdxl_base_model: str = "stabilityai/stable-diffusion-xl-base-1.0"
# ControlNet Inpaint 基础模型与 controlnet 权重
controlnet_base_model: str = "runwayml/stable-diffusion-inpainting"
controlnet_model: str = "lllyasviel/control_v11p_sd15_inpaint"
device: str = "cuda"
@dataclass
class AnimationConfig:
# 统一动画默认后端
backend: Literal["animatediff"] = "animatediff"
# AnimateDiff 根目录(相对 python_server/ 或绝对路径)
animate_diff_root: str = "model/Animation/AnimateDiff"
# 文生图基础模型HuggingFace model id 或本地目录)
pretrained_model_path: str = "runwayml/stable-diffusion-v1-5"
# AnimateDiff 推理配置
inference_config: str = "configs/inference/inference-v3.yaml"
# 运动模块与个性化底模(为空则由脚本按默认处理)
motion_module: str = "v3_sd15_mm.ckpt"
dreambooth_model: str = "realisticVisionV60B1_v51VAE.safetensors"
lora_model: str = ""
lora_alpha: float = 0.8
# 部分环境 xformers 兼容性差,可手动关闭
without_xformers: bool = False
device: str = "cuda"
@dataclass
class AppConfig:
# 使用 default_factory 避免 dataclass 的可变默认值问题
depth: DepthConfig = field(default_factory=DepthConfig)
inpaint: InpaintConfig = field(default_factory=InpaintConfig)
animation: AnimationConfig = field(default_factory=AnimationConfig)
# 后端代码直接 import DEFAULT_CONFIG 即可
DEFAULT_CONFIG = AppConfig()
def list_depth_models() -> List[DepthModelInfo]:
"""
返回所有可用深度模型的元信息,方便前端通过 /models 等接口读取。
"""
return DEPTH_MODELS
def list_inpaint_models() -> List[InpaintModelInfo]:
"""
返回所有可用补全模型的元信息,方便前端通过 /models 等接口读取。
"""
return INPAINT_MODELS
def list_animation_models() -> List[AnimationModelInfo]:
"""
返回所有可用动画模型的元信息,方便前端通过 /models 等接口读取。
"""
return ANIMATION_MODELS

View File

@@ -0,0 +1,153 @@
"""
兼容层:从 Python 配置模块中构造 zoe_loader 需要的 ZoeConfig。
后端其它代码尽量只依赖这里的函数,而不直接依赖 config.py 的具体结构,
便于以后扩展。
"""
from model.Depth.zoe_loader import ZoeConfig
from model.Depth.depth_anything_v2_loader import DepthAnythingV2Config
from model.Depth.dpt_loader import DPTConfig
from model.Depth.midas_loader import MiDaSConfig
from config import AppConfig, DEFAULT_CONFIG
def load_app_config() -> AppConfig:
"""
当前直接返回 DEFAULT_CONFIG。
如未来需要支持多环境 / 覆盖配置,可以在这里加逻辑。
"""
return DEFAULT_CONFIG
def build_zoe_config_from_app(app_cfg: AppConfig | None = None) -> ZoeConfig:
"""
将 AppConfig.depth 映射为 ZoeConfig供 zoe_loader 使用。
如果未显式传入 app_cfg则使用全局 DEFAULT_CONFIG。
"""
if app_cfg is None:
app_cfg = load_app_config()
return ZoeConfig(
model=app_cfg.depth.zoe_model,
device=app_cfg.depth.device,
)
def build_depth_anything_v2_config_from_app(
app_cfg: AppConfig | None = None,
) -> DepthAnythingV2Config:
"""
将 AppConfig.depth 映射为 DepthAnythingV2Config。
"""
if app_cfg is None:
app_cfg = load_app_config()
return DepthAnythingV2Config(
encoder=app_cfg.depth.da_v2_encoder,
device=app_cfg.depth.device,
)
def build_dpt_config_from_app(app_cfg: AppConfig | None = None) -> DPTConfig:
if app_cfg is None:
app_cfg = load_app_config()
return DPTConfig(
model_type=app_cfg.depth.dpt_model_type,
device=app_cfg.depth.device,
)
def build_midas_config_from_app(app_cfg: AppConfig | None = None) -> MiDaSConfig:
if app_cfg is None:
app_cfg = load_app_config()
return MiDaSConfig(
model_type=app_cfg.depth.midas_model_type,
device=app_cfg.depth.device,
)
def get_depth_backend_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.depth.backend
def get_inpaint_backend_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.backend
def get_sdxl_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.sdxl_base_model
def get_controlnet_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.controlnet_base_model
def get_controlnet_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.controlnet_model
def get_animation_backend_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.backend
def get_animatediff_root_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.animate_diff_root
def get_animatediff_pretrained_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.pretrained_model_path
def get_animatediff_inference_config_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.inference_config
def get_animatediff_motion_module_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.motion_module
def get_animatediff_dreambooth_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.dreambooth_model
def get_animatediff_lora_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.lora_model
def get_animatediff_lora_alpha_from_app(app_cfg: AppConfig | None = None) -> float:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.lora_alpha
def get_animatediff_without_xformers_from_app(app_cfg: AppConfig | None = None) -> bool:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.without_xformers

Submodule python_server/model/Animation/AnimateDiff added at e92bd5671b

View File

@@ -0,0 +1,12 @@
from .animation_loader import (
AnimationBackend,
UnifiedAnimationConfig,
build_animation_predictor,
)
__all__ = [
"AnimationBackend",
"UnifiedAnimationConfig",
"build_animation_predictor",
]

View File

@@ -0,0 +1,268 @@
from __future__ import annotations
"""
Unified animation model loading entry.
Current support:
- AnimateDiff (script-based invocation)
"""
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
import json
import os
import subprocess
import sys
import tempfile
from typing import Callable
from config_loader import (
load_app_config,
get_animatediff_root_from_app,
get_animatediff_pretrained_model_from_app,
get_animatediff_inference_config_from_app,
get_animatediff_motion_module_from_app,
get_animatediff_dreambooth_model_from_app,
get_animatediff_lora_model_from_app,
get_animatediff_lora_alpha_from_app,
get_animatediff_without_xformers_from_app,
)
class AnimationBackend(str, Enum):
ANIMATEDIFF = "animatediff"
@dataclass
class UnifiedAnimationConfig:
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
# Optional overrides. If None, values come from app config.
animate_diff_root: str | None = None
pretrained_model_path: str | None = None
inference_config: str | None = None
motion_module: str | None = None
dreambooth_model: str | None = None
lora_model: str | None = None
lora_alpha: float | None = None
without_xformers: bool | None = None
controlnet_path: str | None = None
controlnet_config: str | None = None
def _yaml_string(value: str) -> str:
return json.dumps(value, ensure_ascii=False)
def _resolve_root(root_cfg: str) -> Path:
root = Path(root_cfg)
if not root.is_absolute():
root = Path(__file__).resolve().parents[2] / root_cfg
return root.resolve()
def _make_animatediff_predictor(
cfg: UnifiedAnimationConfig,
) -> Callable[..., Path]:
app_cfg = load_app_config()
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
script_path = root / "scripts" / "animate.py"
samples_dir = root / "samples"
samples_dir.mkdir(parents=True, exist_ok=True)
if not script_path.is_file():
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
pretrained_model_path = (
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
)
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
dreambooth_model = cfg.dreambooth_model
if dreambooth_model is None:
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
lora_model = cfg.lora_model
if lora_model is None:
lora_model = get_animatediff_lora_model_from_app(app_cfg)
lora_alpha = cfg.lora_alpha
if lora_alpha is None:
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
without_xformers = cfg.without_xformers
if without_xformers is None:
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
def _predict(
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 25,
guidance_scale: float = 8.0,
width: int = 512,
height: int = 512,
video_length: int = 16,
seed: int = -1,
control_image_path: str | None = None,
output_format: str = "gif",
) -> Path:
if output_format not in {"gif", "png_sequence"}:
raise ValueError("output_format must be 'gif' or 'png_sequence'")
prompt_value = prompt.strip()
if not prompt_value:
raise ValueError("prompt must not be empty")
negative_prompt_value = negative_prompt or ""
motion_module_line = (
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
)
dreambooth_line = (
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
)
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
control_image_line = ""
if control_image_path:
control_image = Path(control_image_path).expanduser().resolve()
if not control_image.is_file():
raise FileNotFoundError(f"control_image_path not found: {control_image}")
control_image_line = (
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
" controlnet_images:\n"
f' - {_yaml_string(str(control_image))}\n'
" controlnet_image_indexs:\n"
" - 0\n"
)
config_text = (
"- prompt:\n"
f" - {_yaml_string(prompt_value)}\n"
" n_prompt:\n"
f" - {_yaml_string(negative_prompt_value)}\n"
f" steps: {int(num_inference_steps)}\n"
f" guidance_scale: {float(guidance_scale)}\n"
f" W: {int(width)}\n"
f" H: {int(height)}\n"
f" L: {int(video_length)}\n"
" seed:\n"
f" - {int(seed)}\n"
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
)
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
cfg_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",
prefix="animatediff_cfg_",
dir=str(root),
delete=False,
encoding="utf-8",
)
cfg_file_path = Path(cfg_file.name)
try:
cfg_file.write(config_text)
cfg_file.flush()
cfg_file.close()
cmd = [
sys.executable,
str(script_path),
"--pretrained-model-path",
pretrained_model_path,
"--inference-config",
inference_config,
"--config",
str(cfg_file_path),
"--L",
str(int(video_length)),
"--W",
str(int(width)),
"--H",
str(int(height)),
]
if without_xformers:
cmd.append("--without-xformers")
if output_format == "png_sequence":
cmd.append("--save-png-sequence")
env = dict(os.environ)
existing_pythonpath = env.get("PYTHONPATH", "")
root_pythonpath = str(root)
env["PYTHONPATH"] = (
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
)
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
return subprocess.run(
command,
cwd=str(root),
check=True,
capture_output=True,
text=True,
env=env,
)
try:
proc = _run_once(cmd)
except subprocess.CalledProcessError as first_error:
stderr_text = first_error.stderr or ""
should_retry_without_xformers = (
not without_xformers
and "--without-xformers" not in cmd
and (
"memory_efficient_attention" in stderr_text
or "AcceleratorError" in stderr_text
or "invalid configuration argument" in stderr_text
)
)
if not should_retry_without_xformers:
raise
retry_cmd = [*cmd, "--without-xformers"]
proc = _run_once(retry_cmd)
_ = proc
except subprocess.CalledProcessError as e:
raise RuntimeError(
"AnimateDiff inference failed.\n"
f"stdout:\n{e.stdout}\n"
f"stderr:\n{e.stderr}"
) from e
finally:
try:
cfg_file_path.unlink(missing_ok=True)
except Exception:
pass
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
if not candidates:
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
if not candidates:
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
if output_format == "png_sequence":
frames_root = latest / "sample_frames"
if not frames_root.is_dir():
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not frame_dirs:
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
return frame_dirs[0].resolve()
return (latest / "sample.gif").resolve()
return _predict
def build_animation_predictor(
cfg: UnifiedAnimationConfig | None = None,
) -> tuple[Callable[..., Path], AnimationBackend]:
cfg = cfg or UnifiedAnimationConfig()
if cfg.backend == AnimationBackend.ANIMATEDIFF:
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
raise ValueError(f"Unsupported animation backend: {cfg.backend}")

Submodule python_server/model/Depth/DPT added at cd3fe90bb4

Submodule python_server/model/Depth/Depth-Anything-V2 added at e5a2732d3e

Submodule python_server/model/Depth/MiDaS added at 454597711a

Submodule python_server/model/Depth/ZoeDepth added at d87f17b2f5

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,147 @@
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
_THIS_DIR = Path(__file__).resolve().parent
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
if _DA_REPO_ROOT.is_dir():
da_path = str(_DA_REPO_ROOT)
if da_path not in sys.path:
sys.path.insert(0, da_path)
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
@dataclass
class DepthAnythingV2Config:
"""
Depth Anything V2 模型选择配置。
encoder: "vits" | "vitb" | "vitl" | "vitg"
device: "cuda" | "cpu"
input_size: 推理时的输入分辨率(短边),参考官方 demo默认 518。
"""
encoder: EncoderName = "vitl"
device: str = "cuda"
input_size: int = 518
_MODEL_CONFIGS = {
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
}
_DA_V2_WEIGHTS_URLS = {
# 官方权重托管在 HuggingFace:
# - Small -> vits
# - Base -> vitb
# - Large -> vitl
# - Giant -> vitg
# 如需替换为国内镜像,可直接修改这些 URL。
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
}
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _DA_V2_WEIGHTS_URLS.get(encoder)
if not url:
raise FileNotFoundError(
f"找不到权重文件: {ckpt_path}\n"
f"且当前未为 encoder='{encoder}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\n权重下载完成。")
def load_depth_anything_v2_from_config(
cfg: DepthAnythingV2Config,
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
"""
根据配置加载 Depth Anything V2 模型与对应配置。
说明:
- 权重文件路径遵循官方命名约定:
checkpoints/depth_anything_v2_{encoder}.pth
例如depth_anything_v2_vitl.pth
- 请确保上述权重文件已下载到
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
"""
if cfg.encoder not in _MODEL_CONFIGS:
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
_download_if_missing(cfg.encoder, ckpt_path)
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
state_dict = torch.load(str(ckpt_path), map_location="cpu")
model.load_state_dict(state_dict)
if cfg.device.startswith("cuda") and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model = model.to(device).eval()
cfg = DepthAnythingV2Config(
encoder=cfg.encoder,
device=device,
input_size=cfg.input_size,
)
return model, cfg
def infer_depth_anything_v2(
model: DepthAnythingV2,
image_bgr: np.ndarray,
input_size: int,
) -> np.ndarray:
"""
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
"""
depth = model.infer_image(image_bgr, input_size)
depth = np.asarray(depth, dtype="float32")
return depth

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
"""
统一的深度模型加载入口。
当前支持:
- ZoeDepth三种ZoeD_N / ZoeD_K / ZoeD_NK
- Depth Anything V2四种 encodervits / vitb / vitl / vitg
未来如果要加 MiDaS / DPT只需要在这里再接一层即可。
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
build_zoe_config_from_app,
build_depth_anything_v2_config_from_app,
build_dpt_config_from_app,
build_midas_config_from_app,
)
from .zoe_loader import load_zoe_from_config
from .depth_anything_v2_loader import (
load_depth_anything_v2_from_config,
infer_depth_anything_v2,
)
from .dpt_loader import load_dpt_from_config, infer_dpt
from .midas_loader import load_midas_from_config, infer_midas
class DepthBackend(str, Enum):
"""统一的深度模型后端类型。"""
ZOEDEPTH = "zoedepth"
DEPTH_ANYTHING_V2 = "depth_anything_v2"
DPT = "dpt"
MIDAS = "midas"
@dataclass
class UnifiedDepthConfig:
"""
统一深度配置。
backend: 使用哪个后端
device: 强制设备(可选),不填则使用 config.py 中的设置
"""
backend: DepthBackend = DepthBackend.ZOEDEPTH
device: str | None = None
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
zoe_cfg = build_zoe_config_from_app(app_cfg)
if device_override is not None:
zoe_cfg.device = device_override
model, _ = load_zoe_from_config(zoe_cfg)
def _predict(img: Image.Image) -> np.ndarray:
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
return np.asarray(depth, dtype="float32").squeeze()
return _predict
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
if device_override is not None:
da_cfg.device = device_override
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
def _predict(img: Image.Image) -> np.ndarray:
# Depth Anything V2 的 infer_image 接收 BGR uint8
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
bgr = rgb[:, :, ::-1]
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
return depth.astype("float32").squeeze()
return _predict
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
dpt_cfg = build_dpt_config_from_app(app_cfg)
if device_override is not None:
dpt_cfg.device = device_override
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
def _predict(img: Image.Image) -> np.ndarray:
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
midas_cfg = build_midas_config_from_app(app_cfg)
if device_override is not None:
midas_cfg.device = device_override
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
def _predict(img: Image.Image) -> np.ndarray:
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def build_depth_predictor(
cfg: UnifiedDepthConfig | None = None,
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
"""
统一构建深度预测函数。
返回:
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
- 实际使用的 backend 类型
"""
cfg = cfg or UnifiedDepthConfig()
if cfg.backend == DepthBackend.ZOEDEPTH:
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
if cfg.backend == DepthBackend.DPT:
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
if cfg.backend == DepthBackend.MIDAS:
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
raise ValueError(f"不支持的深度后端: {cfg.backend}")

View File

@@ -0,0 +1,156 @@
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
_THIS_DIR = Path(__file__).resolve().parent
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
if _DPT_REPO_ROOT.is_dir():
dpt_path = str(_DPT_REPO_ROOT)
if dpt_path not in sys.path:
sys.path.insert(0, dpt_path)
from dpt.models import DPTDepthModel # type: ignore[import]
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
from torchvision.transforms import Compose
import cv2
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
@dataclass
class DPTConfig:
model_type: DPTModelType = "dpt_large"
device: str = "cuda"
_DPT_WEIGHTS_URLS = {
# 官方 DPT 模型权重托管在:
# https://github.com/isl-org/DPT#models
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
}
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _DPT_WEIGHTS_URLS.get(model_type)
if not url:
raise FileNotFoundError(
f"找不到 DPT 权重文件: {ckpt_path}\n"
f"且当前未为 model_type='{model_type}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\nDPT 权重下载完成。")
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
"""
加载 DPT 模型与对应的预处理 transform。
"""
ckpt_name = {
"dpt_large": "dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
}[cfg.model_type]
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
_download_if_missing(cfg.model_type, ckpt_path)
if cfg.model_type == "dpt_large":
net_w = net_h = 384
model = DPTDepthModel(
path=str(ckpt_path),
backbone="vitl16_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
else:
net_w = net_h = 384
model = DPTDepthModel(
path=str(ckpt_path),
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
model.to(device).eval()
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
cfg = DPTConfig(model_type=cfg.model_type, device=device)
return model, cfg, transform
def infer_dpt(
model: DPTDepthModel,
transform: Compose,
image_bgr: np.ndarray,
device: str,
) -> np.ndarray:
"""
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
"""
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
sample = transform({"image": img})["image"]
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
with torch.no_grad():
prediction = model(input_batch)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
return prediction.astype("float32")

View File

@@ -0,0 +1,127 @@
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
_THIS_DIR = Path(__file__).resolve().parent
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
if _MIDAS_REPO_ROOT.is_dir():
midas_path = str(_MIDAS_REPO_ROOT)
if midas_path not in sys.path:
sys.path.insert(0, midas_path)
from midas.model_loader import load_model, default_models # type: ignore[import]
import utils # from MiDaS repo
MiDaSModelType = Literal[
"dpt_beit_large_512",
"dpt_swin2_large_384",
"dpt_swin2_tiny_256",
"dpt_levit_224",
]
@dataclass
class MiDaSConfig:
model_type: MiDaSModelType = "dpt_beit_large_512"
device: str = "cuda"
_MIDAS_WEIGHTS_URLS = {
# 官方权重参见 MiDaS 仓库 README
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
}
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _MIDAS_WEIGHTS_URLS.get(model_type)
if not url:
raise FileNotFoundError(
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
f"且当前未为 model_type='{model_type}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\nMiDaS 权重下载完成。")
def load_midas_from_config(
cfg: MiDaSConfig,
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
"""
加载 MiDaS 模型与对应 transform。
返回: model, cfg, transform, net_w, net_h
"""
# default_models 中给了默认权重路径名
model_info = default_models[cfg.model_type]
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
_download_if_missing(cfg.model_type, ckpt_path)
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
model, transform, net_w, net_h = load_model(
device=torch.device(device),
model_path=str(ckpt_path),
model_type=cfg.model_type,
optimize=False,
height=None,
square=False,
)
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
return model, cfg, transform, net_w, net_h
def infer_midas(
model: torch.nn.Module,
transform: callable,
image_rgb: np.ndarray,
net_w: int,
net_h: int,
device: str,
) -> np.ndarray:
"""
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
"""
image = transform({"image": image_rgb})["image"]
prediction = utils.process(
torch.device(device),
model,
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
image=image,
input_size=(net_w, net_h),
target_size=image_rgb.shape[1::-1],
optimize=False,
use_camera=False,
)
return np.asarray(prediction, dtype="float32").squeeze()

View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass
from typing import Literal
import sys
from pathlib import Path
import torch
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
# 这样其内部的 `import zoedepth...` 才能正常工作。
_THIS_DIR = Path(__file__).resolve().parent
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
if _ZOE_REPO_ROOT.is_dir():
zoe_path = str(_ZOE_REPO_ROOT)
if zoe_path not in sys.path:
sys.path.insert(0, zoe_path)
from zoedepth.models.builder import build_model
from zoedepth.utils.config import get_config
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
@dataclass
class ZoeConfig:
"""
ZoeDepth 模型选择配置。
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
device: "cuda" | "cpu"
"""
model: ZoeModelName = "ZoeD_N"
device: str = "cuda"
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
"""
手动加载 ZoeDepth 三种模型之一:
- "ZoeD_N"
- "ZoeD_K"
- "ZoeD_NK"
"""
if name == "ZoeD_N":
conf = get_config("zoedepth", "infer")
elif name == "ZoeD_K":
conf = get_config("zoedepth", "infer", config_version="kitti")
elif name == "ZoeD_NK":
conf = get_config("zoedepth_nk", "infer")
else:
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
model = build_model(conf)
if device.startswith("cuda") and torch.cuda.is_available():
model = model.to("cuda")
else:
model = model.to("cpu")
model.eval()
return model, conf
def load_zoe_from_config(config: ZoeConfig):
"""
根据 ZoeConfig 加载模型。
示例:
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
model, conf = load_zoe_from_config(cfg)
"""
return load_zoe_from_name(config.model, config.device)

Submodule python_server/model/Inpaint/ControlNet added at ed85cd1e25

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,413 @@
from __future__ import annotations
"""
统一的补全Inpaint模型加载入口。
当前支持:
- SDXL Inpaintdiffusers AutoPipelineForInpainting
- ControlNet占位暂未统一封装
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
get_sdxl_base_model_from_app,
get_controlnet_base_model_from_app,
get_controlnet_model_from_app,
)
class InpaintBackend(str, Enum):
SDXL_INPAINT = "sdxl_inpaint"
CONTROLNET = "controlnet"
@dataclass
class UnifiedInpaintConfig:
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
device: str | None = None
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
sdxl_base_model: str | None = None
@dataclass
class UnifiedDrawConfig:
"""
统一绘图配置:
- 纯文生图image=None
- 图生图模仿输入图image=某张参考图
"""
device: str | None = None
sdxl_base_model: str | None = None
def _resolve_device_and_dtype(device: str | None):
import torch
app_cfg = load_app_config()
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
return device, torch_dtype
def _enable_memory_opts(pipe, device: str) -> None:
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
run_w, run_h = orig_w, orig_h
if max_side > 0 and max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
return run_w, run_h
def _make_sdxl_inpaint_predictor(
cfg: UnifiedInpaintConfig,
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
"""
返回补全函数:
- 输入image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
- 输出PIL RGB 结果图
"""
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
app_cfg = load_app_config()
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
device = cfg.device
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
base_model,
torch_dtype=torch_dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
).to(device)
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
# 省显存设置(尽量不改变输出语义)
# 注意CPU offload 会明显变慢,但能显著降低显存占用。
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
def _predict(
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.8,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
max_side: int = 1024,
) -> Image.Image:
image = image.convert("RGB")
# diffusers 要求 mask 为单通道,白色区域为需要重绘
mask = mask.convert("L")
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM
# 推理时将图像按比例缩放到不超过 max_side默认 1024并对齐到 8 的倍数。
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
orig_w, orig_h = image.size
run_w, run_h = orig_w, orig_h
if max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
if run_w <= 0:
run_w = 8
if run_h <= 0:
run_h = 8
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
else:
image_run = image
mask_run = mask
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
mask_image=mask_run,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=run_w,
height=run_h,
).images[0]
out = out.convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _predict
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
app_cfg = load_app_config()
device = _.device
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
base_model = get_controlnet_base_model_from_app(app_cfg)
controlnet_id = get_controlnet_model_from_app(app_cfg)
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None,
)
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pipe.to(device)
else:
pipe.to(device)
def _predict(
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.8,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
controlnet_conditioning_scale: float = 1.0,
max_side: int = 768,
) -> Image.Image:
import cv2
import numpy as np
image = image.convert("RGB")
mask = mask.convert("L")
orig_w, orig_h = image.size
run_w, run_h = orig_w, orig_h
if max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
else:
image_run = image
mask_run = mask
# control image使用 canny 边缘作为约束(最通用)
rgb = np.array(image_run, dtype=np.uint8)
edges = cv2.Canny(rgb, 100, 200)
edges3 = np.stack([edges, edges, edges], axis=-1)
control_image = Image.fromarray(edges3)
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
mask_image=mask_run,
control_image=control_image,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
width=run_w,
height=run_h,
).images[0]
out = out.convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _predict
def build_inpaint_predictor(
cfg: UnifiedInpaintConfig | None = None,
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
"""
统一构建补全预测函数。
"""
cfg = cfg or UnifiedInpaintConfig()
if cfg.backend == InpaintBackend.SDXL_INPAINT:
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
if cfg.backend == InpaintBackend.CONTROLNET:
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
raise ValueError(f"不支持的补全后端: {cfg.backend}")
def build_draw_predictor(
cfg: UnifiedDrawConfig | None = None,
) -> Callable[..., Image.Image]:
"""
构建统一绘图函数:
- 文生图draw(prompt, image=None, ...)
- 图生图draw(prompt, image=ref_image, strength=0.55, ...)
"""
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
cfg = cfg or UnifiedDrawConfig()
app_cfg = load_app_config()
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
base_model,
torch_dtype=torch_dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
).to(device)
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
_enable_memory_opts(pipe_t2i, device)
_enable_memory_opts(pipe_i2i, device)
def _draw(
prompt: str,
image: Image.Image | None = None,
negative_prompt: str = "",
strength: float = 0.55,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
width: int = 1024,
height: int = 1024,
max_side: int = 1024,
) -> Image.Image:
prompt = prompt or ""
negative_prompt = negative_prompt or ""
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
if image is None:
run_w, run_h = _align_size(width, height, max_side=max_side)
out = pipe_t2i(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=run_w,
height=run_h,
).images[0]
return out.convert("RGB")
image = image.convert("RGB")
orig_w, orig_h = image.size
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
else:
image_run = image
out = pipe_i2i(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
).images[0].convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _draw

Submodule python_server/model/Inpaint/sdxl-inpaint added at 29867f540b

Submodule python_server/model/Seg/Mask2Former added at 9b0651c6c1

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Literal, Tuple
import numpy as np
from PIL import Image
@dataclass
class Mask2FormerHFConfig:
"""
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
model_id: HuggingFace 模型 id默认 ADE20K semantic
device: "cuda" | "cpu"
"""
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
device: str = "cuda"
def build_mask2former_hf_predictor(
cfg: Mask2FormerHFConfig | None = None,
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
"""
返回 predictor(image_rgb_uint8) -> label_map(int32)。
"""
cfg = cfg or Mask2FormerHFConfig()
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
model.to(device).eval()
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
@torch.no_grad()
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
pil = Image.fromarray(image_rgb_u8, mode="RGB")
inputs = processor(images=pil, return_tensors="pt").to(device)
outputs = model(**inputs)
# post-process to original size
target_sizes = [(pil.height, pil.width)]
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
return seg.detach().to("cpu").numpy().astype("int32")
return _predict, cfg

View File

@@ -0,0 +1,168 @@
from __future__ import annotations
"""
统一的分割模型加载入口。
当前支持:
- SAM (segment-anything)
- Mask2Former使用 HuggingFace transformers 的语义分割实现)
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import sys
from pathlib import Path
import numpy as np
_THIS_DIR = Path(__file__).resolve().parent
class SegBackend(str, Enum):
SAM = "sam"
MASK2FORMER = "mask2former"
@dataclass
class UnifiedSegConfig:
backend: SegBackend = SegBackend.SAM
# -----------------------------
# SAM (Segment Anything)
# -----------------------------
def _ensure_sam_on_path() -> Path:
sam_root = _THIS_DIR / "segment-anything"
if not sam_root.is_dir():
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
sam_path = str(sam_root)
if sam_path not in sys.path:
sys.path.insert(0, sam_path)
return sam_root
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
import requests
ckpt_dir = sam_root / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
if ckpt_path.is_file():
return ckpt_path
url = (
"https://dl.fbaipublicfiles.com/segment_anything/"
"sam_vit_h_4b8939.pth"
)
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print(
"\r[{}{}] {:.1f}%".format(
"#" * done,
"." * (50 - done),
downloaded * 100 / total,
),
end="",
)
print("\nSAM 权重下载完成。")
return ckpt_path
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
"""
返回一个分割函数:
- 输入RGB uint8 图像 (H, W, 3)
- 输出:语义标签图 (H, W),每个目标一个 int id从 1 开始)
"""
sam_root = _ensure_sam_on_path()
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](
checkpoint=str(ckpt_path),
).to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
masks = mask_generator.generate(image_rgb_u8)
h, w, _ = image_rgb_u8.shape
label_map = np.zeros((h, w), dtype="int32")
for idx, m in enumerate(masks, start=1):
seg = m.get("segmentation")
if seg is None:
continue
label_map[seg.astype(bool)] = idx
return label_map
return _predict
# -----------------------------
# Mask2Former (占位)
# -----------------------------
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
from .mask2former_loader import build_mask2former_hf_predictor
predictor, _ = build_mask2former_hf_predictor()
return predictor
# -----------------------------
# 统一构建函数
# -----------------------------
def build_seg_predictor(
cfg: UnifiedSegConfig | None = None,
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
"""
统一构建分割预测函数。
返回:
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
- 实际使用的 backend
"""
cfg = cfg or UnifiedSegConfig()
if cfg.backend == SegBackend.SAM:
return _make_sam_predictor(), SegBackend.SAM
if cfg.backend == SegBackend.MASK2FORMER:
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
raise ValueError(f"不支持的分割后端: {cfg.backend}")

Submodule python_server/model/Seg/segment-anything added at dca509fe79

View File

@@ -0,0 +1 @@

407
python_server/server.py Normal file
View File

@@ -0,0 +1,407 @@
from __future__ import annotations
import base64
import datetime
import io
import os
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel, Field
from PIL import Image, ImageDraw
from config_loader import load_app_config, get_depth_backend_from_app
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
from model.Seg.seg_loader import UnifiedSegConfig, SegBackend, build_seg_predictor
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
from model.Animation.animation_loader import (
UnifiedAnimationConfig,
AnimationBackend,
build_animation_predictor,
)
APP_ROOT = Path(__file__).resolve().parent
OUTPUT_DIR = APP_ROOT / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title="HFUT Model Server", version="0.1.0")
class ImageInput(BaseModel):
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64不含 data: 前缀)")
model_name: Optional[str] = Field(None, description="模型 key来自 /models")
class DepthRequest(ImageInput):
pass
class SegmentRequest(ImageInput):
pass
class InpaintRequest(ImageInput):
prompt: Optional[str] = Field("", description="补全 prompt")
strength: float = Field(0.8, ge=0.0, le=1.0)
negative_prompt: Optional[str] = Field("", description="负向 prompt")
# 可选 mask白色区域为重绘
mask_b64: Optional[str] = Field(None, description="mask PNG base64可选")
# 推理缩放上限(避免 OOM
max_side: int = Field(1024, ge=128, le=2048)
class AnimateRequest(BaseModel):
model_name: Optional[str] = Field(None, description="模型 key来自 /models")
prompt: str = Field(..., description="文本提示词")
negative_prompt: Optional[str] = Field("", description="负向提示词")
num_inference_steps: int = Field(25, ge=1, le=200)
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
width: int = Field(512, ge=128, le=2048)
height: int = Field(512, ge=128, le=2048)
video_length: int = Field(16, ge=1, le=128)
seed: int = Field(-1, description="-1 表示随机种子")
def _b64_to_pil_image(b64: str) -> Image.Image:
raw = base64.b64decode(b64)
return Image.open(io.BytesIO(raw))
def _pil_image_to_png_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def _depth_to_png16_b64(depth: np.ndarray) -> str:
depth = np.asarray(depth, dtype=np.float32)
dmin = float(depth.min())
dmax = float(depth.max())
if dmax > dmin:
norm = (depth - dmin) / (dmax - dmin)
else:
norm = np.zeros_like(depth, dtype=np.float32)
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
img = Image.fromarray(u16, mode="I;16")
return _pil_image_to_png_b64(img)
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
depth = np.asarray(depth, dtype=np.float32)
dmin = float(depth.min())
dmax = float(depth.max())
if dmax > dmin:
norm = (depth - dmin) / (dmax - dmin)
else:
norm = np.zeros_like(depth, dtype=np.float32)
# 前后端约定:最远=255最近=08-bit
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
img = Image.fromarray(u8, mode="L")
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _default_half_mask(img: Image.Image) -> Image.Image:
w, h = img.size
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([w // 2, 0, w, h], fill=255)
return mask
# -----------------------------
# /models给前端/GUI 使用)
# -----------------------------
@app.get("/models")
def get_models() -> Dict[str, Any]:
"""
返回一个兼容 Qt 前端的 schema
{
"models": {
"depth": { "key": { "name": "..."} ... },
"segment": { ... },
"inpaint": { "key": { "name": "...", "params": [...] } ... }
}
}
"""
return {
"models": {
"depth": {
# 兼容旧配置默认值
"midas": {"name": "MiDaS (default)"},
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
"dpt_large": {"name": "DPT (large)"},
"dpt_hybrid": {"name": "DPT (hybrid)"},
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
},
"segment": {
"sam": {"name": "SAM (vit_h)"},
# 兼容旧配置默认值
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
"mask2former": {"name": "Mask2Former (not implemented)"},
},
"inpaint": {
# 兼容旧配置默认值copy 表示不做补全
"copy": {"name": "Copy (no-op)", "params": []},
"sdxl_inpaint": {
"name": "SDXL Inpaint",
"params": [
{"id": "prompt", "label": "提示词", "optional": True},
],
},
"controlnet": {
"name": "ControlNet Inpaint (canny)",
"params": [
{"id": "prompt", "label": "提示词", "optional": True},
],
},
},
"animation": {
"animatediff": {
"name": "AnimateDiff (Text-to-Video)",
"params": [
{"id": "prompt", "label": "提示词", "optional": False},
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
{"id": "width", "label": "宽度", "optional": True},
{"id": "height", "label": "高度", "optional": True},
{"id": "video_length", "label": "帧数", "optional": True},
{"id": "seed", "label": "随机种子", "optional": True},
],
},
},
}
}
# -----------------------------
# Depth
# -----------------------------
_depth_predictor = None
_depth_backend: DepthBackend | None = None
def _ensure_depth_predictor() -> None:
global _depth_predictor, _depth_backend
if _depth_predictor is not None and _depth_backend is not None:
return
app_cfg = load_app_config()
backend_str = get_depth_backend_from_app(app_cfg)
try:
backend = DepthBackend(backend_str)
except Exception as e:
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
@app.post("/depth")
def depth(req: DepthRequest):
"""
计算深度并直接返回二进制 PNG16-bit 灰度)。
约束:
- 前端不传/不选模型;模型选择写死在后端 config.py
- 成功HTTP 200 + Content-Type: image/png
- 失败HTTP 500detail 为错误信息
"""
try:
_ensure_depth_predictor()
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
depth_arr = _depth_predictor(pil) # type: ignore[misc]
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
return Response(content=png_bytes, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# -----------------------------
# Segment
# -----------------------------
_seg_cache: Dict[str, Any] = {}
def _get_seg_predictor(model_name: str):
if model_name in _seg_cache:
return _seg_cache[model_name]
# 兼容旧默认 key
if model_name == "mask2former_debug":
model_name = "sam"
if model_name == "sam":
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
_seg_cache[model_name] = pred
return pred
if model_name == "mask2former":
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
_seg_cache[model_name] = pred
return pred
raise ValueError(f"未知 segment model_name: {model_name}")
@app.post("/segment")
def segment(req: SegmentRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "sam"
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
rgb = np.array(pil, dtype=np.uint8)
predictor = _get_seg_predictor(model_name)
label_map = predictor(rgb).astype(np.int32)
out_dir = OUTPUT_DIR / "segment"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_label.png"
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
return {"success": True, "label_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
# -----------------------------
# Inpaint
# -----------------------------
_inpaint_cache: Dict[str, Any] = {}
def _get_inpaint_predictor(model_name: str):
if model_name in _inpaint_cache:
return _inpaint_cache[model_name]
if model_name == "copy":
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
return image.convert("RGB")
_inpaint_cache[model_name] = _copy
return _copy
if model_name == "sdxl_inpaint":
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
_inpaint_cache[model_name] = pred
return pred
if model_name == "controlnet":
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
_inpaint_cache[model_name] = pred
return pred
raise ValueError(f"未知 inpaint model_name: {model_name}")
@app.post("/inpaint")
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "sdxl_inpaint"
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
if req.mask_b64:
mask = _b64_to_pil_image(req.mask_b64).convert("L")
else:
mask = _default_half_mask(pil)
predictor = _get_inpaint_predictor(model_name)
out = predictor(
pil,
mask,
req.prompt or "",
req.negative_prompt or "",
strength=req.strength,
max_side=req.max_side,
)
out_dir = OUTPUT_DIR / "inpaint"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_inpaint.png"
out.save(out_path)
return {"success": True, "output_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
_animation_cache: Dict[str, Any] = {}
def _get_animation_predictor(model_name: str):
if model_name in _animation_cache:
return _animation_cache[model_name]
if model_name == "animatediff":
pred, _ = build_animation_predictor(
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
)
_animation_cache[model_name] = pred
return pred
raise ValueError(f"未知 animation model_name: {model_name}")
@app.post("/animate")
def animate(req: AnimateRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "animatediff"
predictor = _get_animation_predictor(model_name)
result_path = predictor(
prompt=req.prompt,
negative_prompt=req.negative_prompt or "",
num_inference_steps=req.num_inference_steps,
guidance_scale=req.guidance_scale,
width=req.width,
height=req.height,
video_length=req.video_length,
seed=req.seed,
)
out_dir = OUTPUT_DIR / "animation"
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"{model_name}_{ts}.gif"
shutil.copy2(result_path, out_path)
return {"success": True, "output_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from pathlib import Path
import shutil
from model.Animation.animation_loader import (
build_animation_predictor,
UnifiedAnimationConfig,
AnimationBackend,
)
# -----------------------------
# 配置区(按需修改)
# -----------------------------
OUTPUT_DIR = "outputs/test_animation"
ANIMATION_BACKEND = AnimationBackend.ANIMATEDIFF
OUTPUT_FORMAT = "png_sequence" # "gif" | "png_sequence"
PROMPT = "a cinematic mountain landscape, camera slowly pans left"
NEGATIVE_PROMPT = "blurry, low quality"
NUM_INFERENCE_STEPS = 25
GUIDANCE_SCALE = 8.0
WIDTH = 512
HEIGHT = 512
VIDEO_LENGTH = 16
SEED = -1
CONTROL_IMAGE_PATH = "path/to/your_image.png"
def main() -> None:
base_dir = Path(__file__).resolve().parent
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
predictor, used_backend = build_animation_predictor(
UnifiedAnimationConfig(backend=ANIMATION_BACKEND)
)
if CONTROL_IMAGE_PATH.strip() in {"", "path/to/your_image.png"}:
raise ValueError("请先设置 CONTROL_IMAGE_PATH 为你的输入图片路径png/jpg")
control_image = (base_dir / CONTROL_IMAGE_PATH).resolve()
if not control_image.is_file():
raise FileNotFoundError(f"control image not found: {control_image}")
result_path = predictor(
prompt=PROMPT,
negative_prompt=NEGATIVE_PROMPT,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
width=WIDTH,
height=HEIGHT,
video_length=VIDEO_LENGTH,
seed=SEED,
control_image_path=str(control_image),
output_format=OUTPUT_FORMAT,
)
source = Path(result_path)
if OUTPUT_FORMAT == "png_sequence":
out_seq_dir = out_dir / f"{used_backend.value}_frames"
if out_seq_dir.exists():
shutil.rmtree(out_seq_dir)
shutil.copytree(source, out_seq_dir)
print(f"[Animation] backend={used_backend.value}, saved={out_seq_dir}")
return
out_path = out_dir / f"{used_backend.value}.gif"
out_path.write_bytes(source.read_bytes())
print(f"[Animation] backend={used_backend.value}, saved={out_path}")
if __name__ == "__main__":
main()

101
python_server/test_depth.py Normal file
View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
# 解除大图限制
Image.MAX_IMAGE_PIXELS = None
from model.Depth.depth_loader import build_depth_predictor, UnifiedDepthConfig, DepthBackend
# ================= 配置区 =================
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
OUTPUT_DIR = "outputs/test_depth_v4"
DEPTH_BACKEND = DepthBackend.DEPTH_ANYTHING_V2
# 边缘捕捉参数
# 增大这个值会让边缘更细,减小会让边缘更粗(捕捉更多微弱信息)
EDGE_TOP_PERCENTILE = 96.0 # 选取梯度最强的前 7.0% 的像素
# 局部增强的灵敏度,建议在 2.0 - 5.0 之间
CLAHE_CLIP_LIMIT = 3.0
# 形态学核大小
MORPH_SIZE = 5
# =========================================
def _extract_robust_edges(depth_norm: np.ndarray) -> np.ndarray:
"""
通过局部增强和 Sobel 梯度提取闭合边缘
"""
# 1. 转换为 8 位灰度
depth_u8 = (depth_norm * 255).astype(np.uint8)
# 2. 【核心步骤】CLAHE 局部自适应对比度增强
# 这会强行放大古画中细微的建筑/人物深度差
clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=(16, 16))
enhanced_depth = clahe.apply(depth_u8)
# 3. 高斯模糊:减少数字化噪声
blurred = cv2.GaussianBlur(enhanced_depth, (5, 5), 0)
# 4. Sobel 算子计算梯度强度
grad_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)
grad_mag = np.sqrt(grad_x**2 + grad_y**2)
# 5. 【统计学阈值】不再死守固定数值,而是选 Top X%
threshold = np.percentile(grad_mag, EDGE_TOP_PERCENTILE)
binary_edges = (grad_mag >= threshold).astype(np.uint8) * 255
# 6. 形态学闭合:桥接裂缝,让线条连起来
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (MORPH_SIZE, MORPH_SIZE))
closed = cv2.morphologyEx(binary_edges, cv2.MORPH_CLOSE, kernel)
# 7. 再次轻微膨胀,为 SAM 提供更好的引导范围
final_mask = cv2.dilate(closed, kernel, iterations=1)
return final_mask
def main() -> None:
base_dir = Path(__file__).resolve().parent
img_path = Path(INPUT_IMAGE)
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
# 1. 初始化深度模型
predictor, used_backend = build_depth_predictor(UnifiedDepthConfig(backend=DEPTH_BACKEND))
# 2. 加载图像
print(f"[Loading] 正在处理: {img_path.name}")
img_pil = Image.open(img_path).convert("RGB")
w_orig, h_orig = img_pil.size
# 3. 深度预测
print(f"[Depth] 正在进行深度估计 (Large Image)...")
depth = np.asarray(predictor(img_pil), dtype=np.float32).squeeze()
# 4. 归一化与保存
dmin, dmax = depth.min(), depth.max()
depth_norm = (depth - dmin) / (dmax - dmin + 1e-8)
depth_u16 = (depth_norm * 65535.0).astype(np.uint16)
Image.fromarray(depth_u16).save(out_dir / f"{img_path.stem}.depth.png")
# 5. 提取强鲁棒性边缘
print(f"[Edge] 正在应用 CLAHE + Sobel 增强算法提取边缘...")
edge_mask = _extract_robust_edges(depth_norm)
# 6. 导出
mask_path = out_dir / f"{img_path.stem}.edge_mask_robust.png"
Image.fromarray(edge_mask).save(mask_path)
edge_ratio = float((edge_mask > 0).sum()) / float(edge_mask.size)
print("-" * 30)
print(f"提取完成!")
print(f"边缘密度: {edge_ratio:.2%} (目标通常应在 1% ~ 8% 之间)")
print(f"如果 Mask 依然太黑,请调低 EDGE_TOP_PERCENTILE (如 90.0)")
print(f"如果 Mask 太乱,请调高 EDGE_TOP_PERCENTILE (如 96.0)")
print("-" * 30)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,284 @@
from __future__ import annotations
from datetime import datetime
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from model.Inpaint.inpaint_loader import (
build_inpaint_predictor,
UnifiedInpaintConfig,
InpaintBackend,
build_draw_predictor,
UnifiedDrawConfig,
)
# -----------------------------
# 配置区(按需修改)
# -----------------------------
# 任务模式:
# - "inpaint": 补全
# - "draw": 绘图(文生图 / 图生图)
TASK_MODE = "draw"
# 指向你的输入图像,例如 image_0.png
INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png"
INPUT_MASK = "" # 为空表示不使用手工 mask
OUTPUT_DIR = "outputs/test_inpaint_v2"
MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK)
USE_AUTO_MASK = True # True: 自动从图像推断补全区域
AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时接近黑色像素阈值
AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖
FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask
# 透明输出控制
PRESERVE_TRANSPARENCY = True # 保持输出透明
EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明
NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景
INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果
# -----------------------------
# 关键 Prompt 修改:只生成树
# -----------------------------
# 细化 PROMPT强调树的种类、叶子和风格使其与原图融合
PROMPT = (
"A highly detailed traditional Chinese ink brush painting. "
"Restore and complete the existing trees naturally. "
"Extend the complex tree branches and dense, varied green and teal leaves. "
"Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. "
"Focus solely on vegetation. "
"Maintain the light beige background color and texture."
)
# 使用 NEGATIVE_PROMPT 显式排除不想要的内容
NEGATIVE_PROMPT = (
"buildings, architecture, houses, pavilions, temples, windows, doors, "
"people, figures, persons, characters, figures, clothing, faces, hands, "
"text, writing, characters, words, letters, signatures, seals, stamps, "
"calligraphy, objects, artifacts, boxes, baskets, tools, "
"extra branches crossing unnaturally, bad composition, watermark, signature"
)
STRENGTH = 0.8 # 保持较高强度以进行生成
GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt
NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节
MAX_SIDE = 1024
CONTROLNET_SCALE = 1.0
# -----------------------------
# 绘图draw参数
# -----------------------------
# DRAW_INPUT_IMAGE 为空时:文生图
# DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘)
DRAW_INPUT_IMAGE = ""
DRAW_PROMPT = """
Chinese ink wash painting, vast snowy river under a pale sky,
a small lonely boat at the horizon where water meets sky,
an old fisherman wearing a straw hat sits at the stern, fishing quietly,
gentle snowfall, misty atmosphere, distant mountains barely visible,
minimalist composition, large empty space, soft brush strokes,
calm, cold, and silent mood, poetic and serene
"""
DRAW_NEGATIVE_PROMPT = """
blurry, low quality, many people, bright colors, modern elements, crowded, noisy
"""
DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘
DRAW_GUIDANCE_SCALE = 9
DRAW_STEPS = 64
DRAW_WIDTH = 2560
DRAW_HEIGHT = 1440
DRAW_MAX_SIDE = 2560
def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray:
out = mask.astype(bool)
for _ in range(max(0, iterations)):
p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False)
out = (
p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:]
| p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:]
| p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:]
)
return out
def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image:
w, h = img_size
x1, y1, x2, y2 = MASK_RECT
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w - 1))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h - 1))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x1, y1, x2, y2], fill=255)
return mask
def _auto_mask_from_image(img_path: Path) -> Image.Image:
"""
自动推断缺失区域:
1) 若输入带 alpha透明区域作为 mask
2) 否则将“接近黑色”区域作为候选缺失区域
"""
raw = Image.open(img_path)
arr = np.asarray(raw)
if arr.ndim == 3 and arr.shape[2] == 4:
alpha = arr[:, :, 3]
mask_bool = alpha < 250
else:
rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8)
# 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失
dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1)
mask_bool = dark
mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER)
return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L")
def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray:
"""
生成输出 alpha
- 输入若有 alpha优先沿用
- 输入若无 alpha则把接近黑色区域视作透明背景
"""
raw = Image.open(img_path)
arr = np.asarray(raw)
if arr.ndim == 3 and arr.shape[2] == 4:
return arr[:, :, 3].astype(np.uint8)
rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8)
non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1)
return (non_black.astype(np.uint8) * 255)
def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image:
if INPUT_MASK:
raw_mask_path = Path(INPUT_MASK)
mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path)
if mask_path.is_file():
return Image.open(mask_path).convert("L")
if USE_AUTO_MASK:
auto = _auto_mask_from_image(img_path)
auto_arr = np.asarray(auto, dtype=np.uint8)
if (auto_arr > 0).any():
return auto
if not FALLBACK_TO_RECT_MASK:
raise ValueError("自动mask为空请检查输入图像是否存在透明/黑色缺失区域。")
return _make_rect_mask(img_rgb.size)
def _resolve_path(base_dir: Path, p: str) -> Path:
raw = Path(p)
return raw if raw.is_absolute() else (base_dir / raw)
def run_inpaint_test(base_dir: Path, out_dir: Path) -> None:
img_path = _resolve_path(base_dir, INPUT_IMAGE)
if not img_path.is_file():
raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}")
predictor, used_backend = build_inpaint_predictor(
UnifiedInpaintConfig(backend=INPAINT_BACKEND)
)
img = Image.open(img_path).convert("RGB")
mask = _load_or_make_mask(base_dir, img_path, img)
mask_out = out_dir / f"{img_path.stem}.mask_used.png"
mask.save(mask_out)
kwargs = dict(
strength=STRENGTH,
guidance_scale=GUIDANCE_SCALE,
num_inference_steps=NUM_INFERENCE_STEPS,
max_side=MAX_SIDE,
)
if used_backend == InpaintBackend.CONTROLNET:
kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE
out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs)
out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png"
if PRESERVE_TRANSPARENCY:
alpha = _build_alpha_from_input(img_path, img)
mask_u8 = np.asarray(mask, dtype=np.uint8)
if EXPAND_ALPHA_WITH_MASK:
alpha = np.maximum(alpha, mask_u8)
out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8)
out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1)
Image.fromarray(out_rgba, mode="RGBA").save(out_path)
else:
out.save(out_path)
ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1])
print(f"[Inpaint] backend={used_backend.value}, saved={out_path}")
print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}")
def run_draw_test(base_dir: Path, out_dir: Path) -> None:
"""
绘图测试:
- 文生图DRAW_INPUT_IMAGE=""
- 图生图DRAW_INPUT_IMAGE 指向参考图
"""
draw_predictor = build_draw_predictor(UnifiedDrawConfig())
ref_image: Image.Image | None = None
mode = "text2img"
if DRAW_INPUT_IMAGE:
ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE)
if not ref_path.is_file():
raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}")
ref_image = Image.open(ref_path).convert("RGB")
mode = "img2img"
out = draw_predictor(
prompt=DRAW_PROMPT,
image=ref_image,
negative_prompt=DRAW_NEGATIVE_PROMPT,
strength=DRAW_STRENGTH,
guidance_scale=DRAW_GUIDANCE_SCALE,
num_inference_steps=DRAW_STEPS,
width=DRAW_WIDTH,
height=DRAW_HEIGHT,
max_side=DRAW_MAX_SIDE,
)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"draw_{mode}_{ts}.png"
out.save(out_path)
print(f"[Draw] mode={mode}, saved={out_path}")
def main() -> None:
base_dir = Path(__file__).resolve().parent
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
if TASK_MODE == "inpaint":
run_inpaint_test(base_dir, out_dir)
return
if TASK_MODE == "draw":
run_draw_test(base_dir, out_dir)
return
raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw")
if __name__ == "__main__":
main()

163
python_server/test_seg.py Normal file
View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageDraw
from scipy.ndimage import label as nd_label
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
# ================= 配置区 =================
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
OUTPUT_DIR = "outputs/test_seg_v2"
SEG_BACKEND = SegBackend.SAM
# 目标筛选参数
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
# 视觉效果
MASK_ALPHA = 0.4
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
SAVE_OBJECT_PNG = True
# =========================================
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
h, w = arr.shape[:2]
if max_side <= 0 or max(h, w) <= max_side:
return arr, (h, w), (h, w)
scale = float(max_side) / float(max(h, w))
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
resample = Image.NEAREST if is_mask else Image.BICUBIC
pil = Image.fromarray(arr)
out = pil.resize((run_w, run_h), resample=resample)
return np.asarray(out), (h, w), (run_h, run_w)
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
"""
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
"""
# 确保是布尔类型
mask_bool = edge_mask > 127
# 连通域标记
labeled_array, num_features = nd_label(mask_bool)
prompts = []
component_info = []
for i in range(1, num_features + 1):
coords = np.argwhere(labeled_array == i)
area = len(coords)
if area < 100: continue # 过滤噪声
# 取几何中心作为引导点
center_y, center_x = np.median(coords, axis=0).astype(int)
component_info.append(((center_x, center_y), area))
# 按面积排序,优先处理大面积线条覆盖的物体
component_info.sort(key=lambda x: x[1], reverse=True)
for pt, _ in component_info[:max_components]:
prompts.append({
"point_coords": np.array([pt], dtype=np.float32),
"point_labels": np.array([1], dtype=np.int32)
})
return prompts
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
obj_dir = out_dir / f"{stem}.objects"
obj_dir.mkdir(parents=True, exist_ok=True)
for idx, lb in enumerate(targets, start=1):
mask = (label_map == lb)
ys, xs = np.where(mask)
if len(ys) == 0: continue
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
rgb_crop = rgb[y1:y2+1, x1:x2+1]
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
def main():
# 1. 加载资源
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
out_dir = Path(__file__).parent / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
# 2. 准备推理尺寸 (缩放以节省显存)
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
# 3. 初始化 SAM直接使用 SamPredictor避免 wrapper 接口不支持 point/bbox prompt
import torch
_ensure_sam_on_path()
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
predictor = SamPredictor(sam)
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
# 4. 提取引导点
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
# 5. 执行推理
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
predictor.set_image(rgb_run)
for idx, p in enumerate(prompts, start=1):
# multimask_output=True 可以获得更稳定的结果
masks, scores, _ = predictor.predict(
point_coords=p["point_coords"],
point_labels=p["point_labels"],
multimask_output=True
)
# 挑选得分最高的 mask
best_mask = masks[np.argmax(scores)]
# 只有面积足够才保留
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
final_label_map_run[best_mask > 0] = idx
# 6. 后处理与映射回原图
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
# 7. 导出与可视化
unique_labels = [l for l in np.unique(label_map) if l > 0]
# 绘制可视化图
marked_img = rgb_orig.copy()
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
# 混合颜色显示
overlay = rgb_orig.astype(np.float32)
for lb in unique_labels:
m = (label_map == lb)
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
# 简单边缘处理
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
if SAVE_OBJECT_PNG:
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
if __name__ == "__main__":
main()