from __future__ import annotations from datetime import datetime from pathlib import Path import numpy as np from PIL import Image, ImageDraw from model.Inpaint.inpaint_loader import ( build_inpaint_predictor, UnifiedInpaintConfig, InpaintBackend, build_draw_predictor, UnifiedDrawConfig, ) # ----------------------------- # 配置区(按需修改) # ----------------------------- # 任务模式: # - "inpaint": 补全 # - "draw": 绘图(文生图 / 图生图) TASK_MODE = "draw" # 指向你的输入图像,例如 image_0.png INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png" INPUT_MASK = "" # 为空表示不使用手工 mask OUTPUT_DIR = "outputs/test_inpaint_v2" MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK) USE_AUTO_MASK = True # True: 自动从图像推断补全区域 AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时,接近黑色像素阈值 AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖 FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask # 透明输出控制 PRESERVE_TRANSPARENCY = True # 保持输出透明 EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明 NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景 INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果 # ----------------------------- # 关键 Prompt 修改:只生成树 # ----------------------------- # 细化 PROMPT,强调树的种类、叶子和风格,使其与原图融合 PROMPT = ( "A highly detailed traditional Chinese ink brush painting. " "Restore and complete the existing trees naturally. " "Extend the complex tree branches and dense, varied green and teal leaves. " "Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. " "Focus solely on vegetation. " "Maintain the light beige background color and texture." ) # 使用 NEGATIVE_PROMPT 显式排除不想要的内容 NEGATIVE_PROMPT = ( "buildings, architecture, houses, pavilions, temples, windows, doors, " "people, figures, persons, characters, figures, clothing, faces, hands, " "text, writing, characters, words, letters, signatures, seals, stamps, " "calligraphy, objects, artifacts, boxes, baskets, tools, " "extra branches crossing unnaturally, bad composition, watermark, signature" ) STRENGTH = 0.8 # 保持较高强度以进行生成 GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节 MAX_SIDE = 1024 CONTROLNET_SCALE = 1.0 # ----------------------------- # 绘图(draw)参数 # ----------------------------- # DRAW_INPUT_IMAGE 为空时:文生图 # DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘) DRAW_INPUT_IMAGE = "" DRAW_PROMPT = """ Chinese ink wash painting, vast snowy river under a pale sky, a small lonely boat at the horizon where water meets sky, an old fisherman wearing a straw hat sits at the stern, fishing quietly, gentle snowfall, misty atmosphere, distant mountains barely visible, minimalist composition, large empty space, soft brush strokes, calm, cold, and silent mood, poetic and serene """ DRAW_NEGATIVE_PROMPT = """ blurry, low quality, many people, bright colors, modern elements, crowded, noisy """ DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘 DRAW_GUIDANCE_SCALE = 9 DRAW_STEPS = 64 DRAW_WIDTH = 2560 DRAW_HEIGHT = 1440 DRAW_MAX_SIDE = 2560 def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray: out = mask.astype(bool) for _ in range(max(0, iterations)): p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False) out = ( p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:] | p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:] | p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:] ) return out def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image: w, h = img_size x1, y1, x2, y2 = MASK_RECT x1 = max(0, min(x1, w - 1)) x2 = max(0, min(x2, w - 1)) y1 = max(0, min(y1, h - 1)) y2 = max(0, min(y2, h - 1)) if x2 < x1: x1, x2 = x2, x1 if y2 < y1: y1, y2 = y2, y1 mask = Image.new("L", (w, h), 0) draw = ImageDraw.Draw(mask) draw.rectangle([x1, y1, x2, y2], fill=255) return mask def _auto_mask_from_image(img_path: Path) -> Image.Image: """ 自动推断缺失区域: 1) 若输入带 alpha,透明区域作为 mask 2) 否则将“接近黑色”区域作为候选缺失区域 """ raw = Image.open(img_path) arr = np.asarray(raw) if arr.ndim == 3 and arr.shape[2] == 4: alpha = arr[:, :, 3] mask_bool = alpha < 250 else: rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8) # 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失 dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1) mask_bool = dark mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER) return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L") def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray: """ 生成输出 alpha: - 输入若有 alpha,优先沿用 - 输入若无 alpha,则把接近黑色区域视作透明背景 """ raw = Image.open(img_path) arr = np.asarray(raw) if arr.ndim == 3 and arr.shape[2] == 4: return arr[:, :, 3].astype(np.uint8) rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8) non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1) return (non_black.astype(np.uint8) * 255) def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image: if INPUT_MASK: raw_mask_path = Path(INPUT_MASK) mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path) if mask_path.is_file(): return Image.open(mask_path).convert("L") if USE_AUTO_MASK: auto = _auto_mask_from_image(img_path) auto_arr = np.asarray(auto, dtype=np.uint8) if (auto_arr > 0).any(): return auto if not FALLBACK_TO_RECT_MASK: raise ValueError("自动mask为空,请检查输入图像是否存在透明/黑色缺失区域。") return _make_rect_mask(img_rgb.size) def _resolve_path(base_dir: Path, p: str) -> Path: raw = Path(p) return raw if raw.is_absolute() else (base_dir / raw) def run_inpaint_test(base_dir: Path, out_dir: Path) -> None: img_path = _resolve_path(base_dir, INPUT_IMAGE) if not img_path.is_file(): raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}") predictor, used_backend = build_inpaint_predictor( UnifiedInpaintConfig(backend=INPAINT_BACKEND) ) img = Image.open(img_path).convert("RGB") mask = _load_or_make_mask(base_dir, img_path, img) mask_out = out_dir / f"{img_path.stem}.mask_used.png" mask.save(mask_out) kwargs = dict( strength=STRENGTH, guidance_scale=GUIDANCE_SCALE, num_inference_steps=NUM_INFERENCE_STEPS, max_side=MAX_SIDE, ) if used_backend == InpaintBackend.CONTROLNET: kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs) out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png" if PRESERVE_TRANSPARENCY: alpha = _build_alpha_from_input(img_path, img) mask_u8 = np.asarray(mask, dtype=np.uint8) if EXPAND_ALPHA_WITH_MASK: alpha = np.maximum(alpha, mask_u8) out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8) out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1) Image.fromarray(out_rgba, mode="RGBA").save(out_path) else: out.save(out_path) ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1]) print(f"[Inpaint] backend={used_backend.value}, saved={out_path}") print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}") def run_draw_test(base_dir: Path, out_dir: Path) -> None: """ 绘图测试: - 文生图:DRAW_INPUT_IMAGE="" - 图生图:DRAW_INPUT_IMAGE 指向参考图 """ draw_predictor = build_draw_predictor(UnifiedDrawConfig()) ref_image: Image.Image | None = None mode = "text2img" if DRAW_INPUT_IMAGE: ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE) if not ref_path.is_file(): raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}") ref_image = Image.open(ref_path).convert("RGB") mode = "img2img" out = draw_predictor( prompt=DRAW_PROMPT, image=ref_image, negative_prompt=DRAW_NEGATIVE_PROMPT, strength=DRAW_STRENGTH, guidance_scale=DRAW_GUIDANCE_SCALE, num_inference_steps=DRAW_STEPS, width=DRAW_WIDTH, height=DRAW_HEIGHT, max_side=DRAW_MAX_SIDE, ) ts = datetime.now().strftime("%Y%m%d_%H%M%S") out_path = out_dir / f"draw_{mode}_{ts}.png" out.save(out_path) print(f"[Draw] mode={mode}, saved={out_path}") def main() -> None: base_dir = Path(__file__).resolve().parent out_dir = base_dir / OUTPUT_DIR out_dir.mkdir(parents=True, exist_ok=True) if TASK_MODE == "inpaint": run_inpaint_test(base_dir, out_dir) return if TASK_MODE == "draw": run_draw_test(base_dir, out_dir) return raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw)") if __name__ == "__main__": main()