initial commit
This commit is contained in:
74
python_server/model/Depth/zoe_loader.py
Normal file
74
python_server/model/Depth/zoe_loader.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
|
||||
# 这样其内部的 `import zoedepth...` 才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
|
||||
if _ZOE_REPO_ROOT.is_dir():
|
||||
zoe_path = str(_ZOE_REPO_ROOT)
|
||||
if zoe_path not in sys.path:
|
||||
sys.path.insert(0, zoe_path)
|
||||
|
||||
from zoedepth.models.builder import build_model
|
||||
from zoedepth.utils.config import get_config
|
||||
|
||||
|
||||
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZoeConfig:
|
||||
"""
|
||||
ZoeDepth 模型选择配置。
|
||||
|
||||
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model: ZoeModelName = "ZoeD_N"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
|
||||
"""
|
||||
手动加载 ZoeDepth 三种模型之一:
|
||||
- "ZoeD_N"
|
||||
- "ZoeD_K"
|
||||
- "ZoeD_NK"
|
||||
"""
|
||||
if name == "ZoeD_N":
|
||||
conf = get_config("zoedepth", "infer")
|
||||
elif name == "ZoeD_K":
|
||||
conf = get_config("zoedepth", "infer", config_version="kitti")
|
||||
elif name == "ZoeD_NK":
|
||||
conf = get_config("zoedepth_nk", "infer")
|
||||
else:
|
||||
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
|
||||
|
||||
model = build_model(conf)
|
||||
|
||||
if device.startswith("cuda") and torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
else:
|
||||
model = model.to("cpu")
|
||||
|
||||
model.eval()
|
||||
return model, conf
|
||||
|
||||
|
||||
def load_zoe_from_config(config: ZoeConfig):
|
||||
"""
|
||||
根据 ZoeConfig 加载模型。
|
||||
|
||||
示例:
|
||||
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
|
||||
model, conf = load_zoe_from_config(cfg)
|
||||
"""
|
||||
return load_zoe_from_name(config.model, config.device)
|
||||
|
||||
Reference in New Issue
Block a user