initial commit
This commit is contained in:
407
python_server/server.py
Normal file
407
python_server/server.py
Normal file
@@ -0,0 +1,407 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from config_loader import load_app_config, get_depth_backend_from_app
|
||||
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
|
||||
|
||||
from model.Seg.seg_loader import UnifiedSegConfig, SegBackend, build_seg_predictor
|
||||
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
|
||||
from model.Animation.animation_loader import (
|
||||
UnifiedAnimationConfig,
|
||||
AnimationBackend,
|
||||
build_animation_predictor,
|
||||
)
|
||||
|
||||
|
||||
APP_ROOT = Path(__file__).resolve().parent
|
||||
OUTPUT_DIR = APP_ROOT / "outputs"
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app = FastAPI(title="HFUT Model Server", version="0.1.0")
|
||||
|
||||
|
||||
class ImageInput(BaseModel):
|
||||
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64(不含 data: 前缀)")
|
||||
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||||
|
||||
|
||||
class DepthRequest(ImageInput):
|
||||
pass
|
||||
|
||||
|
||||
class SegmentRequest(ImageInput):
|
||||
pass
|
||||
|
||||
|
||||
class InpaintRequest(ImageInput):
|
||||
prompt: Optional[str] = Field("", description="补全 prompt")
|
||||
strength: float = Field(0.8, ge=0.0, le=1.0)
|
||||
negative_prompt: Optional[str] = Field("", description="负向 prompt")
|
||||
# 可选 mask(白色区域为重绘)
|
||||
mask_b64: Optional[str] = Field(None, description="mask PNG base64(可选)")
|
||||
# 推理缩放上限(避免 OOM)
|
||||
max_side: int = Field(1024, ge=128, le=2048)
|
||||
|
||||
|
||||
class AnimateRequest(BaseModel):
|
||||
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||||
prompt: str = Field(..., description="文本提示词")
|
||||
negative_prompt: Optional[str] = Field("", description="负向提示词")
|
||||
num_inference_steps: int = Field(25, ge=1, le=200)
|
||||
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
|
||||
width: int = Field(512, ge=128, le=2048)
|
||||
height: int = Field(512, ge=128, le=2048)
|
||||
video_length: int = Field(16, ge=1, le=128)
|
||||
seed: int = Field(-1, description="-1 表示随机种子")
|
||||
|
||||
|
||||
def _b64_to_pil_image(b64: str) -> Image.Image:
|
||||
raw = base64.b64decode(b64)
|
||||
return Image.open(io.BytesIO(raw))
|
||||
|
||||
|
||||
def _pil_image_to_png_b64(img: Image.Image) -> str:
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def _depth_to_png16_b64(depth: np.ndarray) -> str:
|
||||
depth = np.asarray(depth, dtype=np.float32)
|
||||
dmin = float(depth.min())
|
||||
dmax = float(depth.max())
|
||||
if dmax > dmin:
|
||||
norm = (depth - dmin) / (dmax - dmin)
|
||||
else:
|
||||
norm = np.zeros_like(depth, dtype=np.float32)
|
||||
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
|
||||
img = Image.fromarray(u16, mode="I;16")
|
||||
return _pil_image_to_png_b64(img)
|
||||
|
||||
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
|
||||
depth = np.asarray(depth, dtype=np.float32)
|
||||
dmin = float(depth.min())
|
||||
dmax = float(depth.max())
|
||||
if dmax > dmin:
|
||||
norm = (depth - dmin) / (dmax - dmin)
|
||||
else:
|
||||
norm = np.zeros_like(depth, dtype=np.float32)
|
||||
# 前后端约定:最远=255,最近=0(8-bit)
|
||||
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
|
||||
img = Image.fromarray(u8, mode="L")
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _default_half_mask(img: Image.Image) -> Image.Image:
|
||||
w, h = img.size
|
||||
mask = Image.new("L", (w, h), 0)
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle([w // 2, 0, w, h], fill=255)
|
||||
return mask
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# /models(给前端/GUI 使用)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
def get_models() -> Dict[str, Any]:
|
||||
"""
|
||||
返回一个兼容 Qt 前端的 schema:
|
||||
{
|
||||
"models": {
|
||||
"depth": { "key": { "name": "..."} ... },
|
||||
"segment": { ... },
|
||||
"inpaint": { "key": { "name": "...", "params": [...] } ... }
|
||||
}
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"models": {
|
||||
"depth": {
|
||||
# 兼容旧配置默认值
|
||||
"midas": {"name": "MiDaS (default)"},
|
||||
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
|
||||
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
|
||||
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
|
||||
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
|
||||
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
|
||||
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
|
||||
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
|
||||
"dpt_large": {"name": "DPT (large)"},
|
||||
"dpt_hybrid": {"name": "DPT (hybrid)"},
|
||||
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
|
||||
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
|
||||
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
|
||||
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
|
||||
},
|
||||
"segment": {
|
||||
"sam": {"name": "SAM (vit_h)"},
|
||||
# 兼容旧配置默认值
|
||||
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
|
||||
"mask2former": {"name": "Mask2Former (not implemented)"},
|
||||
},
|
||||
"inpaint": {
|
||||
# 兼容旧配置默认值:copy 表示不做补全
|
||||
"copy": {"name": "Copy (no-op)", "params": []},
|
||||
"sdxl_inpaint": {
|
||||
"name": "SDXL Inpaint",
|
||||
"params": [
|
||||
{"id": "prompt", "label": "提示词", "optional": True},
|
||||
],
|
||||
},
|
||||
"controlnet": {
|
||||
"name": "ControlNet Inpaint (canny)",
|
||||
"params": [
|
||||
{"id": "prompt", "label": "提示词", "optional": True},
|
||||
],
|
||||
},
|
||||
},
|
||||
"animation": {
|
||||
"animatediff": {
|
||||
"name": "AnimateDiff (Text-to-Video)",
|
||||
"params": [
|
||||
{"id": "prompt", "label": "提示词", "optional": False},
|
||||
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
|
||||
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
|
||||
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
|
||||
{"id": "width", "label": "宽度", "optional": True},
|
||||
{"id": "height", "label": "高度", "optional": True},
|
||||
{"id": "video_length", "label": "帧数", "optional": True},
|
||||
{"id": "seed", "label": "随机种子", "optional": True},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Depth
|
||||
# -----------------------------
|
||||
|
||||
|
||||
_depth_predictor = None
|
||||
_depth_backend: DepthBackend | None = None
|
||||
|
||||
|
||||
def _ensure_depth_predictor() -> None:
|
||||
global _depth_predictor, _depth_backend
|
||||
if _depth_predictor is not None and _depth_backend is not None:
|
||||
return
|
||||
|
||||
app_cfg = load_app_config()
|
||||
backend_str = get_depth_backend_from_app(app_cfg)
|
||||
try:
|
||||
backend = DepthBackend(backend_str)
|
||||
except Exception as e:
|
||||
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
|
||||
|
||||
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
|
||||
|
||||
|
||||
@app.post("/depth")
|
||||
def depth(req: DepthRequest):
|
||||
"""
|
||||
计算深度并直接返回二进制 PNG(16-bit 灰度)。
|
||||
|
||||
约束:
|
||||
- 前端不传/不选模型;模型选择写死在后端 config.py
|
||||
- 成功:HTTP 200 + Content-Type: image/png
|
||||
- 失败:HTTP 500,detail 为错误信息
|
||||
"""
|
||||
try:
|
||||
_ensure_depth_predictor()
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
depth_arr = _depth_predictor(pil) # type: ignore[misc]
|
||||
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
|
||||
return Response(content=png_bytes, media_type="image/png")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Segment
|
||||
# -----------------------------
|
||||
|
||||
|
||||
_seg_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_seg_predictor(model_name: str):
|
||||
if model_name in _seg_cache:
|
||||
return _seg_cache[model_name]
|
||||
|
||||
# 兼容旧默认 key
|
||||
if model_name == "mask2former_debug":
|
||||
model_name = "sam"
|
||||
|
||||
if model_name == "sam":
|
||||
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
|
||||
_seg_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
if model_name == "mask2former":
|
||||
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
|
||||
_seg_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
raise ValueError(f"未知 segment model_name: {model_name}")
|
||||
|
||||
|
||||
@app.post("/segment")
|
||||
def segment(req: SegmentRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
model_name = req.model_name or "sam"
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
rgb = np.array(pil, dtype=np.uint8)
|
||||
|
||||
predictor = _get_seg_predictor(model_name)
|
||||
label_map = predictor(rgb).astype(np.int32)
|
||||
|
||||
out_dir = OUTPUT_DIR / "segment"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"{model_name}_label.png"
|
||||
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
|
||||
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
|
||||
|
||||
return {"success": True, "label_path": str(out_path)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Inpaint
|
||||
# -----------------------------
|
||||
|
||||
|
||||
_inpaint_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_inpaint_predictor(model_name: str):
|
||||
if model_name in _inpaint_cache:
|
||||
return _inpaint_cache[model_name]
|
||||
|
||||
if model_name == "copy":
|
||||
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
|
||||
return image.convert("RGB")
|
||||
_inpaint_cache[model_name] = _copy
|
||||
return _copy
|
||||
|
||||
if model_name == "sdxl_inpaint":
|
||||
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
|
||||
_inpaint_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
if model_name == "controlnet":
|
||||
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
|
||||
_inpaint_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
raise ValueError(f"未知 inpaint model_name: {model_name}")
|
||||
|
||||
|
||||
@app.post("/inpaint")
|
||||
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
model_name = req.model_name or "sdxl_inpaint"
|
||||
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||
|
||||
if req.mask_b64:
|
||||
mask = _b64_to_pil_image(req.mask_b64).convert("L")
|
||||
else:
|
||||
mask = _default_half_mask(pil)
|
||||
|
||||
predictor = _get_inpaint_predictor(model_name)
|
||||
out = predictor(
|
||||
pil,
|
||||
mask,
|
||||
req.prompt or "",
|
||||
req.negative_prompt or "",
|
||||
strength=req.strength,
|
||||
max_side=req.max_side,
|
||||
)
|
||||
|
||||
out_dir = OUTPUT_DIR / "inpaint"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = out_dir / f"{model_name}_inpaint.png"
|
||||
out.save(out_path)
|
||||
|
||||
return {"success": True, "output_path": str(out_path)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
_animation_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_animation_predictor(model_name: str):
|
||||
if model_name in _animation_cache:
|
||||
return _animation_cache[model_name]
|
||||
|
||||
if model_name == "animatediff":
|
||||
pred, _ = build_animation_predictor(
|
||||
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
|
||||
)
|
||||
_animation_cache[model_name] = pred
|
||||
return pred
|
||||
|
||||
raise ValueError(f"未知 animation model_name: {model_name}")
|
||||
|
||||
|
||||
@app.post("/animate")
|
||||
def animate(req: AnimateRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
model_name = req.model_name or "animatediff"
|
||||
predictor = _get_animation_predictor(model_name)
|
||||
|
||||
result_path = predictor(
|
||||
prompt=req.prompt,
|
||||
negative_prompt=req.negative_prompt or "",
|
||||
num_inference_steps=req.num_inference_steps,
|
||||
guidance_scale=req.guidance_scale,
|
||||
width=req.width,
|
||||
height=req.height,
|
||||
video_length=req.video_length,
|
||||
seed=req.seed,
|
||||
)
|
||||
|
||||
out_dir = OUTPUT_DIR / "animation"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = out_dir / f"{model_name}_{ts}.gif"
|
||||
shutil.copy2(result_path, out_path)
|
||||
|
||||
return {"success": True, "output_path": str(out_path)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> Dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
Reference in New Issue
Block a user