Files
hfut-bishe/python_server/model/Depth/dpt_loader.py
2026-04-07 20:55:30 +08:00

157 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
_THIS_DIR = Path(__file__).resolve().parent
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
if _DPT_REPO_ROOT.is_dir():
dpt_path = str(_DPT_REPO_ROOT)
if dpt_path not in sys.path:
sys.path.insert(0, dpt_path)
from dpt.models import DPTDepthModel # type: ignore[import]
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
from torchvision.transforms import Compose
import cv2
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
@dataclass
class DPTConfig:
model_type: DPTModelType = "dpt_large"
device: str = "cuda"
_DPT_WEIGHTS_URLS = {
# 官方 DPT 模型权重托管在:
# https://github.com/isl-org/DPT#models
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
}
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _DPT_WEIGHTS_URLS.get(model_type)
if not url:
raise FileNotFoundError(
f"找不到 DPT 权重文件: {ckpt_path}\n"
f"且当前未为 model_type='{model_type}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\nDPT 权重下载完成。")
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
"""
加载 DPT 模型与对应的预处理 transform。
"""
ckpt_name = {
"dpt_large": "dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
}[cfg.model_type]
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
_download_if_missing(cfg.model_type, ckpt_path)
if cfg.model_type == "dpt_large":
net_w = net_h = 384
model = DPTDepthModel(
path=str(ckpt_path),
backbone="vitl16_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
else:
net_w = net_h = 384
model = DPTDepthModel(
path=str(ckpt_path),
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
model.to(device).eval()
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
cfg = DPTConfig(model_type=cfg.model_type, device=device)
return model, cfg, transform
def infer_dpt(
model: DPTDepthModel,
transform: Compose,
image_bgr: np.ndarray,
device: str,
) -> np.ndarray:
"""
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
"""
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
sample = transform({"image": img})["image"]
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
with torch.no_grad():
prediction = model(input_batch)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
return prediction.astype("float32")