Files
hfut-bishe/python_server/test_seg.py
2026-04-07 20:55:30 +08:00

163 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageDraw
from scipy.ndimage import label as nd_label
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
# ================= 配置区 =================
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
OUTPUT_DIR = "outputs/test_seg_v2"
SEG_BACKEND = SegBackend.SAM
# 目标筛选参数
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
# 视觉效果
MASK_ALPHA = 0.4
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
SAVE_OBJECT_PNG = True
# =========================================
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
h, w = arr.shape[:2]
if max_side <= 0 or max(h, w) <= max_side:
return arr, (h, w), (h, w)
scale = float(max_side) / float(max(h, w))
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
resample = Image.NEAREST if is_mask else Image.BICUBIC
pil = Image.fromarray(arr)
out = pil.resize((run_w, run_h), resample=resample)
return np.asarray(out), (h, w), (run_h, run_w)
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
"""
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
"""
# 确保是布尔类型
mask_bool = edge_mask > 127
# 连通域标记
labeled_array, num_features = nd_label(mask_bool)
prompts = []
component_info = []
for i in range(1, num_features + 1):
coords = np.argwhere(labeled_array == i)
area = len(coords)
if area < 100: continue # 过滤噪声
# 取几何中心作为引导点
center_y, center_x = np.median(coords, axis=0).astype(int)
component_info.append(((center_x, center_y), area))
# 按面积排序,优先处理大面积线条覆盖的物体
component_info.sort(key=lambda x: x[1], reverse=True)
for pt, _ in component_info[:max_components]:
prompts.append({
"point_coords": np.array([pt], dtype=np.float32),
"point_labels": np.array([1], dtype=np.int32)
})
return prompts
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
obj_dir = out_dir / f"{stem}.objects"
obj_dir.mkdir(parents=True, exist_ok=True)
for idx, lb in enumerate(targets, start=1):
mask = (label_map == lb)
ys, xs = np.where(mask)
if len(ys) == 0: continue
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
rgb_crop = rgb[y1:y2+1, x1:x2+1]
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
def main():
# 1. 加载资源
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
out_dir = Path(__file__).parent / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
# 2. 准备推理尺寸 (缩放以节省显存)
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
# 3. 初始化 SAM直接使用 SamPredictor避免 wrapper 接口不支持 point/bbox prompt
import torch
_ensure_sam_on_path()
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
predictor = SamPredictor(sam)
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
# 4. 提取引导点
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
# 5. 执行推理
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
predictor.set_image(rgb_run)
for idx, p in enumerate(prompts, start=1):
# multimask_output=True 可以获得更稳定的结果
masks, scores, _ = predictor.predict(
point_coords=p["point_coords"],
point_labels=p["point_labels"],
multimask_output=True
)
# 挑选得分最高的 mask
best_mask = masks[np.argmax(scores)]
# 只有面积足够才保留
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
final_label_map_run[best_mask > 0] = idx
# 6. 后处理与映射回原图
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
# 7. 导出与可视化
unique_labels = [l for l in np.unique(label_map) if l > 0]
# 绘制可视化图
marked_img = rgb_orig.copy()
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
# 混合颜色显示
overlay = rgb_orig.astype(np.float32)
for lb in unique_labels:
m = (label_map == lb)
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
# 简单边缘处理
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
if SAVE_OBJECT_PNG:
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
if __name__ == "__main__":
main()