75 lines
1.8 KiB
Python
75 lines
1.8 KiB
Python
from dataclasses import dataclass
|
|
from typing import Literal
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
|
|
# 这样其内部的 `import zoedepth...` 才能正常工作。
|
|
_THIS_DIR = Path(__file__).resolve().parent
|
|
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
|
|
if _ZOE_REPO_ROOT.is_dir():
|
|
zoe_path = str(_ZOE_REPO_ROOT)
|
|
if zoe_path not in sys.path:
|
|
sys.path.insert(0, zoe_path)
|
|
|
|
from zoedepth.models.builder import build_model
|
|
from zoedepth.utils.config import get_config
|
|
|
|
|
|
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
|
|
|
|
|
|
@dataclass
|
|
class ZoeConfig:
|
|
"""
|
|
ZoeDepth 模型选择配置。
|
|
|
|
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
|
|
device: "cuda" | "cpu"
|
|
"""
|
|
|
|
model: ZoeModelName = "ZoeD_N"
|
|
device: str = "cuda"
|
|
|
|
|
|
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
|
|
"""
|
|
手动加载 ZoeDepth 三种模型之一:
|
|
- "ZoeD_N"
|
|
- "ZoeD_K"
|
|
- "ZoeD_NK"
|
|
"""
|
|
if name == "ZoeD_N":
|
|
conf = get_config("zoedepth", "infer")
|
|
elif name == "ZoeD_K":
|
|
conf = get_config("zoedepth", "infer", config_version="kitti")
|
|
elif name == "ZoeD_NK":
|
|
conf = get_config("zoedepth_nk", "infer")
|
|
else:
|
|
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
|
|
|
|
model = build_model(conf)
|
|
|
|
if device.startswith("cuda") and torch.cuda.is_available():
|
|
model = model.to("cuda")
|
|
else:
|
|
model = model.to("cpu")
|
|
|
|
model.eval()
|
|
return model, conf
|
|
|
|
|
|
def load_zoe_from_config(config: ZoeConfig):
|
|
"""
|
|
根据 ZoeConfig 加载模型。
|
|
|
|
示例:
|
|
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
|
|
model, conf = load_zoe_from_config(cfg)
|
|
"""
|
|
return load_zoe_from_name(config.model, config.device)
|
|
|