initial commit

This commit is contained in:
2026-04-07 20:55:30 +08:00
commit 81d1fb7856
84 changed files with 11929 additions and 0 deletions

163
python_server/test_seg.py Normal file
View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageDraw
from scipy.ndimage import label as nd_label
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
# ================= 配置区 =================
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
OUTPUT_DIR = "outputs/test_seg_v2"
SEG_BACKEND = SegBackend.SAM
# 目标筛选参数
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
# 视觉效果
MASK_ALPHA = 0.4
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
SAVE_OBJECT_PNG = True
# =========================================
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
h, w = arr.shape[:2]
if max_side <= 0 or max(h, w) <= max_side:
return arr, (h, w), (h, w)
scale = float(max_side) / float(max(h, w))
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
resample = Image.NEAREST if is_mask else Image.BICUBIC
pil = Image.fromarray(arr)
out = pil.resize((run_w, run_h), resample=resample)
return np.asarray(out), (h, w), (run_h, run_w)
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
"""
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
"""
# 确保是布尔类型
mask_bool = edge_mask > 127
# 连通域标记
labeled_array, num_features = nd_label(mask_bool)
prompts = []
component_info = []
for i in range(1, num_features + 1):
coords = np.argwhere(labeled_array == i)
area = len(coords)
if area < 100: continue # 过滤噪声
# 取几何中心作为引导点
center_y, center_x = np.median(coords, axis=0).astype(int)
component_info.append(((center_x, center_y), area))
# 按面积排序,优先处理大面积线条覆盖的物体
component_info.sort(key=lambda x: x[1], reverse=True)
for pt, _ in component_info[:max_components]:
prompts.append({
"point_coords": np.array([pt], dtype=np.float32),
"point_labels": np.array([1], dtype=np.int32)
})
return prompts
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
obj_dir = out_dir / f"{stem}.objects"
obj_dir.mkdir(parents=True, exist_ok=True)
for idx, lb in enumerate(targets, start=1):
mask = (label_map == lb)
ys, xs = np.where(mask)
if len(ys) == 0: continue
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
rgb_crop = rgb[y1:y2+1, x1:x2+1]
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
def main():
# 1. 加载资源
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
out_dir = Path(__file__).parent / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
# 2. 准备推理尺寸 (缩放以节省显存)
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
# 3. 初始化 SAM直接使用 SamPredictor避免 wrapper 接口不支持 point/bbox prompt
import torch
_ensure_sam_on_path()
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
predictor = SamPredictor(sam)
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
# 4. 提取引导点
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
# 5. 执行推理
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
predictor.set_image(rgb_run)
for idx, p in enumerate(prompts, start=1):
# multimask_output=True 可以获得更稳定的结果
masks, scores, _ = predictor.predict(
point_coords=p["point_coords"],
point_labels=p["point_labels"],
multimask_output=True
)
# 挑选得分最高的 mask
best_mask = masks[np.argmax(scores)]
# 只有面积足够才保留
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
final_label_map_run[best_mask > 0] = idx
# 6. 后处理与映射回原图
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
# 7. 导出与可视化
unique_labels = [l for l in np.unique(label_map) if l > 0]
# 绘制可视化图
marked_img = rgb_orig.copy()
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
# 混合颜色显示
overlay = rgb_orig.astype(np.float32)
for lb in unique_labels:
m = (label_map == lb)
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
# 简单边缘处理
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
if SAVE_OBJECT_PNG:
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
if __name__ == "__main__":
main()