Files
hfut-bishe/python_server/model/Depth/depth_anything_v2_loader.py
2026-04-07 20:55:30 +08:00

148 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
_THIS_DIR = Path(__file__).resolve().parent
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
if _DA_REPO_ROOT.is_dir():
da_path = str(_DA_REPO_ROOT)
if da_path not in sys.path:
sys.path.insert(0, da_path)
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
@dataclass
class DepthAnythingV2Config:
"""
Depth Anything V2 模型选择配置。
encoder: "vits" | "vitb" | "vitl" | "vitg"
device: "cuda" | "cpu"
input_size: 推理时的输入分辨率(短边),参考官方 demo默认 518。
"""
encoder: EncoderName = "vitl"
device: str = "cuda"
input_size: int = 518
_MODEL_CONFIGS = {
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
}
_DA_V2_WEIGHTS_URLS = {
# 官方权重托管在 HuggingFace:
# - Small -> vits
# - Base -> vitb
# - Large -> vitl
# - Giant -> vitg
# 如需替换为国内镜像,可直接修改这些 URL。
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
}
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _DA_V2_WEIGHTS_URLS.get(encoder)
if not url:
raise FileNotFoundError(
f"找不到权重文件: {ckpt_path}\n"
f"且当前未为 encoder='{encoder}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\n权重下载完成。")
def load_depth_anything_v2_from_config(
cfg: DepthAnythingV2Config,
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
"""
根据配置加载 Depth Anything V2 模型与对应配置。
说明:
- 权重文件路径遵循官方命名约定:
checkpoints/depth_anything_v2_{encoder}.pth
例如depth_anything_v2_vitl.pth
- 请确保上述权重文件已下载到
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
"""
if cfg.encoder not in _MODEL_CONFIGS:
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
_download_if_missing(cfg.encoder, ckpt_path)
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
state_dict = torch.load(str(ckpt_path), map_location="cpu")
model.load_state_dict(state_dict)
if cfg.device.startswith("cuda") and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model = model.to(device).eval()
cfg = DepthAnythingV2Config(
encoder=cfg.encoder,
device=device,
input_size=cfg.input_size,
)
return model, cfg
def infer_depth_anything_v2(
model: DepthAnythingV2,
image_bgr: np.ndarray,
input_size: int,
) -> np.ndarray:
"""
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
"""
depth = model.infer_image(image_bgr, input_size)
depth = np.asarray(depth, dtype="float32")
return depth