initial commit

This commit is contained in:
2026-04-07 20:55:30 +08:00
commit 81d1fb7856
84 changed files with 11929 additions and 0 deletions

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
"""
统一的深度模型加载入口。
当前支持:
- ZoeDepth三种ZoeD_N / ZoeD_K / ZoeD_NK
- Depth Anything V2四种 encodervits / vitb / vitl / vitg
未来如果要加 MiDaS / DPT只需要在这里再接一层即可。
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
build_zoe_config_from_app,
build_depth_anything_v2_config_from_app,
build_dpt_config_from_app,
build_midas_config_from_app,
)
from .zoe_loader import load_zoe_from_config
from .depth_anything_v2_loader import (
load_depth_anything_v2_from_config,
infer_depth_anything_v2,
)
from .dpt_loader import load_dpt_from_config, infer_dpt
from .midas_loader import load_midas_from_config, infer_midas
class DepthBackend(str, Enum):
"""统一的深度模型后端类型。"""
ZOEDEPTH = "zoedepth"
DEPTH_ANYTHING_V2 = "depth_anything_v2"
DPT = "dpt"
MIDAS = "midas"
@dataclass
class UnifiedDepthConfig:
"""
统一深度配置。
backend: 使用哪个后端
device: 强制设备(可选),不填则使用 config.py 中的设置
"""
backend: DepthBackend = DepthBackend.ZOEDEPTH
device: str | None = None
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
zoe_cfg = build_zoe_config_from_app(app_cfg)
if device_override is not None:
zoe_cfg.device = device_override
model, _ = load_zoe_from_config(zoe_cfg)
def _predict(img: Image.Image) -> np.ndarray:
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
return np.asarray(depth, dtype="float32").squeeze()
return _predict
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
if device_override is not None:
da_cfg.device = device_override
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
def _predict(img: Image.Image) -> np.ndarray:
# Depth Anything V2 的 infer_image 接收 BGR uint8
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
bgr = rgb[:, :, ::-1]
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
return depth.astype("float32").squeeze()
return _predict
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
dpt_cfg = build_dpt_config_from_app(app_cfg)
if device_override is not None:
dpt_cfg.device = device_override
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
def _predict(img: Image.Image) -> np.ndarray:
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
midas_cfg = build_midas_config_from_app(app_cfg)
if device_override is not None:
midas_cfg.device = device_override
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
def _predict(img: Image.Image) -> np.ndarray:
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def build_depth_predictor(
cfg: UnifiedDepthConfig | None = None,
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
"""
统一构建深度预测函数。
返回:
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
- 实际使用的 backend 类型
"""
cfg = cfg or UnifiedDepthConfig()
if cfg.backend == DepthBackend.ZOEDEPTH:
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
if cfg.backend == DepthBackend.DPT:
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
if cfg.backend == DepthBackend.MIDAS:
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
raise ValueError(f"不支持的深度后端: {cfg.backend}")