initial commit
This commit is contained in:
1
python_server/model/Depth/DPT
Submodule
1
python_server/model/Depth/DPT
Submodule
Submodule python_server/model/Depth/DPT added at cd3fe90bb4
1
python_server/model/Depth/Depth-Anything-V2
Submodule
1
python_server/model/Depth/Depth-Anything-V2
Submodule
Submodule python_server/model/Depth/Depth-Anything-V2 added at e5a2732d3e
1
python_server/model/Depth/MiDaS
Submodule
1
python_server/model/Depth/MiDaS
Submodule
Submodule python_server/model/Depth/MiDaS added at 454597711a
1
python_server/model/Depth/ZoeDepth
Submodule
1
python_server/model/Depth/ZoeDepth
Submodule
Submodule python_server/model/Depth/ZoeDepth added at d87f17b2f5
1
python_server/model/Depth/__init__.py
Normal file
1
python_server/model/Depth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
|
||||
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
|
||||
if _DA_REPO_ROOT.is_dir():
|
||||
da_path = str(_DA_REPO_ROOT)
|
||||
if da_path not in sys.path:
|
||||
sys.path.insert(0, da_path)
|
||||
|
||||
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
|
||||
|
||||
|
||||
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthAnythingV2Config:
|
||||
"""
|
||||
Depth Anything V2 模型选择配置。
|
||||
|
||||
encoder: "vits" | "vitb" | "vitl" | "vitg"
|
||||
device: "cuda" | "cpu"
|
||||
input_size: 推理时的输入分辨率(短边),参考官方 demo,默认 518。
|
||||
"""
|
||||
|
||||
encoder: EncoderName = "vitl"
|
||||
device: str = "cuda"
|
||||
input_size: int = 518
|
||||
|
||||
|
||||
_MODEL_CONFIGS = {
|
||||
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
||||
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
||||
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
||||
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
|
||||
}
|
||||
|
||||
|
||||
_DA_V2_WEIGHTS_URLS = {
|
||||
# 官方权重托管在 HuggingFace:
|
||||
# - Small -> vits
|
||||
# - Base -> vitb
|
||||
# - Large -> vitl
|
||||
# - Giant -> vitg
|
||||
# 如需替换为国内镜像,可直接修改这些 URL。
|
||||
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
|
||||
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
|
||||
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
|
||||
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DA_V2_WEIGHTS_URLS.get(encoder)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 encoder='{encoder}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
|
||||
print("\n权重下载完成。")
|
||||
|
||||
|
||||
def load_depth_anything_v2_from_config(
|
||||
cfg: DepthAnythingV2Config,
|
||||
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
|
||||
"""
|
||||
根据配置加载 Depth Anything V2 模型与对应配置。
|
||||
|
||||
说明:
|
||||
- 权重文件路径遵循官方命名约定:
|
||||
checkpoints/depth_anything_v2_{encoder}.pth
|
||||
例如:depth_anything_v2_vitl.pth
|
||||
- 请确保上述权重文件已下载到
|
||||
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
|
||||
"""
|
||||
if cfg.encoder not in _MODEL_CONFIGS:
|
||||
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
|
||||
|
||||
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
|
||||
_download_if_missing(cfg.encoder, ckpt_path)
|
||||
|
||||
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
|
||||
|
||||
state_dict = torch.load(str(ckpt_path), map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if cfg.device.startswith("cuda") and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
model = model.to(device).eval()
|
||||
cfg = DepthAnythingV2Config(
|
||||
encoder=cfg.encoder,
|
||||
device=device,
|
||||
input_size=cfg.input_size,
|
||||
)
|
||||
return model, cfg
|
||||
|
||||
|
||||
def infer_depth_anything_v2(
|
||||
model: DepthAnythingV2,
|
||||
image_bgr: np.ndarray,
|
||||
input_size: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
|
||||
"""
|
||||
depth = model.infer_image(image_bgr, input_size)
|
||||
depth = np.asarray(depth, dtype="float32")
|
||||
return depth
|
||||
|
||||
148
python_server/model/Depth/depth_loader.py
Normal file
148
python_server/model/Depth/depth_loader.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的深度模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- ZoeDepth(三种:ZoeD_N / ZoeD_K / ZoeD_NK)
|
||||
- Depth Anything V2(四种 encoder:vits / vitb / vitl / vitg)
|
||||
|
||||
未来如果要加 MiDaS / DPT,只需要在这里再接一层即可。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
build_zoe_config_from_app,
|
||||
build_depth_anything_v2_config_from_app,
|
||||
build_dpt_config_from_app,
|
||||
build_midas_config_from_app,
|
||||
)
|
||||
from .zoe_loader import load_zoe_from_config
|
||||
from .depth_anything_v2_loader import (
|
||||
load_depth_anything_v2_from_config,
|
||||
infer_depth_anything_v2,
|
||||
)
|
||||
from .dpt_loader import load_dpt_from_config, infer_dpt
|
||||
from .midas_loader import load_midas_from_config, infer_midas
|
||||
|
||||
|
||||
class DepthBackend(str, Enum):
|
||||
"""统一的深度模型后端类型。"""
|
||||
|
||||
ZOEDEPTH = "zoedepth"
|
||||
DEPTH_ANYTHING_V2 = "depth_anything_v2"
|
||||
DPT = "dpt"
|
||||
MIDAS = "midas"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDepthConfig:
|
||||
"""
|
||||
统一深度配置。
|
||||
|
||||
backend: 使用哪个后端
|
||||
device: 强制设备(可选),不填则使用 config.py 中的设置
|
||||
"""
|
||||
|
||||
backend: DepthBackend = DepthBackend.ZOEDEPTH
|
||||
device: str | None = None
|
||||
|
||||
|
||||
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
zoe_cfg = build_zoe_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
zoe_cfg.device = device_override
|
||||
|
||||
model, _ = load_zoe_from_config(zoe_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
|
||||
return np.asarray(depth, dtype="float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
da_cfg.device = device_override
|
||||
|
||||
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
# Depth Anything V2 的 infer_image 接收 BGR uint8
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
|
||||
bgr = rgb[:, :, ::-1]
|
||||
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
dpt_cfg = build_dpt_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
dpt_cfg.device = device_override
|
||||
|
||||
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||
app_cfg = load_app_config()
|
||||
midas_cfg = build_midas_config_from_app(app_cfg)
|
||||
if device_override is not None:
|
||||
midas_cfg.device = device_override
|
||||
|
||||
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
|
||||
|
||||
def _predict(img: Image.Image) -> np.ndarray:
|
||||
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
||||
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
|
||||
return depth.astype("float32").squeeze()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_depth_predictor(
|
||||
cfg: UnifiedDepthConfig | None = None,
|
||||
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
|
||||
"""
|
||||
统一构建深度预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
|
||||
- 实际使用的 backend 类型
|
||||
"""
|
||||
cfg = cfg or UnifiedDepthConfig()
|
||||
|
||||
if cfg.backend == DepthBackend.ZOEDEPTH:
|
||||
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
|
||||
|
||||
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
|
||||
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
|
||||
|
||||
if cfg.backend == DepthBackend.DPT:
|
||||
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
|
||||
|
||||
if cfg.backend == DepthBackend.MIDAS:
|
||||
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
|
||||
|
||||
raise ValueError(f"不支持的深度后端: {cfg.backend}")
|
||||
|
||||
156
python_server/model/Depth/dpt_loader.py
Normal file
156
python_server/model/Depth/dpt_loader.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
|
||||
if _DPT_REPO_ROOT.is_dir():
|
||||
dpt_path = str(_DPT_REPO_ROOT)
|
||||
if dpt_path not in sys.path:
|
||||
sys.path.insert(0, dpt_path)
|
||||
|
||||
from dpt.models import DPTDepthModel # type: ignore[import]
|
||||
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
|
||||
from torchvision.transforms import Compose
|
||||
import cv2
|
||||
|
||||
|
||||
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPTConfig:
|
||||
model_type: DPTModelType = "dpt_large"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_DPT_WEIGHTS_URLS = {
|
||||
# 官方 DPT 模型权重托管在:
|
||||
# https://github.com/isl-org/DPT#models
|
||||
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _DPT_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 DPT 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nDPT 权重下载完成。")
|
||||
|
||||
|
||||
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
|
||||
"""
|
||||
加载 DPT 模型与对应的预处理 transform。
|
||||
"""
|
||||
ckpt_name = {
|
||||
"dpt_large": "dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
||||
}[cfg.model_type]
|
||||
|
||||
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
if cfg.model_type == "dpt_large":
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
else:
|
||||
net_w = net_h = 384
|
||||
model = DPTDepthModel(
|
||||
path=str(ckpt_path),
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
enable_attention_hooks=False,
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model.to(device).eval()
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method="minimal",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
cfg = DPTConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform
|
||||
|
||||
|
||||
def infer_dpt(
|
||||
model: DPTDepthModel,
|
||||
transform: Compose,
|
||||
image_bgr: np.ndarray,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
sample = transform({"image": img})["image"]
|
||||
|
||||
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = model(input_batch)
|
||||
prediction = (
|
||||
torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=img.shape[:2],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
.squeeze()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
return prediction.astype("float32")
|
||||
|
||||
127
python_server/model/Depth/midas_loader.py
Normal file
127
python_server/model/Depth/midas_loader.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
|
||||
if _MIDAS_REPO_ROOT.is_dir():
|
||||
midas_path = str(_MIDAS_REPO_ROOT)
|
||||
if midas_path not in sys.path:
|
||||
sys.path.insert(0, midas_path)
|
||||
|
||||
from midas.model_loader import load_model, default_models # type: ignore[import]
|
||||
import utils # from MiDaS repo
|
||||
|
||||
|
||||
MiDaSModelType = Literal[
|
||||
"dpt_beit_large_512",
|
||||
"dpt_swin2_large_384",
|
||||
"dpt_swin2_tiny_256",
|
||||
"dpt_levit_224",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiDaSConfig:
|
||||
model_type: MiDaSModelType = "dpt_beit_large_512"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
_MIDAS_WEIGHTS_URLS = {
|
||||
# 官方权重参见 MiDaS 仓库 README
|
||||
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
|
||||
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
|
||||
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
|
||||
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||
if ckpt_path.is_file():
|
||||
return
|
||||
|
||||
url = _MIDAS_WEIGHTS_URLS.get(model_type)
|
||||
if not url:
|
||||
raise FileNotFoundError(
|
||||
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
|
||||
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||
)
|
||||
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||
print("\nMiDaS 权重下载完成。")
|
||||
|
||||
|
||||
def load_midas_from_config(
|
||||
cfg: MiDaSConfig,
|
||||
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
|
||||
"""
|
||||
加载 MiDaS 模型与对应 transform。
|
||||
返回: model, cfg, transform, net_w, net_h
|
||||
"""
|
||||
# default_models 中给了默认权重路径名
|
||||
model_info = default_models[cfg.model_type]
|
||||
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
|
||||
_download_if_missing(cfg.model_type, ckpt_path)
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
model, transform, net_w, net_h = load_model(
|
||||
device=torch.device(device),
|
||||
model_path=str(ckpt_path),
|
||||
model_type=cfg.model_type,
|
||||
optimize=False,
|
||||
height=None,
|
||||
square=False,
|
||||
)
|
||||
|
||||
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
|
||||
return model, cfg, transform, net_w, net_h
|
||||
|
||||
|
||||
def infer_midas(
|
||||
model: torch.nn.Module,
|
||||
transform: callable,
|
||||
image_rgb: np.ndarray,
|
||||
net_w: int,
|
||||
net_h: int,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||
"""
|
||||
image = transform({"image": image_rgb})["image"]
|
||||
prediction = utils.process(
|
||||
torch.device(device),
|
||||
model,
|
||||
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
|
||||
image=image,
|
||||
input_size=(net_w, net_h),
|
||||
target_size=image_rgb.shape[1::-1],
|
||||
optimize=False,
|
||||
use_camera=False,
|
||||
)
|
||||
return np.asarray(prediction, dtype="float32").squeeze()
|
||||
|
||||
74
python_server/model/Depth/zoe_loader.py
Normal file
74
python_server/model/Depth/zoe_loader.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
|
||||
# 这样其内部的 `import zoedepth...` 才能正常工作。
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
|
||||
if _ZOE_REPO_ROOT.is_dir():
|
||||
zoe_path = str(_ZOE_REPO_ROOT)
|
||||
if zoe_path not in sys.path:
|
||||
sys.path.insert(0, zoe_path)
|
||||
|
||||
from zoedepth.models.builder import build_model
|
||||
from zoedepth.utils.config import get_config
|
||||
|
||||
|
||||
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZoeConfig:
|
||||
"""
|
||||
ZoeDepth 模型选择配置。
|
||||
|
||||
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model: ZoeModelName = "ZoeD_N"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
|
||||
"""
|
||||
手动加载 ZoeDepth 三种模型之一:
|
||||
- "ZoeD_N"
|
||||
- "ZoeD_K"
|
||||
- "ZoeD_NK"
|
||||
"""
|
||||
if name == "ZoeD_N":
|
||||
conf = get_config("zoedepth", "infer")
|
||||
elif name == "ZoeD_K":
|
||||
conf = get_config("zoedepth", "infer", config_version="kitti")
|
||||
elif name == "ZoeD_NK":
|
||||
conf = get_config("zoedepth_nk", "infer")
|
||||
else:
|
||||
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
|
||||
|
||||
model = build_model(conf)
|
||||
|
||||
if device.startswith("cuda") and torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
else:
|
||||
model = model.to("cpu")
|
||||
|
||||
model.eval()
|
||||
return model, conf
|
||||
|
||||
|
||||
def load_zoe_from_config(config: ZoeConfig):
|
||||
"""
|
||||
根据 ZoeConfig 加载模型。
|
||||
|
||||
示例:
|
||||
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
|
||||
model, conf = load_zoe_from_config(cfg)
|
||||
"""
|
||||
return load_zoe_from_name(config.model, config.device)
|
||||
|
||||
Reference in New Issue
Block a user