489 lines
16 KiB
Python
489 lines
16 KiB
Python
from __future__ import annotations
|
||
|
||
import base64
|
||
import datetime
|
||
import io
|
||
import os
|
||
import shutil
|
||
from dataclasses import asdict
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Optional
|
||
|
||
import numpy as np
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.responses import Response
|
||
from pydantic import BaseModel, Field
|
||
from PIL import Image, ImageDraw
|
||
|
||
from config_loader import load_app_config, get_depth_backend_from_app
|
||
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
|
||
|
||
from model.Seg.seg_loader import (
|
||
UnifiedSegConfig,
|
||
SegBackend,
|
||
build_seg_predictor,
|
||
mask_to_contour_xy,
|
||
run_sam_prompt,
|
||
)
|
||
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
|
||
from model.Animation.animation_loader import (
|
||
UnifiedAnimationConfig,
|
||
AnimationBackend,
|
||
build_animation_predictor,
|
||
)
|
||
|
||
|
||
APP_ROOT = Path(__file__).resolve().parent
|
||
OUTPUT_DIR = APP_ROOT / "outputs"
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
app = FastAPI(title="HFUT Model Server", version="0.1.0")
|
||
|
||
|
||
class ImageInput(BaseModel):
|
||
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64(不含 data: 前缀)")
|
||
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||
|
||
|
||
class DepthRequest(ImageInput):
|
||
pass
|
||
|
||
|
||
class SegmentRequest(ImageInput):
|
||
pass
|
||
|
||
|
||
class SamPromptSegmentRequest(BaseModel):
|
||
image_b64: str = Field(..., description="裁剪后的 RGB 图 base64(PNG/JPG)")
|
||
overlay_b64: Optional[str] = Field(
|
||
None,
|
||
description="与裁剪同尺寸的标记叠加 PNG base64(可选;当前用于校验尺寸一致)",
|
||
)
|
||
point_coords: list[list[float]] = Field(
|
||
...,
|
||
description="裁剪坐标系下的提示点 [[x,y], ...]",
|
||
)
|
||
point_labels: list[int] = Field(
|
||
...,
|
||
description="与 point_coords 等长:1=前景,0=背景",
|
||
)
|
||
box_xyxy: list[float] = Field(
|
||
...,
|
||
description="裁剪内笔画紧包围盒 [x1,y1,x2,y2](像素)",
|
||
min_length=4,
|
||
max_length=4,
|
||
)
|
||
|
||
|
||
class InpaintRequest(ImageInput):
|
||
prompt: Optional[str] = Field("", description="补全 prompt")
|
||
strength: float = Field(0.8, ge=0.0, le=1.0)
|
||
negative_prompt: Optional[str] = Field("", description="负向 prompt")
|
||
# 可选 mask(白色区域为重绘)
|
||
mask_b64: Optional[str] = Field(None, description="mask PNG base64(可选)")
|
||
# 推理缩放上限(避免 OOM)
|
||
max_side: int = Field(1024, ge=128, le=2048)
|
||
|
||
|
||
class AnimateRequest(BaseModel):
|
||
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||
prompt: str = Field(..., description="文本提示词")
|
||
negative_prompt: Optional[str] = Field("", description="负向提示词")
|
||
num_inference_steps: int = Field(25, ge=1, le=200)
|
||
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
|
||
width: int = Field(512, ge=128, le=2048)
|
||
height: int = Field(512, ge=128, le=2048)
|
||
video_length: int = Field(16, ge=1, le=128)
|
||
seed: int = Field(-1, description="-1 表示随机种子")
|
||
|
||
|
||
def _b64_to_pil_image(b64: str) -> Image.Image:
|
||
raw = base64.b64decode(b64)
|
||
return Image.open(io.BytesIO(raw))
|
||
|
||
|
||
def _pil_image_to_png_b64(img: Image.Image) -> str:
|
||
buf = io.BytesIO()
|
||
img.save(buf, format="PNG")
|
||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||
|
||
|
||
def _depth_to_png16_b64(depth: np.ndarray) -> str:
|
||
depth = np.asarray(depth, dtype=np.float32)
|
||
dmin = float(depth.min())
|
||
dmax = float(depth.max())
|
||
if dmax > dmin:
|
||
norm = (depth - dmin) / (dmax - dmin)
|
||
else:
|
||
norm = np.zeros_like(depth, dtype=np.float32)
|
||
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
|
||
img = Image.fromarray(u16, mode="I;16")
|
||
return _pil_image_to_png_b64(img)
|
||
|
||
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
|
||
depth = np.asarray(depth, dtype=np.float32)
|
||
dmin = float(depth.min())
|
||
dmax = float(depth.max())
|
||
if dmax > dmin:
|
||
norm = (depth - dmin) / (dmax - dmin)
|
||
else:
|
||
norm = np.zeros_like(depth, dtype=np.float32)
|
||
# 前后端约定:最远=255,最近=0(8-bit)
|
||
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
|
||
img = Image.fromarray(u8, mode="L")
|
||
buf = io.BytesIO()
|
||
img.save(buf, format="PNG")
|
||
return buf.getvalue()
|
||
|
||
|
||
def _default_half_mask(img: Image.Image) -> Image.Image:
|
||
w, h = img.size
|
||
mask = Image.new("L", (w, h), 0)
|
||
draw = ImageDraw.Draw(mask)
|
||
draw.rectangle([w // 2, 0, w, h], fill=255)
|
||
return mask
|
||
|
||
|
||
# -----------------------------
|
||
# /models(给前端/GUI 使用)
|
||
# -----------------------------
|
||
|
||
|
||
@app.get("/models")
|
||
def get_models() -> Dict[str, Any]:
|
||
"""
|
||
返回一个兼容 Qt 前端的 schema:
|
||
{
|
||
"models": {
|
||
"depth": { "key": { "name": "..."} ... },
|
||
"segment": { ... },
|
||
"inpaint": { "key": { "name": "...", "params": [...] } ... }
|
||
}
|
||
}
|
||
"""
|
||
return {
|
||
"models": {
|
||
"depth": {
|
||
# 兼容旧配置默认值
|
||
"midas": {"name": "MiDaS (default)"},
|
||
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
|
||
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
|
||
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
|
||
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
|
||
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
|
||
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
|
||
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
|
||
"dpt_large": {"name": "DPT (large)"},
|
||
"dpt_hybrid": {"name": "DPT (hybrid)"},
|
||
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
|
||
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
|
||
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
|
||
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
|
||
},
|
||
"segment": {
|
||
"sam": {"name": "SAM (vit_h)"},
|
||
# 兼容旧配置默认值
|
||
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
|
||
"mask2former": {"name": "Mask2Former (not implemented)"},
|
||
},
|
||
"inpaint": {
|
||
# 兼容旧配置默认值:copy 表示不做补全
|
||
"copy": {"name": "Copy (no-op)", "params": []},
|
||
"sdxl_inpaint": {
|
||
"name": "SDXL Inpaint",
|
||
"params": [
|
||
{"id": "prompt", "label": "提示词", "optional": True},
|
||
],
|
||
},
|
||
"controlnet": {
|
||
"name": "ControlNet Inpaint (canny)",
|
||
"params": [
|
||
{"id": "prompt", "label": "提示词", "optional": True},
|
||
],
|
||
},
|
||
},
|
||
"animation": {
|
||
"animatediff": {
|
||
"name": "AnimateDiff (Text-to-Video)",
|
||
"params": [
|
||
{"id": "prompt", "label": "提示词", "optional": False},
|
||
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
|
||
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
|
||
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
|
||
{"id": "width", "label": "宽度", "optional": True},
|
||
{"id": "height", "label": "高度", "optional": True},
|
||
{"id": "video_length", "label": "帧数", "optional": True},
|
||
{"id": "seed", "label": "随机种子", "optional": True},
|
||
],
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
|
||
# -----------------------------
|
||
# Depth
|
||
# -----------------------------
|
||
|
||
|
||
_depth_predictor = None
|
||
_depth_backend: DepthBackend | None = None
|
||
|
||
|
||
def _ensure_depth_predictor() -> None:
|
||
global _depth_predictor, _depth_backend
|
||
if _depth_predictor is not None and _depth_backend is not None:
|
||
return
|
||
|
||
app_cfg = load_app_config()
|
||
backend_str = get_depth_backend_from_app(app_cfg)
|
||
try:
|
||
backend = DepthBackend(backend_str)
|
||
except Exception as e:
|
||
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
|
||
|
||
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
|
||
|
||
|
||
@app.post("/depth")
|
||
def depth(req: DepthRequest):
|
||
"""
|
||
计算深度并直接返回二进制 PNG(16-bit 灰度)。
|
||
|
||
约束:
|
||
- 前端不传/不选模型;模型选择写死在后端 config.py
|
||
- 成功:HTTP 200 + Content-Type: image/png
|
||
- 失败:HTTP 500,detail 为错误信息
|
||
"""
|
||
try:
|
||
_ensure_depth_predictor()
|
||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||
depth_arr = _depth_predictor(pil) # type: ignore[misc]
|
||
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
|
||
return Response(content=png_bytes, media_type="image/png")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# -----------------------------
|
||
# Segment
|
||
# -----------------------------
|
||
|
||
|
||
_seg_cache: Dict[str, Any] = {}
|
||
|
||
|
||
def _get_seg_predictor(model_name: str):
|
||
if model_name in _seg_cache:
|
||
return _seg_cache[model_name]
|
||
|
||
# 兼容旧默认 key
|
||
if model_name == "mask2former_debug":
|
||
model_name = "sam"
|
||
|
||
if model_name == "sam":
|
||
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
|
||
_seg_cache[model_name] = pred
|
||
return pred
|
||
|
||
if model_name == "mask2former":
|
||
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
|
||
_seg_cache[model_name] = pred
|
||
return pred
|
||
|
||
raise ValueError(f"未知 segment model_name: {model_name}")
|
||
|
||
|
||
@app.post("/segment")
|
||
def segment(req: SegmentRequest) -> Dict[str, Any]:
|
||
try:
|
||
model_name = req.model_name or "sam"
|
||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||
rgb = np.array(pil, dtype=np.uint8)
|
||
|
||
predictor = _get_seg_predictor(model_name)
|
||
label_map = predictor(rgb).astype(np.int32)
|
||
|
||
out_dir = OUTPUT_DIR / "segment"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{model_name}_label.png"
|
||
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
|
||
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
|
||
|
||
return {"success": True, "label_path": str(out_path)}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
@app.post("/segment/sam_prompt")
|
||
def segment_sam_prompt(req: SamPromptSegmentRequest) -> Dict[str, Any]:
|
||
"""
|
||
交互式 SAM:裁剪图 + 点/框提示,返回掩膜外轮廓点列(裁剪像素坐标)。
|
||
"""
|
||
try:
|
||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||
rgb = np.array(pil, dtype=np.uint8)
|
||
h, w = rgb.shape[0], rgb.shape[1]
|
||
|
||
if req.overlay_b64:
|
||
ov = _b64_to_pil_image(req.overlay_b64)
|
||
if ov.size != (w, h):
|
||
return {
|
||
"success": False,
|
||
"error": f"overlay 尺寸 {ov.size} 与 image {w}x{h} 不一致",
|
||
"contour": [],
|
||
}
|
||
|
||
if len(req.point_coords) != len(req.point_labels):
|
||
return {
|
||
"success": False,
|
||
"error": "point_coords 与 point_labels 长度不一致",
|
||
"contour": [],
|
||
}
|
||
if len(req.point_coords) < 1:
|
||
return {"success": False, "error": "至少需要一个提示点", "contour": []}
|
||
|
||
pc = np.array(req.point_coords, dtype=np.float32)
|
||
if pc.ndim != 2 or pc.shape[1] != 2:
|
||
return {"success": False, "error": "point_coords 每项须为 [x,y]", "contour": []}
|
||
|
||
pl = np.array(req.point_labels, dtype=np.int64)
|
||
box = np.array(req.box_xyxy, dtype=np.float32)
|
||
|
||
mask = run_sam_prompt(rgb, pc, pl, box_xyxy=box)
|
||
if not np.any(mask):
|
||
return {"success": False, "error": "SAM 未产生有效掩膜", "contour": []}
|
||
|
||
contour = mask_to_contour_xy(mask, epsilon_px=2.0)
|
||
if len(contour) < 3:
|
||
return {"success": False, "error": "轮廓点数不足", "contour": []}
|
||
|
||
return {"success": True, "contour": contour, "error": None}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e), "contour": []}
|
||
|
||
|
||
# -----------------------------
|
||
# Inpaint
|
||
# -----------------------------
|
||
|
||
|
||
_inpaint_cache: Dict[str, Any] = {}
|
||
|
||
|
||
def _get_inpaint_predictor(model_name: str):
|
||
if model_name in _inpaint_cache:
|
||
return _inpaint_cache[model_name]
|
||
|
||
if model_name == "copy":
|
||
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
|
||
return image.convert("RGB")
|
||
_inpaint_cache[model_name] = _copy
|
||
return _copy
|
||
|
||
if model_name == "sdxl_inpaint":
|
||
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
|
||
_inpaint_cache[model_name] = pred
|
||
return pred
|
||
|
||
if model_name == "controlnet":
|
||
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
|
||
_inpaint_cache[model_name] = pred
|
||
return pred
|
||
|
||
raise ValueError(f"未知 inpaint model_name: {model_name}")
|
||
|
||
|
||
@app.post("/inpaint")
|
||
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
|
||
try:
|
||
model_name = req.model_name or "sdxl_inpaint"
|
||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||
|
||
if req.mask_b64:
|
||
mask = _b64_to_pil_image(req.mask_b64).convert("L")
|
||
else:
|
||
mask = _default_half_mask(pil)
|
||
|
||
predictor = _get_inpaint_predictor(model_name)
|
||
out = predictor(
|
||
pil,
|
||
mask,
|
||
req.prompt or "",
|
||
req.negative_prompt or "",
|
||
strength=req.strength,
|
||
max_side=req.max_side,
|
||
)
|
||
|
||
out_dir = OUTPUT_DIR / "inpaint"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = out_dir / f"{model_name}_inpaint.png"
|
||
out.save(out_path)
|
||
|
||
# 兼容 Qt 前端:直接返回结果图,避免前端再去读取服务器磁盘路径
|
||
return {
|
||
"success": True,
|
||
"output_path": str(out_path),
|
||
"output_image_b64": _pil_image_to_png_b64(out),
|
||
}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
_animation_cache: Dict[str, Any] = {}
|
||
|
||
|
||
def _get_animation_predictor(model_name: str):
|
||
if model_name in _animation_cache:
|
||
return _animation_cache[model_name]
|
||
|
||
if model_name == "animatediff":
|
||
pred, _ = build_animation_predictor(
|
||
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
|
||
)
|
||
_animation_cache[model_name] = pred
|
||
return pred
|
||
|
||
raise ValueError(f"未知 animation model_name: {model_name}")
|
||
|
||
|
||
@app.post("/animate")
|
||
def animate(req: AnimateRequest) -> Dict[str, Any]:
|
||
try:
|
||
model_name = req.model_name or "animatediff"
|
||
predictor = _get_animation_predictor(model_name)
|
||
|
||
result_path = predictor(
|
||
prompt=req.prompt,
|
||
negative_prompt=req.negative_prompt or "",
|
||
num_inference_steps=req.num_inference_steps,
|
||
guidance_scale=req.guidance_scale,
|
||
width=req.width,
|
||
height=req.height,
|
||
video_length=req.video_length,
|
||
seed=req.seed,
|
||
)
|
||
|
||
out_dir = OUTPUT_DIR / "animation"
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
out_path = out_dir / f"{model_name}_{ts}.gif"
|
||
shutil.copy2(result_path, out_path)
|
||
|
||
return {"success": True, "output_path": str(out_path)}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
@app.get("/health")
|
||
def health() -> Dict[str, str]:
|
||
return {"status": "ok"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
|
||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||
|