Files
hfut-bishe/python_server/test_inpaint.py
2026-04-07 20:55:30 +08:00

285 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from datetime import datetime
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from model.Inpaint.inpaint_loader import (
build_inpaint_predictor,
UnifiedInpaintConfig,
InpaintBackend,
build_draw_predictor,
UnifiedDrawConfig,
)
# -----------------------------
# 配置区(按需修改)
# -----------------------------
# 任务模式:
# - "inpaint": 补全
# - "draw": 绘图(文生图 / 图生图)
TASK_MODE = "draw"
# 指向你的输入图像,例如 image_0.png
INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png"
INPUT_MASK = "" # 为空表示不使用手工 mask
OUTPUT_DIR = "outputs/test_inpaint_v2"
MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK)
USE_AUTO_MASK = True # True: 自动从图像推断补全区域
AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时接近黑色像素阈值
AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖
FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask
# 透明输出控制
PRESERVE_TRANSPARENCY = True # 保持输出透明
EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明
NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景
INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果
# -----------------------------
# 关键 Prompt 修改:只生成树
# -----------------------------
# 细化 PROMPT强调树的种类、叶子和风格使其与原图融合
PROMPT = (
"A highly detailed traditional Chinese ink brush painting. "
"Restore and complete the existing trees naturally. "
"Extend the complex tree branches and dense, varied green and teal leaves. "
"Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. "
"Focus solely on vegetation. "
"Maintain the light beige background color and texture."
)
# 使用 NEGATIVE_PROMPT 显式排除不想要的内容
NEGATIVE_PROMPT = (
"buildings, architecture, houses, pavilions, temples, windows, doors, "
"people, figures, persons, characters, figures, clothing, faces, hands, "
"text, writing, characters, words, letters, signatures, seals, stamps, "
"calligraphy, objects, artifacts, boxes, baskets, tools, "
"extra branches crossing unnaturally, bad composition, watermark, signature"
)
STRENGTH = 0.8 # 保持较高强度以进行生成
GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt
NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节
MAX_SIDE = 1024
CONTROLNET_SCALE = 1.0
# -----------------------------
# 绘图draw参数
# -----------------------------
# DRAW_INPUT_IMAGE 为空时:文生图
# DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘)
DRAW_INPUT_IMAGE = ""
DRAW_PROMPT = """
Chinese ink wash painting, vast snowy river under a pale sky,
a small lonely boat at the horizon where water meets sky,
an old fisherman wearing a straw hat sits at the stern, fishing quietly,
gentle snowfall, misty atmosphere, distant mountains barely visible,
minimalist composition, large empty space, soft brush strokes,
calm, cold, and silent mood, poetic and serene
"""
DRAW_NEGATIVE_PROMPT = """
blurry, low quality, many people, bright colors, modern elements, crowded, noisy
"""
DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘
DRAW_GUIDANCE_SCALE = 9
DRAW_STEPS = 64
DRAW_WIDTH = 2560
DRAW_HEIGHT = 1440
DRAW_MAX_SIDE = 2560
def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray:
out = mask.astype(bool)
for _ in range(max(0, iterations)):
p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False)
out = (
p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:]
| p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:]
| p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:]
)
return out
def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image:
w, h = img_size
x1, y1, x2, y2 = MASK_RECT
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w - 1))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h - 1))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x1, y1, x2, y2], fill=255)
return mask
def _auto_mask_from_image(img_path: Path) -> Image.Image:
"""
自动推断缺失区域:
1) 若输入带 alpha透明区域作为 mask
2) 否则将“接近黑色”区域作为候选缺失区域
"""
raw = Image.open(img_path)
arr = np.asarray(raw)
if arr.ndim == 3 and arr.shape[2] == 4:
alpha = arr[:, :, 3]
mask_bool = alpha < 250
else:
rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8)
# 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失
dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1)
mask_bool = dark
mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER)
return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L")
def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray:
"""
生成输出 alpha
- 输入若有 alpha优先沿用
- 输入若无 alpha则把接近黑色区域视作透明背景
"""
raw = Image.open(img_path)
arr = np.asarray(raw)
if arr.ndim == 3 and arr.shape[2] == 4:
return arr[:, :, 3].astype(np.uint8)
rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8)
non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1)
return (non_black.astype(np.uint8) * 255)
def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image:
if INPUT_MASK:
raw_mask_path = Path(INPUT_MASK)
mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path)
if mask_path.is_file():
return Image.open(mask_path).convert("L")
if USE_AUTO_MASK:
auto = _auto_mask_from_image(img_path)
auto_arr = np.asarray(auto, dtype=np.uint8)
if (auto_arr > 0).any():
return auto
if not FALLBACK_TO_RECT_MASK:
raise ValueError("自动mask为空请检查输入图像是否存在透明/黑色缺失区域。")
return _make_rect_mask(img_rgb.size)
def _resolve_path(base_dir: Path, p: str) -> Path:
raw = Path(p)
return raw if raw.is_absolute() else (base_dir / raw)
def run_inpaint_test(base_dir: Path, out_dir: Path) -> None:
img_path = _resolve_path(base_dir, INPUT_IMAGE)
if not img_path.is_file():
raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}")
predictor, used_backend = build_inpaint_predictor(
UnifiedInpaintConfig(backend=INPAINT_BACKEND)
)
img = Image.open(img_path).convert("RGB")
mask = _load_or_make_mask(base_dir, img_path, img)
mask_out = out_dir / f"{img_path.stem}.mask_used.png"
mask.save(mask_out)
kwargs = dict(
strength=STRENGTH,
guidance_scale=GUIDANCE_SCALE,
num_inference_steps=NUM_INFERENCE_STEPS,
max_side=MAX_SIDE,
)
if used_backend == InpaintBackend.CONTROLNET:
kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE
out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs)
out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png"
if PRESERVE_TRANSPARENCY:
alpha = _build_alpha_from_input(img_path, img)
mask_u8 = np.asarray(mask, dtype=np.uint8)
if EXPAND_ALPHA_WITH_MASK:
alpha = np.maximum(alpha, mask_u8)
out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8)
out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1)
Image.fromarray(out_rgba, mode="RGBA").save(out_path)
else:
out.save(out_path)
ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1])
print(f"[Inpaint] backend={used_backend.value}, saved={out_path}")
print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}")
def run_draw_test(base_dir: Path, out_dir: Path) -> None:
"""
绘图测试:
- 文生图DRAW_INPUT_IMAGE=""
- 图生图DRAW_INPUT_IMAGE 指向参考图
"""
draw_predictor = build_draw_predictor(UnifiedDrawConfig())
ref_image: Image.Image | None = None
mode = "text2img"
if DRAW_INPUT_IMAGE:
ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE)
if not ref_path.is_file():
raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}")
ref_image = Image.open(ref_path).convert("RGB")
mode = "img2img"
out = draw_predictor(
prompt=DRAW_PROMPT,
image=ref_image,
negative_prompt=DRAW_NEGATIVE_PROMPT,
strength=DRAW_STRENGTH,
guidance_scale=DRAW_GUIDANCE_SCALE,
num_inference_steps=DRAW_STEPS,
width=DRAW_WIDTH,
height=DRAW_HEIGHT,
max_side=DRAW_MAX_SIDE,
)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"draw_{mode}_{ts}.png"
out.save(out_path)
print(f"[Draw] mode={mode}, saved={out_path}")
def main() -> None:
base_dir = Path(__file__).resolve().parent
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
if TASK_MODE == "inpaint":
run_inpaint_test(base_dir, out_dir)
return
if TASK_MODE == "draw":
run_draw_test(base_dir, out_dir)
return
raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw")
if __name__ == "__main__":
main()