Files
hfut-bishe/python_server/model/Seg/seg_loader.py
2026-04-08 14:37:01 +08:00

290 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
"""
统一的分割模型加载入口。
当前支持:
- SAM (segment-anything)
- Mask2Former使用 HuggingFace transformers 的语义分割实现)
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable, List
import sys
from pathlib import Path
import numpy as np
# -----------------------------
# SAM 交互式SamPredictor与 AutomaticMaskGenerator 分流缓存
# -----------------------------
_sam_prompt_predictor = None
def get_sam_prompt_predictor():
"""懒加载 SamPredictorvit_h用于点/框提示分割。"""
global _sam_prompt_predictor
if _sam_prompt_predictor is not None:
return _sam_prompt_predictor
sam_root = _ensure_sam_on_path()
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
_sam_prompt_predictor = SamPredictor(sam)
return _sam_prompt_predictor
def run_sam_prompt(
image_rgb: np.ndarray,
point_coords: np.ndarray,
point_labels: np.ndarray,
box_xyxy: np.ndarray | None = None,
) -> np.ndarray:
"""
使用点提示(必选)与可选矩形框在 RGB 图上分割。
参数均为原图像素坐标:
- point_coords: (N, 2) float
- point_labels: (N,) int1=前景点0=背景
- box_xyxy: (4,) [x1, y1, x2, y2] 或 None
返回bool 掩膜 (H, W)
"""
if image_rgb.dtype != np.uint8:
image_rgb = np.ascontiguousarray(image_rgb.astype(np.uint8))
else:
image_rgb = np.ascontiguousarray(image_rgb)
if image_rgb.ndim != 3 or image_rgb.shape[2] != 3:
raise ValueError(f"image_rgb 期望 HWC RGB uint8当前 shape={image_rgb.shape}")
predictor = get_sam_prompt_predictor()
predictor.set_image(image_rgb)
pc = np.asarray(point_coords, dtype=np.float32)
pl = np.asarray(point_labels, dtype=np.int64)
if pc.ndim != 2 or pc.shape[1] != 2:
raise ValueError("point_coords 应为 Nx2")
if pl.ndim != 1 or pl.shape[0] != pc.shape[0]:
raise ValueError("point_labels 长度须与 point_coords 行数一致")
box_arg = None
if box_xyxy is not None:
b = np.asarray(box_xyxy, dtype=np.float32).reshape(4)
box_arg = b
masks, scores, _low = predictor.predict(
point_coords=pc,
point_labels=pl,
box=box_arg,
multimask_output=True,
)
# masks: C x H x W
best = int(np.argmax(scores))
m = masks[best]
if m.dtype != np.bool_:
m = m > 0.5
return m
def mask_to_contour_xy(
mask_bool: np.ndarray,
epsilon_px: float = 2.0,
) -> List[List[float]]:
"""
从二值掩膜提取最大外轮廓,并用 Douglas-Peucker 简化。
返回 [[x, y], ...](裁剪图坐标系)。
"""
u8 = (np.asarray(mask_bool, dtype=np.uint8) * 255).astype(np.uint8)
try:
import cv2 # type: ignore[import]
contours, _h = cv2.findContours(u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
except Exception:
contours = None
if not contours:
# 无 OpenCV 时的极简回退:外接矩形
ys, xs = np.nonzero(mask_bool)
if ys.size == 0:
return []
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
return [
[float(x0), float(y0)],
[float(x1), float(y0)],
[float(x1), float(y1)],
[float(x0), float(y1)],
]
cnt = max(contours, key=cv2.contourArea)
if cv2.contourArea(cnt) < 1.0:
return []
peri = cv2.arcLength(cnt, True)
eps = max(epsilon_px, 0.001 * peri)
approx = cv2.approxPolyDP(cnt, eps, True)
out: List[List[float]] = []
for p in approx:
out.append([float(p[0][0]), float(p[0][1])])
if len(out) >= 3 and (out[0][0] != out[-1][0] or out[0][1] != out[-1][1]):
pass # 保持开放折线,前端可自行闭合
return out
_THIS_DIR = Path(__file__).resolve().parent
class SegBackend(str, Enum):
SAM = "sam"
MASK2FORMER = "mask2former"
@dataclass
class UnifiedSegConfig:
backend: SegBackend = SegBackend.SAM
# -----------------------------
# SAM (Segment Anything)
# -----------------------------
def _ensure_sam_on_path() -> Path:
sam_root = _THIS_DIR / "segment-anything"
if not sam_root.is_dir():
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
sam_path = str(sam_root)
if sam_path not in sys.path:
sys.path.insert(0, sam_path)
return sam_root
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
import requests
ckpt_dir = sam_root / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
if ckpt_path.is_file():
return ckpt_path
url = (
"https://dl.fbaipublicfiles.com/segment_anything/"
"sam_vit_h_4b8939.pth"
)
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print(
"\r[{}{}] {:.1f}%".format(
"#" * done,
"." * (50 - done),
downloaded * 100 / total,
),
end="",
)
print("\nSAM 权重下载完成。")
return ckpt_path
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
"""
返回一个分割函数:
- 输入RGB uint8 图像 (H, W, 3)
- 输出:语义标签图 (H, W),每个目标一个 int id从 1 开始)
"""
sam_root = _ensure_sam_on_path()
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](
checkpoint=str(ckpt_path),
).to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
masks = mask_generator.generate(image_rgb_u8)
h, w, _ = image_rgb_u8.shape
label_map = np.zeros((h, w), dtype="int32")
for idx, m in enumerate(masks, start=1):
seg = m.get("segmentation")
if seg is None:
continue
label_map[seg.astype(bool)] = idx
return label_map
return _predict
# -----------------------------
# Mask2Former (占位)
# -----------------------------
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
from .mask2former_loader import build_mask2former_hf_predictor
predictor, _ = build_mask2former_hf_predictor()
return predictor
# -----------------------------
# 统一构建函数
# -----------------------------
def build_seg_predictor(
cfg: UnifiedSegConfig | None = None,
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
"""
统一构建分割预测函数。
返回:
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
- 实际使用的 backend
"""
cfg = cfg or UnifiedSegConfig()
if cfg.backend == SegBackend.SAM:
return _make_sam_predictor(), SegBackend.SAM
if cfg.backend == SegBackend.MASK2FORMER:
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
raise ValueError(f"不支持的分割后端: {cfg.backend}")