149 lines
4.5 KiB
Python
149 lines
4.5 KiB
Python
from __future__ import annotations
|
||
|
||
"""
|
||
统一的深度模型加载入口。
|
||
|
||
当前支持:
|
||
- ZoeDepth(三种:ZoeD_N / ZoeD_K / ZoeD_NK)
|
||
- Depth Anything V2(四种 encoder:vits / vitb / vitl / vitg)
|
||
|
||
未来如果要加 MiDaS / DPT,只需要在这里再接一层即可。
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from typing import Callable
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
from config_loader import (
|
||
load_app_config,
|
||
build_zoe_config_from_app,
|
||
build_depth_anything_v2_config_from_app,
|
||
build_dpt_config_from_app,
|
||
build_midas_config_from_app,
|
||
)
|
||
from .zoe_loader import load_zoe_from_config
|
||
from .depth_anything_v2_loader import (
|
||
load_depth_anything_v2_from_config,
|
||
infer_depth_anything_v2,
|
||
)
|
||
from .dpt_loader import load_dpt_from_config, infer_dpt
|
||
from .midas_loader import load_midas_from_config, infer_midas
|
||
|
||
|
||
class DepthBackend(str, Enum):
|
||
"""统一的深度模型后端类型。"""
|
||
|
||
ZOEDEPTH = "zoedepth"
|
||
DEPTH_ANYTHING_V2 = "depth_anything_v2"
|
||
DPT = "dpt"
|
||
MIDAS = "midas"
|
||
|
||
|
||
@dataclass
|
||
class UnifiedDepthConfig:
|
||
"""
|
||
统一深度配置。
|
||
|
||
backend: 使用哪个后端
|
||
device: 强制设备(可选),不填则使用 config.py 中的设置
|
||
"""
|
||
|
||
backend: DepthBackend = DepthBackend.ZOEDEPTH
|
||
device: str | None = None
|
||
|
||
|
||
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||
app_cfg = load_app_config()
|
||
zoe_cfg = build_zoe_config_from_app(app_cfg)
|
||
if device_override is not None:
|
||
zoe_cfg.device = device_override
|
||
|
||
model, _ = load_zoe_from_config(zoe_cfg)
|
||
|
||
def _predict(img: Image.Image) -> np.ndarray:
|
||
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
|
||
return np.asarray(depth, dtype="float32").squeeze()
|
||
|
||
return _predict
|
||
|
||
|
||
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||
app_cfg = load_app_config()
|
||
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
|
||
if device_override is not None:
|
||
da_cfg.device = device_override
|
||
|
||
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
|
||
|
||
def _predict(img: Image.Image) -> np.ndarray:
|
||
# Depth Anything V2 的 infer_image 接收 BGR uint8
|
||
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
|
||
bgr = rgb[:, :, ::-1]
|
||
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
|
||
return depth.astype("float32").squeeze()
|
||
|
||
return _predict
|
||
|
||
|
||
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||
app_cfg = load_app_config()
|
||
dpt_cfg = build_dpt_config_from_app(app_cfg)
|
||
if device_override is not None:
|
||
dpt_cfg.device = device_override
|
||
|
||
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
|
||
|
||
def _predict(img: Image.Image) -> np.ndarray:
|
||
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
|
||
return depth.astype("float32").squeeze()
|
||
|
||
return _predict
|
||
|
||
|
||
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||
app_cfg = load_app_config()
|
||
midas_cfg = build_midas_config_from_app(app_cfg)
|
||
if device_override is not None:
|
||
midas_cfg.device = device_override
|
||
|
||
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
|
||
|
||
def _predict(img: Image.Image) -> np.ndarray:
|
||
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
||
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
|
||
return depth.astype("float32").squeeze()
|
||
|
||
return _predict
|
||
|
||
|
||
def build_depth_predictor(
|
||
cfg: UnifiedDepthConfig | None = None,
|
||
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
|
||
"""
|
||
统一构建深度预测函数。
|
||
|
||
返回:
|
||
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
|
||
- 实际使用的 backend 类型
|
||
"""
|
||
cfg = cfg or UnifiedDepthConfig()
|
||
|
||
if cfg.backend == DepthBackend.ZOEDEPTH:
|
||
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
|
||
|
||
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
|
||
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
|
||
|
||
if cfg.backend == DepthBackend.DPT:
|
||
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
|
||
|
||
if cfg.backend == DepthBackend.MIDAS:
|
||
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
|
||
|
||
raise ValueError(f"不支持的深度后端: {cfg.backend}")
|
||
|