initial commit
This commit is contained in:
268
python_server/model/Animation/animation_loader.py
Normal file
268
python_server/model/Animation/animation_loader.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Unified animation model loading entry.
|
||||
|
||||
Current support:
|
||||
- AnimateDiff (script-based invocation)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
from config_loader import (
|
||||
load_app_config,
|
||||
get_animatediff_root_from_app,
|
||||
get_animatediff_pretrained_model_from_app,
|
||||
get_animatediff_inference_config_from_app,
|
||||
get_animatediff_motion_module_from_app,
|
||||
get_animatediff_dreambooth_model_from_app,
|
||||
get_animatediff_lora_model_from_app,
|
||||
get_animatediff_lora_alpha_from_app,
|
||||
get_animatediff_without_xformers_from_app,
|
||||
)
|
||||
|
||||
|
||||
class AnimationBackend(str, Enum):
|
||||
ANIMATEDIFF = "animatediff"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedAnimationConfig:
|
||||
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
|
||||
# Optional overrides. If None, values come from app config.
|
||||
animate_diff_root: str | None = None
|
||||
pretrained_model_path: str | None = None
|
||||
inference_config: str | None = None
|
||||
motion_module: str | None = None
|
||||
dreambooth_model: str | None = None
|
||||
lora_model: str | None = None
|
||||
lora_alpha: float | None = None
|
||||
without_xformers: bool | None = None
|
||||
controlnet_path: str | None = None
|
||||
controlnet_config: str | None = None
|
||||
|
||||
|
||||
def _yaml_string(value: str) -> str:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _resolve_root(root_cfg: str) -> Path:
|
||||
root = Path(root_cfg)
|
||||
if not root.is_absolute():
|
||||
root = Path(__file__).resolve().parents[2] / root_cfg
|
||||
return root.resolve()
|
||||
|
||||
|
||||
def _make_animatediff_predictor(
|
||||
cfg: UnifiedAnimationConfig,
|
||||
) -> Callable[..., Path]:
|
||||
app_cfg = load_app_config()
|
||||
|
||||
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
|
||||
script_path = root / "scripts" / "animate.py"
|
||||
samples_dir = root / "samples"
|
||||
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not script_path.is_file():
|
||||
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
|
||||
|
||||
pretrained_model_path = (
|
||||
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
|
||||
)
|
||||
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
|
||||
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
|
||||
dreambooth_model = cfg.dreambooth_model
|
||||
if dreambooth_model is None:
|
||||
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
|
||||
lora_model = cfg.lora_model
|
||||
if lora_model is None:
|
||||
lora_model = get_animatediff_lora_model_from_app(app_cfg)
|
||||
lora_alpha = cfg.lora_alpha
|
||||
if lora_alpha is None:
|
||||
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
|
||||
without_xformers = cfg.without_xformers
|
||||
if without_xformers is None:
|
||||
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
|
||||
|
||||
def _predict(
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 25,
|
||||
guidance_scale: float = 8.0,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
video_length: int = 16,
|
||||
seed: int = -1,
|
||||
control_image_path: str | None = None,
|
||||
output_format: str = "gif",
|
||||
) -> Path:
|
||||
if output_format not in {"gif", "png_sequence"}:
|
||||
raise ValueError("output_format must be 'gif' or 'png_sequence'")
|
||||
prompt_value = prompt.strip()
|
||||
if not prompt_value:
|
||||
raise ValueError("prompt must not be empty")
|
||||
|
||||
negative_prompt_value = negative_prompt or ""
|
||||
|
||||
motion_module_line = (
|
||||
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
|
||||
)
|
||||
dreambooth_line = (
|
||||
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
|
||||
)
|
||||
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
|
||||
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
|
||||
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
|
||||
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
|
||||
control_image_line = ""
|
||||
if control_image_path:
|
||||
control_image = Path(control_image_path).expanduser().resolve()
|
||||
if not control_image.is_file():
|
||||
raise FileNotFoundError(f"control_image_path not found: {control_image}")
|
||||
control_image_line = (
|
||||
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
|
||||
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
|
||||
" controlnet_images:\n"
|
||||
f' - {_yaml_string(str(control_image))}\n'
|
||||
" controlnet_image_indexs:\n"
|
||||
" - 0\n"
|
||||
)
|
||||
|
||||
config_text = (
|
||||
"- prompt:\n"
|
||||
f" - {_yaml_string(prompt_value)}\n"
|
||||
" n_prompt:\n"
|
||||
f" - {_yaml_string(negative_prompt_value)}\n"
|
||||
f" steps: {int(num_inference_steps)}\n"
|
||||
f" guidance_scale: {float(guidance_scale)}\n"
|
||||
f" W: {int(width)}\n"
|
||||
f" H: {int(height)}\n"
|
||||
f" L: {int(video_length)}\n"
|
||||
" seed:\n"
|
||||
f" - {int(seed)}\n"
|
||||
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
|
||||
)
|
||||
|
||||
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
|
||||
cfg_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
prefix="animatediff_cfg_",
|
||||
dir=str(root),
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
cfg_file_path = Path(cfg_file.name)
|
||||
try:
|
||||
cfg_file.write(config_text)
|
||||
cfg_file.flush()
|
||||
cfg_file.close()
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--pretrained-model-path",
|
||||
pretrained_model_path,
|
||||
"--inference-config",
|
||||
inference_config,
|
||||
"--config",
|
||||
str(cfg_file_path),
|
||||
"--L",
|
||||
str(int(video_length)),
|
||||
"--W",
|
||||
str(int(width)),
|
||||
"--H",
|
||||
str(int(height)),
|
||||
]
|
||||
if without_xformers:
|
||||
cmd.append("--without-xformers")
|
||||
if output_format == "png_sequence":
|
||||
cmd.append("--save-png-sequence")
|
||||
|
||||
env = dict(os.environ)
|
||||
existing_pythonpath = env.get("PYTHONPATH", "")
|
||||
root_pythonpath = str(root)
|
||||
env["PYTHONPATH"] = (
|
||||
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
|
||||
)
|
||||
|
||||
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(
|
||||
command,
|
||||
cwd=str(root),
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
proc = _run_once(cmd)
|
||||
except subprocess.CalledProcessError as first_error:
|
||||
stderr_text = first_error.stderr or ""
|
||||
should_retry_without_xformers = (
|
||||
not without_xformers
|
||||
and "--without-xformers" not in cmd
|
||||
and (
|
||||
"memory_efficient_attention" in stderr_text
|
||||
or "AcceleratorError" in stderr_text
|
||||
or "invalid configuration argument" in stderr_text
|
||||
)
|
||||
)
|
||||
if not should_retry_without_xformers:
|
||||
raise
|
||||
|
||||
retry_cmd = [*cmd, "--without-xformers"]
|
||||
proc = _run_once(retry_cmd)
|
||||
_ = proc
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
"AnimateDiff inference failed.\n"
|
||||
f"stdout:\n{e.stdout}\n"
|
||||
f"stderr:\n{e.stderr}"
|
||||
) from e
|
||||
finally:
|
||||
try:
|
||||
cfg_file_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
|
||||
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
|
||||
if not candidates:
|
||||
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
|
||||
if not candidates:
|
||||
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
|
||||
|
||||
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
|
||||
if output_format == "png_sequence":
|
||||
frames_root = latest / "sample_frames"
|
||||
if not frames_root.is_dir():
|
||||
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
|
||||
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if not frame_dirs:
|
||||
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
|
||||
return frame_dirs[0].resolve()
|
||||
return (latest / "sample.gif").resolve()
|
||||
|
||||
return _predict
|
||||
|
||||
|
||||
def build_animation_predictor(
|
||||
cfg: UnifiedAnimationConfig | None = None,
|
||||
) -> tuple[Callable[..., Path], AnimationBackend]:
|
||||
cfg = cfg or UnifiedAnimationConfig()
|
||||
|
||||
if cfg.backend == AnimationBackend.ANIMATEDIFF:
|
||||
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
|
||||
|
||||
raise ValueError(f"Unsupported animation backend: {cfg.backend}")
|
||||
|
||||
Reference in New Issue
Block a user