initial commit
This commit is contained in:
168
python_server/model/Seg/seg_loader.py
Normal file
168
python_server/model/Seg/seg_loader.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的分割模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- SAM (segment-anything)
|
||||
- Mask2Former(使用 HuggingFace transformers 的语义分割实现)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class SegBackend(str, Enum):
|
||||
SAM = "sam"
|
||||
MASK2FORMER = "mask2former"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedSegConfig:
|
||||
backend: SegBackend = SegBackend.SAM
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# SAM (Segment Anything)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _ensure_sam_on_path() -> Path:
|
||||
sam_root = _THIS_DIR / "segment-anything"
|
||||
if not sam_root.is_dir():
|
||||
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
|
||||
sam_path = str(sam_root)
|
||||
if sam_path not in sys.path:
|
||||
sys.path.insert(0, sam_path)
|
||||
return sam_root
|
||||
|
||||
|
||||
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
|
||||
import requests
|
||||
|
||||
ckpt_dir = sam_root / "checkpoints"
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
|
||||
|
||||
if ckpt_path.is_file():
|
||||
return ckpt_path
|
||||
|
||||
url = (
|
||||
"https://dl.fbaipublicfiles.com/segment_anything/"
|
||||
"sam_vit_h_4b8939.pth"
|
||||
)
|
||||
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
|
||||
total = int(resp.headers.get("content-length", "0") or "0")
|
||||
downloaded = 0
|
||||
chunk_size = 1024 * 1024
|
||||
|
||||
with ckpt_path.open("wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total > 0:
|
||||
done = int(50 * downloaded / total)
|
||||
print(
|
||||
"\r[{}{}] {:.1f}%".format(
|
||||
"#" * done,
|
||||
"." * (50 - done),
|
||||
downloaded * 100 / total,
|
||||
),
|
||||
end="",
|
||||
)
|
||||
print("\nSAM 权重下载完成。")
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
返回一个分割函数:
|
||||
- 输入:RGB uint8 图像 (H, W, 3)
|
||||
- 输出:语义标签图 (H, W),每个目标一个 int id(从 1 开始)
|
||||
"""
|
||||
sam_root = _ensure_sam_on_path()
|
||||
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||
|
||||
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
sam = sam_model_registry["vit_h"](
|
||||
checkpoint=str(ckpt_path),
|
||||
).to(device)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb_u8 = image_rgb.astype("uint8")
|
||||
else:
|
||||
image_rgb_u8 = image_rgb
|
||||
|
||||
masks = mask_generator.generate(image_rgb_u8)
|
||||
h, w, _ = image_rgb_u8.shape
|
||||
label_map = np.zeros((h, w), dtype="int32")
|
||||
|
||||
for idx, m in enumerate(masks, start=1):
|
||||
seg = m.get("segmentation")
|
||||
if seg is None:
|
||||
continue
|
||||
label_map[seg.astype(bool)] = idx
|
||||
|
||||
return label_map
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Mask2Former (占位)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||
from .mask2former_loader import build_mask2former_hf_predictor
|
||||
|
||||
predictor, _ = build_mask2former_hf_predictor()
|
||||
return predictor
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 统一构建函数
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def build_seg_predictor(
|
||||
cfg: UnifiedSegConfig | None = None,
|
||||
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
|
||||
"""
|
||||
统一构建分割预测函数。
|
||||
|
||||
返回:
|
||||
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
|
||||
- 实际使用的 backend
|
||||
"""
|
||||
cfg = cfg or UnifiedSegConfig()
|
||||
|
||||
if cfg.backend == SegBackend.SAM:
|
||||
return _make_sam_predictor(), SegBackend.SAM
|
||||
|
||||
if cfg.backend == SegBackend.MASK2FORMER:
|
||||
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
|
||||
|
||||
raise ValueError(f"不支持的分割后端: {cfg.backend}")
|
||||
|
||||
Reference in New Issue
Block a user