269 lines
9.8 KiB
Python
269 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
Unified animation model loading entry.
|
|
|
|
Current support:
|
|
- AnimateDiff (script-based invocation)
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from typing import Callable
|
|
|
|
from config_loader import (
|
|
load_app_config,
|
|
get_animatediff_root_from_app,
|
|
get_animatediff_pretrained_model_from_app,
|
|
get_animatediff_inference_config_from_app,
|
|
get_animatediff_motion_module_from_app,
|
|
get_animatediff_dreambooth_model_from_app,
|
|
get_animatediff_lora_model_from_app,
|
|
get_animatediff_lora_alpha_from_app,
|
|
get_animatediff_without_xformers_from_app,
|
|
)
|
|
|
|
|
|
class AnimationBackend(str, Enum):
|
|
ANIMATEDIFF = "animatediff"
|
|
|
|
|
|
@dataclass
|
|
class UnifiedAnimationConfig:
|
|
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
|
|
# Optional overrides. If None, values come from app config.
|
|
animate_diff_root: str | None = None
|
|
pretrained_model_path: str | None = None
|
|
inference_config: str | None = None
|
|
motion_module: str | None = None
|
|
dreambooth_model: str | None = None
|
|
lora_model: str | None = None
|
|
lora_alpha: float | None = None
|
|
without_xformers: bool | None = None
|
|
controlnet_path: str | None = None
|
|
controlnet_config: str | None = None
|
|
|
|
|
|
def _yaml_string(value: str) -> str:
|
|
return json.dumps(value, ensure_ascii=False)
|
|
|
|
|
|
def _resolve_root(root_cfg: str) -> Path:
|
|
root = Path(root_cfg)
|
|
if not root.is_absolute():
|
|
root = Path(__file__).resolve().parents[2] / root_cfg
|
|
return root.resolve()
|
|
|
|
|
|
def _make_animatediff_predictor(
|
|
cfg: UnifiedAnimationConfig,
|
|
) -> Callable[..., Path]:
|
|
app_cfg = load_app_config()
|
|
|
|
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
|
|
script_path = root / "scripts" / "animate.py"
|
|
samples_dir = root / "samples"
|
|
samples_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not script_path.is_file():
|
|
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
|
|
|
|
pretrained_model_path = (
|
|
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
|
|
)
|
|
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
|
|
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
|
|
dreambooth_model = cfg.dreambooth_model
|
|
if dreambooth_model is None:
|
|
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
|
|
lora_model = cfg.lora_model
|
|
if lora_model is None:
|
|
lora_model = get_animatediff_lora_model_from_app(app_cfg)
|
|
lora_alpha = cfg.lora_alpha
|
|
if lora_alpha is None:
|
|
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
|
|
without_xformers = cfg.without_xformers
|
|
if without_xformers is None:
|
|
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
|
|
|
|
def _predict(
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
num_inference_steps: int = 25,
|
|
guidance_scale: float = 8.0,
|
|
width: int = 512,
|
|
height: int = 512,
|
|
video_length: int = 16,
|
|
seed: int = -1,
|
|
control_image_path: str | None = None,
|
|
output_format: str = "gif",
|
|
) -> Path:
|
|
if output_format not in {"gif", "png_sequence"}:
|
|
raise ValueError("output_format must be 'gif' or 'png_sequence'")
|
|
prompt_value = prompt.strip()
|
|
if not prompt_value:
|
|
raise ValueError("prompt must not be empty")
|
|
|
|
negative_prompt_value = negative_prompt or ""
|
|
|
|
motion_module_line = (
|
|
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
|
|
)
|
|
dreambooth_line = (
|
|
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
|
|
)
|
|
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
|
|
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
|
|
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
|
|
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
|
|
control_image_line = ""
|
|
if control_image_path:
|
|
control_image = Path(control_image_path).expanduser().resolve()
|
|
if not control_image.is_file():
|
|
raise FileNotFoundError(f"control_image_path not found: {control_image}")
|
|
control_image_line = (
|
|
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
|
|
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
|
|
" controlnet_images:\n"
|
|
f' - {_yaml_string(str(control_image))}\n'
|
|
" controlnet_image_indexs:\n"
|
|
" - 0\n"
|
|
)
|
|
|
|
config_text = (
|
|
"- prompt:\n"
|
|
f" - {_yaml_string(prompt_value)}\n"
|
|
" n_prompt:\n"
|
|
f" - {_yaml_string(negative_prompt_value)}\n"
|
|
f" steps: {int(num_inference_steps)}\n"
|
|
f" guidance_scale: {float(guidance_scale)}\n"
|
|
f" W: {int(width)}\n"
|
|
f" H: {int(height)}\n"
|
|
f" L: {int(video_length)}\n"
|
|
" seed:\n"
|
|
f" - {int(seed)}\n"
|
|
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
|
|
)
|
|
|
|
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
|
|
cfg_file = tempfile.NamedTemporaryFile(
|
|
mode="w",
|
|
suffix=".yaml",
|
|
prefix="animatediff_cfg_",
|
|
dir=str(root),
|
|
delete=False,
|
|
encoding="utf-8",
|
|
)
|
|
cfg_file_path = Path(cfg_file.name)
|
|
try:
|
|
cfg_file.write(config_text)
|
|
cfg_file.flush()
|
|
cfg_file.close()
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(script_path),
|
|
"--pretrained-model-path",
|
|
pretrained_model_path,
|
|
"--inference-config",
|
|
inference_config,
|
|
"--config",
|
|
str(cfg_file_path),
|
|
"--L",
|
|
str(int(video_length)),
|
|
"--W",
|
|
str(int(width)),
|
|
"--H",
|
|
str(int(height)),
|
|
]
|
|
if without_xformers:
|
|
cmd.append("--without-xformers")
|
|
if output_format == "png_sequence":
|
|
cmd.append("--save-png-sequence")
|
|
|
|
env = dict(os.environ)
|
|
existing_pythonpath = env.get("PYTHONPATH", "")
|
|
root_pythonpath = str(root)
|
|
env["PYTHONPATH"] = (
|
|
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
|
|
)
|
|
|
|
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
|
|
return subprocess.run(
|
|
command,
|
|
cwd=str(root),
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
env=env,
|
|
)
|
|
|
|
try:
|
|
proc = _run_once(cmd)
|
|
except subprocess.CalledProcessError as first_error:
|
|
stderr_text = first_error.stderr or ""
|
|
should_retry_without_xformers = (
|
|
not without_xformers
|
|
and "--without-xformers" not in cmd
|
|
and (
|
|
"memory_efficient_attention" in stderr_text
|
|
or "AcceleratorError" in stderr_text
|
|
or "invalid configuration argument" in stderr_text
|
|
)
|
|
)
|
|
if not should_retry_without_xformers:
|
|
raise
|
|
|
|
retry_cmd = [*cmd, "--without-xformers"]
|
|
proc = _run_once(retry_cmd)
|
|
_ = proc
|
|
except subprocess.CalledProcessError as e:
|
|
raise RuntimeError(
|
|
"AnimateDiff inference failed.\n"
|
|
f"stdout:\n{e.stdout}\n"
|
|
f"stderr:\n{e.stderr}"
|
|
) from e
|
|
finally:
|
|
try:
|
|
cfg_file_path.unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
|
|
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
|
|
if not candidates:
|
|
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
|
|
if not candidates:
|
|
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
|
|
|
|
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
|
|
if output_format == "png_sequence":
|
|
frames_root = latest / "sample_frames"
|
|
if not frames_root.is_dir():
|
|
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
|
|
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
|
|
if not frame_dirs:
|
|
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
|
|
return frame_dirs[0].resolve()
|
|
return (latest / "sample.gif").resolve()
|
|
|
|
return _predict
|
|
|
|
|
|
def build_animation_predictor(
|
|
cfg: UnifiedAnimationConfig | None = None,
|
|
) -> tuple[Callable[..., Path], AnimationBackend]:
|
|
cfg = cfg or UnifiedAnimationConfig()
|
|
|
|
if cfg.backend == AnimationBackend.ANIMATEDIFF:
|
|
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
|
|
|
|
raise ValueError(f"Unsupported animation backend: {cfg.backend}")
|
|
|