from __future__ import annotations import os from pathlib import Path import numpy as np import cv2 from PIL import Image, ImageDraw from scipy.ndimage import label as nd_label from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed # ================= 配置区 ================= INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg" INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png" OUTPUT_DIR = "outputs/test_seg_v2" SEG_BACKEND = SegBackend.SAM # 目标筛选参数 TARGET_MIN_AREA = 1000 # 过滤太小的碎片 TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体 SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制 # 视觉效果 MASK_ALPHA = 0.4 BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色 TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8) SAVE_OBJECT_PNG = True # ========================================= def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]: h, w = arr.shape[:2] if max_side <= 0 or max(h, w) <= max_side: return arr, (h, w), (h, w) scale = float(max_side) / float(max(h, w)) run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale))) resample = Image.NEAREST if is_mask else Image.BICUBIC pil = Image.fromarray(arr) out = pil.resize((run_w, run_h), resample=resample) return np.asarray(out), (h, w), (run_h, run_w) def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20): """ 分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点 """ # 确保是布尔类型 mask_bool = edge_mask > 127 # 连通域标记 labeled_array, num_features = nd_label(mask_bool) prompts = [] component_info = [] for i in range(1, num_features + 1): coords = np.argwhere(labeled_array == i) area = len(coords) if area < 100: continue # 过滤噪声 # 取几何中心作为引导点 center_y, center_x = np.median(coords, axis=0).astype(int) component_info.append(((center_x, center_y), area)) # 按面积排序,优先处理大面积线条覆盖的物体 component_info.sort(key=lambda x: x[1], reverse=True) for pt, _ in component_info[:max_components]: prompts.append({ "point_coords": np.array([pt], dtype=np.float32), "point_labels": np.array([1], dtype=np.int32) }) return prompts def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str): obj_dir = out_dir / f"{stem}.objects" obj_dir.mkdir(parents=True, exist_ok=True) for idx, lb in enumerate(targets, start=1): mask = (label_map == lb) ys, xs = np.where(mask) if len(ys) == 0: continue y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max() rgb_crop = rgb[y1:y2+1, x1:x2+1] alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None] rgba = np.concatenate([rgb_crop, alpha], axis=-1) Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png") def main(): # 1. 加载资源 img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK) out_dir = Path(__file__).parent / OUTPUT_DIR out_dir.mkdir(parents=True, exist_ok=True) rgb_orig = np.asarray(Image.open(img_path).convert("RGB")) mask_orig = np.asarray(Image.open(mask_path).convert("L")) # 2. 准备推理尺寸 (缩放以节省显存) rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE) mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True) # 3. 初始化 SAM(直接使用 SamPredictor,避免 wrapper 接口不支持 point/bbox prompt) import torch _ensure_sam_on_path() sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything" ckpt_path = _download_sam_checkpoint_if_needed(sam_root) from segment_anything import sam_model_registry, SamPredictor # type: ignore[import] device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device) predictor = SamPredictor(sam) print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}") # 4. 提取引导点 prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS) print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点") # 5. 执行推理 final_label_map_run = np.zeros(run_hw, dtype=np.int32) predictor.set_image(rgb_run) for idx, p in enumerate(prompts, start=1): # multimask_output=True 可以获得更稳定的结果 masks, scores, _ = predictor.predict( point_coords=p["point_coords"], point_labels=p["point_labels"], multimask_output=True ) # 挑选得分最高的 mask best_mask = masks[np.argmax(scores)] # 只有面积足够才保留 if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]): # 将新 mask 覆盖到 label_map 上,后续的覆盖前面的 final_label_map_run[best_mask > 0] = idx # 6. 后处理与映射回原图 # 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数) label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST)) # 7. 导出与可视化 unique_labels = [l for l in np.unique(label_map) if l > 0] # 绘制可视化图 marked_img = rgb_orig.copy() draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便 # 混合颜色显示 overlay = rgb_orig.astype(np.float32) for lb in unique_labels: m = (label_map == lb) overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA # 简单边缘处理 contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2) final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0)) final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png") if SAVE_OBJECT_PNG: _save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem) print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}") if __name__ == "__main__": main()