initial commit
This commit is contained in:
284
python_server/test_inpaint.py
Normal file
284
python_server/test_inpaint.py
Normal file
@@ -0,0 +1,284 @@
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from model.Inpaint.inpaint_loader import (
|
||||
build_inpaint_predictor,
|
||||
UnifiedInpaintConfig,
|
||||
InpaintBackend,
|
||||
build_draw_predictor,
|
||||
UnifiedDrawConfig,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# 配置区(按需修改)
|
||||
# -----------------------------
|
||||
# 任务模式:
|
||||
# - "inpaint": 补全
|
||||
# - "draw": 绘图(文生图 / 图生图)
|
||||
TASK_MODE = "draw"
|
||||
|
||||
# 指向你的输入图像,例如 image_0.png
|
||||
INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png"
|
||||
INPUT_MASK = "" # 为空表示不使用手工 mask
|
||||
OUTPUT_DIR = "outputs/test_inpaint_v2"
|
||||
MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK)
|
||||
USE_AUTO_MASK = True # True: 自动从图像推断补全区域
|
||||
AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时,接近黑色像素阈值
|
||||
AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖
|
||||
FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask
|
||||
|
||||
# 透明输出控制
|
||||
PRESERVE_TRANSPARENCY = True # 保持输出透明
|
||||
EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明
|
||||
NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景
|
||||
|
||||
INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果
|
||||
|
||||
# -----------------------------
|
||||
# 关键 Prompt 修改:只生成树
|
||||
# -----------------------------
|
||||
# 细化 PROMPT,强调树的种类、叶子和风格,使其与原图融合
|
||||
PROMPT = (
|
||||
"A highly detailed traditional Chinese ink brush painting. "
|
||||
"Restore and complete the existing trees naturally. "
|
||||
"Extend the complex tree branches and dense, varied green and teal leaves. "
|
||||
"Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. "
|
||||
"Focus solely on vegetation. "
|
||||
"Maintain the light beige background color and texture."
|
||||
)
|
||||
|
||||
# 使用 NEGATIVE_PROMPT 显式排除不想要的内容
|
||||
NEGATIVE_PROMPT = (
|
||||
"buildings, architecture, houses, pavilions, temples, windows, doors, "
|
||||
"people, figures, persons, characters, figures, clothing, faces, hands, "
|
||||
"text, writing, characters, words, letters, signatures, seals, stamps, "
|
||||
"calligraphy, objects, artifacts, boxes, baskets, tools, "
|
||||
"extra branches crossing unnaturally, bad composition, watermark, signature"
|
||||
)
|
||||
|
||||
STRENGTH = 0.8 # 保持较高强度以进行生成
|
||||
GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt
|
||||
NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节
|
||||
MAX_SIDE = 1024
|
||||
CONTROLNET_SCALE = 1.0
|
||||
|
||||
# -----------------------------
|
||||
# 绘图(draw)参数
|
||||
# -----------------------------
|
||||
# DRAW_INPUT_IMAGE 为空时:文生图
|
||||
# DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘)
|
||||
DRAW_INPUT_IMAGE = ""
|
||||
DRAW_PROMPT = """
|
||||
Chinese ink wash painting, vast snowy river under a pale sky,
|
||||
a small lonely boat at the horizon where water meets sky,
|
||||
an old fisherman wearing a straw hat sits at the stern, fishing quietly,
|
||||
gentle snowfall, misty atmosphere, distant mountains barely visible,
|
||||
minimalist composition, large empty space, soft brush strokes,
|
||||
calm, cold, and silent mood, poetic and serene
|
||||
"""
|
||||
|
||||
DRAW_NEGATIVE_PROMPT = """
|
||||
blurry, low quality, many people, bright colors, modern elements, crowded, noisy
|
||||
"""
|
||||
DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘
|
||||
DRAW_GUIDANCE_SCALE = 9
|
||||
DRAW_STEPS = 64
|
||||
DRAW_WIDTH = 2560
|
||||
DRAW_HEIGHT = 1440
|
||||
DRAW_MAX_SIDE = 2560
|
||||
|
||||
|
||||
def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray:
|
||||
out = mask.astype(bool)
|
||||
for _ in range(max(0, iterations)):
|
||||
p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False)
|
||||
out = (
|
||||
p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:]
|
||||
| p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:]
|
||||
| p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:]
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image:
|
||||
w, h = img_size
|
||||
x1, y1, x2, y2 = MASK_RECT
|
||||
x1 = max(0, min(x1, w - 1))
|
||||
x2 = max(0, min(x2, w - 1))
|
||||
y1 = max(0, min(y1, h - 1))
|
||||
y2 = max(0, min(y2, h - 1))
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
mask = Image.new("L", (w, h), 0)
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle([x1, y1, x2, y2], fill=255)
|
||||
return mask
|
||||
|
||||
|
||||
def _auto_mask_from_image(img_path: Path) -> Image.Image:
|
||||
"""
|
||||
自动推断缺失区域:
|
||||
1) 若输入带 alpha,透明区域作为 mask
|
||||
2) 否则将“接近黑色”区域作为候选缺失区域
|
||||
"""
|
||||
raw = Image.open(img_path)
|
||||
arr = np.asarray(raw)
|
||||
|
||||
if arr.ndim == 3 and arr.shape[2] == 4:
|
||||
alpha = arr[:, :, 3]
|
||||
mask_bool = alpha < 250
|
||||
else:
|
||||
rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8)
|
||||
# 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失
|
||||
dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1)
|
||||
mask_bool = dark
|
||||
|
||||
mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER)
|
||||
return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L")
|
||||
|
||||
|
||||
def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
生成输出 alpha:
|
||||
- 输入若有 alpha,优先沿用
|
||||
- 输入若无 alpha,则把接近黑色区域视作透明背景
|
||||
"""
|
||||
raw = Image.open(img_path)
|
||||
arr = np.asarray(raw)
|
||||
if arr.ndim == 3 and arr.shape[2] == 4:
|
||||
return arr[:, :, 3].astype(np.uint8)
|
||||
|
||||
rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8)
|
||||
non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1)
|
||||
return (non_black.astype(np.uint8) * 255)
|
||||
|
||||
|
||||
def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image:
|
||||
if INPUT_MASK:
|
||||
raw_mask_path = Path(INPUT_MASK)
|
||||
mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path)
|
||||
if mask_path.is_file():
|
||||
return Image.open(mask_path).convert("L")
|
||||
|
||||
if USE_AUTO_MASK:
|
||||
auto = _auto_mask_from_image(img_path)
|
||||
auto_arr = np.asarray(auto, dtype=np.uint8)
|
||||
if (auto_arr > 0).any():
|
||||
return auto
|
||||
if not FALLBACK_TO_RECT_MASK:
|
||||
raise ValueError("自动mask为空,请检查输入图像是否存在透明/黑色缺失区域。")
|
||||
|
||||
return _make_rect_mask(img_rgb.size)
|
||||
|
||||
|
||||
def _resolve_path(base_dir: Path, p: str) -> Path:
|
||||
raw = Path(p)
|
||||
return raw if raw.is_absolute() else (base_dir / raw)
|
||||
|
||||
|
||||
def run_inpaint_test(base_dir: Path, out_dir: Path) -> None:
|
||||
img_path = _resolve_path(base_dir, INPUT_IMAGE)
|
||||
if not img_path.is_file():
|
||||
raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}")
|
||||
|
||||
predictor, used_backend = build_inpaint_predictor(
|
||||
UnifiedInpaintConfig(backend=INPAINT_BACKEND)
|
||||
)
|
||||
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
mask = _load_or_make_mask(base_dir, img_path, img)
|
||||
mask_out = out_dir / f"{img_path.stem}.mask_used.png"
|
||||
mask.save(mask_out)
|
||||
|
||||
kwargs = dict(
|
||||
strength=STRENGTH,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
num_inference_steps=NUM_INFERENCE_STEPS,
|
||||
max_side=MAX_SIDE,
|
||||
)
|
||||
if used_backend == InpaintBackend.CONTROLNET:
|
||||
kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE
|
||||
|
||||
out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs)
|
||||
out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png"
|
||||
if PRESERVE_TRANSPARENCY:
|
||||
alpha = _build_alpha_from_input(img_path, img)
|
||||
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||
if EXPAND_ALPHA_WITH_MASK:
|
||||
alpha = np.maximum(alpha, mask_u8)
|
||||
|
||||
out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8)
|
||||
out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1)
|
||||
Image.fromarray(out_rgba, mode="RGBA").save(out_path)
|
||||
else:
|
||||
out.save(out_path)
|
||||
|
||||
ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1])
|
||||
print(f"[Inpaint] backend={used_backend.value}, saved={out_path}")
|
||||
print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}")
|
||||
|
||||
|
||||
def run_draw_test(base_dir: Path, out_dir: Path) -> None:
|
||||
"""
|
||||
绘图测试:
|
||||
- 文生图:DRAW_INPUT_IMAGE=""
|
||||
- 图生图:DRAW_INPUT_IMAGE 指向参考图
|
||||
"""
|
||||
draw_predictor = build_draw_predictor(UnifiedDrawConfig())
|
||||
ref_image: Image.Image | None = None
|
||||
mode = "text2img"
|
||||
|
||||
if DRAW_INPUT_IMAGE:
|
||||
ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE)
|
||||
if not ref_path.is_file():
|
||||
raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}")
|
||||
ref_image = Image.open(ref_path).convert("RGB")
|
||||
mode = "img2img"
|
||||
|
||||
out = draw_predictor(
|
||||
prompt=DRAW_PROMPT,
|
||||
image=ref_image,
|
||||
negative_prompt=DRAW_NEGATIVE_PROMPT,
|
||||
strength=DRAW_STRENGTH,
|
||||
guidance_scale=DRAW_GUIDANCE_SCALE,
|
||||
num_inference_steps=DRAW_STEPS,
|
||||
width=DRAW_WIDTH,
|
||||
height=DRAW_HEIGHT,
|
||||
max_side=DRAW_MAX_SIDE,
|
||||
)
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = out_dir / f"draw_{mode}_{ts}.png"
|
||||
out.save(out_path)
|
||||
print(f"[Draw] mode={mode}, saved={out_path}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
out_dir = base_dir / OUTPUT_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if TASK_MODE == "inpaint":
|
||||
run_inpaint_test(base_dir, out_dir)
|
||||
return
|
||||
if TASK_MODE == "draw":
|
||||
run_draw_test(base_dir, out_dir)
|
||||
return
|
||||
|
||||
raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user