60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Callable, Literal, Tuple
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
|
||
@dataclass
|
||
class Mask2FormerHFConfig:
|
||
"""
|
||
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
|
||
|
||
model_id: HuggingFace 模型 id(默认 ADE20K semantic)
|
||
device: "cuda" | "cpu"
|
||
"""
|
||
|
||
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
|
||
device: str = "cuda"
|
||
|
||
|
||
def build_mask2former_hf_predictor(
|
||
cfg: Mask2FormerHFConfig | None = None,
|
||
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
|
||
"""
|
||
返回 predictor(image_rgb_uint8) -> label_map(int32)。
|
||
"""
|
||
cfg = cfg or Mask2FormerHFConfig()
|
||
|
||
import torch
|
||
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
||
|
||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||
|
||
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
|
||
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
|
||
model.to(device).eval()
|
||
|
||
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
|
||
|
||
@torch.no_grad()
|
||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||
if image_rgb.dtype != np.uint8:
|
||
image_rgb_u8 = image_rgb.astype("uint8")
|
||
else:
|
||
image_rgb_u8 = image_rgb
|
||
|
||
pil = Image.fromarray(image_rgb_u8, mode="RGB")
|
||
inputs = processor(images=pil, return_tensors="pt").to(device)
|
||
outputs = model(**inputs)
|
||
|
||
# post-process to original size
|
||
target_sizes = [(pil.height, pil.width)]
|
||
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
|
||
return seg.detach().to("cpu").numpy().astype("int32")
|
||
|
||
return _predict, cfg
|
||
|