initial commit
This commit is contained in:
156
python_server/model/Depth/dpt_loader.py
Normal file
156
python_server/model/Depth/dpt_loader.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
|
||||
if _DPT_REPO_ROOT.is_dir():
|
||||
dpt_path = str(_DPT_REPO_ROOT)
|
||||
if dpt_path not in sys.path:
|
||||
sys.path.insert(0, dpt_path)
|
||||
|
||||
from dpt.models import DPTDepthModel # type: ignore[import]
|
||||
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
|
||||
from torchvision.transforms import Compose
|
||||
import cv2
|
||||
|
||||
|
||||
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPTConfig:
|
||||
model_type: DPTModelType = "dpt_large"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_DPT_WEIGHTS_URLS = {
|
||||
# 官方 DPT 模型权重托管在:
|
||||
# https://github.com/isl-org/DPT#models
|
||||
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DPT_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 DPT 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nDPT 权重下载完成。")
|
||||
|
||||
|
||||
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
|
||||
"""
|
||||
加载 DPT 模型与对应的预处理 transform。
|
||||
"""
|
||||
ckpt_name = {
|
||||
"dpt_large": "dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
||||
}[cfg.model_type]
|
||||
|
||||
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
if cfg.model_type == "dpt_large":
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
else:
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model.to(device).eval()
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method="minimal",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
cfg = DPTConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform
|
||||
|
||||
|
||||
def infer_dpt(
|
||||
model: DPTDepthModel,
|
||||
transform: Compose,
|
||||
image_bgr: np.ndarray,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
sample = transform({"image": img})["image"]
|
||||
|
||||
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = model(input_batch)
|
||||
prediction = (
|
||||
torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=img.shape[:2],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
return prediction.astype("float32")
|
||||
|
||||
Reference in New Issue
Block a user