Files
hfut-bishe/python_server/model/Seg/seg_loader.py
2026-04-07 20:55:30 +08:00

169 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
"""
统一的分割模型加载入口。
当前支持:
- SAM (segment-anything)
- Mask2Former使用 HuggingFace transformers 的语义分割实现)
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import sys
from pathlib import Path
import numpy as np
_THIS_DIR = Path(__file__).resolve().parent
class SegBackend(str, Enum):
SAM = "sam"
MASK2FORMER = "mask2former"
@dataclass
class UnifiedSegConfig:
backend: SegBackend = SegBackend.SAM
# -----------------------------
# SAM (Segment Anything)
# -----------------------------
def _ensure_sam_on_path() -> Path:
sam_root = _THIS_DIR / "segment-anything"
if not sam_root.is_dir():
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
sam_path = str(sam_root)
if sam_path not in sys.path:
sys.path.insert(0, sam_path)
return sam_root
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
import requests
ckpt_dir = sam_root / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
if ckpt_path.is_file():
return ckpt_path
url = (
"https://dl.fbaipublicfiles.com/segment_anything/"
"sam_vit_h_4b8939.pth"
)
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print(
"\r[{}{}] {:.1f}%".format(
"#" * done,
"." * (50 - done),
downloaded * 100 / total,
),
end="",
)
print("\nSAM 权重下载完成。")
return ckpt_path
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
"""
返回一个分割函数:
- 输入RGB uint8 图像 (H, W, 3)
- 输出:语义标签图 (H, W),每个目标一个 int id从 1 开始)
"""
sam_root = _ensure_sam_on_path()
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](
checkpoint=str(ckpt_path),
).to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
masks = mask_generator.generate(image_rgb_u8)
h, w, _ = image_rgb_u8.shape
label_map = np.zeros((h, w), dtype="int32")
for idx, m in enumerate(masks, start=1):
seg = m.get("segmentation")
if seg is None:
continue
label_map[seg.astype(bool)] = idx
return label_map
return _predict
# -----------------------------
# Mask2Former (占位)
# -----------------------------
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
from .mask2former_loader import build_mask2former_hf_predictor
predictor, _ = build_mask2former_hf_predictor()
return predictor
# -----------------------------
# 统一构建函数
# -----------------------------
def build_seg_predictor(
cfg: UnifiedSegConfig | None = None,
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
"""
统一构建分割预测函数。
返回:
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
- 实际使用的 backend
"""
cfg = cfg or UnifiedSegConfig()
if cfg.backend == SegBackend.SAM:
return _make_sam_predictor(), SegBackend.SAM
if cfg.backend == SegBackend.MASK2FORMER:
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
raise ValueError(f"不支持的分割后端: {cfg.backend}")