from __future__ import annotations """ 统一的补全(Inpaint)模型加载入口。 当前支持: - SDXL Inpaint(diffusers AutoPipelineForInpainting) - ControlNet(占位,暂未统一封装) """ from dataclasses import dataclass from enum import Enum from typing import Callable import numpy as np from PIL import Image from config_loader import ( load_app_config, get_sdxl_base_model_from_app, get_controlnet_base_model_from_app, get_controlnet_model_from_app, ) class InpaintBackend(str, Enum): SDXL_INPAINT = "sdxl_inpaint" CONTROLNET = "controlnet" @dataclass class UnifiedInpaintConfig: backend: InpaintBackend = InpaintBackend.SDXL_INPAINT device: str | None = None # SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值 sdxl_base_model: str | None = None @dataclass class UnifiedDrawConfig: """ 统一绘图配置: - 纯文生图:image=None - 图生图(模仿输入图):image=某张参考图 """ device: str | None = None sdxl_base_model: str | None = None def _resolve_device_and_dtype(device: str | None): import torch app_cfg = load_app_config() if device is None: device = app_cfg.inpaint.device device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 return device, torch_dtype def _enable_memory_opts(pipe, device: str) -> None: if device == "cuda": try: pipe.enable_attention_slicing() except Exception: pass try: pipe.enable_vae_slicing() except Exception: pass try: pipe.enable_vae_tiling() except Exception: pass try: pipe.enable_model_cpu_offload() except Exception: pass def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]: run_w, run_h = orig_w, orig_h if max_side > 0 and max(orig_w, orig_h) > max_side: scale = max_side / float(max(orig_w, orig_h)) run_w = int(round(orig_w * scale)) run_h = int(round(orig_h * scale)) run_w = max(8, run_w - (run_w % 8)) run_h = max(8, run_h - (run_h % 8)) return run_w, run_h def _make_sdxl_inpaint_predictor( cfg: UnifiedInpaintConfig, ) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]: """ 返回补全函数: - 输入:image(PIL RGB), mask(PIL L/1), prompt, negative_prompt - 输出:PIL RGB 结果图 """ import torch from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting app_cfg = load_app_config() base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg) device = cfg.device if device is None: device = app_cfg.inpaint.device device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 pipe_t2i = AutoPipelineForText2Image.from_pretrained( base_model, torch_dtype=torch_dtype, variant="fp16" if device == "cuda" else None, use_safetensors=True, ).to(device) pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device) # 省显存设置(尽量不改变输出语义) # 注意:CPU offload 会明显变慢,但能显著降低显存占用。 if device == "cuda": try: pipe.enable_attention_slicing() except Exception: pass try: pipe.enable_vae_slicing() except Exception: pass try: pipe.enable_vae_tiling() except Exception: pass try: pipe.enable_model_cpu_offload() except Exception: pass def _predict( image: Image.Image, mask: Image.Image, prompt: str, negative_prompt: str = "", strength: float = 0.8, guidance_scale: float = 7.5, num_inference_steps: int = 30, max_side: int = 1024, ) -> Image.Image: image = image.convert("RGB") # diffusers 要求 mask 为单通道,白色区域为需要重绘 mask = mask.convert("L") # SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM, # 推理时将图像按比例缩放到不超过 max_side(默认 1024)并对齐到 8 的倍数。 # 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。 orig_w, orig_h = image.size run_w, run_h = orig_w, orig_h if max(orig_w, orig_h) > max_side: scale = max_side / float(max(orig_w, orig_h)) run_w = int(round(orig_w * scale)) run_h = int(round(orig_h * scale)) run_w = max(8, run_w - (run_w % 8)) run_h = max(8, run_h - (run_h % 8)) if run_w <= 0: run_w = 8 if run_h <= 0: run_h = 8 if (run_w, run_h) != (orig_w, orig_h): image_run = image.resize((run_w, run_h), resample=Image.BICUBIC) mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST) else: image_run = image mask_run = mask if device == "cuda": try: torch.cuda.empty_cache() except Exception: pass out = pipe( prompt=prompt, negative_prompt=negative_prompt, image=image_run, mask_image=mask_run, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=run_w, height=run_h, ).images[0] out = out.convert("RGB") if out.size != (orig_w, orig_h): out = out.resize((orig_w, orig_h), resample=Image.BICUBIC) return out return _predict def _make_controlnet_predictor(_: UnifiedInpaintConfig): import torch from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline app_cfg = load_app_config() device = _.device if device is None: device = app_cfg.inpaint.device device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 base_model = get_controlnet_base_model_from_app(app_cfg) controlnet_id = get_controlnet_model_from_app(app_cfg) controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype) pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( base_model, controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None, ) if device == "cuda": try: pipe.enable_attention_slicing() except Exception: pass try: pipe.enable_vae_slicing() except Exception: pass try: pipe.enable_vae_tiling() except Exception: pass try: pipe.enable_model_cpu_offload() except Exception: pipe.to(device) else: pipe.to(device) def _predict( image: Image.Image, mask: Image.Image, prompt: str, negative_prompt: str = "", strength: float = 0.8, guidance_scale: float = 7.5, num_inference_steps: int = 30, controlnet_conditioning_scale: float = 1.0, max_side: int = 768, ) -> Image.Image: import cv2 import numpy as np image = image.convert("RGB") mask = mask.convert("L") orig_w, orig_h = image.size run_w, run_h = orig_w, orig_h if max(orig_w, orig_h) > max_side: scale = max_side / float(max(orig_w, orig_h)) run_w = int(round(orig_w * scale)) run_h = int(round(orig_h * scale)) run_w = max(8, run_w - (run_w % 8)) run_h = max(8, run_h - (run_h % 8)) if (run_w, run_h) != (orig_w, orig_h): image_run = image.resize((run_w, run_h), resample=Image.BICUBIC) mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST) else: image_run = image mask_run = mask # control image:使用 canny 边缘作为约束(最通用) rgb = np.array(image_run, dtype=np.uint8) edges = cv2.Canny(rgb, 100, 200) edges3 = np.stack([edges, edges, edges], axis=-1) control_image = Image.fromarray(edges3) if device == "cuda": try: torch.cuda.empty_cache() except Exception: pass out = pipe( prompt=prompt, negative_prompt=negative_prompt, image=image_run, mask_image=mask_run, control_image=control_image, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, controlnet_conditioning_scale=controlnet_conditioning_scale, width=run_w, height=run_h, ).images[0] out = out.convert("RGB") if out.size != (orig_w, orig_h): out = out.resize((orig_w, orig_h), resample=Image.BICUBIC) return out return _predict def build_inpaint_predictor( cfg: UnifiedInpaintConfig | None = None, ) -> tuple[Callable[..., Image.Image], InpaintBackend]: """ 统一构建补全预测函数。 """ cfg = cfg or UnifiedInpaintConfig() if cfg.backend == InpaintBackend.SDXL_INPAINT: return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT if cfg.backend == InpaintBackend.CONTROLNET: return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET raise ValueError(f"不支持的补全后端: {cfg.backend}") def build_draw_predictor( cfg: UnifiedDrawConfig | None = None, ) -> Callable[..., Image.Image]: """ 构建统一绘图函数: - 文生图:draw(prompt, image=None, ...) - 图生图:draw(prompt, image=ref_image, strength=0.55, ...) """ import torch from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image cfg = cfg or UnifiedDrawConfig() app_cfg = load_app_config() base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg) device, torch_dtype = _resolve_device_and_dtype(cfg.device) pipe_t2i = AutoPipelineForText2Image.from_pretrained( base_model, torch_dtype=torch_dtype, variant="fp16" if device == "cuda" else None, use_safetensors=True, ).to(device) pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device) _enable_memory_opts(pipe_t2i, device) _enable_memory_opts(pipe_i2i, device) def _draw( prompt: str, image: Image.Image | None = None, negative_prompt: str = "", strength: float = 0.55, guidance_scale: float = 7.5, num_inference_steps: int = 30, width: int = 1024, height: int = 1024, max_side: int = 1024, ) -> Image.Image: prompt = prompt or "" negative_prompt = negative_prompt or "" if device == "cuda": try: torch.cuda.empty_cache() except Exception: pass if image is None: run_w, run_h = _align_size(width, height, max_side=max_side) out = pipe_t2i( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=run_w, height=run_h, ).images[0] return out.convert("RGB") image = image.convert("RGB") orig_w, orig_h = image.size run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side) if (run_w, run_h) != (orig_w, orig_h): image_run = image.resize((run_w, run_h), resample=Image.BICUBIC) else: image_run = image out = pipe_i2i( prompt=prompt, negative_prompt=negative_prompt, image=image_run, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, ).images[0].convert("RGB") if out.size != (orig_w, orig_h): out = out.resize((orig_w, orig_h), resample=Image.BICUBIC) return out return _draw