initial commit

This commit is contained in:
2026-04-07 20:55:30 +08:00
commit 81d1fb7856
84 changed files with 11929 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass
from typing import Literal
import sys
from pathlib import Path
import torch
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
# 这样其内部的 `import zoedepth...` 才能正常工作。
_THIS_DIR = Path(__file__).resolve().parent
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
if _ZOE_REPO_ROOT.is_dir():
zoe_path = str(_ZOE_REPO_ROOT)
if zoe_path not in sys.path:
sys.path.insert(0, zoe_path)
from zoedepth.models.builder import build_model
from zoedepth.utils.config import get_config
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
@dataclass
class ZoeConfig:
"""
ZoeDepth 模型选择配置。
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
device: "cuda" | "cpu"
"""
model: ZoeModelName = "ZoeD_N"
device: str = "cuda"
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
"""
手动加载 ZoeDepth 三种模型之一:
- "ZoeD_N"
- "ZoeD_K"
- "ZoeD_NK"
"""
if name == "ZoeD_N":
conf = get_config("zoedepth", "infer")
elif name == "ZoeD_K":
conf = get_config("zoedepth", "infer", config_version="kitti")
elif name == "ZoeD_NK":
conf = get_config("zoedepth_nk", "infer")
else:
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
model = build_model(conf)
if device.startswith("cuda") and torch.cuda.is_available():
model = model.to("cuda")
else:
model = model.to("cpu")
model.eval()
return model, conf
def load_zoe_from_config(config: ZoeConfig):
"""
根据 ZoeConfig 加载模型。
示例:
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
model, conf = load_zoe_from_config(cfg)
"""
return load_zoe_from_name(config.model, config.device)