148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
from dataclasses import dataclass
|
||
from typing import Literal, Tuple
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import requests
|
||
|
||
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
|
||
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
|
||
_THIS_DIR = Path(__file__).resolve().parent
|
||
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
|
||
if _DA_REPO_ROOT.is_dir():
|
||
da_path = str(_DA_REPO_ROOT)
|
||
if da_path not in sys.path:
|
||
sys.path.insert(0, da_path)
|
||
|
||
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
|
||
|
||
|
||
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
|
||
|
||
|
||
@dataclass
|
||
class DepthAnythingV2Config:
|
||
"""
|
||
Depth Anything V2 模型选择配置。
|
||
|
||
encoder: "vits" | "vitb" | "vitl" | "vitg"
|
||
device: "cuda" | "cpu"
|
||
input_size: 推理时的输入分辨率(短边),参考官方 demo,默认 518。
|
||
"""
|
||
|
||
encoder: EncoderName = "vitl"
|
||
device: str = "cuda"
|
||
input_size: int = 518
|
||
|
||
|
||
_MODEL_CONFIGS = {
|
||
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
||
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
||
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
||
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
|
||
}
|
||
|
||
|
||
_DA_V2_WEIGHTS_URLS = {
|
||
# 官方权重托管在 HuggingFace:
|
||
# - Small -> vits
|
||
# - Base -> vitb
|
||
# - Large -> vitl
|
||
# - Giant -> vitg
|
||
# 如需替换为国内镜像,可直接修改这些 URL。
|
||
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
|
||
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
|
||
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
|
||
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
|
||
}
|
||
|
||
|
||
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
|
||
if ckpt_path.is_file():
|
||
return
|
||
|
||
url = _DA_V2_WEIGHTS_URLS.get(encoder)
|
||
if not url:
|
||
raise FileNotFoundError(
|
||
f"找不到权重文件: {ckpt_path}\n"
|
||
f"且当前未为 encoder='{encoder}' 配置自动下载 URL,请手动下载到该路径。"
|
||
)
|
||
|
||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
|
||
|
||
resp = requests.get(url, stream=True)
|
||
resp.raise_for_status()
|
||
|
||
total = int(resp.headers.get("content-length", "0") or "0")
|
||
downloaded = 0
|
||
chunk_size = 1024 * 1024
|
||
|
||
with ckpt_path.open("wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||
if not chunk:
|
||
continue
|
||
f.write(chunk)
|
||
downloaded += len(chunk)
|
||
if total > 0:
|
||
done = int(50 * downloaded / total)
|
||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||
|
||
print("\n权重下载完成。")
|
||
|
||
|
||
def load_depth_anything_v2_from_config(
|
||
cfg: DepthAnythingV2Config,
|
||
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
|
||
"""
|
||
根据配置加载 Depth Anything V2 模型与对应配置。
|
||
|
||
说明:
|
||
- 权重文件路径遵循官方命名约定:
|
||
checkpoints/depth_anything_v2_{encoder}.pth
|
||
例如:depth_anything_v2_vitl.pth
|
||
- 请确保上述权重文件已下载到
|
||
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
|
||
"""
|
||
if cfg.encoder not in _MODEL_CONFIGS:
|
||
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
|
||
|
||
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
|
||
_download_if_missing(cfg.encoder, ckpt_path)
|
||
|
||
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
|
||
|
||
state_dict = torch.load(str(ckpt_path), map_location="cpu")
|
||
model.load_state_dict(state_dict)
|
||
|
||
if cfg.device.startswith("cuda") and torch.cuda.is_available():
|
||
device = "cuda"
|
||
else:
|
||
device = "cpu"
|
||
|
||
model = model.to(device).eval()
|
||
cfg = DepthAnythingV2Config(
|
||
encoder=cfg.encoder,
|
||
device=device,
|
||
input_size=cfg.input_size,
|
||
)
|
||
return model, cfg
|
||
|
||
|
||
def infer_depth_anything_v2(
|
||
model: DepthAnythingV2,
|
||
image_bgr: np.ndarray,
|
||
input_size: int,
|
||
) -> np.ndarray:
|
||
"""
|
||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
|
||
"""
|
||
depth = model.infer_image(image_bgr, input_size)
|
||
depth = np.asarray(depth, dtype="float32")
|
||
return depth
|
||
|