from dataclasses import dataclass from typing import Literal, Tuple import sys from pathlib import Path import numpy as np import torch import requests _THIS_DIR = Path(__file__).resolve().parent _MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS" if _MIDAS_REPO_ROOT.is_dir(): midas_path = str(_MIDAS_REPO_ROOT) if midas_path not in sys.path: sys.path.insert(0, midas_path) from midas.model_loader import load_model, default_models # type: ignore[import] import utils # from MiDaS repo MiDaSModelType = Literal[ "dpt_beit_large_512", "dpt_swin2_large_384", "dpt_swin2_tiny_256", "dpt_levit_224", ] @dataclass class MiDaSConfig: model_type: MiDaSModelType = "dpt_beit_large_512" device: str = "cuda" _MIDAS_WEIGHTS_URLS = { # 官方权重参见 MiDaS 仓库 README "dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt", "dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt", "dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt", "dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt", } def _download_if_missing(model_type: str, ckpt_path: Path) -> None: if ckpt_path.is_file(): return url = _MIDAS_WEIGHTS_URLS.get(model_type) if not url: raise FileNotFoundError( f"找不到 MiDaS 权重文件: {ckpt_path}\n" f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。" ) ckpt_path.parent.mkdir(parents=True, exist_ok=True) print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}") resp = requests.get(url, stream=True) resp.raise_for_status() total = int(resp.headers.get("content-length", "0") or "0") downloaded = 0 chunk_size = 1024 * 1024 with ckpt_path.open("wb") as f: for chunk in resp.iter_content(chunk_size=chunk_size): if not chunk: continue f.write(chunk) downloaded += len(chunk) if total > 0: done = int(50 * downloaded / total) print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="") print("\nMiDaS 权重下载完成。") def load_midas_from_config( cfg: MiDaSConfig, ) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]: """ 加载 MiDaS 模型与对应 transform。 返回: model, cfg, transform, net_w, net_h """ # default_models 中给了默认权重路径名 model_info = default_models[cfg.model_type] ckpt_path = _MIDAS_REPO_ROOT / model_info.path _download_if_missing(cfg.model_type, ckpt_path) device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu" model, transform, net_w, net_h = load_model( device=torch.device(device), model_path=str(ckpt_path), model_type=cfg.model_type, optimize=False, height=None, square=False, ) cfg = MiDaSConfig(model_type=cfg.model_type, device=device) return model, cfg, transform, net_w, net_h def infer_midas( model: torch.nn.Module, transform: callable, image_rgb: np.ndarray, net_w: int, net_h: int, device: str, ) -> np.ndarray: """ 对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。 """ image = transform({"image": image_rgb})["image"] prediction = utils.process( torch.device(device), model, model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino" image=image, input_size=(net_w, net_h), target_size=image_rgb.shape[1::-1], optimize=False, use_camera=False, ) return np.asarray(prediction, dtype="float32").squeeze()