Files
hfut-bishe/python_server/server.py
2026-04-09 23:38:14 +08:00

489 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import base64
import datetime
import io
import os
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel, Field
from PIL import Image, ImageDraw
from config_loader import load_app_config, get_depth_backend_from_app
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
from model.Seg.seg_loader import (
UnifiedSegConfig,
SegBackend,
build_seg_predictor,
mask_to_contour_xy,
run_sam_prompt,
)
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
from model.Animation.animation_loader import (
UnifiedAnimationConfig,
AnimationBackend,
build_animation_predictor,
)
APP_ROOT = Path(__file__).resolve().parent
OUTPUT_DIR = APP_ROOT / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title="HFUT Model Server", version="0.1.0")
class ImageInput(BaseModel):
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64不含 data: 前缀)")
model_name: Optional[str] = Field(None, description="模型 key来自 /models")
class DepthRequest(ImageInput):
pass
class SegmentRequest(ImageInput):
pass
class SamPromptSegmentRequest(BaseModel):
image_b64: str = Field(..., description="裁剪后的 RGB 图 base64PNG/JPG")
overlay_b64: Optional[str] = Field(
None,
description="与裁剪同尺寸的标记叠加 PNG base64可选当前用于校验尺寸一致",
)
point_coords: list[list[float]] = Field(
...,
description="裁剪坐标系下的提示点 [[x,y], ...]",
)
point_labels: list[int] = Field(
...,
description="与 point_coords 等长1=前景0=背景",
)
box_xyxy: list[float] = Field(
...,
description="裁剪内笔画紧包围盒 [x1,y1,x2,y2](像素)",
min_length=4,
max_length=4,
)
class InpaintRequest(ImageInput):
prompt: Optional[str] = Field("", description="补全 prompt")
strength: float = Field(0.8, ge=0.0, le=1.0)
negative_prompt: Optional[str] = Field("", description="负向 prompt")
# 可选 mask白色区域为重绘
mask_b64: Optional[str] = Field(None, description="mask PNG base64可选")
# 推理缩放上限(避免 OOM
max_side: int = Field(1024, ge=128, le=2048)
class AnimateRequest(BaseModel):
model_name: Optional[str] = Field(None, description="模型 key来自 /models")
prompt: str = Field(..., description="文本提示词")
negative_prompt: Optional[str] = Field("", description="负向提示词")
num_inference_steps: int = Field(25, ge=1, le=200)
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
width: int = Field(512, ge=128, le=2048)
height: int = Field(512, ge=128, le=2048)
video_length: int = Field(16, ge=1, le=128)
seed: int = Field(-1, description="-1 表示随机种子")
def _b64_to_pil_image(b64: str) -> Image.Image:
raw = base64.b64decode(b64)
return Image.open(io.BytesIO(raw))
def _pil_image_to_png_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def _depth_to_png16_b64(depth: np.ndarray) -> str:
depth = np.asarray(depth, dtype=np.float32)
dmin = float(depth.min())
dmax = float(depth.max())
if dmax > dmin:
norm = (depth - dmin) / (dmax - dmin)
else:
norm = np.zeros_like(depth, dtype=np.float32)
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
img = Image.fromarray(u16, mode="I;16")
return _pil_image_to_png_b64(img)
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
depth = np.asarray(depth, dtype=np.float32)
dmin = float(depth.min())
dmax = float(depth.max())
if dmax > dmin:
norm = (depth - dmin) / (dmax - dmin)
else:
norm = np.zeros_like(depth, dtype=np.float32)
# 前后端约定:最远=255最近=08-bit
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
img = Image.fromarray(u8, mode="L")
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _default_half_mask(img: Image.Image) -> Image.Image:
w, h = img.size
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([w // 2, 0, w, h], fill=255)
return mask
# -----------------------------
# /models给前端/GUI 使用)
# -----------------------------
@app.get("/models")
def get_models() -> Dict[str, Any]:
"""
返回一个兼容 Qt 前端的 schema
{
"models": {
"depth": { "key": { "name": "..."} ... },
"segment": { ... },
"inpaint": { "key": { "name": "...", "params": [...] } ... }
}
}
"""
return {
"models": {
"depth": {
# 兼容旧配置默认值
"midas": {"name": "MiDaS (default)"},
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
"dpt_large": {"name": "DPT (large)"},
"dpt_hybrid": {"name": "DPT (hybrid)"},
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
},
"segment": {
"sam": {"name": "SAM (vit_h)"},
# 兼容旧配置默认值
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
"mask2former": {"name": "Mask2Former (not implemented)"},
},
"inpaint": {
# 兼容旧配置默认值copy 表示不做补全
"copy": {"name": "Copy (no-op)", "params": []},
"sdxl_inpaint": {
"name": "SDXL Inpaint",
"params": [
{"id": "prompt", "label": "提示词", "optional": True},
],
},
"controlnet": {
"name": "ControlNet Inpaint (canny)",
"params": [
{"id": "prompt", "label": "提示词", "optional": True},
],
},
},
"animation": {
"animatediff": {
"name": "AnimateDiff (Text-to-Video)",
"params": [
{"id": "prompt", "label": "提示词", "optional": False},
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
{"id": "width", "label": "宽度", "optional": True},
{"id": "height", "label": "高度", "optional": True},
{"id": "video_length", "label": "帧数", "optional": True},
{"id": "seed", "label": "随机种子", "optional": True},
],
},
},
}
}
# -----------------------------
# Depth
# -----------------------------
_depth_predictor = None
_depth_backend: DepthBackend | None = None
def _ensure_depth_predictor() -> None:
global _depth_predictor, _depth_backend
if _depth_predictor is not None and _depth_backend is not None:
return
app_cfg = load_app_config()
backend_str = get_depth_backend_from_app(app_cfg)
try:
backend = DepthBackend(backend_str)
except Exception as e:
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
@app.post("/depth")
def depth(req: DepthRequest):
"""
计算深度并直接返回二进制 PNG16-bit 灰度)。
约束:
- 前端不传/不选模型;模型选择写死在后端 config.py
- 成功HTTP 200 + Content-Type: image/png
- 失败HTTP 500detail 为错误信息
"""
try:
_ensure_depth_predictor()
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
depth_arr = _depth_predictor(pil) # type: ignore[misc]
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
return Response(content=png_bytes, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# -----------------------------
# Segment
# -----------------------------
_seg_cache: Dict[str, Any] = {}
def _get_seg_predictor(model_name: str):
if model_name in _seg_cache:
return _seg_cache[model_name]
# 兼容旧默认 key
if model_name == "mask2former_debug":
model_name = "sam"
if model_name == "sam":
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
_seg_cache[model_name] = pred
return pred
if model_name == "mask2former":
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
_seg_cache[model_name] = pred
return pred
raise ValueError(f"未知 segment model_name: {model_name}")
@app.post("/segment")
def segment(req: SegmentRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "sam"
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
rgb = np.array(pil, dtype=np.uint8)
predictor = _get_seg_predictor(model_name)
label_map = predictor(rgb).astype(np.int32)
out_dir = OUTPUT_DIR / "segment"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_label.png"
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
return {"success": True, "label_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
@app.post("/segment/sam_prompt")
def segment_sam_prompt(req: SamPromptSegmentRequest) -> Dict[str, Any]:
"""
交互式 SAM裁剪图 + 点/框提示,返回掩膜外轮廓点列(裁剪像素坐标)。
"""
try:
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
rgb = np.array(pil, dtype=np.uint8)
h, w = rgb.shape[0], rgb.shape[1]
if req.overlay_b64:
ov = _b64_to_pil_image(req.overlay_b64)
if ov.size != (w, h):
return {
"success": False,
"error": f"overlay 尺寸 {ov.size} 与 image {w}x{h} 不一致",
"contour": [],
}
if len(req.point_coords) != len(req.point_labels):
return {
"success": False,
"error": "point_coords 与 point_labels 长度不一致",
"contour": [],
}
if len(req.point_coords) < 1:
return {"success": False, "error": "至少需要一个提示点", "contour": []}
pc = np.array(req.point_coords, dtype=np.float32)
if pc.ndim != 2 or pc.shape[1] != 2:
return {"success": False, "error": "point_coords 每项须为 [x,y]", "contour": []}
pl = np.array(req.point_labels, dtype=np.int64)
box = np.array(req.box_xyxy, dtype=np.float32)
mask = run_sam_prompt(rgb, pc, pl, box_xyxy=box)
if not np.any(mask):
return {"success": False, "error": "SAM 未产生有效掩膜", "contour": []}
contour = mask_to_contour_xy(mask, epsilon_px=2.0)
if len(contour) < 3:
return {"success": False, "error": "轮廓点数不足", "contour": []}
return {"success": True, "contour": contour, "error": None}
except Exception as e:
return {"success": False, "error": str(e), "contour": []}
# -----------------------------
# Inpaint
# -----------------------------
_inpaint_cache: Dict[str, Any] = {}
def _get_inpaint_predictor(model_name: str):
if model_name in _inpaint_cache:
return _inpaint_cache[model_name]
if model_name == "copy":
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
return image.convert("RGB")
_inpaint_cache[model_name] = _copy
return _copy
if model_name == "sdxl_inpaint":
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
_inpaint_cache[model_name] = pred
return pred
if model_name == "controlnet":
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
_inpaint_cache[model_name] = pred
return pred
raise ValueError(f"未知 inpaint model_name: {model_name}")
@app.post("/inpaint")
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "sdxl_inpaint"
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
if req.mask_b64:
mask = _b64_to_pil_image(req.mask_b64).convert("L")
else:
mask = _default_half_mask(pil)
predictor = _get_inpaint_predictor(model_name)
out = predictor(
pil,
mask,
req.prompt or "",
req.negative_prompt or "",
strength=req.strength,
max_side=req.max_side,
)
out_dir = OUTPUT_DIR / "inpaint"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_inpaint.png"
out.save(out_path)
# 兼容 Qt 前端:直接返回结果图,避免前端再去读取服务器磁盘路径
return {
"success": True,
"output_path": str(out_path),
"output_image_b64": _pil_image_to_png_b64(out),
}
except Exception as e:
return {"success": False, "error": str(e)}
_animation_cache: Dict[str, Any] = {}
def _get_animation_predictor(model_name: str):
if model_name in _animation_cache:
return _animation_cache[model_name]
if model_name == "animatediff":
pred, _ = build_animation_predictor(
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
)
_animation_cache[model_name] = pred
return pred
raise ValueError(f"未知 animation model_name: {model_name}")
@app.post("/animate")
def animate(req: AnimateRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "animatediff"
predictor = _get_animation_predictor(model_name)
result_path = predictor(
prompt=req.prompt,
negative_prompt=req.negative_prompt or "",
num_inference_steps=req.num_inference_steps,
guidance_scale=req.guidance_scale,
width=req.width,
height=req.height,
video_length=req.video_length,
seed=req.seed,
)
out_dir = OUTPUT_DIR / "animation"
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"{model_name}_{ts}.gif"
shutil.copy2(result_path, out_path)
return {"success": True, "output_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)