Files
hfut-bishe/python_server/model/Depth/depth_loader.py
2026-04-07 20:55:30 +08:00

149 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
"""
统一的深度模型加载入口。
当前支持:
- ZoeDepth三种ZoeD_N / ZoeD_K / ZoeD_NK
- Depth Anything V2四种 encodervits / vitb / vitl / vitg
未来如果要加 MiDaS / DPT只需要在这里再接一层即可。
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
build_zoe_config_from_app,
build_depth_anything_v2_config_from_app,
build_dpt_config_from_app,
build_midas_config_from_app,
)
from .zoe_loader import load_zoe_from_config
from .depth_anything_v2_loader import (
load_depth_anything_v2_from_config,
infer_depth_anything_v2,
)
from .dpt_loader import load_dpt_from_config, infer_dpt
from .midas_loader import load_midas_from_config, infer_midas
class DepthBackend(str, Enum):
"""统一的深度模型后端类型。"""
ZOEDEPTH = "zoedepth"
DEPTH_ANYTHING_V2 = "depth_anything_v2"
DPT = "dpt"
MIDAS = "midas"
@dataclass
class UnifiedDepthConfig:
"""
统一深度配置。
backend: 使用哪个后端
device: 强制设备(可选),不填则使用 config.py 中的设置
"""
backend: DepthBackend = DepthBackend.ZOEDEPTH
device: str | None = None
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
zoe_cfg = build_zoe_config_from_app(app_cfg)
if device_override is not None:
zoe_cfg.device = device_override
model, _ = load_zoe_from_config(zoe_cfg)
def _predict(img: Image.Image) -> np.ndarray:
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
return np.asarray(depth, dtype="float32").squeeze()
return _predict
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
if device_override is not None:
da_cfg.device = device_override
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
def _predict(img: Image.Image) -> np.ndarray:
# Depth Anything V2 的 infer_image 接收 BGR uint8
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
bgr = rgb[:, :, ::-1]
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
return depth.astype("float32").squeeze()
return _predict
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
dpt_cfg = build_dpt_config_from_app(app_cfg)
if device_override is not None:
dpt_cfg.device = device_override
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
def _predict(img: Image.Image) -> np.ndarray:
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
midas_cfg = build_midas_config_from_app(app_cfg)
if device_override is not None:
midas_cfg.device = device_override
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
def _predict(img: Image.Image) -> np.ndarray:
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def build_depth_predictor(
cfg: UnifiedDepthConfig | None = None,
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
"""
统一构建深度预测函数。
返回:
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
- 实际使用的 backend 类型
"""
cfg = cfg or UnifiedDepthConfig()
if cfg.backend == DepthBackend.ZOEDEPTH:
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
if cfg.backend == DepthBackend.DPT:
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
if cfg.backend == DepthBackend.MIDAS:
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
raise ValueError(f"不支持的深度后端: {cfg.backend}")