initial commit
This commit is contained in:
59
python_server/model/Seg/mask2former_loader.py
Normal file
59
python_server/model/Seg/mask2former_loader.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mask2FormerHFConfig:
|
||||
"""
|
||||
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
|
||||
|
||||
model_id: HuggingFace 模型 id(默认 ADE20K semantic)
|
||||
device: "cuda" | "cpu"
|
||||
"""
|
||||
|
||||
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def build_mask2former_hf_predictor(
|
||||
cfg: Mask2FormerHFConfig | None = None,
|
||||
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
|
||||
"""
|
||||
返回 predictor(image_rgb_uint8) -> label_map(int32)。
|
||||
"""
|
||||
cfg = cfg or Mask2FormerHFConfig()
|
||||
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
||||
|
||||
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
|
||||
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
|
||||
model.to(device).eval()
|
||||
|
||||
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||
if image_rgb.dtype != np.uint8:
|
||||
image_rgb_u8 = image_rgb.astype("uint8")
|
||||
else:
|
||||
image_rgb_u8 = image_rgb
|
||||
|
||||
pil = Image.fromarray(image_rgb_u8, mode="RGB")
|
||||
inputs = processor(images=pil, return_tensors="pt").to(device)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# post-process to original size
|
||||
target_sizes = [(pil.height, pil.width)]
|
||||
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
|
||||
return seg.detach().to("cpu").numpy().astype("int32")
|
||||
|
||||
return _predict, cfg
|
||||
|
||||
Reference in New Issue
Block a user