添加模型分割
This commit is contained in:
@@ -10,13 +10,134 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
from typing import Callable, List
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# -----------------------------
|
||||
# SAM 交互式(SamPredictor),与 AutomaticMaskGenerator 分流缓存
|
||||
# -----------------------------
|
||||
|
||||
_sam_prompt_predictor = None
|
||||
|
||||
|
||||
def get_sam_prompt_predictor():
|
||||
"""懒加载 SamPredictor(vit_h),用于点/框提示分割。"""
|
||||
global _sam_prompt_predictor
|
||||
if _sam_prompt_predictor is not None:
|
||||
return _sam_prompt_predictor
|
||||
|
||||
sam_root = _ensure_sam_on_path()
|
||||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||
|
||||
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
|
||||
_sam_prompt_predictor = SamPredictor(sam)
|
||||
return _sam_prompt_predictor
|
||||
|
||||
|
||||
def run_sam_prompt(
|
||||
image_rgb: np.ndarray,
|
||||
point_coords: np.ndarray,
|
||||
point_labels: np.ndarray,
|
||||
box_xyxy: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用点提示(必选)与可选矩形框在 RGB 图上分割。
|
||||
|
||||
参数均为原图像素坐标:
|
||||
- point_coords: (N, 2) float
|
||||
- point_labels: (N,) int,1=前景点,0=背景
|
||||
- box_xyxy: (4,) [x1, y1, x2, y2] 或 None
|
||||
|
||||
返回:bool 掩膜 (H, W)
|
||||
"""
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb = np.ascontiguousarray(image_rgb.astype(np.uint8))
|
||||
else:
|
||||
image_rgb = np.ascontiguousarray(image_rgb)
|
||||
|
||||
if image_rgb.ndim != 3 or image_rgb.shape[2] != 3:
|
||||
raise ValueError(f"image_rgb 期望 HWC RGB uint8,当前 shape={image_rgb.shape}")
|
||||
|
||||
predictor = get_sam_prompt_predictor()
|
||||
predictor.set_image(image_rgb)
|
||||
|
||||
pc = np.asarray(point_coords, dtype=np.float32)
|
||||
pl = np.asarray(point_labels, dtype=np.int64)
|
||||
if pc.ndim != 2 or pc.shape[1] != 2:
|
||||
raise ValueError("point_coords 应为 Nx2")
|
||||
if pl.ndim != 1 or pl.shape[0] != pc.shape[0]:
|
||||
raise ValueError("point_labels 长度须与 point_coords 行数一致")
|
||||
|
||||
box_arg = None
|
||||
if box_xyxy is not None:
|
||||
b = np.asarray(box_xyxy, dtype=np.float32).reshape(4)
|
||||
box_arg = b
|
||||
|
||||
masks, scores, _low = predictor.predict(
|
||||
point_coords=pc,
|
||||
point_labels=pl,
|
||||
box=box_arg,
|
||||
multimask_output=True,
|
||||
)
|
||||
# masks: C x H x W
|
||||
best = int(np.argmax(scores))
|
||||
m = masks[best]
|
||||
if m.dtype != np.bool_:
|
||||
m = m > 0.5
|
||||
return m
|
||||
|
||||
|
||||
def mask_to_contour_xy(
|
||||
mask_bool: np.ndarray,
|
||||
epsilon_px: float = 2.0,
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
从二值掩膜提取最大外轮廓,并用 Douglas-Peucker 简化。
|
||||
返回 [[x, y], ...](裁剪图坐标系)。
|
||||
"""
|
||||
u8 = (np.asarray(mask_bool, dtype=np.uint8) * 255).astype(np.uint8)
|
||||
try:
|
||||
import cv2 # type: ignore[import]
|
||||
|
||||
contours, _h = cv2.findContours(u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
except Exception:
|
||||
contours = None
|
||||
|
||||
if not contours:
|
||||
# 无 OpenCV 时的极简回退:外接矩形
|
||||
ys, xs = np.nonzero(mask_bool)
|
||||
if ys.size == 0:
|
||||
return []
|
||||
x0, x1 = int(xs.min()), int(xs.max())
|
||||
y0, y1 = int(ys.min()), int(ys.max())
|
||||
return [
|
||||
[float(x0), float(y0)],
|
||||
[float(x1), float(y0)],
|
||||
[float(x1), float(y1)],
|
||||
[float(x0), float(y1)],
|
||||
]
|
||||
|
||||
cnt = max(contours, key=cv2.contourArea)
|
||||
if cv2.contourArea(cnt) < 1.0:
|
||||
return []
|
||||
peri = cv2.arcLength(cnt, True)
|
||||
eps = max(epsilon_px, 0.001 * peri)
|
||||
approx = cv2.approxPolyDP(cnt, eps, True)
|
||||
out: List[List[float]] = []
|
||||
for p in approx:
|
||||
out.append([float(p[0][0]), float(p[0][1])])
|
||||
if len(out) >= 3 and (out[0][0] != out[-1][0] or out[0][1] != out[-1][1]):
|
||||
pass # 保持开放折线,前端可自行闭合
|
||||
return out
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,13 @@ from PIL import Image, ImageDraw
|
||||
from config_loader import load_app_config, get_depth_backend_from_app
|
||||
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
|
||||
|
||||
from model.Seg.seg_loader import UnifiedSegConfig, SegBackend, build_seg_predictor
|
||||
from model.Seg.seg_loader import (
|
||||
UnifiedSegConfig,
|
||||
SegBackend,
|
||||
build_seg_predictor,
|
||||
mask_to_contour_xy,
|
||||
run_sam_prompt,
|
||||
)
|
||||
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
|
||||
from model.Animation.animation_loader import (
|
||||
UnifiedAnimationConfig,
|
||||
@@ -47,6 +53,28 @@ class SegmentRequest(ImageInput):
|
||||
pass
|
||||
|
||||
|
||||
class SamPromptSegmentRequest(BaseModel):
|
||||
image_b64: str = Field(..., description="裁剪后的 RGB 图 base64(PNG/JPG)")
|
||||
overlay_b64: Optional[str] = Field(
|
||||
None,
|
||||
description="与裁剪同尺寸的标记叠加 PNG base64(可选;当前用于校验尺寸一致)",
|
||||
)
|
||||
point_coords: list[list[float]] = Field(
|
||||
...,
|
||||
description="裁剪坐标系下的提示点 [[x,y], ...]",
|
||||
)
|
||||
point_labels: list[int] = Field(
|
||||
...,
|
||||
description="与 point_coords 等长:1=前景,0=背景",
|
||||
)
|
||||
box_xyxy: list[float] = Field(
|
||||
...,
|
||||
description="裁剪内笔画紧包围盒 [x1,y1,x2,y2](像素)",
|
||||
min_length=4,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
|
||||
class InpaintRequest(ImageInput):
|
||||
prompt: Optional[str] = Field("", description="补全 prompt")
|
||||
strength: float = Field(0.8, ge=0.0, le=1.0)
|
||||
@@ -287,6 +315,54 @@ def segment(req: SegmentRequest) -> Dict[str, Any]:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.post("/segment/sam_prompt")
|
||||
def segment_sam_prompt(req: SamPromptSegmentRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
交互式 SAM:裁剪图 + 点/框提示,返回掩膜外轮廓点列(裁剪像素坐标)。
|
||||
"""
|
||||
try:
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
rgb = np.array(pil, dtype=np.uint8)
|
||||
h, w = rgb.shape[0], rgb.shape[1]
|
||||
|
||||
if req.overlay_b64:
|
||||
ov = _b64_to_pil_image(req.overlay_b64)
|
||||
if ov.size != (w, h):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"overlay 尺寸 {ov.size} 与 image {w}x{h} 不一致",
|
||||
"contour": [],
|
||||
}
|
||||
|
||||
if len(req.point_coords) != len(req.point_labels):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "point_coords 与 point_labels 长度不一致",
|
||||
"contour": [],
|
||||
}
|
||||
if len(req.point_coords) < 1:
|
||||
return {"success": False, "error": "至少需要一个提示点", "contour": []}
|
||||
|
||||
pc = np.array(req.point_coords, dtype=np.float32)
|
||||
if pc.ndim != 2 or pc.shape[1] != 2:
|
||||
return {"success": False, "error": "point_coords 每项须为 [x,y]", "contour": []}
|
||||
|
||||
pl = np.array(req.point_labels, dtype=np.int64)
|
||||
box = np.array(req.box_xyxy, dtype=np.float32)
|
||||
|
||||
mask = run_sam_prompt(rgb, pc, pl, box_xyxy=box)
|
||||
if not np.any(mask):
|
||||
return {"success": False, "error": "SAM 未产生有效掩膜", "contour": []}
|
||||
|
||||
contour = mask_to_contour_xy(mask, epsilon_px=2.0)
|
||||
if len(contour) < 3:
|
||||
return {"success": False, "error": "轮廓点数不足", "contour": []}
|
||||
|
||||
return {"success": True, "contour": contour, "error": None}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e), "contour": []}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Inpaint
|
||||
# -----------------------------
|
||||
|
||||
Reference in New Issue
Block a user