initial commit
This commit is contained in:
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
统一的补全(Inpaint)模型加载入口。
|
||||
|
||||
当前支持:
|
||||
- SDXL Inpaint(diffusers AutoPipelineForInpainting)
|
||||
- ControlNet(占位,暂未统一封装)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
get_sdxl_base_model_from_app,
|
||||
get_controlnet_base_model_from_app,
|
||||
get_controlnet_model_from_app,
|
||||
)
|
||||
|
||||
|
||||
class InpaintBackend(str, Enum):
|
||||
SDXL_INPAINT = "sdxl_inpaint"
|
||||
CONTROLNET = "controlnet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedInpaintConfig:
|
||||
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
|
||||
device: str | None = None
|
||||
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
|
||||
sdxl_base_model: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedDrawConfig:
|
||||
"""
|
||||
统一绘图配置:
|
||||
- 纯文生图:image=None
|
||||
- 图生图(模仿输入图):image=某张参考图
|
||||
"""
|
||||
device: str | None = None
|
||||
sdxl_base_model: str | None = None
|
||||
|
||||
|
||||
def _resolve_device_and_dtype(device: str | None):
|
||||
import torch
|
||||
|
||||
app_cfg = load_app_config()
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
return device, torch_dtype
|
||||
|
||||
|
||||
def _enable_memory_opts(pipe, device: str) -> None:
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
|
||||
run_w, run_h = orig_w, orig_h
|
||||
if max_side > 0 and max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
return run_w, run_h
|
||||
|
||||
|
||||
def _make_sdxl_inpaint_predictor(
|
||||
cfg: UnifiedInpaintConfig,
|
||||
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
|
||||
"""
|
||||
返回补全函数:
|
||||
- 输入:image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
|
||||
- 输出:PIL RGB 结果图
|
||||
"""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
|
||||
|
||||
app_cfg = load_app_config()
|
||||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||
|
||||
device = cfg.device
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16" if device == "cuda" else None,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
|
||||
|
||||
# 省显存设置(尽量不改变输出语义)
|
||||
# 注意:CPU offload 会明显变慢,但能显著降低显存占用。
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _predict(
|
||||
image: Image.Image,
|
||||
mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.8,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
max_side: int = 1024,
|
||||
) -> Image.Image:
|
||||
image = image.convert("RGB")
|
||||
# diffusers 要求 mask 为单通道,白色区域为需要重绘
|
||||
mask = mask.convert("L")
|
||||
|
||||
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM,
|
||||
# 推理时将图像按比例缩放到不超过 max_side(默认 1024)并对齐到 8 的倍数。
|
||||
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = orig_w, orig_h
|
||||
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
if run_w <= 0:
|
||||
run_w = 8
|
||||
if run_h <= 0:
|
||||
run_h = 8
|
||||
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||
else:
|
||||
image_run = image
|
||||
mask_run = mask
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
mask_image=mask_run,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
out = out.convert("RGB")
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
|
||||
import torch
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
|
||||
app_cfg = load_app_config()
|
||||
|
||||
device = _.device
|
||||
if device is None:
|
||||
device = app_cfg.inpaint.device
|
||||
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
base_model = get_controlnet_base_model_from_app(app_cfg)
|
||||
controlnet_id = get_controlnet_model_from_app(app_cfg)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
|
||||
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_vae_tiling()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pipe.enable_model_cpu_offload()
|
||||
except Exception:
|
||||
pipe.to(device)
|
||||
else:
|
||||
pipe.to(device)
|
||||
|
||||
def _predict(
|
||||
image: Image.Image,
|
||||
mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.8,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
max_side: int = 768,
|
||||
) -> Image.Image:
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
image = image.convert("RGB")
|
||||
mask = mask.convert("L")
|
||||
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = orig_w, orig_h
|
||||
if max(orig_w, orig_h) > max_side:
|
||||
scale = max_side / float(max(orig_w, orig_h))
|
||||
run_w = int(round(orig_w * scale))
|
||||
run_h = int(round(orig_h * scale))
|
||||
run_w = max(8, run_w - (run_w % 8))
|
||||
run_h = max(8, run_h - (run_h % 8))
|
||||
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||
else:
|
||||
image_run = image
|
||||
mask_run = mask
|
||||
|
||||
# control image:使用 canny 边缘作为约束(最通用)
|
||||
rgb = np.array(image_run, dtype=np.uint8)
|
||||
edges = cv2.Canny(rgb, 100, 200)
|
||||
edges3 = np.stack([edges, edges, edges], axis=-1)
|
||||
control_image = Image.fromarray(edges3)
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
mask_image=mask_run,
|
||||
control_image=control_image,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
|
||||
out = out.convert("RGB")
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_inpaint_predictor(
|
||||
cfg: UnifiedInpaintConfig | None = None,
|
||||
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
|
||||
"""
|
||||
统一构建补全预测函数。
|
||||
"""
|
||||
cfg = cfg or UnifiedInpaintConfig()
|
||||
|
||||
if cfg.backend == InpaintBackend.SDXL_INPAINT:
|
||||
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
|
||||
|
||||
if cfg.backend == InpaintBackend.CONTROLNET:
|
||||
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
|
||||
|
||||
raise ValueError(f"不支持的补全后端: {cfg.backend}")
|
||||
|
||||
|
||||
def build_draw_predictor(
|
||||
cfg: UnifiedDrawConfig | None = None,
|
||||
) -> Callable[..., Image.Image]:
|
||||
"""
|
||||
构建统一绘图函数:
|
||||
- 文生图:draw(prompt, image=None, ...)
|
||||
- 图生图:draw(prompt, image=ref_image, strength=0.55, ...)
|
||||
"""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
|
||||
cfg = cfg or UnifiedDrawConfig()
|
||||
app_cfg = load_app_config()
|
||||
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
|
||||
|
||||
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
base_model,
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16" if device == "cuda" else None,
|
||||
use_safetensors=True,
|
||||
).to(device)
|
||||
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
|
||||
|
||||
_enable_memory_opts(pipe_t2i, device)
|
||||
_enable_memory_opts(pipe_i2i, device)
|
||||
|
||||
def _draw(
|
||||
prompt: str,
|
||||
image: Image.Image | None = None,
|
||||
negative_prompt: str = "",
|
||||
strength: float = 0.55,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
max_side: int = 1024,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or ""
|
||||
negative_prompt = negative_prompt or ""
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if image is None:
|
||||
run_w, run_h = _align_size(width, height, max_side=max_side)
|
||||
out = pipe_t2i(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
width=run_w,
|
||||
height=run_h,
|
||||
).images[0]
|
||||
return out.convert("RGB")
|
||||
|
||||
image = image.convert("RGB")
|
||||
orig_w, orig_h = image.size
|
||||
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
|
||||
if (run_w, run_h) != (orig_w, orig_h):
|
||||
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||
else:
|
||||
image_run = image
|
||||
|
||||
out = pipe_i2i(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image_run,
|
||||
strength=strength,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0].convert("RGB")
|
||||
|
||||
if out.size != (orig_w, orig_h):
|
||||
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||
return out
|
||||
|
||||
return _draw
|
||||
|
||||
Reference in New Issue
Block a user