from dataclasses import dataclass from typing import Literal import sys from pathlib import Path import torch # 确保本地克隆的 ZoeDepth 仓库在 sys.path 中, # 这样其内部的 `import zoedepth...` 才能正常工作。 _THIS_DIR = Path(__file__).resolve().parent _ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth" if _ZOE_REPO_ROOT.is_dir(): zoe_path = str(_ZOE_REPO_ROOT) if zoe_path not in sys.path: sys.path.insert(0, zoe_path) from zoedepth.models.builder import build_model from zoedepth.utils.config import get_config ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"] @dataclass class ZoeConfig: """ ZoeDepth 模型选择配置。 model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK" device: "cuda" | "cpu" """ model: ZoeModelName = "ZoeD_N" device: str = "cuda" def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"): """ 手动加载 ZoeDepth 三种模型之一: - "ZoeD_N" - "ZoeD_K" - "ZoeD_NK" """ if name == "ZoeD_N": conf = get_config("zoedepth", "infer") elif name == "ZoeD_K": conf = get_config("zoedepth", "infer", config_version="kitti") elif name == "ZoeD_NK": conf = get_config("zoedepth_nk", "infer") else: raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}") model = build_model(conf) if device.startswith("cuda") and torch.cuda.is_available(): model = model.to("cuda") else: model = model.to("cpu") model.eval() return model, conf def load_zoe_from_config(config: ZoeConfig): """ 根据 ZoeConfig 加载模型。 示例: cfg = ZoeConfig(model="ZoeD_NK", device="cuda") model, conf = load_zoe_from_config(cfg) """ return load_zoe_from_name(config.model, config.device)