Files
hfut-bishe/python_server/model/Seg/mask2former_loader.py
2026-04-07 20:55:30 +08:00

60 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Literal, Tuple
import numpy as np
from PIL import Image
@dataclass
class Mask2FormerHFConfig:
"""
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
model_id: HuggingFace 模型 id默认 ADE20K semantic
device: "cuda" | "cpu"
"""
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
device: str = "cuda"
def build_mask2former_hf_predictor(
cfg: Mask2FormerHFConfig | None = None,
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
"""
返回 predictor(image_rgb_uint8) -> label_map(int32)。
"""
cfg = cfg or Mask2FormerHFConfig()
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
model.to(device).eval()
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
@torch.no_grad()
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
pil = Image.fromarray(image_rgb_u8, mode="RGB")
inputs = processor(images=pil, return_tensors="pt").to(device)
outputs = model(**inputs)
# post-process to original size
target_sizes = [(pil.height, pil.width)]
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
return seg.detach().to("cpu").numpy().astype("int32")
return _predict, cfg