290 lines
8.1 KiB
Python
290 lines
8.1 KiB
Python
from __future__ import annotations
|
||
|
||
"""
|
||
统一的分割模型加载入口。
|
||
|
||
当前支持:
|
||
- SAM (segment-anything)
|
||
- Mask2Former(使用 HuggingFace transformers 的语义分割实现)
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from typing import Callable, List
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
|
||
# -----------------------------
|
||
# SAM 交互式(SamPredictor),与 AutomaticMaskGenerator 分流缓存
|
||
# -----------------------------
|
||
|
||
_sam_prompt_predictor = None
|
||
|
||
|
||
def get_sam_prompt_predictor():
|
||
"""懒加载 SamPredictor(vit_h),用于点/框提示分割。"""
|
||
global _sam_prompt_predictor
|
||
if _sam_prompt_predictor is not None:
|
||
return _sam_prompt_predictor
|
||
|
||
sam_root = _ensure_sam_on_path()
|
||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||
|
||
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
|
||
import torch
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
|
||
_sam_prompt_predictor = SamPredictor(sam)
|
||
return _sam_prompt_predictor
|
||
|
||
|
||
def run_sam_prompt(
|
||
image_rgb: np.ndarray,
|
||
point_coords: np.ndarray,
|
||
point_labels: np.ndarray,
|
||
box_xyxy: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""
|
||
使用点提示(必选)与可选矩形框在 RGB 图上分割。
|
||
|
||
参数均为原图像素坐标:
|
||
- point_coords: (N, 2) float
|
||
- point_labels: (N,) int,1=前景点,0=背景
|
||
- box_xyxy: (4,) [x1, y1, x2, y2] 或 None
|
||
|
||
返回:bool 掩膜 (H, W)
|
||
"""
|
||
if image_rgb.dtype != np.uint8:
|
||
image_rgb = np.ascontiguousarray(image_rgb.astype(np.uint8))
|
||
else:
|
||
image_rgb = np.ascontiguousarray(image_rgb)
|
||
|
||
if image_rgb.ndim != 3 or image_rgb.shape[2] != 3:
|
||
raise ValueError(f"image_rgb 期望 HWC RGB uint8,当前 shape={image_rgb.shape}")
|
||
|
||
predictor = get_sam_prompt_predictor()
|
||
predictor.set_image(image_rgb)
|
||
|
||
pc = np.asarray(point_coords, dtype=np.float32)
|
||
pl = np.asarray(point_labels, dtype=np.int64)
|
||
if pc.ndim != 2 or pc.shape[1] != 2:
|
||
raise ValueError("point_coords 应为 Nx2")
|
||
if pl.ndim != 1 or pl.shape[0] != pc.shape[0]:
|
||
raise ValueError("point_labels 长度须与 point_coords 行数一致")
|
||
|
||
box_arg = None
|
||
if box_xyxy is not None:
|
||
b = np.asarray(box_xyxy, dtype=np.float32).reshape(4)
|
||
box_arg = b
|
||
|
||
masks, scores, _low = predictor.predict(
|
||
point_coords=pc,
|
||
point_labels=pl,
|
||
box=box_arg,
|
||
multimask_output=True,
|
||
)
|
||
# masks: C x H x W
|
||
best = int(np.argmax(scores))
|
||
m = masks[best]
|
||
if m.dtype != np.bool_:
|
||
m = m > 0.5
|
||
return m
|
||
|
||
|
||
def mask_to_contour_xy(
|
||
mask_bool: np.ndarray,
|
||
epsilon_px: float = 2.0,
|
||
) -> List[List[float]]:
|
||
"""
|
||
从二值掩膜提取最大外轮廓,并用 Douglas-Peucker 简化。
|
||
返回 [[x, y], ...](裁剪图坐标系)。
|
||
"""
|
||
u8 = (np.asarray(mask_bool, dtype=np.uint8) * 255).astype(np.uint8)
|
||
try:
|
||
import cv2 # type: ignore[import]
|
||
|
||
contours, _h = cv2.findContours(u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
except Exception:
|
||
contours = None
|
||
|
||
if not contours:
|
||
# 无 OpenCV 时的极简回退:外接矩形
|
||
ys, xs = np.nonzero(mask_bool)
|
||
if ys.size == 0:
|
||
return []
|
||
x0, x1 = int(xs.min()), int(xs.max())
|
||
y0, y1 = int(ys.min()), int(ys.max())
|
||
return [
|
||
[float(x0), float(y0)],
|
||
[float(x1), float(y0)],
|
||
[float(x1), float(y1)],
|
||
[float(x0), float(y1)],
|
||
]
|
||
|
||
cnt = max(contours, key=cv2.contourArea)
|
||
if cv2.contourArea(cnt) < 1.0:
|
||
return []
|
||
peri = cv2.arcLength(cnt, True)
|
||
eps = max(epsilon_px, 0.001 * peri)
|
||
approx = cv2.approxPolyDP(cnt, eps, True)
|
||
out: List[List[float]] = []
|
||
for p in approx:
|
||
out.append([float(p[0][0]), float(p[0][1])])
|
||
if len(out) >= 3 and (out[0][0] != out[-1][0] or out[0][1] != out[-1][1]):
|
||
pass # 保持开放折线,前端可自行闭合
|
||
return out
|
||
|
||
_THIS_DIR = Path(__file__).resolve().parent
|
||
|
||
|
||
class SegBackend(str, Enum):
|
||
SAM = "sam"
|
||
MASK2FORMER = "mask2former"
|
||
|
||
|
||
@dataclass
|
||
class UnifiedSegConfig:
|
||
backend: SegBackend = SegBackend.SAM
|
||
|
||
|
||
# -----------------------------
|
||
# SAM (Segment Anything)
|
||
# -----------------------------
|
||
|
||
|
||
def _ensure_sam_on_path() -> Path:
|
||
sam_root = _THIS_DIR / "segment-anything"
|
||
if not sam_root.is_dir():
|
||
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
|
||
sam_path = str(sam_root)
|
||
if sam_path not in sys.path:
|
||
sys.path.insert(0, sam_path)
|
||
return sam_root
|
||
|
||
|
||
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
|
||
import requests
|
||
|
||
ckpt_dir = sam_root / "checkpoints"
|
||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
|
||
|
||
if ckpt_path.is_file():
|
||
return ckpt_path
|
||
|
||
url = (
|
||
"https://dl.fbaipublicfiles.com/segment_anything/"
|
||
"sam_vit_h_4b8939.pth"
|
||
)
|
||
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
|
||
|
||
resp = requests.get(url, stream=True)
|
||
resp.raise_for_status()
|
||
|
||
total = int(resp.headers.get("content-length", "0") or "0")
|
||
downloaded = 0
|
||
chunk_size = 1024 * 1024
|
||
|
||
with ckpt_path.open("wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||
if not chunk:
|
||
continue
|
||
f.write(chunk)
|
||
downloaded += len(chunk)
|
||
if total > 0:
|
||
done = int(50 * downloaded / total)
|
||
print(
|
||
"\r[{}{}] {:.1f}%".format(
|
||
"#" * done,
|
||
"." * (50 - done),
|
||
downloaded * 100 / total,
|
||
),
|
||
end="",
|
||
)
|
||
print("\nSAM 权重下载完成。")
|
||
return ckpt_path
|
||
|
||
|
||
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||
"""
|
||
返回一个分割函数:
|
||
- 输入:RGB uint8 图像 (H, W, 3)
|
||
- 输出:语义标签图 (H, W),每个目标一个 int id(从 1 开始)
|
||
"""
|
||
sam_root = _ensure_sam_on_path()
|
||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||
|
||
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
|
||
import torch
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
sam = sam_model_registry["vit_h"](
|
||
checkpoint=str(ckpt_path),
|
||
).to(device)
|
||
|
||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||
|
||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||
if image_rgb.dtype != np.uint8:
|
||
image_rgb_u8 = image_rgb.astype("uint8")
|
||
else:
|
||
image_rgb_u8 = image_rgb
|
||
|
||
masks = mask_generator.generate(image_rgb_u8)
|
||
h, w, _ = image_rgb_u8.shape
|
||
label_map = np.zeros((h, w), dtype="int32")
|
||
|
||
for idx, m in enumerate(masks, start=1):
|
||
seg = m.get("segmentation")
|
||
if seg is None:
|
||
continue
|
||
label_map[seg.astype(bool)] = idx
|
||
|
||
return label_map
|
||
|
||
return _predict
|
||
|
||
|
||
# -----------------------------
|
||
# Mask2Former (占位)
|
||
# -----------------------------
|
||
|
||
|
||
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||
from .mask2former_loader import build_mask2former_hf_predictor
|
||
|
||
predictor, _ = build_mask2former_hf_predictor()
|
||
return predictor
|
||
|
||
|
||
# -----------------------------
|
||
# 统一构建函数
|
||
# -----------------------------
|
||
|
||
|
||
def build_seg_predictor(
|
||
cfg: UnifiedSegConfig | None = None,
|
||
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
|
||
"""
|
||
统一构建分割预测函数。
|
||
|
||
返回:
|
||
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
|
||
- 实际使用的 backend
|
||
"""
|
||
cfg = cfg or UnifiedSegConfig()
|
||
|
||
if cfg.backend == SegBackend.SAM:
|
||
return _make_sam_predictor(), SegBackend.SAM
|
||
|
||
if cfg.backend == SegBackend.MASK2FORMER:
|
||
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
|
||
|
||
raise ValueError(f"不支持的分割后端: {cfg.backend}")
|
||
|