initial commit
This commit is contained in:
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
|
||||
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
|
||||
if _DA_REPO_ROOT.is_dir():
|
||||
da_path = str(_DA_REPO_ROOT)
|
||||
if da_path not in sys.path:
|
||||
sys.path.insert(0, da_path)
|
||||
|
||||
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
|
||||
|
||||
|
||||
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthAnythingV2Config:
|
||||
"""
|
||||
Depth Anything V2 模型选择配置。
|
||||
|
||||
encoder: "vits" | "vitb" | "vitl" | "vitg"
|
||||
device: "cuda" | "cpu"
|
||||
input_size: 推理时的输入分辨率(短边),参考官方 demo,默认 518。
|
||||
"""
|
||||
|
||||
encoder: EncoderName = "vitl"
|
||||
device: str = "cuda"
|
||||
input_size: int = 518
|
||||
|
||||
|
||||
_MODEL_CONFIGS = {
|
||||
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
||||
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
||||
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
||||
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
|
||||
}
|
||||
|
||||
|
||||
_DA_V2_WEIGHTS_URLS = {
|
||||
# 官方权重托管在 HuggingFace:
|
||||
# - Small -> vits
|
||||
# - Base -> vitb
|
||||
# - Large -> vitl
|
||||
# - Giant -> vitg
|
||||
# 如需替换为国内镜像,可直接修改这些 URL。
|
||||
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
|
||||
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
|
||||
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
|
||||
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DA_V2_WEIGHTS_URLS.get(encoder)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 encoder='{encoder}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
|
||||
print("\n权重下载完成。")
|
||||
|
||||
|
||||
def load_depth_anything_v2_from_config(
|
||||
cfg: DepthAnythingV2Config,
|
||||
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
|
||||
"""
|
||||
根据配置加载 Depth Anything V2 模型与对应配置。
|
||||
|
||||
说明:
|
||||
- 权重文件路径遵循官方命名约定:
|
||||
checkpoints/depth_anything_v2_{encoder}.pth
|
||||
例如:depth_anything_v2_vitl.pth
|
||||
- 请确保上述权重文件已下载到
|
||||
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
|
||||
"""
|
||||
if cfg.encoder not in _MODEL_CONFIGS:
|
||||
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
|
||||
|
||||
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
|
||||
_download_if_missing(cfg.encoder, ckpt_path)
|
||||
|
||||
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
|
||||
|
||||
state_dict = torch.load(str(ckpt_path), map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if cfg.device.startswith("cuda") and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
model = model.to(device).eval()
|
||||
cfg = DepthAnythingV2Config(
|
||||
encoder=cfg.encoder,
|
||||
device=device,
|
||||
input_size=cfg.input_size,
|
||||
)
|
||||
return model, cfg
|
||||
|
||||
|
||||
def infer_depth_anything_v2(
|
||||
model: DepthAnythingV2,
|
||||
image_bgr: np.ndarray,
|
||||
input_size: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
|
||||
"""
|
||||
depth = model.infer_image(image_bgr, input_size)
|
||||
depth = np.asarray(depth, dtype="float32")
|
||||
return depth
|
||||
|
||||
Reference in New Issue
Block a user