Files
hfut-bishe/python_server/model/Inpaint/inpaint_loader.py
2026-04-07 20:55:30 +08:00

414 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
"""
统一的补全Inpaint模型加载入口。
当前支持:
- SDXL Inpaintdiffusers AutoPipelineForInpainting
- ControlNet占位暂未统一封装
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
get_sdxl_base_model_from_app,
get_controlnet_base_model_from_app,
get_controlnet_model_from_app,
)
class InpaintBackend(str, Enum):
SDXL_INPAINT = "sdxl_inpaint"
CONTROLNET = "controlnet"
@dataclass
class UnifiedInpaintConfig:
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
device: str | None = None
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
sdxl_base_model: str | None = None
@dataclass
class UnifiedDrawConfig:
"""
统一绘图配置:
- 纯文生图image=None
- 图生图模仿输入图image=某张参考图
"""
device: str | None = None
sdxl_base_model: str | None = None
def _resolve_device_and_dtype(device: str | None):
import torch
app_cfg = load_app_config()
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
return device, torch_dtype
def _enable_memory_opts(pipe, device: str) -> None:
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
run_w, run_h = orig_w, orig_h
if max_side > 0 and max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
return run_w, run_h
def _make_sdxl_inpaint_predictor(
cfg: UnifiedInpaintConfig,
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
"""
返回补全函数:
- 输入image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
- 输出PIL RGB 结果图
"""
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
app_cfg = load_app_config()
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
device = cfg.device
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
base_model,
torch_dtype=torch_dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
).to(device)
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
# 省显存设置(尽量不改变输出语义)
# 注意CPU offload 会明显变慢,但能显著降低显存占用。
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
def _predict(
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.8,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
max_side: int = 1024,
) -> Image.Image:
image = image.convert("RGB")
# diffusers 要求 mask 为单通道,白色区域为需要重绘
mask = mask.convert("L")
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM
# 推理时将图像按比例缩放到不超过 max_side默认 1024并对齐到 8 的倍数。
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
orig_w, orig_h = image.size
run_w, run_h = orig_w, orig_h
if max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
if run_w <= 0:
run_w = 8
if run_h <= 0:
run_h = 8
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
else:
image_run = image
mask_run = mask
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
mask_image=mask_run,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=run_w,
height=run_h,
).images[0]
out = out.convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _predict
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
app_cfg = load_app_config()
device = _.device
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
base_model = get_controlnet_base_model_from_app(app_cfg)
controlnet_id = get_controlnet_model_from_app(app_cfg)
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None,
)
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pipe.to(device)
else:
pipe.to(device)
def _predict(
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.8,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
controlnet_conditioning_scale: float = 1.0,
max_side: int = 768,
) -> Image.Image:
import cv2
import numpy as np
image = image.convert("RGB")
mask = mask.convert("L")
orig_w, orig_h = image.size
run_w, run_h = orig_w, orig_h
if max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
else:
image_run = image
mask_run = mask
# control image使用 canny 边缘作为约束(最通用)
rgb = np.array(image_run, dtype=np.uint8)
edges = cv2.Canny(rgb, 100, 200)
edges3 = np.stack([edges, edges, edges], axis=-1)
control_image = Image.fromarray(edges3)
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
mask_image=mask_run,
control_image=control_image,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
width=run_w,
height=run_h,
).images[0]
out = out.convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _predict
def build_inpaint_predictor(
cfg: UnifiedInpaintConfig | None = None,
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
"""
统一构建补全预测函数。
"""
cfg = cfg or UnifiedInpaintConfig()
if cfg.backend == InpaintBackend.SDXL_INPAINT:
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
if cfg.backend == InpaintBackend.CONTROLNET:
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
raise ValueError(f"不支持的补全后端: {cfg.backend}")
def build_draw_predictor(
cfg: UnifiedDrawConfig | None = None,
) -> Callable[..., Image.Image]:
"""
构建统一绘图函数:
- 文生图draw(prompt, image=None, ...)
- 图生图draw(prompt, image=ref_image, strength=0.55, ...)
"""
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
cfg = cfg or UnifiedDrawConfig()
app_cfg = load_app_config()
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
base_model,
torch_dtype=torch_dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
).to(device)
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
_enable_memory_opts(pipe_t2i, device)
_enable_memory_opts(pipe_i2i, device)
def _draw(
prompt: str,
image: Image.Image | None = None,
negative_prompt: str = "",
strength: float = 0.55,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
width: int = 1024,
height: int = 1024,
max_side: int = 1024,
) -> Image.Image:
prompt = prompt or ""
negative_prompt = negative_prompt or ""
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
if image is None:
run_w, run_h = _align_size(width, height, max_side=max_side)
out = pipe_t2i(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=run_w,
height=run_h,
).images[0]
return out.convert("RGB")
image = image.convert("RGB")
orig_w, orig_h = image.size
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
else:
image_run = image
out = pipe_i2i(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
).images[0].convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _draw