添加模型分割

This commit is contained in:
2026-04-08 14:37:01 +08:00
parent 088dd91e27
commit a79c31a056
17 changed files with 1327 additions and 183 deletions

View File

@@ -10,13 +10,134 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Callable
from typing import Callable, List
import sys
from pathlib import Path
import numpy as np
# -----------------------------
# SAM 交互式SamPredictor与 AutomaticMaskGenerator 分流缓存
# -----------------------------
_sam_prompt_predictor = None
def get_sam_prompt_predictor():
"""懒加载 SamPredictorvit_h用于点/框提示分割。"""
global _sam_prompt_predictor
if _sam_prompt_predictor is not None:
return _sam_prompt_predictor
sam_root = _ensure_sam_on_path()
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
_sam_prompt_predictor = SamPredictor(sam)
return _sam_prompt_predictor
def run_sam_prompt(
image_rgb: np.ndarray,
point_coords: np.ndarray,
point_labels: np.ndarray,
box_xyxy: np.ndarray | None = None,
) -> np.ndarray:
"""
使用点提示(必选)与可选矩形框在 RGB 图上分割。
参数均为原图像素坐标:
- point_coords: (N, 2) float
- point_labels: (N,) int1=前景点0=背景
- box_xyxy: (4,) [x1, y1, x2, y2] 或 None
返回bool 掩膜 (H, W)
"""
if image_rgb.dtype != np.uint8:
image_rgb = np.ascontiguousarray(image_rgb.astype(np.uint8))
else:
image_rgb = np.ascontiguousarray(image_rgb)
if image_rgb.ndim != 3 or image_rgb.shape[2] != 3:
raise ValueError(f"image_rgb 期望 HWC RGB uint8当前 shape={image_rgb.shape}")
predictor = get_sam_prompt_predictor()
predictor.set_image(image_rgb)
pc = np.asarray(point_coords, dtype=np.float32)
pl = np.asarray(point_labels, dtype=np.int64)
if pc.ndim != 2 or pc.shape[1] != 2:
raise ValueError("point_coords 应为 Nx2")
if pl.ndim != 1 or pl.shape[0] != pc.shape[0]:
raise ValueError("point_labels 长度须与 point_coords 行数一致")
box_arg = None
if box_xyxy is not None:
b = np.asarray(box_xyxy, dtype=np.float32).reshape(4)
box_arg = b
masks, scores, _low = predictor.predict(
point_coords=pc,
point_labels=pl,
box=box_arg,
multimask_output=True,
)
# masks: C x H x W
best = int(np.argmax(scores))
m = masks[best]
if m.dtype != np.bool_:
m = m > 0.5
return m
def mask_to_contour_xy(
mask_bool: np.ndarray,
epsilon_px: float = 2.0,
) -> List[List[float]]:
"""
从二值掩膜提取最大外轮廓,并用 Douglas-Peucker 简化。
返回 [[x, y], ...](裁剪图坐标系)。
"""
u8 = (np.asarray(mask_bool, dtype=np.uint8) * 255).astype(np.uint8)
try:
import cv2 # type: ignore[import]
contours, _h = cv2.findContours(u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
except Exception:
contours = None
if not contours:
# 无 OpenCV 时的极简回退:外接矩形
ys, xs = np.nonzero(mask_bool)
if ys.size == 0:
return []
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
return [
[float(x0), float(y0)],
[float(x1), float(y0)],
[float(x1), float(y1)],
[float(x0), float(y1)],
]
cnt = max(contours, key=cv2.contourArea)
if cv2.contourArea(cnt) < 1.0:
return []
peri = cv2.arcLength(cnt, True)
eps = max(epsilon_px, 0.001 * peri)
approx = cv2.approxPolyDP(cnt, eps, True)
out: List[List[float]] = []
for p in approx:
out.append([float(p[0][0]), float(p[0][1])])
if len(out) >= 3 and (out[0][0] != out[-1][0] or out[0][1] != out[-1][1]):
pass # 保持开放折线,前端可自行闭合
return out
_THIS_DIR = Path(__file__).resolve().parent