from __future__ import annotations import base64 import datetime import io import os import shutil from dataclasses import asdict from pathlib import Path from typing import Any, Dict, Optional import numpy as np from fastapi import FastAPI, HTTPException from fastapi.responses import Response from pydantic import BaseModel, Field from PIL import Image, ImageDraw from config_loader import load_app_config, get_depth_backend_from_app from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor from model.Seg.seg_loader import ( UnifiedSegConfig, SegBackend, build_seg_predictor, mask_to_contour_xy, run_sam_prompt, ) from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor from model.Animation.animation_loader import ( UnifiedAnimationConfig, AnimationBackend, build_animation_predictor, ) APP_ROOT = Path(__file__).resolve().parent OUTPUT_DIR = APP_ROOT / "outputs" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI(title="HFUT Model Server", version="0.1.0") class ImageInput(BaseModel): image_b64: str = Field(..., description="PNG/JPG 编码后的 base64(不含 data: 前缀)") model_name: Optional[str] = Field(None, description="模型 key(来自 /models)") class DepthRequest(ImageInput): pass class SegmentRequest(ImageInput): pass class SamPromptSegmentRequest(BaseModel): image_b64: str = Field(..., description="裁剪后的 RGB 图 base64(PNG/JPG)") overlay_b64: Optional[str] = Field( None, description="与裁剪同尺寸的标记叠加 PNG base64(可选;当前用于校验尺寸一致)", ) point_coords: list[list[float]] = Field( ..., description="裁剪坐标系下的提示点 [[x,y], ...]", ) point_labels: list[int] = Field( ..., description="与 point_coords 等长:1=前景,0=背景", ) box_xyxy: list[float] = Field( ..., description="裁剪内笔画紧包围盒 [x1,y1,x2,y2](像素)", min_length=4, max_length=4, ) class InpaintRequest(ImageInput): prompt: Optional[str] = Field("", description="补全 prompt") strength: float = Field(0.8, ge=0.0, le=1.0) negative_prompt: Optional[str] = Field("", description="负向 prompt") # 可选 mask(白色区域为重绘) mask_b64: Optional[str] = Field(None, description="mask PNG base64(可选)") # 推理缩放上限(避免 OOM) max_side: int = Field(1024, ge=128, le=2048) class AnimateRequest(BaseModel): model_name: Optional[str] = Field(None, description="模型 key(来自 /models)") prompt: str = Field(..., description="文本提示词") negative_prompt: Optional[str] = Field("", description="负向提示词") num_inference_steps: int = Field(25, ge=1, le=200) guidance_scale: float = Field(8.0, ge=0.0, le=30.0) width: int = Field(512, ge=128, le=2048) height: int = Field(512, ge=128, le=2048) video_length: int = Field(16, ge=1, le=128) seed: int = Field(-1, description="-1 表示随机种子") def _b64_to_pil_image(b64: str) -> Image.Image: raw = base64.b64decode(b64) return Image.open(io.BytesIO(raw)) def _pil_image_to_png_b64(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("ascii") def _depth_to_png16_b64(depth: np.ndarray) -> str: depth = np.asarray(depth, dtype=np.float32) dmin = float(depth.min()) dmax = float(depth.max()) if dmax > dmin: norm = (depth - dmin) / (dmax - dmin) else: norm = np.zeros_like(depth, dtype=np.float32) u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16) img = Image.fromarray(u16, mode="I;16") return _pil_image_to_png_b64(img) def _depth_to_png16_bytes(depth: np.ndarray) -> bytes: depth = np.asarray(depth, dtype=np.float32) dmin = float(depth.min()) dmax = float(depth.max()) if dmax > dmin: norm = (depth - dmin) / (dmax - dmin) else: norm = np.zeros_like(depth, dtype=np.float32) # 前后端约定:最远=255,最近=0(8-bit) u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8) img = Image.fromarray(u8, mode="L") buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() def _default_half_mask(img: Image.Image) -> Image.Image: w, h = img.size mask = Image.new("L", (w, h), 0) draw = ImageDraw.Draw(mask) draw.rectangle([w // 2, 0, w, h], fill=255) return mask # ----------------------------- # /models(给前端/GUI 使用) # ----------------------------- @app.get("/models") def get_models() -> Dict[str, Any]: """ 返回一个兼容 Qt 前端的 schema: { "models": { "depth": { "key": { "name": "..."} ... }, "segment": { ... }, "inpaint": { "key": { "name": "...", "params": [...] } ... } } } """ return { "models": { "depth": { # 兼容旧配置默认值 "midas": {"name": "MiDaS (default)"}, "zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"}, "zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"}, "zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"}, "depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"}, "depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"}, "depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"}, "depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"}, "dpt_large": {"name": "DPT (large)"}, "dpt_hybrid": {"name": "DPT (hybrid)"}, "midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"}, "midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"}, "midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"}, "midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"}, }, "segment": { "sam": {"name": "SAM (vit_h)"}, # 兼容旧配置默认值 "mask2former_debug": {"name": "SAM (compat mask2former_debug)"}, "mask2former": {"name": "Mask2Former (not implemented)"}, }, "inpaint": { # 兼容旧配置默认值:copy 表示不做补全 "copy": {"name": "Copy (no-op)", "params": []}, "sdxl_inpaint": { "name": "SDXL Inpaint", "params": [ {"id": "prompt", "label": "提示词", "optional": True}, ], }, "controlnet": { "name": "ControlNet Inpaint (canny)", "params": [ {"id": "prompt", "label": "提示词", "optional": True}, ], }, }, "animation": { "animatediff": { "name": "AnimateDiff (Text-to-Video)", "params": [ {"id": "prompt", "label": "提示词", "optional": False}, {"id": "negative_prompt", "label": "负向提示词", "optional": True}, {"id": "num_inference_steps", "label": "采样步数", "optional": True}, {"id": "guidance_scale", "label": "CFG Scale", "optional": True}, {"id": "width", "label": "宽度", "optional": True}, {"id": "height", "label": "高度", "optional": True}, {"id": "video_length", "label": "帧数", "optional": True}, {"id": "seed", "label": "随机种子", "optional": True}, ], }, }, } } # ----------------------------- # Depth # ----------------------------- _depth_predictor = None _depth_backend: DepthBackend | None = None def _ensure_depth_predictor() -> None: global _depth_predictor, _depth_backend if _depth_predictor is not None and _depth_backend is not None: return app_cfg = load_app_config() backend_str = get_depth_backend_from_app(app_cfg) try: backend = DepthBackend(backend_str) except Exception as e: raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e _depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend)) @app.post("/depth") def depth(req: DepthRequest): """ 计算深度并直接返回二进制 PNG(16-bit 灰度)。 约束: - 前端不传/不选模型;模型选择写死在后端 config.py - 成功:HTTP 200 + Content-Type: image/png - 失败:HTTP 500,detail 为错误信息 """ try: _ensure_depth_predictor() pil = _b64_to_pil_image(req.image_b64).convert("RGB") depth_arr = _depth_predictor(pil) # type: ignore[misc] png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr)) return Response(content=png_bytes, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ----------------------------- # Segment # ----------------------------- _seg_cache: Dict[str, Any] = {} def _get_seg_predictor(model_name: str): if model_name in _seg_cache: return _seg_cache[model_name] # 兼容旧默认 key if model_name == "mask2former_debug": model_name = "sam" if model_name == "sam": pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM)) _seg_cache[model_name] = pred return pred if model_name == "mask2former": pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER)) _seg_cache[model_name] = pred return pred raise ValueError(f"未知 segment model_name: {model_name}") @app.post("/segment") def segment(req: SegmentRequest) -> Dict[str, Any]: try: model_name = req.model_name or "sam" pil = _b64_to_pil_image(req.image_b64).convert("RGB") rgb = np.array(pil, dtype=np.uint8) predictor = _get_seg_predictor(model_name) label_map = predictor(rgb).astype(np.int32) out_dir = OUTPUT_DIR / "segment" out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / f"{model_name}_label.png" # 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大) Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path) return {"success": True, "label_path": str(out_path)} except Exception as e: return {"success": False, "error": str(e)} @app.post("/segment/sam_prompt") def segment_sam_prompt(req: SamPromptSegmentRequest) -> Dict[str, Any]: """ 交互式 SAM:裁剪图 + 点/框提示,返回掩膜外轮廓点列(裁剪像素坐标)。 """ try: pil = _b64_to_pil_image(req.image_b64).convert("RGB") rgb = np.array(pil, dtype=np.uint8) h, w = rgb.shape[0], rgb.shape[1] if req.overlay_b64: ov = _b64_to_pil_image(req.overlay_b64) if ov.size != (w, h): return { "success": False, "error": f"overlay 尺寸 {ov.size} 与 image {w}x{h} 不一致", "contour": [], } if len(req.point_coords) != len(req.point_labels): return { "success": False, "error": "point_coords 与 point_labels 长度不一致", "contour": [], } if len(req.point_coords) < 1: return {"success": False, "error": "至少需要一个提示点", "contour": []} pc = np.array(req.point_coords, dtype=np.float32) if pc.ndim != 2 or pc.shape[1] != 2: return {"success": False, "error": "point_coords 每项须为 [x,y]", "contour": []} pl = np.array(req.point_labels, dtype=np.int64) box = np.array(req.box_xyxy, dtype=np.float32) mask = run_sam_prompt(rgb, pc, pl, box_xyxy=box) if not np.any(mask): return {"success": False, "error": "SAM 未产生有效掩膜", "contour": []} contour = mask_to_contour_xy(mask, epsilon_px=2.0) if len(contour) < 3: return {"success": False, "error": "轮廓点数不足", "contour": []} return {"success": True, "contour": contour, "error": None} except Exception as e: return {"success": False, "error": str(e), "contour": []} # ----------------------------- # Inpaint # ----------------------------- _inpaint_cache: Dict[str, Any] = {} def _get_inpaint_predictor(model_name: str): if model_name in _inpaint_cache: return _inpaint_cache[model_name] if model_name == "copy": def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image: return image.convert("RGB") _inpaint_cache[model_name] = _copy return _copy if model_name == "sdxl_inpaint": pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT)) _inpaint_cache[model_name] = pred return pred if model_name == "controlnet": pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET)) _inpaint_cache[model_name] = pred return pred raise ValueError(f"未知 inpaint model_name: {model_name}") @app.post("/inpaint") def inpaint(req: InpaintRequest) -> Dict[str, Any]: try: model_name = req.model_name or "sdxl_inpaint" pil = _b64_to_pil_image(req.image_b64).convert("RGB") if req.mask_b64: mask = _b64_to_pil_image(req.mask_b64).convert("L") else: mask = _default_half_mask(pil) predictor = _get_inpaint_predictor(model_name) out = predictor( pil, mask, req.prompt or "", req.negative_prompt or "", strength=req.strength, max_side=req.max_side, ) out_dir = OUTPUT_DIR / "inpaint" out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / f"{model_name}_inpaint.png" out.save(out_path) # 兼容 Qt 前端:直接返回结果图,避免前端再去读取服务器磁盘路径 return { "success": True, "output_path": str(out_path), "output_image_b64": _pil_image_to_png_b64(out), } except Exception as e: return {"success": False, "error": str(e)} _animation_cache: Dict[str, Any] = {} def _get_animation_predictor(model_name: str): if model_name in _animation_cache: return _animation_cache[model_name] if model_name == "animatediff": pred, _ = build_animation_predictor( UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF) ) _animation_cache[model_name] = pred return pred raise ValueError(f"未知 animation model_name: {model_name}") @app.post("/animate") def animate(req: AnimateRequest) -> Dict[str, Any]: try: model_name = req.model_name or "animatediff" predictor = _get_animation_predictor(model_name) result_path = predictor( prompt=req.prompt, negative_prompt=req.negative_prompt or "", num_inference_steps=req.num_inference_steps, guidance_scale=req.guidance_scale, width=req.width, height=req.height, video_length=req.video_length, seed=req.seed, ) out_dir = OUTPUT_DIR / "animation" out_dir.mkdir(parents=True, exist_ok=True) ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") out_path = out_dir / f"{model_name}_{ts}.gif" shutil.copy2(result_path, out_path) return {"success": True, "output_path": str(out_path)} except Exception as e: return {"success": False, "error": str(e)} @app.get("/health") def health() -> Dict[str, str]: return {"status": "ok"} if __name__ == "__main__": import uvicorn port = int(os.environ.get("MODEL_SERVER_PORT", "8000")) uvicorn.run(app, host="0.0.0.0", port=port)