initial commit
This commit is contained in:
127
python_server/model/Depth/midas_loader.py
Normal file
127
python_server/model/Depth/midas_loader.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
|
||||
if _MIDAS_REPO_ROOT.is_dir():
|
||||
midas_path = str(_MIDAS_REPO_ROOT)
|
||||
if midas_path not in sys.path:
|
||||
sys.path.insert(0, midas_path)
|
||||
|
||||
from midas.model_loader import load_model, default_models # type: ignore[import]
|
||||
import utils # from MiDaS repo
|
||||
|
||||
|
||||
MiDaSModelType = Literal[
|
||||
"dpt_beit_large_512",
|
||||
"dpt_swin2_large_384",
|
||||
"dpt_swin2_tiny_256",
|
||||
"dpt_levit_224",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiDaSConfig:
|
||||
model_type: MiDaSModelType = "dpt_beit_large_512"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_MIDAS_WEIGHTS_URLS = {
|
||||
# 官方权重参见 MiDaS 仓库 README
|
||||
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
|
||||
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
|
||||
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
|
||||
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _MIDAS_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nMiDaS 权重下载完成。")
|
||||
|
||||
|
||||
def load_midas_from_config(
|
||||
cfg: MiDaSConfig,
|
||||
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
|
||||
"""
|
||||
加载 MiDaS 模型与对应 transform。
|
||||
返回: model, cfg, transform, net_w, net_h
|
||||
"""
|
||||
# default_models 中给了默认权重路径名
|
||||
model_info = default_models[cfg.model_type]
|
||||
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model, transform, net_w, net_h = load_model(
|
||||
device=torch.device(device),
|
||||
model_path=str(ckpt_path),
|
||||
model_type=cfg.model_type,
|
||||
optimize=False,
|
||||
height=None,
|
||||
square=False,
|
||||
)
|
||||
|
||||
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform, net_w, net_h
|
||||
|
||||
|
||||
def infer_midas(
|
||||
model: torch.nn.Module,
|
||||
transform: callable,
|
||||
image_rgb: np.ndarray,
|
||||
net_w: int,
|
||||
net_h: int,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
image = transform({"image": image_rgb})["image"]
|
||||
prediction = utils.process(
|
||||
torch.device(device),
|
||||
model,
|
||||
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
|
||||
image=image,
|
||||
input_size=(net_w, net_h),
|
||||
target_size=image_rgb.shape[1::-1],
|
||||
optimize=False,
|
||||
use_camera=False,
|
||||
)
|
||||
return np.asarray(prediction, dtype="float32").squeeze()
|
||||
|
||||
Reference in New Issue
Block a user