414 lines
12 KiB
Python
414 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
"""
|
||
统一的补全(Inpaint)模型加载入口。
|
||
|
||
当前支持:
|
||
- SDXL Inpaint(diffusers AutoPipelineForInpainting)
|
||
- ControlNet(占位,暂未统一封装)
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from typing import Callable
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
from config_loader import (
|
||
load_app_config,
|
||
get_sdxl_base_model_from_app,
|
||
get_controlnet_base_model_from_app,
|
||
get_controlnet_model_from_app,
|
||
)
|
||
|
||
|
||
class InpaintBackend(str, Enum):
|
||
SDXL_INPAINT = "sdxl_inpaint"
|
||
CONTROLNET = "controlnet"
|
||
|
||
|
||
@dataclass
|
||
class UnifiedInpaintConfig:
|
||
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
|
||
device: str | None = None
|
||
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
|
||
sdxl_base_model: str | None = None
|
||
|
||
|
||
@dataclass
|
||
class UnifiedDrawConfig:
|
||
"""
|
||
统一绘图配置:
|
||
- 纯文生图:image=None
|
||
- 图生图(模仿输入图):image=某张参考图
|
||
"""
|
||
device: str | None = None
|
||
sdxl_base_model: str | None = None
|
||
|
||
|
||
def _resolve_device_and_dtype(device: str | None):
|
||
import torch
|
||
|
||
app_cfg = load_app_config()
|
||
if device is None:
|
||
device = app_cfg.inpaint.device
|
||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||
return device, torch_dtype
|
||
|
||
|
||
def _enable_memory_opts(pipe, device: str) -> None:
|
||
if device == "cuda":
|
||
try:
|
||
pipe.enable_attention_slicing()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_vae_slicing()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_vae_tiling()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_model_cpu_offload()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
|
||
run_w, run_h = orig_w, orig_h
|
||
if max_side > 0 and max(orig_w, orig_h) > max_side:
|
||
scale = max_side / float(max(orig_w, orig_h))
|
||
run_w = int(round(orig_w * scale))
|
||
run_h = int(round(orig_h * scale))
|
||
run_w = max(8, run_w - (run_w % 8))
|
||
run_h = max(8, run_h - (run_h % 8))
|
||
return run_w, run_h
|
||
|
||
|
||
def _make_sdxl_inpaint_predictor(
|
||
cfg: UnifiedInpaintConfig,
|
||
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
|
||
"""
|
||
返回补全函数:
|
||
- 输入:image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
|
||
- 输出:PIL RGB 结果图
|
||
"""
|
||
import torch
|
||
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
|
||
|
||
app_cfg = load_app_config()
|
||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||
|
||
device = cfg.device
|
||
if device is None:
|
||
device = app_cfg.inpaint.device
|
||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||
|
||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||
|
||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||
base_model,
|
||
torch_dtype=torch_dtype,
|
||
variant="fp16" if device == "cuda" else None,
|
||
use_safetensors=True,
|
||
).to(device)
|
||
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
|
||
|
||
# 省显存设置(尽量不改变输出语义)
|
||
# 注意:CPU offload 会明显变慢,但能显著降低显存占用。
|
||
if device == "cuda":
|
||
try:
|
||
pipe.enable_attention_slicing()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_vae_slicing()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_vae_tiling()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_model_cpu_offload()
|
||
except Exception:
|
||
pass
|
||
|
||
def _predict(
|
||
image: Image.Image,
|
||
mask: Image.Image,
|
||
prompt: str,
|
||
negative_prompt: str = "",
|
||
strength: float = 0.8,
|
||
guidance_scale: float = 7.5,
|
||
num_inference_steps: int = 30,
|
||
max_side: int = 1024,
|
||
) -> Image.Image:
|
||
image = image.convert("RGB")
|
||
# diffusers 要求 mask 为单通道,白色区域为需要重绘
|
||
mask = mask.convert("L")
|
||
|
||
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM,
|
||
# 推理时将图像按比例缩放到不超过 max_side(默认 1024)并对齐到 8 的倍数。
|
||
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
|
||
orig_w, orig_h = image.size
|
||
run_w, run_h = orig_w, orig_h
|
||
|
||
if max(orig_w, orig_h) > max_side:
|
||
scale = max_side / float(max(orig_w, orig_h))
|
||
run_w = int(round(orig_w * scale))
|
||
run_h = int(round(orig_h * scale))
|
||
|
||
run_w = max(8, run_w - (run_w % 8))
|
||
run_h = max(8, run_h - (run_h % 8))
|
||
if run_w <= 0:
|
||
run_w = 8
|
||
if run_h <= 0:
|
||
run_h = 8
|
||
|
||
if (run_w, run_h) != (orig_w, orig_h):
|
||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||
else:
|
||
image_run = image
|
||
mask_run = mask
|
||
|
||
if device == "cuda":
|
||
try:
|
||
torch.cuda.empty_cache()
|
||
except Exception:
|
||
pass
|
||
|
||
out = pipe(
|
||
prompt=prompt,
|
||
negative_prompt=negative_prompt,
|
||
image=image_run,
|
||
mask_image=mask_run,
|
||
strength=strength,
|
||
guidance_scale=guidance_scale,
|
||
num_inference_steps=num_inference_steps,
|
||
width=run_w,
|
||
height=run_h,
|
||
).images[0]
|
||
out = out.convert("RGB")
|
||
if out.size != (orig_w, orig_h):
|
||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||
return out
|
||
|
||
return _predict
|
||
|
||
|
||
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
|
||
import torch
|
||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||
|
||
app_cfg = load_app_config()
|
||
|
||
device = _.device
|
||
if device is None:
|
||
device = app_cfg.inpaint.device
|
||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||
|
||
base_model = get_controlnet_base_model_from_app(app_cfg)
|
||
controlnet_id = get_controlnet_model_from_app(app_cfg)
|
||
|
||
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
|
||
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||
base_model,
|
||
controlnet=controlnet,
|
||
torch_dtype=torch_dtype,
|
||
safety_checker=None,
|
||
)
|
||
|
||
if device == "cuda":
|
||
try:
|
||
pipe.enable_attention_slicing()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_vae_slicing()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_vae_tiling()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
pipe.enable_model_cpu_offload()
|
||
except Exception:
|
||
pipe.to(device)
|
||
else:
|
||
pipe.to(device)
|
||
|
||
def _predict(
|
||
image: Image.Image,
|
||
mask: Image.Image,
|
||
prompt: str,
|
||
negative_prompt: str = "",
|
||
strength: float = 0.8,
|
||
guidance_scale: float = 7.5,
|
||
num_inference_steps: int = 30,
|
||
controlnet_conditioning_scale: float = 1.0,
|
||
max_side: int = 768,
|
||
) -> Image.Image:
|
||
import cv2
|
||
import numpy as np
|
||
|
||
image = image.convert("RGB")
|
||
mask = mask.convert("L")
|
||
|
||
orig_w, orig_h = image.size
|
||
run_w, run_h = orig_w, orig_h
|
||
if max(orig_w, orig_h) > max_side:
|
||
scale = max_side / float(max(orig_w, orig_h))
|
||
run_w = int(round(orig_w * scale))
|
||
run_h = int(round(orig_h * scale))
|
||
run_w = max(8, run_w - (run_w % 8))
|
||
run_h = max(8, run_h - (run_h % 8))
|
||
|
||
if (run_w, run_h) != (orig_w, orig_h):
|
||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||
else:
|
||
image_run = image
|
||
mask_run = mask
|
||
|
||
# control image:使用 canny 边缘作为约束(最通用)
|
||
rgb = np.array(image_run, dtype=np.uint8)
|
||
edges = cv2.Canny(rgb, 100, 200)
|
||
edges3 = np.stack([edges, edges, edges], axis=-1)
|
||
control_image = Image.fromarray(edges3)
|
||
|
||
if device == "cuda":
|
||
try:
|
||
torch.cuda.empty_cache()
|
||
except Exception:
|
||
pass
|
||
|
||
out = pipe(
|
||
prompt=prompt,
|
||
negative_prompt=negative_prompt,
|
||
image=image_run,
|
||
mask_image=mask_run,
|
||
control_image=control_image,
|
||
strength=strength,
|
||
guidance_scale=guidance_scale,
|
||
num_inference_steps=num_inference_steps,
|
||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||
width=run_w,
|
||
height=run_h,
|
||
).images[0]
|
||
|
||
out = out.convert("RGB")
|
||
if out.size != (orig_w, orig_h):
|
||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||
return out
|
||
|
||
return _predict
|
||
|
||
|
||
def build_inpaint_predictor(
|
||
cfg: UnifiedInpaintConfig | None = None,
|
||
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
|
||
"""
|
||
统一构建补全预测函数。
|
||
"""
|
||
cfg = cfg or UnifiedInpaintConfig()
|
||
|
||
if cfg.backend == InpaintBackend.SDXL_INPAINT:
|
||
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
|
||
|
||
if cfg.backend == InpaintBackend.CONTROLNET:
|
||
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
|
||
|
||
raise ValueError(f"不支持的补全后端: {cfg.backend}")
|
||
|
||
|
||
def build_draw_predictor(
|
||
cfg: UnifiedDrawConfig | None = None,
|
||
) -> Callable[..., Image.Image]:
|
||
"""
|
||
构建统一绘图函数:
|
||
- 文生图:draw(prompt, image=None, ...)
|
||
- 图生图:draw(prompt, image=ref_image, strength=0.55, ...)
|
||
"""
|
||
import torch
|
||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||
|
||
cfg = cfg or UnifiedDrawConfig()
|
||
app_cfg = load_app_config()
|
||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
|
||
|
||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||
base_model,
|
||
torch_dtype=torch_dtype,
|
||
variant="fp16" if device == "cuda" else None,
|
||
use_safetensors=True,
|
||
).to(device)
|
||
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
|
||
|
||
_enable_memory_opts(pipe_t2i, device)
|
||
_enable_memory_opts(pipe_i2i, device)
|
||
|
||
def _draw(
|
||
prompt: str,
|
||
image: Image.Image | None = None,
|
||
negative_prompt: str = "",
|
||
strength: float = 0.55,
|
||
guidance_scale: float = 7.5,
|
||
num_inference_steps: int = 30,
|
||
width: int = 1024,
|
||
height: int = 1024,
|
||
max_side: int = 1024,
|
||
) -> Image.Image:
|
||
prompt = prompt or ""
|
||
negative_prompt = negative_prompt or ""
|
||
|
||
if device == "cuda":
|
||
try:
|
||
torch.cuda.empty_cache()
|
||
except Exception:
|
||
pass
|
||
|
||
if image is None:
|
||
run_w, run_h = _align_size(width, height, max_side=max_side)
|
||
out = pipe_t2i(
|
||
prompt=prompt,
|
||
negative_prompt=negative_prompt,
|
||
guidance_scale=guidance_scale,
|
||
num_inference_steps=num_inference_steps,
|
||
width=run_w,
|
||
height=run_h,
|
||
).images[0]
|
||
return out.convert("RGB")
|
||
|
||
image = image.convert("RGB")
|
||
orig_w, orig_h = image.size
|
||
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
|
||
if (run_w, run_h) != (orig_w, orig_h):
|
||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||
else:
|
||
image_run = image
|
||
|
||
out = pipe_i2i(
|
||
prompt=prompt,
|
||
negative_prompt=negative_prompt,
|
||
image=image_run,
|
||
strength=strength,
|
||
guidance_scale=guidance_scale,
|
||
num_inference_steps=num_inference_steps,
|
||
).images[0].convert("RGB")
|
||
|
||
if out.size != (orig_w, orig_h):
|
||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||
return out
|
||
|
||
return _draw
|
||
|