Files
hfut-bishe/python_server/model/Depth/midas_loader.py
2026-04-07 20:55:30 +08:00

128 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
_THIS_DIR = Path(__file__).resolve().parent
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
if _MIDAS_REPO_ROOT.is_dir():
midas_path = str(_MIDAS_REPO_ROOT)
if midas_path not in sys.path:
sys.path.insert(0, midas_path)
from midas.model_loader import load_model, default_models # type: ignore[import]
import utils # from MiDaS repo
MiDaSModelType = Literal[
"dpt_beit_large_512",
"dpt_swin2_large_384",
"dpt_swin2_tiny_256",
"dpt_levit_224",
]
@dataclass
class MiDaSConfig:
model_type: MiDaSModelType = "dpt_beit_large_512"
device: str = "cuda"
_MIDAS_WEIGHTS_URLS = {
# 官方权重参见 MiDaS 仓库 README
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
}
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _MIDAS_WEIGHTS_URLS.get(model_type)
if not url:
raise FileNotFoundError(
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
f"且当前未为 model_type='{model_type}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\nMiDaS 权重下载完成。")
def load_midas_from_config(
cfg: MiDaSConfig,
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
"""
加载 MiDaS 模型与对应 transform。
返回: model, cfg, transform, net_w, net_h
"""
# default_models 中给了默认权重路径名
model_info = default_models[cfg.model_type]
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
_download_if_missing(cfg.model_type, ckpt_path)
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
model, transform, net_w, net_h = load_model(
device=torch.device(device),
model_path=str(ckpt_path),
model_type=cfg.model_type,
optimize=False,
height=None,
square=False,
)
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
return model, cfg, transform, net_w, net_h
def infer_midas(
model: torch.nn.Module,
transform: callable,
image_rgb: np.ndarray,
net_w: int,
net_h: int,
device: str,
) -> np.ndarray:
"""
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
"""
image = transform({"image": image_rgb})["image"]
prediction = utils.process(
torch.device(device),
model,
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
image=image,
input_size=(net_w, net_h),
target_size=image_rgb.shape[1::-1],
optimize=False,
use_camera=False,
)
return np.asarray(prediction, dtype="float32").squeeze()