initial commit
This commit is contained in:
148
python_server/model/Depth/depth_loader.py
Normal file
148
python_server/model/Depth/depth_loader.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的深度模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- ZoeDepth(三种:ZoeD_N / ZoeD_K / ZoeD_NK)
|
||||
- Depth Anything V2(四种 encoder:vits / vitb / vitl / vitg)
|
||||
|
||||
未来如果要加 MiDaS / DPT,只需要在这里再接一层即可。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
build_zoe_config_from_app,
|
||||
build_depth_anything_v2_config_from_app,
|
||||
build_dpt_config_from_app,
|
||||
build_midas_config_from_app,
|
||||
)
|
||||
from .zoe_loader import load_zoe_from_config
|
||||
from .depth_anything_v2_loader import (
|
||||
load_depth_anything_v2_from_config,
|
||||
infer_depth_anything_v2,
|
||||
)
|
||||
from .dpt_loader import load_dpt_from_config, infer_dpt
|
||||
from .midas_loader import load_midas_from_config, infer_midas
|
||||
|
||||
|
||||
class DepthBackend(str, Enum):
|
||||
"""统一的深度模型后端类型。"""
|
||||
|
||||
ZOEDEPTH = "zoedepth"
|
||||
DEPTH_ANYTHING_V2 = "depth_anything_v2"
|
||||
DPT = "dpt"
|
||||
MIDAS = "midas"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDepthConfig:
|
||||
"""
|
||||
统一深度配置。
|
||||
|
||||
backend: 使用哪个后端
|
||||
device: 强制设备(可选),不填则使用 config.py 中的设置
|
||||
"""
|
||||
|
||||
backend: DepthBackend = DepthBackend.ZOEDEPTH
|
||||
device: str | None = None
|
||||
|
||||
|
||||
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
zoe_cfg = build_zoe_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
zoe_cfg.device = device_override
|
||||
|
||||
model, _ = load_zoe_from_config(zoe_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
|
||||
return np.asarray(depth, dtype="float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
da_cfg.device = device_override
|
||||
|
||||
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
# Depth Anything V2 的 infer_image 接收 BGR uint8
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
|
||||
bgr = rgb[:, :, ::-1]
|
||||
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
dpt_cfg = build_dpt_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
dpt_cfg.device = device_override
|
||||
|
||||
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
midas_cfg = build_midas_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
midas_cfg.device = device_override
|
||||
|
||||
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
||||
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_depth_predictor(
|
||||
cfg: UnifiedDepthConfig | None = None,
|
||||
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
|
||||
"""
|
||||
统一构建深度预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
|
||||
- 实际使用的 backend 类型
|
||||
"""
|
||||
cfg = cfg or UnifiedDepthConfig()
|
||||
|
||||
if cfg.backend == DepthBackend.ZOEDEPTH:
|
||||
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
|
||||
|
||||
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
|
||||
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
|
||||
|
||||
if cfg.backend == DepthBackend.DPT:
|
||||
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
|
||||
|
||||
if cfg.backend == DepthBackend.MIDAS:
|
||||
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
|
||||
|
||||
raise ValueError(f"不支持的深度后端: {cfg.backend}")
|
||||
|
||||
Reference in New Issue
Block a user