initial commit

This commit is contained in:
2026-04-07 20:55:30 +08:00
commit 81d1fb7856
84 changed files with 11929 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Literal, Tuple
import numpy as np
from PIL import Image
@dataclass
class Mask2FormerHFConfig:
"""
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
model_id: HuggingFace 模型 id默认 ADE20K semantic
device: "cuda" | "cpu"
"""
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
device: str = "cuda"
def build_mask2former_hf_predictor(
cfg: Mask2FormerHFConfig | None = None,
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
"""
返回 predictor(image_rgb_uint8) -> label_map(int32)。
"""
cfg = cfg or Mask2FormerHFConfig()
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
model.to(device).eval()
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
@torch.no_grad()
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
pil = Image.fromarray(image_rgb_u8, mode="RGB")
inputs = processor(images=pil, return_tensors="pt").to(device)
outputs = model(**inputs)
# post-process to original size
target_sizes = [(pil.height, pil.width)]
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
return seg.detach().to("cpu").numpy().astype("int32")
return _predict, cfg