initial commit
This commit is contained in:
1
python_server/model/Animation/AnimateDiff
Submodule
1
python_server/model/Animation/AnimateDiff
Submodule
Submodule python_server/model/Animation/AnimateDiff added at e92bd5671b
12
python_server/model/Animation/__init__.py
Normal file
12
python_server/model/Animation/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .animation_loader import (
|
||||
AnimationBackend,
|
||||
UnifiedAnimationConfig,
|
||||
build_animation_predictor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnimationBackend",
|
||||
"UnifiedAnimationConfig",
|
||||
"build_animation_predictor",
|
||||
]
|
||||
|
||||
268
python_server/model/Animation/animation_loader.py
Normal file
268
python_server/model/Animation/animation_loader.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Unified animation model loading entry.
|
||||
|
||||
Current support:
|
||||
- AnimateDiff (script-based invocation)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
get_animatediff_root_from_app,
|
||||
get_animatediff_pretrained_model_from_app,
|
||||
get_animatediff_inference_config_from_app,
|
||||
get_animatediff_motion_module_from_app,
|
||||
get_animatediff_dreambooth_model_from_app,
|
||||
get_animatediff_lora_model_from_app,
|
||||
get_animatediff_lora_alpha_from_app,
|
||||
get_animatediff_without_xformers_from_app,
|
||||
)
|
||||
|
||||
|
||||
class AnimationBackend(str, Enum):
|
||||
ANIMATEDIFF = "animatediff"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedAnimationConfig:
|
||||
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
|
||||
# Optional overrides. If None, values come from app config.
|
||||
animate_diff_root: str | None = None
|
||||
pretrained_model_path: str | None = None
|
||||
inference_config: str | None = None
|
||||
motion_module: str | None = None
|
||||
dreambooth_model: str | None = None
|
||||
lora_model: str | None = None
|
||||
lora_alpha: float | None = None
|
||||
without_xformers: bool | None = None
|
||||
controlnet_path: str | None = None
|
||||
controlnet_config: str | None = None
|
||||
|
||||
|
||||
def _yaml_string(value: str) -> str:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _resolve_root(root_cfg: str) -> Path:
|
||||
root = Path(root_cfg)
|
||||
if not root.is_absolute():
|
||||
root = Path(__file__).resolve().parents[2] / root_cfg
|
||||
return root.resolve()
|
||||
|
||||
|
||||
def _make_animatediff_predictor(
|
||||
cfg: UnifiedAnimationConfig,
|
||||
) -> Callable[..., Path]:
|
||||
app_cfg = load_app_config()
|
||||
|
||||
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
|
||||
script_path = root / "scripts" / "animate.py"
|
||||
samples_dir = root / "samples"
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not script_path.is_file():
|
||||
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
|
||||
|
||||
pretrained_model_path = (
|
||||
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
|
||||
)
|
||||
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
|
||||
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
|
||||
dreambooth_model = cfg.dreambooth_model
|
||||
if dreambooth_model is None:
|
||||
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
|
||||
lora_model = cfg.lora_model
|
||||
if lora_model is None:
|
||||
lora_model = get_animatediff_lora_model_from_app(app_cfg)
|
||||
lora_alpha = cfg.lora_alpha
|
||||
if lora_alpha is None:
|
||||
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
|
||||
without_xformers = cfg.without_xformers
|
||||
if without_xformers is None:
|
||||
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
|
||||
|
||||
def _predict(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 25,
|
||||
guidance_scale: float = 8.0,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
video_length: int = 16,
|
||||
seed: int = -1,
|
||||
control_image_path: str | None = None,
|
||||
output_format: str = "gif",
|
||||
) -> Path:
|
||||
if output_format not in {"gif", "png_sequence"}:
|
||||
raise ValueError("output_format must be 'gif' or 'png_sequence'")
|
||||
prompt_value = prompt.strip()
|
||||
if not prompt_value:
|
||||
raise ValueError("prompt must not be empty")
|
||||
|
||||
negative_prompt_value = negative_prompt or ""
|
||||
|
||||
motion_module_line = (
|
||||
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
|
||||
)
|
||||
dreambooth_line = (
|
||||
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
|
||||
)
|
||||
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
|
||||
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
|
||||
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
|
||||
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
|
||||
control_image_line = ""
|
||||
if control_image_path:
|
||||
control_image = Path(control_image_path).expanduser().resolve()
|
||||
if not control_image.is_file():
|
||||
raise FileNotFoundError(f"control_image_path not found: {control_image}")
|
||||
control_image_line = (
|
||||
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
|
||||
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
|
||||
" controlnet_images:\n"
|
||||
f' - {_yaml_string(str(control_image))}\n'
|
||||
" controlnet_image_indexs:\n"
|
||||
" - 0\n"
|
||||
)
|
||||
|
||||
config_text = (
|
||||
"- prompt:\n"
|
||||
f" - {_yaml_string(prompt_value)}\n"
|
||||
" n_prompt:\n"
|
||||
f" - {_yaml_string(negative_prompt_value)}\n"
|
||||
f" steps: {int(num_inference_steps)}\n"
|
||||
f" guidance_scale: {float(guidance_scale)}\n"
|
||||
f" W: {int(width)}\n"
|
||||
f" H: {int(height)}\n"
|
||||
f" L: {int(video_length)}\n"
|
||||
" seed:\n"
|
||||
f" - {int(seed)}\n"
|
||||
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
|
||||
)
|
||||
|
||||
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
|
||||
cfg_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
prefix="animatediff_cfg_",
|
||||
dir=str(root),
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
cfg_file_path = Path(cfg_file.name)
|
||||
try:
|
||||
cfg_file.write(config_text)
|
||||
cfg_file.flush()
|
||||
cfg_file.close()
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--pretrained-model-path",
|
||||
pretrained_model_path,
|
||||
"--inference-config",
|
||||
inference_config,
|
||||
"--config",
|
||||
str(cfg_file_path),
|
||||
"--L",
|
||||
str(int(video_length)),
|
||||
"--W",
|
||||
str(int(width)),
|
||||
"--H",
|
||||
str(int(height)),
|
||||
]
|
||||
if without_xformers:
|
||||
cmd.append("--without-xformers")
|
||||
if output_format == "png_sequence":
|
||||
cmd.append("--save-png-sequence")
|
||||
|
||||
env = dict(os.environ)
|
||||
existing_pythonpath = env.get("PYTHONPATH", "")
|
||||
root_pythonpath = str(root)
|
||||
env["PYTHONPATH"] = (
|
||||
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
|
||||
)
|
||||
|
||||
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(
|
||||
command,
|
||||
cwd=str(root),
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
proc = _run_once(cmd)
|
||||
except subprocess.CalledProcessError as first_error:
|
||||
stderr_text = first_error.stderr or ""
|
||||
should_retry_without_xformers = (
|
||||
not without_xformers
|
||||
and "--without-xformers" not in cmd
|
||||
and (
|
||||
"memory_efficient_attention" in stderr_text
|
||||
or "AcceleratorError" in stderr_text
|
||||
or "invalid configuration argument" in stderr_text
|
||||
)
|
||||
)
|
||||
if not should_retry_without_xformers:
|
||||
raise
|
||||
|
||||
retry_cmd = [*cmd, "--without-xformers"]
|
||||
proc = _run_once(retry_cmd)
|
||||
_ = proc
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
"AnimateDiff inference failed.\n"
|
||||
f"stdout:\n{e.stdout}\n"
|
||||
f"stderr:\n{e.stderr}"
|
||||
) from e
|
||||
finally:
|
||||
try:
|
||||
cfg_file_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
|
||||
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
|
||||
if not candidates:
|
||||
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
|
||||
if not candidates:
|
||||
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
|
||||
|
||||
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
|
||||
if output_format == "png_sequence":
|
||||
frames_root = latest / "sample_frames"
|
||||
if not frames_root.is_dir():
|
||||
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
|
||||
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if not frame_dirs:
|
||||
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
|
||||
return frame_dirs[0].resolve()
|
||||
return (latest / "sample.gif").resolve()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_animation_predictor(
|
||||
cfg: UnifiedAnimationConfig | None = None,
|
||||
) -> tuple[Callable[..., Path], AnimationBackend]:
|
||||
cfg = cfg or UnifiedAnimationConfig()
|
||||
|
||||
if cfg.backend == AnimationBackend.ANIMATEDIFF:
|
||||
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
|
||||
|
||||
raise ValueError(f"Unsupported animation backend: {cfg.backend}")
|
||||
|
||||
1
python_server/model/Depth/DPT
Submodule
1
python_server/model/Depth/DPT
Submodule
Submodule python_server/model/Depth/DPT added at cd3fe90bb4
1
python_server/model/Depth/Depth-Anything-V2
Submodule
1
python_server/model/Depth/Depth-Anything-V2
Submodule
Submodule python_server/model/Depth/Depth-Anything-V2 added at e5a2732d3e
1
python_server/model/Depth/MiDaS
Submodule
1
python_server/model/Depth/MiDaS
Submodule
Submodule python_server/model/Depth/MiDaS added at 454597711a
1
python_server/model/Depth/ZoeDepth
Submodule
1
python_server/model/Depth/ZoeDepth
Submodule
Submodule python_server/model/Depth/ZoeDepth added at d87f17b2f5
1
python_server/model/Depth/__init__.py
Normal file
1
python_server/model/Depth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
|
||||
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
|
||||
if _DA_REPO_ROOT.is_dir():
|
||||
da_path = str(_DA_REPO_ROOT)
|
||||
if da_path not in sys.path:
|
||||
sys.path.insert(0, da_path)
|
||||
|
||||
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
|
||||
|
||||
|
||||
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthAnythingV2Config:
|
||||
"""
|
||||
Depth Anything V2 模型选择配置。
|
||||
|
||||
encoder: "vits" | "vitb" | "vitl" | "vitg"
|
||||
device: "cuda" | "cpu"
|
||||
input_size: 推理时的输入分辨率(短边),参考官方 demo,默认 518。
|
||||
"""
|
||||
|
||||
encoder: EncoderName = "vitl"
|
||||
device: str = "cuda"
|
||||
input_size: int = 518
|
||||
|
||||
|
||||
_MODEL_CONFIGS = {
|
||||
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
||||
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
||||
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
||||
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
|
||||
}
|
||||
|
||||
|
||||
_DA_V2_WEIGHTS_URLS = {
|
||||
# 官方权重托管在 HuggingFace:
|
||||
# - Small -> vits
|
||||
# - Base -> vitb
|
||||
# - Large -> vitl
|
||||
# - Giant -> vitg
|
||||
# 如需替换为国内镜像,可直接修改这些 URL。
|
||||
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
|
||||
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
|
||||
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
|
||||
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DA_V2_WEIGHTS_URLS.get(encoder)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 encoder='{encoder}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
|
||||
print("\n权重下载完成。")
|
||||
|
||||
|
||||
def load_depth_anything_v2_from_config(
|
||||
cfg: DepthAnythingV2Config,
|
||||
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
|
||||
"""
|
||||
根据配置加载 Depth Anything V2 模型与对应配置。
|
||||
|
||||
说明:
|
||||
- 权重文件路径遵循官方命名约定:
|
||||
checkpoints/depth_anything_v2_{encoder}.pth
|
||||
例如:depth_anything_v2_vitl.pth
|
||||
- 请确保上述权重文件已下载到
|
||||
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
|
||||
"""
|
||||
if cfg.encoder not in _MODEL_CONFIGS:
|
||||
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
|
||||
|
||||
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
|
||||
_download_if_missing(cfg.encoder, ckpt_path)
|
||||
|
||||
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
|
||||
|
||||
state_dict = torch.load(str(ckpt_path), map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if cfg.device.startswith("cuda") and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
model = model.to(device).eval()
|
||||
cfg = DepthAnythingV2Config(
|
||||
encoder=cfg.encoder,
|
||||
device=device,
|
||||
input_size=cfg.input_size,
|
||||
)
|
||||
return model, cfg
|
||||
|
||||
|
||||
def infer_depth_anything_v2(
|
||||
model: DepthAnythingV2,
|
||||
image_bgr: np.ndarray,
|
||||
input_size: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
|
||||
"""
|
||||
depth = model.infer_image(image_bgr, input_size)
|
||||
depth = np.asarray(depth, dtype="float32")
|
||||
return depth
|
||||
|
||||
148
python_server/model/Depth/depth_loader.py
Normal file
148
python_server/model/Depth/depth_loader.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的深度模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- ZoeDepth(三种:ZoeD_N / ZoeD_K / ZoeD_NK)
|
||||
- Depth Anything V2(四种 encoder:vits / vitb / vitl / vitg)
|
||||
|
||||
未来如果要加 MiDaS / DPT,只需要在这里再接一层即可。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
build_zoe_config_from_app,
|
||||
build_depth_anything_v2_config_from_app,
|
||||
build_dpt_config_from_app,
|
||||
build_midas_config_from_app,
|
||||
)
|
||||
from .zoe_loader import load_zoe_from_config
|
||||
from .depth_anything_v2_loader import (
|
||||
load_depth_anything_v2_from_config,
|
||||
infer_depth_anything_v2,
|
||||
)
|
||||
from .dpt_loader import load_dpt_from_config, infer_dpt
|
||||
from .midas_loader import load_midas_from_config, infer_midas
|
||||
|
||||
|
||||
class DepthBackend(str, Enum):
|
||||
"""统一的深度模型后端类型。"""
|
||||
|
||||
ZOEDEPTH = "zoedepth"
|
||||
DEPTH_ANYTHING_V2 = "depth_anything_v2"
|
||||
DPT = "dpt"
|
||||
MIDAS = "midas"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDepthConfig:
|
||||
"""
|
||||
统一深度配置。
|
||||
|
||||
backend: 使用哪个后端
|
||||
device: 强制设备(可选),不填则使用 config.py 中的设置
|
||||
"""
|
||||
|
||||
backend: DepthBackend = DepthBackend.ZOEDEPTH
|
||||
device: str | None = None
|
||||
|
||||
|
||||
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
zoe_cfg = build_zoe_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
zoe_cfg.device = device_override
|
||||
|
||||
model, _ = load_zoe_from_config(zoe_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
|
||||
return np.asarray(depth, dtype="float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
da_cfg.device = device_override
|
||||
|
||||
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
# Depth Anything V2 的 infer_image 接收 BGR uint8
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
|
||||
bgr = rgb[:, :, ::-1]
|
||||
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
dpt_cfg = build_dpt_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
dpt_cfg.device = device_override
|
||||
|
||||
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
midas_cfg = build_midas_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
midas_cfg.device = device_override
|
||||
|
||||
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
||||
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_depth_predictor(
|
||||
cfg: UnifiedDepthConfig | None = None,
|
||||
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
|
||||
"""
|
||||
统一构建深度预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
|
||||
- 实际使用的 backend 类型
|
||||
"""
|
||||
cfg = cfg or UnifiedDepthConfig()
|
||||
|
||||
if cfg.backend == DepthBackend.ZOEDEPTH:
|
||||
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
|
||||
|
||||
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
|
||||
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
|
||||
|
||||
if cfg.backend == DepthBackend.DPT:
|
||||
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
|
||||
|
||||
if cfg.backend == DepthBackend.MIDAS:
|
||||
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
|
||||
|
||||
raise ValueError(f"不支持的深度后端: {cfg.backend}")
|
||||
|
||||
156
python_server/model/Depth/dpt_loader.py
Normal file
156
python_server/model/Depth/dpt_loader.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
|
||||
if _DPT_REPO_ROOT.is_dir():
|
||||
dpt_path = str(_DPT_REPO_ROOT)
|
||||
if dpt_path not in sys.path:
|
||||
sys.path.insert(0, dpt_path)
|
||||
|
||||
from dpt.models import DPTDepthModel # type: ignore[import]
|
||||
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
|
||||
from torchvision.transforms import Compose
|
||||
import cv2
|
||||
|
||||
|
||||
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPTConfig:
|
||||
model_type: DPTModelType = "dpt_large"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_DPT_WEIGHTS_URLS = {
|
||||
# 官方 DPT 模型权重托管在:
|
||||
# https://github.com/isl-org/DPT#models
|
||||
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DPT_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 DPT 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nDPT 权重下载完成。")
|
||||
|
||||
|
||||
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
|
||||
"""
|
||||
加载 DPT 模型与对应的预处理 transform。
|
||||
"""
|
||||
ckpt_name = {
|
||||
"dpt_large": "dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
||||
}[cfg.model_type]
|
||||
|
||||
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
if cfg.model_type == "dpt_large":
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
else:
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model.to(device).eval()
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method="minimal",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
cfg = DPTConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform
|
||||
|
||||
|
||||
def infer_dpt(
|
||||
model: DPTDepthModel,
|
||||
transform: Compose,
|
||||
image_bgr: np.ndarray,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
sample = transform({"image": img})["image"]
|
||||
|
||||
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = model(input_batch)
|
||||
prediction = (
|
||||
torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=img.shape[:2],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
return prediction.astype("float32")
|
||||
|
||||
127
python_server/model/Depth/midas_loader.py
Normal file
127
python_server/model/Depth/midas_loader.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
|
||||
if _MIDAS_REPO_ROOT.is_dir():
|
||||
midas_path = str(_MIDAS_REPO_ROOT)
|
||||
if midas_path not in sys.path:
|
||||
sys.path.insert(0, midas_path)
|
||||
|
||||
from midas.model_loader import load_model, default_models # type: ignore[import]
|
||||
import utils # from MiDaS repo
|
||||
|
||||
|
||||
MiDaSModelType = Literal[
|
||||
"dpt_beit_large_512",
|
||||
"dpt_swin2_large_384",
|
||||
"dpt_swin2_tiny_256",
|
||||
"dpt_levit_224",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiDaSConfig:
|
||||
model_type: MiDaSModelType = "dpt_beit_large_512"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_MIDAS_WEIGHTS_URLS = {
|
||||
# 官方权重参见 MiDaS 仓库 README
|
||||
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
|
||||
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
|
||||
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
|
||||
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _MIDAS_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nMiDaS 权重下载完成。")
|
||||
|
||||
|
||||
def load_midas_from_config(
|
||||
cfg: MiDaSConfig,
|
||||
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
|
||||
"""
|
||||
加载 MiDaS 模型与对应 transform。
|
||||
返回: model, cfg, transform, net_w, net_h
|
||||
"""
|
||||
# default_models 中给了默认权重路径名
|
||||
model_info = default_models[cfg.model_type]
|
||||
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model, transform, net_w, net_h = load_model(
|
||||
device=torch.device(device),
|
||||
model_path=str(ckpt_path),
|
||||
model_type=cfg.model_type,
|
||||
optimize=False,
|
||||
height=None,
|
||||
square=False,
|
||||
)
|
||||
|
||||
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform, net_w, net_h
|
||||
|
||||
|
||||
def infer_midas(
|
||||
model: torch.nn.Module,
|
||||
transform: callable,
|
||||
image_rgb: np.ndarray,
|
||||
net_w: int,
|
||||
net_h: int,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
image = transform({"image": image_rgb})["image"]
|
||||
prediction = utils.process(
|
||||
torch.device(device),
|
||||
model,
|
||||
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
|
||||
image=image,
|
||||
input_size=(net_w, net_h),
|
||||
target_size=image_rgb.shape[1::-1],
|
||||
optimize=False,
|
||||
use_camera=False,
|
||||
)
|
||||
return np.asarray(prediction, dtype="float32").squeeze()
|
||||
|
||||
74
python_server/model/Depth/zoe_loader.py
Normal file
74
python_server/model/Depth/zoe_loader.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
|
||||
# 这样其内部的 `import zoedepth...` 才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
|
||||
if _ZOE_REPO_ROOT.is_dir():
|
||||
zoe_path = str(_ZOE_REPO_ROOT)
|
||||
if zoe_path not in sys.path:
|
||||
sys.path.insert(0, zoe_path)
|
||||
|
||||
from zoedepth.models.builder import build_model
|
||||
from zoedepth.utils.config import get_config
|
||||
|
||||
|
||||
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZoeConfig:
|
||||
"""
|
||||
ZoeDepth 模型选择配置。
|
||||
|
||||
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model: ZoeModelName = "ZoeD_N"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
|
||||
"""
|
||||
手动加载 ZoeDepth 三种模型之一:
|
||||
- "ZoeD_N"
|
||||
- "ZoeD_K"
|
||||
- "ZoeD_NK"
|
||||
"""
|
||||
if name == "ZoeD_N":
|
||||
conf = get_config("zoedepth", "infer")
|
||||
elif name == "ZoeD_K":
|
||||
conf = get_config("zoedepth", "infer", config_version="kitti")
|
||||
elif name == "ZoeD_NK":
|
||||
conf = get_config("zoedepth_nk", "infer")
|
||||
else:
|
||||
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
|
||||
|
||||
model = build_model(conf)
|
||||
|
||||
if device.startswith("cuda") and torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
else:
|
||||
model = model.to("cpu")
|
||||
|
||||
model.eval()
|
||||
return model, conf
|
||||
|
||||
|
||||
def load_zoe_from_config(config: ZoeConfig):
|
||||
"""
|
||||
根据 ZoeConfig 加载模型。
|
||||
|
||||
示例:
|
||||
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
|
||||
model, conf = load_zoe_from_config(cfg)
|
||||
"""
|
||||
return load_zoe_from_name(config.model, config.device)
|
||||
|
||||
1
python_server/model/Inpaint/ControlNet
Submodule
1
python_server/model/Inpaint/ControlNet
Submodule
Submodule python_server/model/Inpaint/ControlNet added at ed85cd1e25
1
python_server/model/Inpaint/__init__.py
Normal file
1
python_server/model/Inpaint/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的补全(Inpaint)模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- SDXL Inpaint(diffusers AutoPipelineForInpainting)
|
||||
- ControlNet(占位,暂未统一封装)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
get_sdxl_base_model_from_app,
|
||||
get_controlnet_base_model_from_app,
|
||||
get_controlnet_model_from_app,
|
||||
)
|
||||
|
||||
|
||||
class InpaintBackend(str, Enum):
|
||||
SDXL_INPAINT = "sdxl_inpaint"
|
||||
CONTROLNET = "controlnet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedInpaintConfig:
|
||||
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
|
||||
device: str | None = None
|
||||
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
|
||||
sdxl_base_model: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDrawConfig:
|
||||
"""
|
||||
统一绘图配置:
|
||||
- 纯文生图:image=None
|
||||
- 图生图(模仿输入图):image=某张参考图
|
||||
"""
|
||||
device: str | None = None
|
||||
sdxl_base_model: str | None = None
|
||||
|
||||
|
||||
def _resolve_device_and_dtype(device: str | None):
|
||||
import torch
|
||||
|
||||
app_cfg = load_app_config()
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def _enable_memory_opts(pipe, device: str) -> None:
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
|
||||
run_w, run_h = orig_w, orig_h
|
||||
if max_side > 0 and max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
return run_w, run_h
|
||||
|
||||
|
||||
def _make_sdxl_inpaint_predictor(
|
||||
cfg: UnifiedInpaintConfig,
|
||||
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
|
||||
"""
|
||||
返回补全函数:
|
||||
- 输入:image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
|
||||
- 输出:PIL RGB 结果图
|
||||
"""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
|
||||
|
||||
app_cfg = load_app_config()
|
||||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||
|
||||
device = cfg.device
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16" if device == "cuda" else None,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
|
||||
|
||||
# 省显存设置(尽量不改变输出语义)
|
||||
# 注意:CPU offload 会明显变慢,但能显著降低显存占用。
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _predict(
|
||||
image: Image.Image,
|
||||
mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.8,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
max_side: int = 1024,
|
||||
) -> Image.Image:
|
||||
image = image.convert("RGB")
|
||||
# diffusers 要求 mask 为单通道,白色区域为需要重绘
|
||||
mask = mask.convert("L")
|
||||
|
||||
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM,
|
||||
# 推理时将图像按比例缩放到不超过 max_side(默认 1024)并对齐到 8 的倍数。
|
||||
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = orig_w, orig_h
|
||||
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
if run_w <= 0:
|
||||
run_w = 8
|
||||
if run_h <= 0:
|
||||
run_h = 8
|
||||
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||
else:
|
||||
image_run = image
|
||||
mask_run = mask
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
mask_image=mask_run,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
out = out.convert("RGB")
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
|
||||
import torch
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
|
||||
app_cfg = load_app_config()
|
||||
|
||||
device = _.device
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
base_model = get_controlnet_base_model_from_app(app_cfg)
|
||||
controlnet_id = get_controlnet_model_from_app(app_cfg)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
|
||||
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pipe.to(device)
|
||||
else:
|
||||
pipe.to(device)
|
||||
|
||||
def _predict(
|
||||
image: Image.Image,
|
||||
mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.8,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
max_side: int = 768,
|
||||
) -> Image.Image:
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
image = image.convert("RGB")
|
||||
mask = mask.convert("L")
|
||||
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = orig_w, orig_h
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||
else:
|
||||
image_run = image
|
||||
mask_run = mask
|
||||
|
||||
# control image:使用 canny 边缘作为约束(最通用)
|
||||
rgb = np.array(image_run, dtype=np.uint8)
|
||||
edges = cv2.Canny(rgb, 100, 200)
|
||||
edges3 = np.stack([edges, edges, edges], axis=-1)
|
||||
control_image = Image.fromarray(edges3)
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
mask_image=mask_run,
|
||||
control_image=control_image,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
|
||||
out = out.convert("RGB")
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_inpaint_predictor(
|
||||
cfg: UnifiedInpaintConfig | None = None,
|
||||
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
|
||||
"""
|
||||
统一构建补全预测函数。
|
||||
"""
|
||||
cfg = cfg or UnifiedInpaintConfig()
|
||||
|
||||
if cfg.backend == InpaintBackend.SDXL_INPAINT:
|
||||
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
|
||||
|
||||
if cfg.backend == InpaintBackend.CONTROLNET:
|
||||
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
|
||||
|
||||
raise ValueError(f"不支持的补全后端: {cfg.backend}")
|
||||
|
||||
|
||||
def build_draw_predictor(
|
||||
cfg: UnifiedDrawConfig | None = None,
|
||||
) -> Callable[..., Image.Image]:
|
||||
"""
|
||||
构建统一绘图函数:
|
||||
- 文生图:draw(prompt, image=None, ...)
|
||||
- 图生图:draw(prompt, image=ref_image, strength=0.55, ...)
|
||||
"""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
|
||||
cfg = cfg or UnifiedDrawConfig()
|
||||
app_cfg = load_app_config()
|
||||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
|
||||
|
||||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16" if device == "cuda" else None,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
|
||||
|
||||
_enable_memory_opts(pipe_t2i, device)
|
||||
_enable_memory_opts(pipe_i2i, device)
|
||||
|
||||
def _draw(
|
||||
prompt: str,
|
||||
image: Image.Image | None = None,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.55,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
max_side: int = 1024,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if image is None:
|
||||
run_w, run_h = _align_size(width, height, max_side=max_side)
|
||||
out = pipe_t2i(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
return out.convert("RGB")
|
||||
|
||||
image = image.convert("RGB")
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
else:
|
||||
image_run = image
|
||||
|
||||
out = pipe_i2i(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0].convert("RGB")
|
||||
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _draw
|
||||
|
||||
1
python_server/model/Inpaint/sdxl-inpaint
Submodule
1
python_server/model/Inpaint/sdxl-inpaint
Submodule
Submodule python_server/model/Inpaint/sdxl-inpaint added at 29867f540b
1
python_server/model/Seg/Mask2Former
Submodule
1
python_server/model/Seg/Mask2Former
Submodule
Submodule python_server/model/Seg/Mask2Former added at 9b0651c6c1
1
python_server/model/Seg/__init__.py
Normal file
1
python_server/model/Seg/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
59
python_server/model/Seg/mask2former_loader.py
Normal file
59
python_server/model/Seg/mask2former_loader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mask2FormerHFConfig:
|
||||
"""
|
||||
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
|
||||
|
||||
model_id: HuggingFace 模型 id(默认 ADE20K semantic)
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def build_mask2former_hf_predictor(
|
||||
cfg: Mask2FormerHFConfig | None = None,
|
||||
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
|
||||
"""
|
||||
返回 predictor(image_rgb_uint8) -> label_map(int32)。
|
||||
"""
|
||||
cfg = cfg or Mask2FormerHFConfig()
|
||||
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
|
||||
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
|
||||
model.to(device).eval()
|
||||
|
||||
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb_u8 = image_rgb.astype("uint8")
|
||||
else:
|
||||
image_rgb_u8 = image_rgb
|
||||
|
||||
pil = Image.fromarray(image_rgb_u8, mode="RGB")
|
||||
inputs = processor(images=pil, return_tensors="pt").to(device)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# post-process to original size
|
||||
target_sizes = [(pil.height, pil.width)]
|
||||
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
|
||||
return seg.detach().to("cpu").numpy().astype("int32")
|
||||
|
||||
return _predict, cfg
|
||||
|
||||
168
python_server/model/Seg/seg_loader.py
Normal file
168
python_server/model/Seg/seg_loader.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的分割模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- SAM (segment-anything)
|
||||
- Mask2Former(使用 HuggingFace transformers 的语义分割实现)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class SegBackend(str, Enum):
|
||||
SAM = "sam"
|
||||
MASK2FORMER = "mask2former"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedSegConfig:
|
||||
backend: SegBackend = SegBackend.SAM
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# SAM (Segment Anything)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _ensure_sam_on_path() -> Path:
|
||||
sam_root = _THIS_DIR / "segment-anything"
|
||||
if not sam_root.is_dir():
|
||||
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
|
||||
sam_path = str(sam_root)
|
||||
if sam_path not in sys.path:
|
||||
sys.path.insert(0, sam_path)
|
||||
return sam_root
|
||||
|
||||
|
||||
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
|
||||
import requests
|
||||
|
||||
ckpt_dir = sam_root / "checkpoints"
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
|
||||
|
||||
if ckpt_path.is_file():
|
||||
return ckpt_path
|
||||
|
||||
url = (
|
||||
"https://dl.fbaipublicfiles.com/segment_anything/"
|
||||
"sam_vit_h_4b8939.pth"
|
||||
)
|
||||
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print(
|
||||
"\r[{}{}] {:.1f}%".format(
|
||||
"#" * done,
|
||||
"." * (50 - done),
|
||||
downloaded * 100 / total,
|
||||
),
|
||||
end="",
|
||||
)
|
||||
print("\nSAM 权重下载完成。")
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
返回一个分割函数:
|
||||
- 输入:RGB uint8 图像 (H, W, 3)
|
||||
- 输出:语义标签图 (H, W),每个目标一个 int id(从 1 开始)
|
||||
"""
|
||||
sam_root = _ensure_sam_on_path()
|
||||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||
|
||||
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
sam = sam_model_registry["vit_h"](
|
||||
checkpoint=str(ckpt_path),
|
||||
).to(device)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb_u8 = image_rgb.astype("uint8")
|
||||
else:
|
||||
image_rgb_u8 = image_rgb
|
||||
|
||||
masks = mask_generator.generate(image_rgb_u8)
|
||||
h, w, _ = image_rgb_u8.shape
|
||||
label_map = np.zeros((h, w), dtype="int32")
|
||||
|
||||
for idx, m in enumerate(masks, start=1):
|
||||
seg = m.get("segmentation")
|
||||
if seg is None:
|
||||
continue
|
||||
label_map[seg.astype(bool)] = idx
|
||||
|
||||
return label_map
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Mask2Former (占位)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||
from .mask2former_loader import build_mask2former_hf_predictor
|
||||
|
||||
predictor, _ = build_mask2former_hf_predictor()
|
||||
return predictor
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 统一构建函数
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def build_seg_predictor(
|
||||
cfg: UnifiedSegConfig | None = None,
|
||||
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
|
||||
"""
|
||||
统一构建分割预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
|
||||
- 实际使用的 backend
|
||||
"""
|
||||
cfg = cfg or UnifiedSegConfig()
|
||||
|
||||
if cfg.backend == SegBackend.SAM:
|
||||
return _make_sam_predictor(), SegBackend.SAM
|
||||
|
||||
if cfg.backend == SegBackend.MASK2FORMER:
|
||||
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
|
||||
|
||||
raise ValueError(f"不支持的分割后端: {cfg.backend}")
|
||||
|
||||
1
python_server/model/Seg/segment-anything
Submodule
1
python_server/model/Seg/segment-anything
Submodule
Submodule python_server/model/Seg/segment-anything added at dca509fe79
1
python_server/model/__init__.py
Normal file
1
python_server/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
Reference in New Issue
Block a user