from __future__ import annotations """ 统一的分割模型加载入口。 当前支持: - SAM (segment-anything) - Mask2Former(使用 HuggingFace transformers 的语义分割实现) """ from dataclasses import dataclass from enum import Enum from typing import Callable, List import sys from pathlib import Path import numpy as np # ----------------------------- # SAM 交互式(SamPredictor),与 AutomaticMaskGenerator 分流缓存 # ----------------------------- _sam_prompt_predictor = None def get_sam_prompt_predictor(): """懒加载 SamPredictor(vit_h),用于点/框提示分割。""" global _sam_prompt_predictor if _sam_prompt_predictor is not None: return _sam_prompt_predictor sam_root = _ensure_sam_on_path() ckpt_path = _download_sam_checkpoint_if_needed(sam_root) from segment_anything import sam_model_registry, SamPredictor # type: ignore[import] import torch device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device) _sam_prompt_predictor = SamPredictor(sam) return _sam_prompt_predictor def run_sam_prompt( image_rgb: np.ndarray, point_coords: np.ndarray, point_labels: np.ndarray, box_xyxy: np.ndarray | None = None, ) -> np.ndarray: """ 使用点提示(必选)与可选矩形框在 RGB 图上分割。 参数均为原图像素坐标: - point_coords: (N, 2) float - point_labels: (N,) int,1=前景点,0=背景 - box_xyxy: (4,) [x1, y1, x2, y2] 或 None 返回:bool 掩膜 (H, W) """ if image_rgb.dtype != np.uint8: image_rgb = np.ascontiguousarray(image_rgb.astype(np.uint8)) else: image_rgb = np.ascontiguousarray(image_rgb) if image_rgb.ndim != 3 or image_rgb.shape[2] != 3: raise ValueError(f"image_rgb 期望 HWC RGB uint8,当前 shape={image_rgb.shape}") predictor = get_sam_prompt_predictor() predictor.set_image(image_rgb) pc = np.asarray(point_coords, dtype=np.float32) pl = np.asarray(point_labels, dtype=np.int64) if pc.ndim != 2 or pc.shape[1] != 2: raise ValueError("point_coords 应为 Nx2") if pl.ndim != 1 or pl.shape[0] != pc.shape[0]: raise ValueError("point_labels 长度须与 point_coords 行数一致") box_arg = None if box_xyxy is not None: b = np.asarray(box_xyxy, dtype=np.float32).reshape(4) box_arg = b masks, scores, _low = predictor.predict( point_coords=pc, point_labels=pl, box=box_arg, multimask_output=True, ) # masks: C x H x W best = int(np.argmax(scores)) m = masks[best] if m.dtype != np.bool_: m = m > 0.5 return m def mask_to_contour_xy( mask_bool: np.ndarray, epsilon_px: float = 2.0, ) -> List[List[float]]: """ 从二值掩膜提取最大外轮廓,并用 Douglas-Peucker 简化。 返回 [[x, y], ...](裁剪图坐标系)。 """ u8 = (np.asarray(mask_bool, dtype=np.uint8) * 255).astype(np.uint8) try: import cv2 # type: ignore[import] contours, _h = cv2.findContours(u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) except Exception: contours = None if not contours: # 无 OpenCV 时的极简回退:外接矩形 ys, xs = np.nonzero(mask_bool) if ys.size == 0: return [] x0, x1 = int(xs.min()), int(xs.max()) y0, y1 = int(ys.min()), int(ys.max()) return [ [float(x0), float(y0)], [float(x1), float(y0)], [float(x1), float(y1)], [float(x0), float(y1)], ] cnt = max(contours, key=cv2.contourArea) if cv2.contourArea(cnt) < 1.0: return [] peri = cv2.arcLength(cnt, True) eps = max(epsilon_px, 0.001 * peri) approx = cv2.approxPolyDP(cnt, eps, True) out: List[List[float]] = [] for p in approx: out.append([float(p[0][0]), float(p[0][1])]) if len(out) >= 3 and (out[0][0] != out[-1][0] or out[0][1] != out[-1][1]): pass # 保持开放折线,前端可自行闭合 return out _THIS_DIR = Path(__file__).resolve().parent class SegBackend(str, Enum): SAM = "sam" MASK2FORMER = "mask2former" @dataclass class UnifiedSegConfig: backend: SegBackend = SegBackend.SAM # ----------------------------- # SAM (Segment Anything) # ----------------------------- def _ensure_sam_on_path() -> Path: sam_root = _THIS_DIR / "segment-anything" if not sam_root.is_dir(): raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}") sam_path = str(sam_root) if sam_path not in sys.path: sys.path.insert(0, sam_path) return sam_root def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path: import requests ckpt_dir = sam_root / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth" if ckpt_path.is_file(): return ckpt_path url = ( "https://dl.fbaipublicfiles.com/segment_anything/" "sam_vit_h_4b8939.pth" ) print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}") resp = requests.get(url, stream=True) resp.raise_for_status() total = int(resp.headers.get("content-length", "0") or "0") downloaded = 0 chunk_size = 1024 * 1024 with ckpt_path.open("wb") as f: for chunk in resp.iter_content(chunk_size=chunk_size): if not chunk: continue f.write(chunk) downloaded += len(chunk) if total > 0: done = int(50 * downloaded / total) print( "\r[{}{}] {:.1f}%".format( "#" * done, "." * (50 - done), downloaded * 100 / total, ), end="", ) print("\nSAM 权重下载完成。") return ckpt_path def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]: """ 返回一个分割函数: - 输入:RGB uint8 图像 (H, W, 3) - 输出:语义标签图 (H, W),每个目标一个 int id(从 1 开始) """ sam_root = _ensure_sam_on_path() ckpt_path = _download_sam_checkpoint_if_needed(sam_root) from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import] import torch device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry["vit_h"]( checkpoint=str(ckpt_path), ).to(device) mask_generator = SamAutomaticMaskGenerator(sam) def _predict(image_rgb: np.ndarray) -> np.ndarray: if image_rgb.dtype != np.uint8: image_rgb_u8 = image_rgb.astype("uint8") else: image_rgb_u8 = image_rgb masks = mask_generator.generate(image_rgb_u8) h, w, _ = image_rgb_u8.shape label_map = np.zeros((h, w), dtype="int32") for idx, m in enumerate(masks, start=1): seg = m.get("segmentation") if seg is None: continue label_map[seg.astype(bool)] = idx return label_map return _predict # ----------------------------- # Mask2Former (占位) # ----------------------------- def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]: from .mask2former_loader import build_mask2former_hf_predictor predictor, _ = build_mask2former_hf_predictor() return predictor # ----------------------------- # 统一构建函数 # ----------------------------- def build_seg_predictor( cfg: UnifiedSegConfig | None = None, ) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]: """ 统一构建分割预测函数。 返回: - predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32 - 实际使用的 backend """ cfg = cfg or UnifiedSegConfig() if cfg.backend == SegBackend.SAM: return _make_sam_predictor(), SegBackend.SAM if cfg.backend == SegBackend.MASK2FORMER: return _make_mask2former_predictor(), SegBackend.MASK2FORMER raise ValueError(f"不支持的分割后端: {cfg.backend}")