initial commit
This commit is contained in:
163
python_server/test_seg.py
Normal file
163
python_server/test_seg.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw
|
||||
from scipy.ndimage import label as nd_label
|
||||
|
||||
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
|
||||
|
||||
# ================= 配置区 =================
|
||||
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
|
||||
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
|
||||
OUTPUT_DIR = "outputs/test_seg_v2"
|
||||
SEG_BACKEND = SegBackend.SAM
|
||||
|
||||
# 目标筛选参数
|
||||
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
|
||||
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
|
||||
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
|
||||
|
||||
# 视觉效果
|
||||
MASK_ALPHA = 0.4
|
||||
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
|
||||
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
|
||||
SAVE_OBJECT_PNG = True
|
||||
# =========================================
|
||||
|
||||
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
||||
h, w = arr.shape[:2]
|
||||
if max_side <= 0 or max(h, w) <= max_side:
|
||||
return arr, (h, w), (h, w)
|
||||
scale = float(max_side) / float(max(h, w))
|
||||
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
|
||||
resample = Image.NEAREST if is_mask else Image.BICUBIC
|
||||
pil = Image.fromarray(arr)
|
||||
out = pil.resize((run_w, run_h), resample=resample)
|
||||
return np.asarray(out), (h, w), (run_h, run_w)
|
||||
|
||||
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
|
||||
"""
|
||||
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
|
||||
"""
|
||||
# 确保是布尔类型
|
||||
mask_bool = edge_mask > 127
|
||||
# 连通域标记
|
||||
labeled_array, num_features = nd_label(mask_bool)
|
||||
|
||||
prompts = []
|
||||
component_info = []
|
||||
for i in range(1, num_features + 1):
|
||||
coords = np.argwhere(labeled_array == i)
|
||||
area = len(coords)
|
||||
if area < 100: continue # 过滤噪声
|
||||
# 取几何中心作为引导点
|
||||
center_y, center_x = np.median(coords, axis=0).astype(int)
|
||||
component_info.append(((center_x, center_y), area))
|
||||
|
||||
# 按面积排序,优先处理大面积线条覆盖的物体
|
||||
component_info.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for pt, _ in component_info[:max_components]:
|
||||
prompts.append({
|
||||
"point_coords": np.array([pt], dtype=np.float32),
|
||||
"point_labels": np.array([1], dtype=np.int32)
|
||||
})
|
||||
return prompts
|
||||
|
||||
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
|
||||
obj_dir = out_dir / f"{stem}.objects"
|
||||
obj_dir.mkdir(parents=True, exist_ok=True)
|
||||
for idx, lb in enumerate(targets, start=1):
|
||||
mask = (label_map == lb)
|
||||
ys, xs = np.where(mask)
|
||||
if len(ys) == 0: continue
|
||||
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
|
||||
rgb_crop = rgb[y1:y2+1, x1:x2+1]
|
||||
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
|
||||
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
|
||||
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
|
||||
|
||||
def main():
|
||||
# 1. 加载资源
|
||||
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
|
||||
out_dir = Path(__file__).parent / OUTPUT_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
|
||||
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
|
||||
|
||||
# 2. 准备推理尺寸 (缩放以节省显存)
|
||||
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
|
||||
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
|
||||
|
||||
# 3. 初始化 SAM(直接使用 SamPredictor,避免 wrapper 接口不支持 point/bbox prompt)
|
||||
import torch
|
||||
|
||||
_ensure_sam_on_path()
|
||||
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
|
||||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
|
||||
|
||||
# 4. 提取引导点
|
||||
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
|
||||
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
|
||||
|
||||
# 5. 执行推理
|
||||
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
|
||||
predictor.set_image(rgb_run)
|
||||
|
||||
for idx, p in enumerate(prompts, start=1):
|
||||
# multimask_output=True 可以获得更稳定的结果
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=p["point_coords"],
|
||||
point_labels=p["point_labels"],
|
||||
multimask_output=True
|
||||
)
|
||||
# 挑选得分最高的 mask
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# 只有面积足够才保留
|
||||
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
|
||||
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
|
||||
final_label_map_run[best_mask > 0] = idx
|
||||
|
||||
# 6. 后处理与映射回原图
|
||||
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
|
||||
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
|
||||
|
||||
# 7. 导出与可视化
|
||||
unique_labels = [l for l in np.unique(label_map) if l > 0]
|
||||
|
||||
# 绘制可视化图
|
||||
marked_img = rgb_orig.copy()
|
||||
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
|
||||
|
||||
# 混合颜色显示
|
||||
overlay = rgb_orig.astype(np.float32)
|
||||
for lb in unique_labels:
|
||||
m = (label_map == lb)
|
||||
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
|
||||
# 简单边缘处理
|
||||
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
|
||||
|
||||
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
|
||||
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
|
||||
|
||||
if SAVE_OBJECT_PNG:
|
||||
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
|
||||
|
||||
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user