initial commit

This commit is contained in:
2026-04-07 20:55:30 +08:00
commit 81d1fb7856
84 changed files with 11929 additions and 0 deletions

4
.clang-tidy Normal file
View File

@@ -0,0 +1,4 @@
Checks: >
clang-diagnostic-unused-variable,
clang-diagnostic-unused-parameter,
clang-diagnostic-unused-lambda-capture

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/build
/.venv
__pycache__/

20
.vscode/c_cpp_properties.json vendored Normal file
View File

@@ -0,0 +1,20 @@
{
"configurations": [
{
"name": "Linux",
"includePath": [
"${workspaceFolder}/**",
"/usr/include/x86_64-linux-gnu/qt6/QtGui",
"/usr/include/x86_64-linux-gnu/qt6",
"/usr/include/x86_64-linux-gnu/qt6/QtCore",
"/usr/include/x86_64-linux-gnu/qt6/QtWidgets"
],
"defines": [],
"compilerPath": "/usr/bin/clang++",
"cStandard": "c23",
"cppStandard": "c++26",
"intelliSenseMode": "linux-clang-x64"
}
],
"version": 4
}

7
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": []
}

5
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"C_Cpp.codeAnalysis.clangTidy.enabled": true,
"C_Cpp.codeAnalysis.clangTidy.useBuildPath": true,
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
}

46
CMakeLists.txt Normal file
View File

@@ -0,0 +1,46 @@
cmake_minimum_required(VERSION 3.16)
project(LandscapeInteractiveTool
VERSION 0.1.0
LANGUAGES CXX
)
set(CMAKE_CXX_STANDARD 26)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# 开启较严格的编译告警(包含未使用变量等)
if (MSVC)
add_compile_options(/W4 /permissive-)
else()
add_compile_options(
-Wall
-Wextra
-Wpedantic
-Wconversion
-Wsign-conversion
-Wunused
-Wunused-const-variable=2
-Wunused-parameter
-Wshadow
)
endif()
# 尝试优先使用 Qt6其次 Qt5
set(QT_REQUIRED_COMPONENTS Widgets Gui Core Network)
find_package(Qt6 COMPONENTS ${QT_REQUIRED_COMPONENTS} QUIET)
if(Qt6_FOUND)
message(STATUS "Configuring with Qt6")
set(QT_PACKAGE Qt6)
else()
find_package(Qt5 COMPONENTS ${QT_REQUIRED_COMPONENTS} REQUIRED)
message(STATUS "Configuring with Qt5")
set(QT_PACKAGE Qt5)
endif()
set(CMAKE_AUTOMOC ON)
set(CMAKE_AUTOUIC ON)
set(CMAKE_AUTORCC ON)
add_subdirectory(client)

10
README.md Normal file
View File

@@ -0,0 +1,10 @@
```
cmake -S . -B build
cmake --build build -j
./build/client/gui/landscape_tool
```
```
cmake --build build --target update_translations
```

4
client/CMakeLists.txt Normal file
View File

@@ -0,0 +1,4 @@
set(SRC_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
add_subdirectory(core)
add_subdirectory(gui)

24
client/README.md Normal file
View File

@@ -0,0 +1,24 @@
# ClientQt 桌面端)
## 目录结构(按模块)
**`core/`**(静态库 `core`include 根为 `client/core`
- `domain/` — 领域模型:`Project``Entity`
- `persistence/``PersistentBinaryObject`(统一二进制头与原子写)、`EntityPayloadBinary``.hfe` / 旧 `.anim`
- `workspace/` — 项目目录、索引 JSON、撤销栈`ProjectWorkspace`
- `depth/` — 假深度图生成:`DepthService`
- `animation/` — 关键帧采样Hold / 线性插值):`AnimationSampling`
**`gui/`**(可执行程序 `landscape_tool`,额外 include 根为 `client/gui`
- `app/` — 入口 `main.cpp`
- `main_window/` — 主窗口与时间轴等:`MainWindow`
- `editor/` — 编辑画布:`EditorCanvas`
- `dialogs/``AboutWindow``ImageCropDialog`
引用方式示例:`#include "core/workspace/ProjectWorkspace.h"`(以 `client/` 为根)、`#include "editor/EditorCanvas.h"`(以 `client/gui/` 为根)。
## 界面语言
界面文案为中文(无运行时语言切换)。

View File

@@ -0,0 +1,39 @@
# 模块domain、persistence、workspace、depth、animation时间采样
set(CORE_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(CORE_SOURCES
${CORE_ROOT}/domain/Project.cpp
${CORE_ROOT}/workspace/ProjectWorkspace.cpp
${CORE_ROOT}/persistence/PersistentBinaryObject.cpp
${CORE_ROOT}/persistence/EntityPayloadBinary.cpp
${CORE_ROOT}/animation/AnimationSampling.cpp
${CORE_ROOT}/depth/DepthService.cpp
${CORE_ROOT}/net/ModelServerClient.cpp
)
set(CORE_HEADERS
${CORE_ROOT}/domain/Project.h
${CORE_ROOT}/workspace/ProjectWorkspace.h
${CORE_ROOT}/persistence/PersistentBinaryObject.h
${CORE_ROOT}/persistence/EntityPayloadBinary.h
${CORE_ROOT}/animation/AnimationSampling.h
${CORE_ROOT}/depth/DepthService.h
${CORE_ROOT}/net/ModelServerClient.h
)
add_library(core STATIC
${CORE_SOURCES}
${CORE_HEADERS}
)
target_include_directories(core
PUBLIC
${CORE_ROOT}
)
target_link_libraries(core
PUBLIC
${QT_PACKAGE}::Core
${QT_PACKAGE}::Gui
${QT_PACKAGE}::Network
)

View File

@@ -0,0 +1,191 @@
#include "animation/AnimationSampling.h"
#include <algorithm>
namespace core {
namespace {
template <typename KeyT, typename FrameGetter>
void sortKeysByFrame(QVector<KeyT>& keys, FrameGetter getFrame) {
std::sort(keys.begin(), keys.end(), [&](const KeyT& a, const KeyT& b) { return getFrame(a) < getFrame(b); });
}
} // namespace
QPointF sampleLocation(const QVector<Project::Entity::KeyframeVec2>& keys,
int frame,
const QPointF& fallbackOrigin,
KeyInterpolation mode) {
QVector<Project::Entity::KeyframeVec2> sorted = keys;
sortKeysByFrame(sorted, [](const Project::Entity::KeyframeVec2& k) { return k.frame; });
if (sorted.isEmpty()) {
return fallbackOrigin;
}
if (mode == KeyInterpolation::Hold) {
QPointF out = fallbackOrigin;
int best = -1;
for (const auto& k : sorted) {
if (k.frame <= frame && k.frame >= best) {
best = k.frame;
out = k.value;
}
}
return out;
}
// Linear区间外夹持到端点中间在相邻关键帧间线性插值对 x、y 分别 lerp
const auto& first = sorted.front();
const auto& last = sorted.back();
if (frame <= first.frame) {
return first.value;
}
if (frame >= last.frame) {
return last.value;
}
for (int i = 0; i + 1 < sorted.size(); ++i) {
const int f0 = sorted[i].frame;
const int f1 = sorted[i + 1].frame;
if (frame < f0) {
continue;
}
if (frame <= f1) {
if (f1 == f0 || frame == f0) {
return sorted[i].value;
}
const double t = static_cast<double>(frame - f0) / static_cast<double>(f1 - f0);
const QPointF& a = sorted[i].value;
const QPointF& b = sorted[i + 1].value;
return QPointF(a.x() + (b.x() - a.x()) * t, a.y() + (b.y() - a.y()) * t);
}
}
return last.value;
}
double sampleDepthScale01(const QVector<Project::Entity::KeyframeFloat01>& keys,
int frame,
double fallback01,
KeyInterpolation mode) {
QVector<Project::Entity::KeyframeFloat01> sorted = keys;
sortKeysByFrame(sorted, [](const Project::Entity::KeyframeFloat01& k) { return k.frame; });
const double fb = std::clamp(fallback01, 0.0, 1.0);
if (sorted.isEmpty()) {
return fb;
}
if (mode == KeyInterpolation::Hold) {
double out = fb;
int best = -1;
for (const auto& k : sorted) {
if (k.frame <= frame && k.frame >= best) {
best = k.frame;
out = k.value;
}
}
return std::clamp(out, 0.0, 1.0);
}
const auto& first = sorted.front();
const auto& last = sorted.back();
if (frame <= first.frame) {
return std::clamp(first.value, 0.0, 1.0);
}
if (frame >= last.frame) {
return std::clamp(last.value, 0.0, 1.0);
}
for (int i = 0; i + 1 < sorted.size(); ++i) {
const int f0 = sorted[i].frame;
const int f1 = sorted[i + 1].frame;
if (frame < f0) {
continue;
}
if (frame <= f1) {
if (f1 == f0 || frame == f0) {
return std::clamp(sorted[i].value, 0.0, 1.0);
}
const double t = static_cast<double>(frame - f0) / static_cast<double>(f1 - f0);
const double a = sorted[i].value;
const double b = sorted[i + 1].value;
return std::clamp(a + (b - a) * t, 0.0, 1.0);
}
}
return std::clamp(last.value, 0.0, 1.0);
}
double sampleUserScale(const QVector<Project::Entity::KeyframeDouble>& keys,
int frame,
double fallback,
KeyInterpolation mode) {
QVector<Project::Entity::KeyframeDouble> sorted = keys;
sortKeysByFrame(sorted, [](const Project::Entity::KeyframeDouble& k) { return k.frame; });
const double fb = std::max(fallback, 1e-6);
if (sorted.isEmpty()) {
return fb;
}
if (mode == KeyInterpolation::Hold) {
double out = fb;
int best = -1;
for (const auto& k : sorted) {
if (k.frame <= frame && k.frame >= best) {
best = k.frame;
out = k.value;
}
}
return std::max(out, 1e-6);
}
const auto& first = sorted.front();
const auto& last = sorted.back();
if (frame <= first.frame) {
return std::max(first.value, 1e-6);
}
if (frame >= last.frame) {
return std::max(last.value, 1e-6);
}
for (int i = 0; i + 1 < sorted.size(); ++i) {
const int f0 = sorted[i].frame;
const int f1 = sorted[i + 1].frame;
if (frame < f0) {
continue;
}
if (frame <= f1) {
if (f1 == f0 || frame == f0) {
return std::max(sorted[i].value, 1e-6);
}
const double t = static_cast<double>(frame - f0) / static_cast<double>(f1 - f0);
const double a = sorted[i].value;
const double b = sorted[i + 1].value;
return std::max(a + (b - a) * t, 1e-6);
}
}
return std::max(last.value, 1e-6);
}
QString sampleImagePath(const QVector<Project::Entity::ImageFrame>& frames,
int frame,
const QString& fallbackPath) {
QVector<Project::Entity::ImageFrame> sorted = frames;
sortKeysByFrame(sorted, [](const Project::Entity::ImageFrame& k) { return k.frame; });
QString out = fallbackPath;
int best = -1;
for (const auto& k : sorted) {
if (k.frame <= frame && k.frame >= best && !k.imagePath.isEmpty()) {
best = k.frame;
out = k.imagePath;
}
}
return out;
}
} // namespace core

View File

@@ -0,0 +1,33 @@
#pragma once
#include "domain/Project.h"
#include <QPointF>
#include <QString>
#include <QVector>
namespace core {
enum class KeyInterpolation { Hold, Linear };
// 关键帧按 frame 排序后使用;内部会对副本排序以保证稳健。
[[nodiscard]] QPointF sampleLocation(const QVector<Project::Entity::KeyframeVec2>& keys,
int frame,
const QPointF& fallbackOrigin,
KeyInterpolation mode);
[[nodiscard]] double sampleDepthScale01(const QVector<Project::Entity::KeyframeFloat01>& keys,
int frame,
double fallback01,
KeyInterpolation mode);
[[nodiscard]] double sampleUserScale(const QVector<Project::Entity::KeyframeDouble>& keys,
int frame,
double fallback,
KeyInterpolation mode);
[[nodiscard]] QString sampleImagePath(const QVector<Project::Entity::ImageFrame>& frames,
int frame,
const QString& fallbackPath);
} // namespace core

View File

@@ -0,0 +1,58 @@
#include "depth/DepthService.h"
#include <algorithm>
namespace core {
QImage DepthService::computeFakeDepth(const QSize& size) {
if (size.isEmpty() || size.width() <= 0 || size.height() <= 0) {
return {};
}
QImage depth(size, QImage::Format_Grayscale8);
if (depth.isNull()) {
return {};
}
depth.fill(0);
return depth;
}
QImage DepthService::computeFakeDepthFromBackground(const QImage& background) {
if (background.isNull()) {
return {};
}
return computeFakeDepth(background.size());
}
QImage DepthService::depthToColormapOverlay(const QImage& depth8, int alpha) {
if (depth8.isNull()) {
return {};
}
const QImage src = (depth8.format() == QImage::Format_Grayscale8) ? depth8 : depth8.convertToFormat(QImage::Format_Grayscale8);
if (src.isNull()) {
return {};
}
const int a = std::clamp(alpha, 0, 255);
QImage out(src.size(), QImage::Format_ARGB32_Premultiplied);
if (out.isNull()) {
return {};
}
for (int y = 0; y < src.height(); ++y) {
const uchar* row = src.constScanLine(y);
QRgb* dst = reinterpret_cast<QRgb*>(out.scanLine(y));
for (int x = 0; x < src.width(); ++x) {
const int d = static_cast<int>(row[x]); // 0..255
// depth=0-> 蓝depth=255-> 红
const int r = d;
const int g = 0;
const int b = 255 - d;
dst[x] = qRgba(r, g, b, a);
}
}
return out;
}
} // namespace core

View File

@@ -0,0 +1,20 @@
#pragma once
#include <QImage>
#include <QSize>
namespace core {
class DepthService final {
public:
// 生成 8-bit 深度图0 最远255 最近。当前实现为全 0假深度
static QImage computeFakeDepth(const QSize& size);
static QImage computeFakeDepthFromBackground(const QImage& background);
// 把 8-bit 深度Grayscale8映射为伪彩色 ARGB32带 alpha用于叠加显示。
// 约定depth=0最远-> 蓝depth=255最近-> 红(线性插值)。
static QImage depthToColormapOverlay(const QImage& depth8, int alpha /*0-255*/);
};
} // namespace core

View File

@@ -0,0 +1,5 @@
#include "domain/Project.h"
namespace core {
} // namespace core

View File

@@ -0,0 +1,96 @@
#pragma once
#include <QString>
#include <QPointF>
#include <QVector>
#include <algorithm>
namespace core {
class Project {
public:
void setName(const QString& name) { m_name = name; }
const QString& name() const { return m_name; }
// 背景图在项目目录内的相对路径,例如 "assets/background.png"
void setBackgroundImagePath(const QString& relativePath) { m_backgroundImagePath = relativePath; }
const QString& backgroundImagePath() const { return m_backgroundImagePath; }
// 背景在视口/预览中的显隐(默认显示)
void setBackgroundVisible(bool on) { m_backgroundVisible = on; }
bool backgroundVisible() const { return m_backgroundVisible; }
void setDepthComputed(bool on) { m_depthComputed = on; }
bool depthComputed() const { return m_depthComputed; }
// 深度图在项目目录内的相对路径,例如 "assets/depth.png"
void setDepthMapPath(const QString& relativePath) { m_depthMapPath = relativePath; }
const QString& depthMapPath() const { return m_depthMapPath; }
void setFrameStart(int f) { m_frameStart = f; }
int frameStart() const { return m_frameStart; }
void setFrameEnd(int f) { m_frameEnd = f; }
int frameEnd() const { return m_frameEnd; }
void setFps(int fps) { m_fps = std::max(1, fps); }
int fps() const { return m_fps; }
struct Entity {
QString id;
QString displayName; // 显示名(空则界面用 id
bool visible = true; // Outliner 眼睛:默认显示
// 可移动实体形状:存为局部坐标(相对 originWorld
QVector<QPointF> polygonLocal;
// 从背景中抠洞的位置:固定在创建时的 world 坐标,不随实体移动
QVector<QPointF> cutoutPolygonWorld;
QPointF originWorld;
int depth = 0; // 0..255
QString imagePath; // 相对路径,例如 "assets/entities/entity-1.png"
QPointF imageTopLeftWorld; // 贴图左上角 world 坐标
// 人为整体缩放,与深度驱动的距离缩放相乘(画布中 visualScale = distanceScale * userScale
double userScale = 1.0;
struct KeyframeVec2 {
int frame = 0;
QPointF value;
};
struct KeyframeFloat01 {
int frame = 0;
double value = 0.5; // 0..1,默认 0.5 -> scale=1.00.5..1.5 映射)
};
struct KeyframeDouble {
int frame = 0;
double value = 1.0;
};
struct ImageFrame {
int frame = 0;
QString imagePath; // 相对路径
};
// v2project.json 仅存 id + payload几何与动画在 entityPayloadPath.hfe中。
QString entityPayloadPath; // 例如 "assets/entities/entity-1.hfe"
// 仅打开 v1 项目时由 JSON 的 animationBundle 填入,用于合并旧 .anim保存 v2 前应为空。
QString legacyAnimSidecarPath;
QVector<KeyframeVec2> locationKeys;
QVector<KeyframeFloat01> depthScaleKeys;
QVector<KeyframeDouble> userScaleKeys;
QVector<ImageFrame> imageFrames;
};
void setEntities(const QVector<Entity>& entities) { m_entities = entities; }
const QVector<Entity>& entities() const { return m_entities; }
private:
QString m_name;
QString m_backgroundImagePath;
bool m_backgroundVisible = true;
bool m_depthComputed = false;
QString m_depthMapPath;
int m_frameStart = 0;
int m_frameEnd = 600;
int m_fps = 60;
QVector<Entity> m_entities;
};
} // namespace core

View File

@@ -0,0 +1,139 @@
#include "net/ModelServerClient.h"
#include <QEventLoop>
#include <QJsonDocument>
#include <QJsonObject>
#include <QNetworkAccessManager>
#include <QNetworkReply>
#include <QNetworkRequest>
#include <QTimer>
#include <QUrl>
namespace core {
ModelServerClient::ModelServerClient(QObject* parent)
: QObject(parent)
, m_nam(new QNetworkAccessManager(this)) {
}
void ModelServerClient::setBaseUrl(const QUrl& baseUrl) {
m_baseUrl = baseUrl;
}
QUrl ModelServerClient::baseUrl() const {
return m_baseUrl;
}
QNetworkReply* ModelServerClient::computeDepthPng8Async(const QByteArray& imageBytes, QString* outImmediateError) {
if (outImmediateError) {
outImmediateError->clear();
}
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
if (outImmediateError) *outImmediateError = QStringLiteral("后端地址无效。");
return nullptr;
}
if (imageBytes.isEmpty()) {
if (outImmediateError) *outImmediateError = QStringLiteral("输入图像为空。");
return nullptr;
}
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/depth")));
QNetworkRequest req(url);
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
const QByteArray imageB64 = imageBytes.toBase64();
const QJsonObject payload{
{QStringLiteral("image_b64"), QString::fromLatin1(imageB64)},
};
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
return m_nam->post(req, body);
}
bool ModelServerClient::computeDepthPng8(
const QByteArray& imageBytes,
QByteArray& outPngBytes,
QString& outError,
int timeoutMs
) {
outPngBytes.clear();
outError.clear();
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
outError = QStringLiteral("后端地址无效。");
return false;
}
if (imageBytes.isEmpty()) {
outError = QStringLiteral("输入图像为空。");
return false;
}
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/depth")));
QNetworkRequest req(url);
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
const QByteArray imageB64 = imageBytes.toBase64();
const QJsonObject payload{
{QStringLiteral("image_b64"), QString::fromLatin1(imageB64)},
};
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
QNetworkReply* reply = m_nam->post(req, body);
if (!reply) {
outError = QStringLiteral("创建网络请求失败。");
return false;
}
QEventLoop loop;
QTimer timer;
timer.setSingleShot(true);
const int t = (timeoutMs <= 0) ? 30000 : timeoutMs;
QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
QObject::connect(&timer, &QTimer::timeout, &loop, &QEventLoop::quit);
timer.start(t);
loop.exec();
if (timer.isActive() == false && reply->isFinished() == false) {
reply->abort();
reply->deleteLater();
outError = QStringLiteral("请求超时(%1ms").arg(t);
return false;
}
const int httpStatus = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
const QByteArray raw = reply->readAll();
const auto netErr = reply->error();
const QString netErrStr = reply->errorString();
reply->deleteLater();
if (netErr != QNetworkReply::NoError) {
outError = QStringLiteral("网络错误:%1").arg(netErrStr);
return false;
}
if (httpStatus != 200) {
// FastAPI HTTPException 默认返回 {"detail": "..."}
QString detail;
const QJsonDocument jd = QJsonDocument::fromJson(raw);
if (jd.isObject()) {
const auto obj = jd.object();
detail = obj.value(QStringLiteral("detail")).toString();
}
outError = detail.isEmpty()
? QStringLiteral("后端返回HTTP %1。").arg(httpStatus)
: QStringLiteral("后端错误HTTP %1%2").arg(httpStatus).arg(detail);
return false;
}
if (raw.isEmpty()) {
outError = QStringLiteral("后端返回空数据。");
return false;
}
outPngBytes = raw;
return true;
}
} // namespace core

View File

@@ -0,0 +1,36 @@
#pragma once
#include <QByteArray>
#include <QObject>
#include <QString>
#include <QUrl>
#include <QNetworkReply>
class QNetworkAccessManager;
class QUrl;
class QNetworkReply;
namespace core {
class ModelServerClient final : public QObject {
Q_OBJECT
public:
explicit ModelServerClient(QObject* parent = nullptr);
void setBaseUrl(const QUrl& baseUrl);
QUrl baseUrl() const;
// 同步调用:向后端 POST /depth 发送背景图,成功返回 PNG(8-bit 灰度) 的二进制数据。
// timeoutMs<=0 表示使用默认超时30s
bool computeDepthPng8(const QByteArray& imageBytes, QByteArray& outPngBytes, QString& outError, int timeoutMs = 30000);
// 异步调用:发起 POST /depth返回 reply由 Qt 管理生命周期;调用方负责连接 finished/错误处理)。
// 返回 nullptr 表示参数/URL 非法导致无法发起。
QNetworkReply* computeDepthPng8Async(const QByteArray& imageBytes, QString* outImmediateError = nullptr);
private:
QNetworkAccessManager* m_nam = nullptr;
QUrl m_baseUrl;
};
} // namespace core

View File

@@ -0,0 +1,324 @@
#include "persistence/EntityPayloadBinary.h"
#include "persistence/PersistentBinaryObject.h"
#include "domain/Project.h"
#include <QDataStream>
#include <QFile>
#include <QtGlobal>
#include <algorithm>
namespace core {
namespace {
void sortByFrame(QVector<Project::Entity::KeyframeVec2>& v) {
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
}
void sortByFrame(QVector<Project::Entity::KeyframeFloat01>& v) {
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
}
void sortByFrame(QVector<Project::Entity::KeyframeDouble>& v) {
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
}
void sortByFrame(QVector<Project::Entity::ImageFrame>& v) {
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
}
bool readAnimationBlock(QDataStream& ds, Project::Entity& out, bool hasUserScaleKeys) {
out.locationKeys.clear();
out.depthScaleKeys.clear();
out.userScaleKeys.clear();
out.imageFrames.clear();
qint32 nLoc = 0;
ds >> nLoc;
if (ds.status() != QDataStream::Ok || nLoc < 0 || nLoc > 1000000) {
return false;
}
out.locationKeys.reserve(nLoc);
for (qint32 i = 0; i < nLoc; ++i) {
qint32 frame = 0;
double x = 0.0;
double y = 0.0;
ds >> frame >> x >> y;
if (ds.status() != QDataStream::Ok) {
return false;
}
out.locationKeys.push_back(Project::Entity::KeyframeVec2{frame, QPointF(x, y)});
}
qint32 nDepth = 0;
ds >> nDepth;
if (ds.status() != QDataStream::Ok || nDepth < 0 || nDepth > 1000000) {
return false;
}
out.depthScaleKeys.reserve(nDepth);
for (qint32 i = 0; i < nDepth; ++i) {
qint32 frame = 0;
double v = 0.5;
ds >> frame >> v;
if (ds.status() != QDataStream::Ok) {
return false;
}
out.depthScaleKeys.push_back(Project::Entity::KeyframeFloat01{frame, v});
}
if (hasUserScaleKeys) {
qint32 nUser = 0;
ds >> nUser;
if (ds.status() != QDataStream::Ok || nUser < 0 || nUser > 1000000) {
return false;
}
out.userScaleKeys.reserve(nUser);
for (qint32 i = 0; i < nUser; ++i) {
qint32 frame = 0;
double v = 1.0;
ds >> frame >> v;
if (ds.status() != QDataStream::Ok) {
return false;
}
out.userScaleKeys.push_back(Project::Entity::KeyframeDouble{frame, v});
}
}
qint32 nImg = 0;
ds >> nImg;
if (ds.status() != QDataStream::Ok || nImg < 0 || nImg > 1000000) {
return false;
}
out.imageFrames.reserve(nImg);
for (qint32 i = 0; i < nImg; ++i) {
qint32 frame = 0;
QString path;
ds >> frame >> path;
if (ds.status() != QDataStream::Ok) {
return false;
}
if (!path.isEmpty()) {
out.imageFrames.push_back(Project::Entity::ImageFrame{frame, path});
}
}
sortByFrame(out.locationKeys);
sortByFrame(out.depthScaleKeys);
sortByFrame(out.userScaleKeys);
sortByFrame(out.imageFrames);
return true;
}
void writeAnimationBlock(QDataStream& ds, const Project::Entity& entity, bool writeUserScaleKeys) {
ds << qint32(entity.locationKeys.size());
for (const auto& k : entity.locationKeys) {
ds << qint32(k.frame) << double(k.value.x()) << double(k.value.y());
}
ds << qint32(entity.depthScaleKeys.size());
for (const auto& k : entity.depthScaleKeys) {
ds << qint32(k.frame) << double(k.value);
}
if (writeUserScaleKeys) {
ds << qint32(entity.userScaleKeys.size());
for (const auto& k : entity.userScaleKeys) {
ds << qint32(k.frame) << double(k.value);
}
}
ds << qint32(entity.imageFrames.size());
for (const auto& k : entity.imageFrames) {
ds << qint32(k.frame) << k.imagePath;
}
}
bool readEntityPayloadV1(QDataStream& ds, Project::Entity& tmp, bool hasUserScaleKeys) {
ds >> tmp.id;
qint32 depth = 0;
ds >> depth;
tmp.depth = static_cast<int>(depth);
ds >> tmp.imagePath;
double ox = 0.0;
double oy = 0.0;
double itlx = 0.0;
double itly = 0.0;
ds >> ox >> oy >> itlx >> itly;
tmp.originWorld = QPointF(ox, oy);
tmp.imageTopLeftWorld = QPointF(itlx, itly);
qint32 nLocal = 0;
ds >> nLocal;
if (ds.status() != QDataStream::Ok || nLocal < 0 || nLocal > 1000000) {
return false;
}
tmp.polygonLocal.reserve(nLocal);
for (qint32 i = 0; i < nLocal; ++i) {
double x = 0.0;
double y = 0.0;
ds >> x >> y;
if (ds.status() != QDataStream::Ok) {
return false;
}
tmp.polygonLocal.push_back(QPointF(x, y));
}
qint32 nCut = 0;
ds >> nCut;
if (ds.status() != QDataStream::Ok || nCut < 0 || nCut > 1000000) {
return false;
}
tmp.cutoutPolygonWorld.reserve(nCut);
for (qint32 i = 0; i < nCut; ++i) {
double x = 0.0;
double y = 0.0;
ds >> x >> y;
if (ds.status() != QDataStream::Ok) {
return false;
}
tmp.cutoutPolygonWorld.push_back(QPointF(x, y));
}
if (!readAnimationBlock(ds, tmp, hasUserScaleKeys)) {
return false;
}
if (tmp.id.isEmpty() || tmp.polygonLocal.isEmpty()) {
return false;
}
return true;
}
class EntityBinaryRecord final : public PersistentBinaryObject {
public:
explicit EntityBinaryRecord(const Project::Entity& e) : m_src(&e), m_dst(nullptr) {}
explicit EntityBinaryRecord(Project::Entity& e) : m_src(nullptr), m_dst(&e) {}
quint32 recordMagic() const override { return EntityPayloadBinary::kMagicPayload; }
quint32 recordFormatVersion() const override { return EntityPayloadBinary::kPayloadVersion; }
void writeBody(QDataStream& ds) const override {
Q_ASSERT(m_src != nullptr);
const Project::Entity& entity = *m_src;
ds << entity.id;
ds << qint32(entity.depth);
ds << entity.imagePath;
ds << double(entity.originWorld.x()) << double(entity.originWorld.y());
ds << double(entity.imageTopLeftWorld.x()) << double(entity.imageTopLeftWorld.y());
ds << qint32(entity.polygonLocal.size());
for (const auto& pt : entity.polygonLocal) {
ds << double(pt.x()) << double(pt.y());
}
ds << qint32(entity.cutoutPolygonWorld.size());
for (const auto& pt : entity.cutoutPolygonWorld) {
ds << double(pt.x()) << double(pt.y());
}
writeAnimationBlock(ds, entity, true);
ds << entity.displayName << double(entity.userScale);
}
bool readBody(QDataStream& ds) override {
Q_ASSERT(m_dst != nullptr);
Project::Entity tmp;
if (!readEntityPayloadV1(ds, tmp, true)) {
return false;
}
QString dn;
double us = 1.0;
ds >> dn >> us;
if (ds.status() != QDataStream::Ok) {
return false;
}
tmp.displayName = dn;
tmp.userScale = std::clamp(us, 1e-3, 1e3);
*m_dst = std::move(tmp);
return true;
}
private:
const Project::Entity* m_src;
Project::Entity* m_dst;
};
class LegacyAnimSidecarRecord final : public PersistentBinaryObject {
public:
explicit LegacyAnimSidecarRecord(Project::Entity& e) : m_entity(&e) {}
quint32 recordMagic() const override { return EntityPayloadBinary::kMagicLegacyAnim; }
quint32 recordFormatVersion() const override { return EntityPayloadBinary::kLegacyAnimVersion; }
void writeBody(QDataStream& ds) const override { Q_UNUSED(ds); }
bool readBody(QDataStream& ds) override {
Project::Entity tmp = *m_entity;
if (!readAnimationBlock(ds, tmp, false)) {
return false;
}
m_entity->locationKeys = std::move(tmp.locationKeys);
m_entity->depthScaleKeys = std::move(tmp.depthScaleKeys);
m_entity->userScaleKeys = std::move(tmp.userScaleKeys);
m_entity->imageFrames = std::move(tmp.imageFrames);
return true;
}
private:
Project::Entity* m_entity;
};
} // namespace
bool EntityPayloadBinary::save(const QString& absolutePath, const Project::Entity& entity) {
if (absolutePath.isEmpty() || entity.id.isEmpty()) {
return false;
}
return EntityBinaryRecord(entity).saveToFile(absolutePath);
}
bool EntityPayloadBinary::load(const QString& absolutePath, Project::Entity& entity) {
QFile f(absolutePath);
if (!f.open(QIODevice::ReadOnly)) {
return false;
}
QDataStream ds(&f);
ds.setVersion(QDataStream::Qt_5_15);
quint32 magic = 0;
quint32 ver = 0;
ds >> magic >> ver;
if (ds.status() != QDataStream::Ok || magic != kMagicPayload) {
return false;
}
if (ver != 1 && ver != 2 && ver != 3) {
return false;
}
Project::Entity tmp;
if (!readEntityPayloadV1(ds, tmp, ver >= 3)) {
return false;
}
if (ver >= 2) {
QString dn;
double us = 1.0;
ds >> dn >> us;
if (ds.status() != QDataStream::Ok) {
return false;
}
tmp.displayName = dn;
tmp.userScale = std::clamp(us, 1e-3, 1e3);
} else {
tmp.displayName.clear();
tmp.userScale = 1.0;
}
entity = std::move(tmp);
return true;
}
bool EntityPayloadBinary::loadLegacyAnimFile(const QString& absolutePath, Project::Entity& entity) {
return LegacyAnimSidecarRecord(entity).loadFromFile(absolutePath);
}
} // namespace core

View File

@@ -0,0 +1,30 @@
#pragma once
#include "domain/Project.h"
#include <QString>
namespace core {
// 实体完整数据(几何 + 贴图路径 + 动画轨道)的二进制格式,与 project.json v2 的 payload 字段对应。
// 贴图 PNG 仍单独存放在 assets/entities/,本文件不嵌入像素。
// 具体读写通过继承 PersistentBinaryObject 的适配器类完成(见 EntityPayloadBinary.cpp
class EntityPayloadBinary {
public:
static constexpr quint32 kMagicPayload = 0x48464550; // 'HFEP'
static constexpr quint32 kPayloadVersion = 3; // v3追加 userScaleKeys动画轨道
// 旧版独立动画文件(仍用于打开 v1 项目时合并)
static constexpr quint32 kMagicLegacyAnim = 0x48465441; // 'HFTA'
static constexpr quint32 kLegacyAnimVersion = 1;
static bool save(const QString& absolutePath, const Project::Entity& entity);
// 读入后覆盖 entity 中除调用方已校验外的字段;失败时尽量保持 entity 不变。
static bool load(const QString& absolutePath, Project::Entity& entity);
// 仅读取旧 .animHFTA写入 entity 的三条动画轨道。
static bool loadLegacyAnimFile(const QString& absolutePath, Project::Entity& entity);
};
} // namespace core

View File

@@ -0,0 +1,57 @@
#include "persistence/PersistentBinaryObject.h"
#include <QDataStream>
#include <QDir>
#include <QFile>
#include <QFileInfo>
namespace core {
bool PersistentBinaryObject::saveToFile(const QString& absolutePath) const {
if (absolutePath.isEmpty()) {
return false;
}
const auto parent = QFileInfo(absolutePath).absolutePath();
if (!QFileInfo(parent).exists()) {
QDir().mkpath(parent);
}
const QString tmpPath = absolutePath + QStringLiteral(".tmp");
QFile f(tmpPath);
if (!f.open(QIODevice::WriteOnly | QIODevice::Truncate)) {
return false;
}
QDataStream ds(&f);
ds.setVersion(QDataStream::Qt_5_15);
ds << quint32(recordMagic()) << quint32(recordFormatVersion());
writeBody(ds);
f.close();
if (f.error() != QFile::NoError) {
QFile::remove(tmpPath);
return false;
}
QFile::remove(absolutePath);
return QFile::rename(tmpPath, absolutePath);
}
bool PersistentBinaryObject::loadFromFile(const QString& absolutePath) {
QFile f(absolutePath);
if (!f.open(QIODevice::ReadOnly)) {
return false;
}
QDataStream ds(&f);
ds.setVersion(QDataStream::Qt_5_15);
quint32 magic = 0;
quint32 version = 0;
ds >> magic >> version;
if (ds.status() != QDataStream::Ok || magic != recordMagic() || version != recordFormatVersion()) {
return false;
}
return readBody(ds);
}
} // namespace core

View File

@@ -0,0 +1,27 @@
#pragma once
#include <QString>
class QDataStream;
namespace core {
// 二进制记录的统一持久化基类:魔数/版本、QDataStream 版本、.tmp 原子替换、父目录创建。
//
// 领域类型(如 Project::Entity应保持为可拷贝的值类型不要继承本类为每种存储格式写一个
// 小的适配器类(如 EntityBinaryRecord继承本类并实现 writeBody/readBody 即可。
class PersistentBinaryObject {
public:
virtual ~PersistentBinaryObject() = default;
[[nodiscard]] bool saveToFile(const QString& absolutePath) const;
[[nodiscard]] bool loadFromFile(const QString& absolutePath);
protected:
virtual quint32 recordMagic() const = 0;
virtual quint32 recordFormatVersion() const = 0;
virtual void writeBody(QDataStream& ds) const = 0;
virtual bool readBody(QDataStream& ds) = 0;
};
} // namespace core

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,148 @@
#pragma once
#include "domain/Project.h"
#include <QImage>
#include <QJsonObject>
#include <QRect>
#include <QString>
#include <QStringList>
#include <QVector>
namespace core {
class ProjectWorkspace {
public:
static constexpr const char* kProjectIndexFileName = "project.json";
static constexpr const char* kAssetsDirName = "assets";
// 写入 project.json 的 version 字段;仍可读 version 1内嵌实体 + 可选 .anim
static constexpr int kProjectIndexFormatVersion = 2;
ProjectWorkspace() = default;
// 新建项目:
// - 传入的 parentDir 是“父目录”(你在文件对话框中选择的目录)
// - 会在 parentDir 下创建一个新的项目目录(默认使用项目名做文件夹名;若重名会自动加后缀)
// - 项目结构为v2
// <projectDir>/project.json (索引:背景/深度路径 + 实体 id 与 .hfe 路径)
// <projectDir>/assets/background.png
// <projectDir>/assets/entities/*.png / *.hfe
bool createNew(const QString& parentDir, const QString& name, const QString& backgroundImageSourcePath);
bool createNew(const QString& parentDir, const QString& name, const QString& backgroundImageSourcePath,
const QRect& cropRectInSourceImage);
bool openExisting(const QString& projectDir);
void close();
bool isOpen() const { return !m_projectDir.isEmpty(); }
const QString& projectDir() const { return m_projectDir; }
QString indexFilePath() const;
QString assetsDirPath() const;
bool hasBackground() const { return !m_project.backgroundImagePath().isEmpty(); }
QString backgroundAbsolutePath() const;
bool backgroundVisible() const { return m_project.backgroundVisible(); }
bool setBackgroundVisible(bool on);
bool hasDepth() const;
QString depthAbsolutePath() const;
// 写入 project.json 的 name 字段(可 undo
bool setProjectTitle(const QString& title);
Project& project() { return m_project; }
const Project& project() const { return m_project; }
// 历史操作(最多 30 步),类似 Blender维护 undo/redo 栈
bool canUndo() const;
bool canRedo() const;
bool undo();
bool redo();
QStringList historyLabelsNewestFirst() const;
// 追加一次“导入并设置背景图”操作:把图片拷贝进 assets/,并作为背景写入项目(会进入历史)。
bool importBackgroundImage(const QString& backgroundImageSourcePath);
bool importBackgroundImage(const QString& backgroundImageSourcePath, const QRect& cropRectInSourceImage);
// 计算并写入假深度图assets/depth.png同时更新 project.jsondepthComputed/depthMapPath
bool computeFakeDepthForProject();
// 从后端计算深度并落盘assets/depth.png同时更新 project.jsondepthComputed/depthMapPath
// - serverBaseUrl 为空时:优先读环境变量 MODEL_SERVER_URL否则默认 http://127.0.0.1:8000
// - outError 可选:返回失败原因
bool computeDepthForProjectFromServer(const QString& serverBaseUrl, QString* outError = nullptr, int timeoutMs = 30000);
// 直接保存深度图PNG bytes到 assets/depth.png并更新 project.json。
bool saveDepthMapPngBytes(const QByteArray& pngBytes, QString* outError = nullptr);
const QVector<Project::Entity>& entities() const { return m_project.entities(); }
bool addEntity(const Project::Entity& entity, const QImage& image);
bool setEntityVisible(const QString& id, bool on);
bool setEntityDisplayName(const QString& id, const QString& displayName);
bool setEntityUserScale(const QString& id, double userScale);
// 将多边形质心平移到 targetCentroidWorld整体平移sTotal 须与画布一致
bool moveEntityCentroidTo(const QString& id, int frame, const QPointF& targetCentroidWorld, double sTotal,
bool autoKeyLocation);
// 在保持外形不变的前提下移动枢轴点sTotal 须与画布一致(距离缩放×整体缩放)
bool reanchorEntityPivot(const QString& id, int frame, const QPointF& newPivotWorld, double sTotal);
bool reorderEntitiesById(const QStringList& idsInOrder);
// currentFrame自动关键帧时写入位置曲线autoKeyLocation 为 false 时忽略。
bool moveEntityBy(const QString& id, const QPointF& delta, int currentFrame, bool autoKeyLocation);
bool setEntityLocationKey(const QString& id, int frame, const QPointF& originWorld);
bool setEntityDepthScaleKey(const QString& id, int frame, double value01);
bool setEntityUserScaleKey(const QString& id, int frame, double userScale);
bool setEntityImageFrame(const QString& id, int frame, const QImage& image, QString* outRelPath = nullptr);
bool removeEntityLocationKey(const QString& id, int frame);
bool removeEntityDepthScaleKey(const QString& id, int frame);
bool removeEntityUserScaleKey(const QString& id, int frame);
bool removeEntityImageFrame(const QString& id, int frame);
private:
bool writeIndexJson();
bool readIndexJson(const QString& indexPath);
bool syncEntityPayloadsToDisk();
bool hydrateEntityPayloadsFromDisk();
void loadV1LegacyAnimationSidecars();
static QJsonObject projectToJson(const Project& project);
static bool projectFromJson(const QJsonObject& root, Project& outProject, int* outFileVersion);
static QString asRelativeUnderProject(const QString& relativePath);
static QString fileSuffixWithDot(const QString& path);
static QString asOptionalRelativeUnderProject(const QString& relativePath);
static QJsonObject entityToJson(const Project::Entity& e);
static bool entityFromJsonV1(const QJsonObject& o, Project::Entity& out);
static bool entityStubFromJsonV2(const QJsonObject& o, Project::Entity& out);
struct Operation {
enum class Type { ImportBackground, SetEntities, SetProjectTitle };
Type type {Type::ImportBackground};
QString label;
QString beforeBackgroundPath;
QString afterBackgroundPath;
QVector<Project::Entity> beforeEntities;
QVector<Project::Entity> afterEntities;
QString beforeProjectTitle;
QString afterProjectTitle;
};
static constexpr int kMaxHistorySteps = 30;
void pushOperation(const Operation& op);
bool applyBackgroundPath(const QString& relativePath, bool recordHistory, const QString& label);
bool applyEntities(const QVector<Project::Entity>& entities, bool recordHistory, const QString& label);
QString copyIntoAssetsAsBackground(const QString& sourceFilePath, const QRect& cropRectInSourceImage);
bool writeDepthMap(const QImage& depth8);
bool writeDepthMapBytes(const QByteArray& pngBytes);
QString ensureEntitiesDir() const;
bool writeEntityImage(const QString& entityId, const QImage& image, QString& outRelPath);
bool writeEntityFrameImage(const QString& entityId, int frame, const QImage& image, QString& outRelPath);
private:
QString m_projectDir;
Project m_project;
QVector<Operation> m_undoStack;
QVector<Operation> m_redoStack;
};
} // namespace core

62
client/gui/CMakeLists.txt Normal file
View File

@@ -0,0 +1,62 @@
# 模块app入口、main_window主窗口与时间轴等、editor画布、dialogs裁剪/关于)
set(GUI_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(GUI_SOURCES
${GUI_ROOT}/app/main.cpp
${GUI_ROOT}/main_window/MainWindow.cpp
${GUI_ROOT}/main_window/RecentProjectHistory.cpp
${GUI_ROOT}/dialogs/AboutWindow.cpp
${GUI_ROOT}/dialogs/ImageCropDialog.cpp
${GUI_ROOT}/dialogs/FrameAnimationDialog.cpp
${GUI_ROOT}/dialogs/CancelableTaskDialog.cpp
${GUI_ROOT}/editor/EditorCanvas.cpp
${GUI_ROOT}/params/ParamControls.cpp
${GUI_ROOT}/props/BackgroundPropertySection.cpp
${GUI_ROOT}/props/EntityPropertySection.cpp
${GUI_ROOT}/timeline/TimelineWidget.cpp
)
set(GUI_HEADERS
${GUI_ROOT}/main_window/MainWindow.h
${GUI_ROOT}/main_window/RecentProjectHistory.h
${GUI_ROOT}/dialogs/AboutWindow.h
${GUI_ROOT}/dialogs/ImageCropDialog.h
${GUI_ROOT}/dialogs/FrameAnimationDialog.h
${GUI_ROOT}/dialogs/CancelableTaskDialog.h
${GUI_ROOT}/editor/EditorCanvas.h
${GUI_ROOT}/params/ParamControls.h
${GUI_ROOT}/props/BackgroundPropertySection.h
${GUI_ROOT}/props/EntityPropertySection.h
${GUI_ROOT}/props/PropertySectionWidget.h
${GUI_ROOT}/timeline/TimelineWidget.h
)
if(QT_PACKAGE STREQUAL "Qt6")
qt_add_executable(LandscapeInteractiveToolApp
${GUI_SOURCES}
${GUI_HEADERS}
)
else()
add_executable(LandscapeInteractiveToolApp
${GUI_SOURCES}
${GUI_HEADERS}
)
endif()
target_include_directories(LandscapeInteractiveToolApp
PRIVATE
${SRC_ROOT}
${GUI_ROOT}
)
target_link_libraries(LandscapeInteractiveToolApp
PRIVATE
${QT_PACKAGE}::Core
${QT_PACKAGE}::Gui
${QT_PACKAGE}::Widgets
core
)
set_target_properties(LandscapeInteractiveToolApp PROPERTIES
OUTPUT_NAME "landscape_tool"
)

13
client/gui/app/main.cpp Normal file
View File

@@ -0,0 +1,13 @@
#include "main_window/MainWindow.h"
#include <QApplication>
int main(int argc, char* argv[]) {
QApplication app(argc, argv);
app.setApplicationName(QStringLiteral("landscape tool"));
MainWindow window;
window.show();
return app.exec();
}

View File

@@ -0,0 +1,66 @@
#include "dialogs/AboutWindow.h"
#include <QVBoxLayout>
#include <QLabel>
#include <QPushButton>
#include <QHBoxLayout>
#include <QDesktopServices>
#include <QUrl>
#include <QFont>
AboutWindow::AboutWindow(QWidget* parent)
: QDialog(parent)
{
setWindowTitle("About");
setFixedSize(400, 300);
// ===== 标题 =====
titleLabel = new QLabel("Landscape Interactive Tool");
QFont titleFont;
titleFont.setPointSize(16);
titleFont.setBold(true);
titleLabel->setFont(titleFont);
titleLabel->setAlignment(Qt::AlignCenter);
// ===== 版本 =====
versionLabel = new QLabel("Version: 1.0.0");
versionLabel->setAlignment(Qt::AlignCenter);
// ===== 作者 =====
authorLabel = new QLabel("Author: 丁伟豪");
authorLabel->setAlignment(Qt::AlignCenter);
// ===== 描述 =====
descLabel = new QLabel("An interactive tool for landscape visualization.\n"
"Built with Qt.");
descLabel->setAlignment(Qt::AlignCenter);
descLabel->setWordWrap(true);
// // ===== GitHub 按钮 =====
// githubButton = new QPushButton("GitHub");
// connect(githubButton, &QPushButton::clicked, []() {
// QDesktopServices::openUrl(QUrl("https://github.com/your_repo"));
// });
// ===== 关闭按钮 =====
closeButton = new QPushButton("Close");
connect(closeButton, &QPushButton::clicked, this, &QDialog::accept);
// ===== 按钮布局 =====
QHBoxLayout* buttonLayout = new QHBoxLayout;
buttonLayout->addStretch();
// buttonLayout->addWidget(githubButton);
buttonLayout->addWidget(closeButton);
// ===== 主布局 =====
QVBoxLayout* layout = new QVBoxLayout(this);
layout->addWidget(titleLabel);
layout->addWidget(versionLabel);
layout->addWidget(authorLabel);
layout->addSpacing(10);
layout->addWidget(descLabel);
layout->addStretch();
layout->addLayout(buttonLayout);
setLayout(layout);
}

View File

@@ -0,0 +1,20 @@
#pragma once
#include <QDialog>
class QLabel;
class QPushButton;
class AboutWindow : public QDialog
{
Q_OBJECT
public:
explicit AboutWindow(QWidget* parent = nullptr);
private:
QLabel* titleLabel;
QLabel* versionLabel;
QLabel* authorLabel;
QLabel* descLabel;
// QPushButton* githubButton;
QPushButton* closeButton;
};

View File

@@ -0,0 +1,50 @@
#include "dialogs/CancelableTaskDialog.h"
#include <QBoxLayout>
#include <QLabel>
#include <QProgressBar>
#include <QPushButton>
CancelableTaskDialog::CancelableTaskDialog(const QString& title,
const QString& message,
QWidget* parent)
: QDialog(parent) {
setWindowTitle(title);
setModal(true);
setMinimumWidth(420);
auto* root = new QVBoxLayout(this);
root->setContentsMargins(14, 14, 14, 14);
root->setSpacing(10);
m_label = new QLabel(message, this);
m_label->setWordWrap(true);
root->addWidget(m_label);
m_bar = new QProgressBar(this);
m_bar->setRange(0, 0); // 不定进度
root->addWidget(m_bar);
auto* row = new QHBoxLayout();
row->addStretch(1);
m_btnCancel = new QPushButton(QStringLiteral("取消"), this);
row->addWidget(m_btnCancel);
root->addLayout(row);
connect(m_btnCancel, &QPushButton::clicked, this, &CancelableTaskDialog::onCancel);
}
void CancelableTaskDialog::setMessage(const QString& message) {
if (m_label) {
m_label->setText(message);
}
}
void CancelableTaskDialog::onCancel() {
if (m_canceled) {
return;
}
m_canceled = true;
emit canceled();
}

View File

@@ -0,0 +1,35 @@
#pragma once
#include <QDialog>
#include <QString>
class QLabel;
class QProgressBar;
class QPushButton;
// 可复用的“长任务提示框”:显示提示文本 + 不定进度条 + 取消按钮。
// - 任务本身由调用方启动(例如网络请求/后台线程)
// - 调用方在取消时应中止任务,并调用 reject()/close()
class CancelableTaskDialog final : public QDialog {
Q_OBJECT
public:
explicit CancelableTaskDialog(const QString& title,
const QString& message,
QWidget* parent = nullptr);
void setMessage(const QString& message);
bool wasCanceled() const { return m_canceled; }
signals:
void canceled();
private slots:
void onCancel();
private:
QLabel* m_label = nullptr;
QProgressBar* m_bar = nullptr;
QPushButton* m_btnCancel = nullptr;
bool m_canceled = false;
};

View File

@@ -0,0 +1,252 @@
#include "dialogs/FrameAnimationDialog.h"
#include "core/animation/AnimationSampling.h"
#include "core/workspace/ProjectWorkspace.h"
#include <QBoxLayout>
#include <QDir>
#include <QFileDialog>
#include <QFileInfo>
#include <QImage>
#include <QLabel>
#include <QListWidget>
#include <QMessageBox>
#include <QPixmap>
#include <QPushButton>
namespace {
QString resolvedImageAbsForFrame(const core::ProjectWorkspace& ws,
const core::Project::Entity& e,
int frame) {
const QString rel = core::sampleImagePath(e.imageFrames, frame, e.imagePath);
if (rel.isEmpty()) return {};
const QString abs = QDir(ws.projectDir()).filePath(rel);
return abs;
}
} // namespace
FrameAnimationDialog::FrameAnimationDialog(core::ProjectWorkspace& workspace,
const QString& entityId,
int startFrame,
int endFrame,
QWidget* parent)
: QDialog(parent)
, m_workspace(workspace)
, m_entityId(entityId) {
setWindowTitle(QStringLiteral("区间动画帧"));
setModal(true);
setMinimumSize(720, 420);
m_start = std::min(startFrame, endFrame);
m_end = std::max(startFrame, endFrame);
auto* root = new QVBoxLayout(this);
root->setContentsMargins(12, 12, 12, 12);
root->setSpacing(10);
m_title = new QLabel(this);
m_title->setText(QStringLiteral("实体 %1 | 区间 [%2, %3]").arg(m_entityId).arg(m_start).arg(m_end));
root->addWidget(m_title);
auto* mid = new QHBoxLayout();
root->addLayout(mid, 1);
m_list = new QListWidget(this);
m_list->setMinimumWidth(240);
mid->addWidget(m_list, 0);
auto* right = new QVBoxLayout();
mid->addLayout(right, 1);
m_preview = new QLabel(this);
m_preview->setMinimumSize(320, 240);
m_preview->setFrameShape(QFrame::StyledPanel);
m_preview->setAlignment(Qt::AlignCenter);
m_preview->setText(QStringLiteral("选择一帧"));
right->addWidget(m_preview, 1);
auto* row = new QHBoxLayout();
right->addLayout(row);
m_btnReplace = new QPushButton(QStringLiteral("替换此帧…"), this);
m_btnClear = new QPushButton(QStringLiteral("清除此帧(恢复默认)"), this);
row->addWidget(m_btnReplace);
row->addWidget(m_btnClear);
auto* row2 = new QHBoxLayout();
right->addLayout(row2);
m_btnImportFiles = new QPushButton(QStringLiteral("批量导入(多选图片)…"), this);
m_btnImportFolder = new QPushButton(QStringLiteral("批量导入(文件夹)…"), this);
row2->addWidget(m_btnImportFiles);
row2->addWidget(m_btnImportFolder);
row2->addStretch(1);
auto* closeRow = new QHBoxLayout();
root->addLayout(closeRow);
closeRow->addStretch(1);
auto* btnClose = new QPushButton(QStringLiteral("关闭"), this);
closeRow->addWidget(btnClose);
connect(btnClose, &QPushButton::clicked, this, &QDialog::accept);
connect(m_list, &QListWidget::currentRowChanged, this, [this](int) { onSelectFrame(); });
connect(m_btnReplace, &QPushButton::clicked, this, &FrameAnimationDialog::onReplaceCurrentFrame);
connect(m_btnClear, &QPushButton::clicked, this, &FrameAnimationDialog::onClearCurrentFrame);
connect(m_btnImportFiles, &QPushButton::clicked, this, &FrameAnimationDialog::onBatchImportFiles);
connect(m_btnImportFolder, &QPushButton::clicked, this, &FrameAnimationDialog::onBatchImportFolder);
rebuildFrameList();
if (m_list->count() > 0) {
m_list->setCurrentRow(0);
}
}
void FrameAnimationDialog::rebuildFrameList() {
m_list->clear();
if (!m_workspace.isOpen()) return;
const auto& ents = m_workspace.entities();
const core::Project::Entity* hit = nullptr;
for (const auto& e : ents) {
if (e.id == m_entityId) {
hit = &e;
break;
}
}
if (!hit) return;
// 默认贴图(用于 UI 提示)
m_defaultImageAbs.clear();
if (!hit->imagePath.isEmpty()) {
const QString abs = QDir(m_workspace.projectDir()).filePath(hit->imagePath);
if (QFileInfo::exists(abs)) {
m_defaultImageAbs = abs;
}
}
for (int f = m_start; f <= m_end; ++f) {
bool hasCustom = false;
for (const auto& k : hit->imageFrames) {
if (k.frame == f) {
hasCustom = true;
break;
}
}
auto* it = new QListWidgetItem(QStringLiteral("%1%2").arg(f).arg(hasCustom ? QStringLiteral(" *") : QString()));
it->setData(Qt::UserRole, f);
m_list->addItem(it);
}
}
void FrameAnimationDialog::onSelectFrame() {
auto* it = m_list->currentItem();
if (!it) return;
const int f = it->data(Qt::UserRole).toInt();
updatePreviewForFrame(f);
}
void FrameAnimationDialog::updatePreviewForFrame(int frame) {
if (!m_workspace.isOpen()) return;
const auto& ents = m_workspace.entities();
const core::Project::Entity* hit = nullptr;
for (const auto& e : ents) {
if (e.id == m_entityId) {
hit = &e;
break;
}
}
if (!hit) return;
const QString abs = resolvedImageAbsForFrame(m_workspace, *hit, frame);
if (abs.isEmpty() || !QFileInfo::exists(abs)) {
m_preview->setText(QStringLiteral("无图像"));
return;
}
QPixmap pm(abs);
if (pm.isNull()) {
m_preview->setText(QStringLiteral("加载失败"));
return;
}
m_preview->setPixmap(pm.scaled(m_preview->size(), Qt::KeepAspectRatio, Qt::SmoothTransformation));
}
bool FrameAnimationDialog::applyImageToFrame(int frame, const QString& absImagePath) {
QImage img(absImagePath);
if (img.isNull()) {
return false;
}
if (img.format() != QImage::Format_ARGB32_Premultiplied) {
img = img.convertToFormat(QImage::Format_ARGB32_Premultiplied);
}
return m_workspace.setEntityImageFrame(m_entityId, frame, img);
}
void FrameAnimationDialog::onReplaceCurrentFrame() {
auto* it = m_list->currentItem();
if (!it) return;
const int f = it->data(Qt::UserRole).toInt();
const QString path = QFileDialog::getOpenFileName(
this,
QStringLiteral("选择该帧图像"),
QString(),
QStringLiteral("Images (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)"));
if (path.isEmpty()) return;
if (!applyImageToFrame(f, path)) {
QMessageBox::warning(this, QStringLiteral("动画帧"), QStringLiteral("写入该帧失败。"));
return;
}
rebuildFrameList();
updatePreviewForFrame(f);
}
void FrameAnimationDialog::onClearCurrentFrame() {
auto* it = m_list->currentItem();
if (!it) return;
const int f = it->data(Qt::UserRole).toInt();
if (!m_workspace.removeEntityImageFrame(m_entityId, f)) {
return;
}
rebuildFrameList();
updatePreviewForFrame(f);
}
void FrameAnimationDialog::onBatchImportFiles() {
const QStringList paths = QFileDialog::getOpenFileNames(
this,
QStringLiteral("选择逐帧动画图片(按文件名排序)"),
QString(),
QStringLiteral("Images (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)"));
if (paths.isEmpty()) return;
QStringList sorted = paths;
sorted.sort(Qt::CaseInsensitive);
const int need = m_end - m_start + 1;
const int count = std::min(need, static_cast<int>(sorted.size()));
for (int i = 0; i < count; ++i) {
applyImageToFrame(m_start + i, sorted[i]);
}
rebuildFrameList();
onSelectFrame();
}
void FrameAnimationDialog::onBatchImportFolder() {
const QString dir = QFileDialog::getExistingDirectory(this, QStringLiteral("选择逐帧动画图片文件夹"));
if (dir.isEmpty()) return;
QDir d(dir);
d.setFilter(QDir::Files | QDir::Readable);
d.setSorting(QDir::Name);
const QStringList filters = {QStringLiteral("*.png"),
QStringLiteral("*.jpg"),
QStringLiteral("*.jpeg"),
QStringLiteral("*.bmp"),
QStringLiteral("*.webp")};
const QStringList files = d.entryList(filters, QDir::Files, QDir::Name);
if (files.isEmpty()) return;
const int need = m_end - m_start + 1;
const int count = std::min(need, static_cast<int>(files.size()));
for (int i = 0; i < count; ++i) {
applyImageToFrame(m_start + i, d.filePath(files[i]));
}
rebuildFrameList();
onSelectFrame();
}

View File

@@ -0,0 +1,52 @@
#pragma once
#include <QDialog>
#include <QString>
#include <QStringList>
namespace core {
class ProjectWorkspace;
}
class QLabel;
class QListWidget;
class QPushButton;
class FrameAnimationDialog final : public QDialog {
Q_OBJECT
public:
FrameAnimationDialog(core::ProjectWorkspace& workspace,
const QString& entityId,
int startFrame,
int endFrame,
QWidget* parent = nullptr);
private slots:
void onSelectFrame();
void onReplaceCurrentFrame();
void onClearCurrentFrame();
void onBatchImportFiles();
void onBatchImportFolder();
private:
void rebuildFrameList();
void updatePreviewForFrame(int frame);
bool applyImageToFrame(int frame, const QString& absImagePath);
private:
core::ProjectWorkspace& m_workspace;
QString m_entityId;
int m_start = 0;
int m_end = 0;
QLabel* m_title = nullptr;
QListWidget* m_list = nullptr;
QLabel* m_preview = nullptr;
QPushButton* m_btnReplace = nullptr;
QPushButton* m_btnClear = nullptr;
QPushButton* m_btnImportFiles = nullptr;
QPushButton* m_btnImportFolder = nullptr;
QString m_defaultImageAbs;
};

View File

@@ -0,0 +1,209 @@
#include "dialogs/ImageCropDialog.h"
#include <QBoxLayout>
#include <QDialogButtonBox>
#include <QLabel>
#include <QMouseEvent>
#include <QPainter>
#include <QPushButton>
#include <QtMath>
class ImageCropDialog::CropView final : public QWidget {
public:
explicit CropView(QWidget* parent = nullptr)
: QWidget(parent) {
setMouseTracking(true);
setMinimumSize(480, 320);
}
void setImage(const QImage& img) {
m_image = img;
m_selection = {};
updateGeometry();
update();
}
bool hasSelection() const { return !m_selection.isNull() && m_selection.width() > 0 && m_selection.height() > 0; }
QRect selectionInImagePixels() const {
if (m_image.isNull() || !hasSelection()) {
return {};
}
const auto map = viewToImageTransform();
// selection 是 view 坐标;映射到 image 像素坐标
const QRectF selF = QRectF(m_selection).normalized();
bool invertible = false;
const QTransform inv = map.inverted(&invertible);
if (!invertible) {
return {};
}
const QPointF topLeftImg = inv.map(selF.topLeft());
const QPointF bottomRightImg = inv.map(selF.bottomRight());
// 使用 floor/ceil避免因为取整导致宽高变 0
const int left = qFloor(std::min(topLeftImg.x(), bottomRightImg.x()));
const int top = qFloor(std::min(topLeftImg.y(), bottomRightImg.y()));
const int right = qCeil(std::max(topLeftImg.x(), bottomRightImg.x()));
const int bottom = qCeil(std::max(topLeftImg.y(), bottomRightImg.y()));
QRect r(QPoint(left, top), QPoint(right, bottom));
r = r.normalized().intersected(QRect(0, 0, m_image.width(), m_image.height()));
return r;
}
void resetSelection() {
m_selection = {};
update();
}
protected:
void paintEvent(QPaintEvent*) override {
QPainter p(this);
p.fillRect(rect(), palette().window());
if (m_image.isNull()) {
p.setPen(palette().text().color());
p.drawText(rect(), Qt::AlignCenter, QStringLiteral("无法加载图片"));
return;
}
const auto map = viewToImageTransform();
p.setRenderHint(QPainter::SmoothPixmapTransform, true);
p.setTransform(map);
p.drawImage(QPoint(0, 0), m_image);
p.resetTransform();
if (hasSelection()) {
// 避免 CompositionMode_Clear 在某些平台/样式下表现异常:
// 用“围绕选区画四块遮罩”的方式实现高亮裁剪区域。
const QRect sel = m_selection.normalized().intersected(rect());
const QColor shade(0, 0, 0, 120);
// 上
p.fillRect(QRect(0, 0, width(), sel.top()), shade);
// 下
p.fillRect(QRect(0, sel.bottom(), width(), height() - sel.bottom()), shade);
// 左
p.fillRect(QRect(0, sel.top(), sel.left(), sel.height()), shade);
// 右
p.fillRect(QRect(sel.right(), sel.top(), width() - sel.right(), sel.height()), shade);
p.setPen(QPen(QColor(255, 255, 255, 220), 2));
p.drawRect(sel);
}
}
void mousePressEvent(QMouseEvent* e) override {
if (m_image.isNull() || e->button() != Qt::LeftButton) {
return;
}
m_dragging = true;
m_anchor = e->position().toPoint();
m_selection = QRect(m_anchor, m_anchor);
update();
}
void mouseMoveEvent(QMouseEvent* e) override {
if (!m_dragging) {
return;
}
const QPoint cur = e->position().toPoint();
m_selection = QRect(m_anchor, cur).normalized();
update();
}
void mouseReleaseEvent(QMouseEvent* e) override {
if (e->button() != Qt::LeftButton) {
return;
}
m_dragging = false;
update();
}
private:
QTransform viewToImageTransform() const {
// 让图片按比例 fit 到 view 中居中显示
const QSizeF viewSize = size();
const QSizeF imgSize = m_image.size();
const qreal sx = viewSize.width() / imgSize.width();
const qreal sy = viewSize.height() / imgSize.height();
const qreal s = std::min(sx, sy);
const qreal drawW = imgSize.width() * s;
const qreal drawH = imgSize.height() * s;
const qreal offsetX = (viewSize.width() - drawW) / 2.0;
const qreal offsetY = (viewSize.height() - drawH) / 2.0;
QTransform t;
t.translate(offsetX, offsetY);
t.scale(s, s);
return t;
}
private:
QImage m_image;
bool m_dragging = false;
QPoint m_anchor;
QRect m_selection;
};
ImageCropDialog::ImageCropDialog(const QString& imagePath, QWidget* parent)
: QDialog(parent),
m_imagePath(imagePath) {
setWindowTitle(QStringLiteral("裁剪图片"));
setModal(true);
resize(900, 600);
loadImageOrClose();
rebuildUi();
}
void ImageCropDialog::loadImageOrClose() {
m_image = QImage(m_imagePath);
if (m_image.isNull()) {
reject();
}
}
void ImageCropDialog::rebuildUi() {
auto* root = new QVBoxLayout(this);
auto* hint = new QLabel(QStringLiteral("拖拽选择裁剪区域(不选则使用整张图)。"), this);
root->addWidget(hint);
m_view = new CropView(this);
m_view->setImage(m_image);
root->addWidget(m_view, 1);
auto* buttons = new QDialogButtonBox(QDialogButtonBox::Ok | QDialogButtonBox::Cancel, this);
m_okButton = buttons->button(QDialogButtonBox::Ok);
auto* resetBtn = new QPushButton(QStringLiteral("重置选择"), this);
buttons->addButton(resetBtn, QDialogButtonBox::ActionRole);
connect(resetBtn, &QPushButton::clicked, this, &ImageCropDialog::onReset);
connect(buttons, &QDialogButtonBox::accepted, this, &ImageCropDialog::onOk);
connect(buttons, &QDialogButtonBox::rejected, this, &ImageCropDialog::reject);
root->addWidget(buttons);
}
bool ImageCropDialog::hasValidSelection() const {
return m_view && m_view->hasSelection();
}
QRect ImageCropDialog::selectedRectInImagePixels() const {
if (!m_view) {
return {};
}
return m_view->selectionInImagePixels();
}
void ImageCropDialog::onReset() {
if (m_view) {
m_view->resetSelection();
}
}
void ImageCropDialog::onOk() {
accept();
}

View File

@@ -0,0 +1,34 @@
#pragma once
#include <QDialog>
#include <QImage>
#include <QRect>
class QLabel;
class QPushButton;
class ImageCropDialog final : public QDialog {
Q_OBJECT
public:
explicit ImageCropDialog(const QString& imagePath, QWidget* parent = nullptr);
bool hasValidSelection() const;
QRect selectedRectInImagePixels() const;
private slots:
void onReset();
void onOk();
private:
void loadImageOrClose();
void rebuildUi();
private:
class CropView;
CropView* m_view = nullptr;
QPushButton* m_okButton = nullptr;
QString m_imagePath;
QImage m_image;
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,179 @@
#pragma once
#include "core/domain/Project.h"
#include <QPixmap>
#include <QPointF>
#include <QImage>
#include <QPainterPath>
#include <QVector>
#include <QWidget>
#include <QElapsedTimer>
class EditorCanvas final : public QWidget {
Q_OBJECT
public:
enum class Tool { Move, Zoom, CreateEntity };
Q_ENUM(Tool)
explicit EditorCanvas(QWidget* parent = nullptr);
void setBackgroundImagePath(const QString& absolutePath);
QString backgroundImagePath() const { return m_bgAbsPath; }
void setBackgroundVisible(bool on);
bool backgroundVisible() const { return m_backgroundVisible; }
void setDepthMapPath(const QString& absolutePath);
void setDepthOverlayEnabled(bool on);
bool depthOverlayEnabled() const { return m_depthOverlayEnabled; }
void setTool(Tool tool);
Tool tool() const { return m_tool; }
void resetView();
void zoomToFit();
void setWorldAxesVisible(bool on);
bool worldAxesVisible() const { return m_worldAxesVisible; }
void setAxisLabelsVisible(bool on);
bool axisLabelsVisible() const { return m_axisLabelsVisible; }
void setGizmoLabelsVisible(bool on);
bool gizmoLabelsVisible() const { return m_gizmoLabelsVisible; }
void setGridVisible(bool on);
bool gridVisible() const { return m_gridVisible; }
void setCheckerboardVisible(bool on);
bool checkerboardVisible() const { return m_checkerboardVisible; }
// 预览呈现:完整背景 + 全部实体(忽略显隐开关),隐藏编辑辅助元素,仅可平移/缩放查看
void setPresentationPreviewMode(bool on);
bool presentationPreviewMode() const { return m_presentationPreviewMode; }
void setEntities(const QVector<core::Project::Entity>& entities, const QString& projectDirAbs);
void setCurrentFrame(int frame);
int currentFrame() const { return m_currentFrame; }
bool isDraggingEntity() const { return m_draggingEntity; }
void selectEntityById(const QString& id);
void clearEntitySelection();
// 与动画求值一致的原点/缩放(用于 K 帧与自动关键帧)
QPointF selectedAnimatedOriginWorld() const;
double selectedDepthScale01() const;
QPointF selectedEntityCentroidWorld() const;
double selectedDistanceScaleMultiplier() const;
double selectedUserScale() const;
double selectedCombinedScale() const;
enum class DragMode { None, Free, AxisX, AxisY };
signals:
void hoveredWorldPosChanged(const QPointF& worldPos);
void hoveredWorldPosDepthChanged(const QPointF& worldPos, int depthZ);
void selectedEntityChanged(bool hasSelection, const QString& id, int depth, const QPointF& originWorld);
void requestAddEntity(const core::Project::Entity& entity, const QImage& image);
void requestMoveEntity(const QString& id, const QPointF& delta);
void entityDragActiveChanged(bool on);
void selectedEntityPreviewChanged(const QString& id, int depth, const QPointF& originWorld);
protected:
void paintEvent(QPaintEvent* e) override;
void resizeEvent(QResizeEvent* e) override;
void mousePressEvent(QMouseEvent* e) override;
void mouseMoveEvent(QMouseEvent* e) override;
void mouseReleaseEvent(QMouseEvent* e) override;
void wheelEvent(QWheelEvent* e) override;
private:
void ensurePixmapLoaded() const;
void invalidatePixmap();
void updateCursor();
QPointF viewToWorld(const QPointF& v) const;
QPointF worldToView(const QPointF& w) const;
QRectF worldRectOfBackground() const;
private:
struct Entity {
QString id;
QRectF rect; // world 坐标(用于拖拽与约束)
QVector<QPointF> polygonWorld; // 非空则使用 polygon
QPainterPath pathWorld; // polygonWorld 对应的 world 路径(缓存,避免每帧重建)
QVector<QPointF> cutoutPolygonWorld;
QColor color;
// 实体独立信息:
int depth = 0; // 0..255,来自划分区域平均深度
QImage image; // 抠图后的实体图像(带透明)
QPointF imageTopLeft; // image 对应的 world 左上角
double visualScale = 1.0; // 实体在 world 坐标下的缩放(用于贴图绘制)
double userScale = 1.0; // 与深度距离缩放相乘
QPointF animatedOriginWorld;
double animatedDepthScale01 = 0.5;
// 编辑模式下实体被设为隐藏时:不响应点选且不绘制,除非当前选中(便于树选隐藏实体)
bool hiddenInEditMode = false;
};
int hitTestEntity(const QPointF& worldPos) const;
private:
QString m_bgAbsPath;
bool m_backgroundVisible = true;
mutable QPixmap m_bgPixmap;
mutable bool m_pixmapDirty = true;
mutable QImage m_bgImage; // 原背景(用于抠图/填充)
mutable QImage m_bgImageCutout; // 抠图后的背景(实体区域填黑)
mutable bool m_bgImageDirty = true;
mutable bool m_bgCutoutDirty = true;
QString m_depthAbsPath;
mutable QImage m_depthImage8;
mutable bool m_depthDirty = true;
bool m_depthOverlayEnabled = false;
int m_depthOverlayAlpha = 110;
bool m_worldAxesVisible = true;
bool m_axisLabelsVisible = true;
bool m_gizmoLabelsVisible = true;
bool m_gridVisible = true;
bool m_checkerboardVisible = true;
bool m_presentationPreviewMode = false;
Tool m_tool = Tool::Move;
qreal m_scale = 1.0;
QPointF m_pan; // world 原点对应的 view 坐标偏移view = world*scale + pan
bool m_dragging = false;
bool m_draggingEntity = false;
bool m_drawingEntity = false;
QPointF m_lastMouseView;
// 拖动以“实体原点 animatedOriginWorld”为基准避免因缩放导致 rect/topLeft 抖动
QPointF m_entityDragOffsetOriginWorld;
QPointF m_entityDragStartAnimatedOrigin;
// 拖动性能优化:拖动过程中不逐点修改 polygonWorld而是保留基准形状+增量参数,在 paint 时做变换预览
bool m_dragPreviewActive = false;
QVector<QPointF> m_dragPolyBase;
QPainterPath m_dragPathBase;
QPointF m_dragImageTopLeftBase;
QRectF m_dragRectBase;
QPointF m_dragOriginBase;
QPointF m_dragDelta; // 纯平移
QPointF m_dragCentroidBase;
double m_dragScaleBase = 1.0; // 拖动开始时的 visualScale
double m_dragScaleRatio = 1.0; // 相对 m_dragScaleBase 的缩放比(由深度重算驱动)
QElapsedTimer m_previewEmitTimer;
qint64 m_lastPreviewEmitMs = 0;
qint64 m_lastDepthScaleRecalcMs = 0;
int m_selectedEntity = -1;
DragMode m_dragMode = DragMode::None;
QPointF m_dragStartMouseWorld;
QVector<Entity> m_entities;
QVector<QPointF> m_strokeWorld;
int m_currentFrame = 0;
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,177 @@
#pragma once
#include "core/workspace/ProjectWorkspace.h"
#include "main_window/RecentProjectHistory.h"
#include <QMainWindow>
#include <QPointF>
#include <QFrame>
#include <QIcon>
#include <QTimer>
class QAction;
class QCheckBox;
class QComboBox;
class QDockWidget;
class QFormLayout;
class QLabel;
class QMenu;
class QFrame;
class QIcon;
class QPushButton;
class QSlider;
class QStackedWidget;
class QToolButton;
class QTreeWidget;
class QTreeWidgetItem;
class QWidget;
class EditorCanvas;
class TimelineWidget;
namespace gui {
class BackgroundPropertySection;
class EntityPropertySection;
}
class MainWindow : public QMainWindow {
Q_OBJECT
public:
explicit MainWindow(QWidget* parent = nullptr);
protected:
bool eventFilter(QObject* watched, QEvent* event) override;
private slots:
// 文件菜单槽函数
void onNewProject();
void onOpenProject();
void onSaveProject();
void onCloseProject();
// 编辑菜单槽函数
void onUndo();
void onRedo();
void onCopyObject();
void onPasteObject();
// 帮助菜单槽函数
void onAbout();
void onComputeDepth();
void onTogglePlay(bool on);
void onInsertCombinedKey(); // 位置 + userScale
void onProjectTreeItemClicked(QTreeWidgetItem* item, int column);
private:
void computeDepthAsync();
// UI 状态分三种:
// - Welcome未打开项目。只显示欢迎页其它 dock 一律隐藏,视图开关禁用。
// - Editor已打开项目。显示编辑页按默认规则显示 dock同时允许用户通过“视图”菜单控制。
// - Preview预览展示。用于全流程完成后的展示要求项目已打开且背景不为空
enum class UiMode { Welcome, Editor, Preview };
void createMenus(); // 菜单和工具栏
void createFileMenu(); // 文件菜单
void createEditMenu(); // 编辑菜单
void createHelpMenu(); // 帮助菜单
void createViewMenu(); // 视图菜单
void createProjectTreeDock();
void createTimelineDock();
void refreshProjectTree();
void updateUiEnabledState(); // 更新“可用性/勾选/默认显隐”,不要做业务逻辑
void applyUiMode(UiMode mode); // 统一控制 welcome/editor 两态的显隐策略
UiMode currentUiMode() const; // 根据 workspace 状态推导
void syncCanvasViewMenuFromState();
void showProjectRootContextMenu(const QPoint& globalPos);
void showBackgroundContextMenu(const QPoint& globalPos);
void rebuildCentralPages();
void showWelcomePage();
void showEditorPage();
void showPreviewPage();
void refreshWelcomeRecentList();
void openProjectFromPath(const QString& dir);
void refreshPreviewPage();
void refreshEditorPage();
void applyTimelineFromProject();
void refreshDopeSheet();
void setPreviewRequested(bool preview);
QStackedWidget* m_centerStack = nullptr;
QWidget* m_pageWelcome = nullptr;
QTreeWidget* m_welcomeRecentTree = nullptr;
QLabel* m_welcomeRecentEmptyLabel = nullptr;
QWidget* m_pageEditor = nullptr;
QWidget* m_canvasHost = nullptr;
QFrame* m_floatingModeDock = nullptr;
QFrame* m_floatingToolDock = nullptr;
QComboBox* m_modeSelector = nullptr;
QStackedWidget* m_propertyStack = nullptr;
gui::BackgroundPropertySection* m_bgPropertySection = nullptr;
gui::EntityPropertySection* m_entityPropertySection = nullptr;
QToolButton* m_btnCreateEntity = nullptr;
QToolButton* m_btnToggleDepthOverlay = nullptr;
EditorCanvas* m_editorCanvas = nullptr;
QTreeWidget* m_projectTree = nullptr;
QDockWidget* m_dockProjectTree = nullptr;
QDockWidget* m_dockProperties = nullptr;
QDockWidget* m_dockTimeline = nullptr;
QTreeWidgetItem* m_itemBackground = nullptr;
QAction* m_actionUndo = nullptr;
QAction* m_actionRedo = nullptr;
QAction* m_actionCopy = nullptr;
QAction* m_actionPaste = nullptr;
QAction* m_actionToggleProjectTree = nullptr;
QAction* m_actionToggleProperties = nullptr;
QAction* m_actionToggleTimeline = nullptr;
QAction* m_actionEnterPreview = nullptr;
QAction* m_actionBackToEditor = nullptr;
QAction* m_actionCanvasWorldAxes = nullptr;
QAction* m_actionCanvasAxisValues = nullptr;
QAction* m_actionCanvasGrid = nullptr;
QAction* m_actionCanvasCheckerboard = nullptr;
QAction* m_actionCanvasDepthOverlay = nullptr;
QAction* m_actionCanvasGizmoLabels = nullptr;
core::ProjectWorkspace m_workspace;
RecentProjectHistory m_recentHistory;
bool m_previewRequested = false;
/// 因右侧栏过窄自动收起;用户通过视图菜单再次打开时清除
bool m_rightDocksNarrowHidden = false;
QPointF m_lastWorldPos;
int m_lastWorldZ = -1;
bool m_hasSelectedEntity = false;
bool m_syncingTreeSelection = false;
int m_selectedEntityDepth = 0;
QPointF m_selectedEntityOrigin;
QString m_selectedEntityId;
QString m_selectedEntityDisplayNameCache;
QString m_bgAbsCache;
QString m_bgSizeTextCache;
void updateStatusBarText();
void refreshPropertyPanel();
void refreshEntityPropertyPanelFast();
void syncProjectTreeFromCanvasSelection();
bool m_timelineScrubbing = false;
bool m_entityDragging = false;
QTimer* m_propertySyncTimer = nullptr;
int m_currentFrame = 0;
bool m_playing = false;
QTimer* m_playTimer = nullptr;
TimelineWidget* m_timeline = nullptr;
QToolButton* m_btnPlay = nullptr;
QLabel* m_frameLabel = nullptr;
// 时间轴区间选择(用于逐帧贴图动画)
int m_timelineRangeStart = -1;
int m_timelineRangeEnd = -1;
QCheckBox* m_chkAutoKeyframe = nullptr;
// 旧版 DopeSheet 已移除,这里保留占位便于后续扩展区间 UI如自定义小部件
QTreeWidget* m_dopeTree = nullptr;
QPushButton* m_btnDopeDeleteKey = nullptr;
};

View File

@@ -0,0 +1,100 @@
#include "main_window/RecentProjectHistory.h"
#include <QDir>
#include <QFile>
#include <QFileInfo>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonValue>
#include <QDebug>
#include <QStandardPaths>
QString RecentProjectHistory::cacheFilePath() {
const QString base = QStandardPaths::writableLocation(QStandardPaths::GenericCacheLocation);
return QDir(base).filePath(QStringLiteral("landscape_tool/recent_projects.cache"));
}
QString RecentProjectHistory::normalizePath(const QString& path) {
if (path.isEmpty()) {
return QString();
}
const QFileInfo fi(path);
const QString c = fi.canonicalFilePath();
return c.isEmpty() ? QDir::cleanPath(fi.absoluteFilePath()) : c;
}
QStringList RecentProjectHistory::dedupeNewestFirst(const QStringList& paths) {
QStringList out;
out.reserve(paths.size());
for (const QString& p : paths) {
const QString n = normalizePath(p);
if (n.isEmpty()) {
continue;
}
if (out.contains(n)) {
continue;
}
out.append(n);
if (out.size() >= kMaxEntries) {
break;
}
}
return out;
}
QStringList RecentProjectHistory::load() const {
const QString filePath = cacheFilePath();
QFile f(filePath);
if (!f.open(QIODevice::ReadOnly)) {
return {};
}
const QJsonDocument doc = QJsonDocument::fromJson(f.readAll());
if (!doc.isArray()) {
return {};
}
QStringList paths;
for (const QJsonValue& v : doc.array()) {
if (v.isString()) {
paths.append(v.toString());
}
}
return dedupeNewestFirst(paths);
}
bool RecentProjectHistory::save(const QStringList& paths) const {
const QString filePath = cacheFilePath();
const QFileInfo fi(filePath);
QDir().mkpath(fi.absolutePath());
QJsonArray arr;
for (const QString& p : dedupeNewestFirst(paths)) {
arr.append(p);
}
const QJsonDocument doc(arr);
QFile f(filePath);
if (!f.open(QIODevice::WriteOnly | QIODevice::Truncate)) {
qWarning() << "RecentProjectHistory: cannot write" << filePath;
return false;
}
f.write(doc.toJson(QJsonDocument::Compact));
return true;
}
void RecentProjectHistory::addAndSave(const QString& projectDir) {
const QString n = normalizePath(projectDir);
if (n.isEmpty()) {
return;
}
QStringList paths = load();
paths.removeAll(n);
paths.prepend(n);
save(paths);
}
void RecentProjectHistory::removeAndSave(const QString& projectDir) {
const QString n = normalizePath(projectDir);
QStringList paths = load();
paths.removeAll(n);
save(paths);
}

View File

@@ -0,0 +1,21 @@
#pragma once
#include <QString>
#include <QStringList>
class RecentProjectHistory {
public:
static constexpr int kMaxEntries = 15;
static QString cacheFilePath();
QStringList load() const;
bool save(const QStringList& paths) const;
void addAndSave(const QString& projectDir);
void removeAndSave(const QString& projectDir);
static QString normalizePath(const QString& path);
private:
static QStringList dedupeNewestFirst(const QStringList& paths);
};

View File

@@ -0,0 +1,127 @@
#include "params/ParamControls.h"
#include <algorithm>
#include <cmath>
#include <QDoubleSpinBox>
#include <QHBoxLayout>
#include <QSlider>
namespace gui {
Float01ParamControl::Float01ParamControl(QWidget* parent)
: QWidget(parent) {
auto* row = new QHBoxLayout(this);
row->setContentsMargins(0, 0, 0, 0);
row->setSpacing(8);
m_slider = new QSlider(Qt::Horizontal, this);
m_slider->setRange(0, 1000);
m_slider->setSingleStep(1);
m_slider->setPageStep(10);
row->addWidget(m_slider, 1);
m_spin = new QDoubleSpinBox(this);
m_spin->setRange(0.0, 1.0);
m_spin->setDecimals(3);
m_spin->setSingleStep(0.01);
m_spin->setMinimumWidth(84);
row->addWidget(m_spin);
connect(m_slider, &QSlider::valueChanged, this, [this]() { syncFromSlider(); });
connect(m_spin, qOverload<double>(&QDoubleSpinBox::valueChanged), this, [this]() { syncFromSpin(); });
setValue01(0.5);
}
void Float01ParamControl::setEnabled(bool on) {
QWidget::setEnabled(on);
if (m_slider) m_slider->setEnabled(on);
if (m_spin) m_spin->setEnabled(on);
}
double Float01ParamControl::value01() const {
return m_spin ? m_spin->value() : 0.5;
}
void Float01ParamControl::setValue01(double v) {
const double clamped = std::clamp(v, 0.0, 1.0);
m_block = true;
if (m_spin) m_spin->setValue(clamped);
if (m_slider) m_slider->setValue(static_cast<int>(std::lround(clamped * 1000.0)));
m_block = false;
}
void Float01ParamControl::syncFromSlider() {
if (m_block || !m_slider || !m_spin) return;
m_block = true;
const double v = static_cast<double>(m_slider->value()) / 1000.0;
m_spin->setValue(v);
m_block = false;
emit valueChanged01(v);
}
void Float01ParamControl::syncFromSpin() {
if (m_block || !m_slider || !m_spin) return;
m_block = true;
const double v = m_spin->value();
m_slider->setValue(static_cast<int>(std::lround(v * 1000.0)));
m_block = false;
emit valueChanged01(v);
}
Vec2ParamControl::Vec2ParamControl(QWidget* parent)
: QWidget(parent) {
auto* row = new QHBoxLayout(this);
row->setContentsMargins(0, 0, 0, 0);
row->setSpacing(8);
m_x = new QDoubleSpinBox(this);
m_x->setRange(-1e9, 1e9);
m_x->setDecimals(2);
m_x->setSingleStep(1.0);
m_x->setMinimumWidth(88);
row->addWidget(m_x, 1);
m_y = new QDoubleSpinBox(this);
m_y->setRange(-1e9, 1e9);
m_y->setDecimals(2);
m_y->setSingleStep(1.0);
m_y->setMinimumWidth(88);
row->addWidget(m_y, 1);
connect(m_x, qOverload<double>(&QDoubleSpinBox::valueChanged), this, [this]() { emitIfChanged(); });
connect(m_y, qOverload<double>(&QDoubleSpinBox::valueChanged), this, [this]() { emitIfChanged(); });
setValue(0.0, 0.0);
}
void Vec2ParamControl::setEnabled(bool on) {
QWidget::setEnabled(on);
if (m_x) m_x->setEnabled(on);
if (m_y) m_y->setEnabled(on);
}
void Vec2ParamControl::setValue(double x, double y) {
m_block = true;
if (m_x) m_x->setValue(x);
if (m_y) m_y->setValue(y);
m_lastX = x;
m_lastY = y;
m_block = false;
}
double Vec2ParamControl::x() const { return m_x ? m_x->value() : 0.0; }
double Vec2ParamControl::y() const { return m_y ? m_y->value() : 0.0; }
void Vec2ParamControl::emitIfChanged() {
if (m_block || !m_x || !m_y) return;
const double nx = m_x->value();
const double ny = m_y->value();
if (nx == m_lastX && ny == m_lastY) return;
m_lastX = nx;
m_lastY = ny;
emit valueChanged(nx, ny);
}
} // namespace gui

View File

@@ -0,0 +1,60 @@
#pragma once
#include <QWidget>
class QDoubleSpinBox;
class QSlider;
class QLabel;
namespace gui {
// 0..1 浮点参数Slider + DoubleSpinBox可复用
class Float01ParamControl final : public QWidget {
Q_OBJECT
public:
explicit Float01ParamControl(QWidget* parent = nullptr);
void setValue01(double v);
double value01() const;
void setEnabled(bool on);
signals:
void valueChanged01(double v);
private:
void syncFromSlider();
void syncFromSpin();
QSlider* m_slider = nullptr;
QDoubleSpinBox* m_spin = nullptr;
bool m_block = false;
};
// Vec2 参数:两个 DoubleSpinBox可复用
class Vec2ParamControl final : public QWidget {
Q_OBJECT
public:
explicit Vec2ParamControl(QWidget* parent = nullptr);
void setValue(double x, double y);
double x() const;
double y() const;
void setEnabled(bool on);
signals:
void valueChanged(double x, double y);
private:
void emitIfChanged();
QDoubleSpinBox* m_x = nullptr;
QDoubleSpinBox* m_y = nullptr;
bool m_block = false;
double m_lastX = 0.0;
double m_lastY = 0.0;
};
} // namespace gui

View File

@@ -0,0 +1,77 @@
#include "props/BackgroundPropertySection.h"
#include <QCheckBox>
#include <QFormLayout>
#include <QLabel>
#include <QVBoxLayout>
namespace gui {
BackgroundPropertySection::BackgroundPropertySection(QWidget* parent)
: PropertySectionWidget(parent) {
auto* lay = new QVBoxLayout(this);
lay->setContentsMargins(0, 0, 0, 0);
lay->setSpacing(6);
auto* form = new QFormLayout();
form->setContentsMargins(0, 0, 0, 0);
form->setSpacing(6);
m_sizeLabel = new QLabel(QStringLiteral("-"), this);
m_sizeLabel->setTextInteractionFlags(Qt::TextSelectableByMouse);
form->addRow(QStringLiteral("背景尺寸"), m_sizeLabel);
m_showBackground = new QCheckBox(QStringLiteral("显示背景"), this);
m_showBackground->setToolTip(QStringLiteral("是否绘制背景图"));
form->addRow(QString(), m_showBackground);
m_depthOverlay = new QCheckBox(QStringLiteral("叠加深度"), this);
m_depthOverlay->setToolTip(QStringLiteral("在背景上叠加深度伪彩图"));
form->addRow(QString(), m_depthOverlay);
lay->addLayout(form);
lay->addStretch(1);
connect(m_showBackground, &QCheckBox::toggled, this, &BackgroundPropertySection::backgroundVisibleToggled);
connect(m_depthOverlay, &QCheckBox::toggled, this, &BackgroundPropertySection::depthOverlayToggled);
}
void BackgroundPropertySection::setBackgroundSizeText(const QString& text) {
if (m_sizeLabel) {
m_sizeLabel->setText(text);
}
}
void BackgroundPropertySection::syncBackgroundVisible(bool visible, bool controlsEnabled) {
if (!m_showBackground) {
return;
}
m_showBackground->blockSignals(true);
m_showBackground->setChecked(visible);
m_showBackground->setEnabled(controlsEnabled);
m_showBackground->blockSignals(false);
}
void BackgroundPropertySection::syncDepthOverlayChecked(bool on) {
if (!m_depthOverlay) {
return;
}
m_depthOverlay->blockSignals(true);
m_depthOverlay->setChecked(on);
m_depthOverlay->blockSignals(false);
}
void BackgroundPropertySection::setDepthOverlayCheckEnabled(bool on) {
if (m_depthOverlay) {
m_depthOverlay->setEnabled(on);
}
}
void BackgroundPropertySection::setProjectClosedAppearance() {
setBackgroundSizeText(QStringLiteral("-"));
syncBackgroundVisible(true, false);
syncDepthOverlayChecked(false);
setDepthOverlayCheckEnabled(false);
}
} // namespace gui

View File

@@ -0,0 +1,32 @@
#pragma once
#include "props/PropertySectionWidget.h"
class QLabel;
class QCheckBox;
namespace gui {
// 背景相关属性:尺寸、显隐、深度叠加(可嵌入 QStackedWidget 的一页)
class BackgroundPropertySection final : public PropertySectionWidget {
Q_OBJECT
public:
explicit BackgroundPropertySection(QWidget* parent = nullptr);
void setBackgroundSizeText(const QString& text);
void syncBackgroundVisible(bool visible, bool controlsEnabled);
void syncDepthOverlayChecked(bool on);
void setDepthOverlayCheckEnabled(bool on);
void setProjectClosedAppearance();
signals:
void backgroundVisibleToggled(bool on);
void depthOverlayToggled(bool on);
private:
QLabel* m_sizeLabel = nullptr;
QCheckBox* m_showBackground = nullptr;
QCheckBox* m_depthOverlay = nullptr;
};
} // namespace gui

View File

@@ -0,0 +1,108 @@
#include "props/EntityPropertySection.h"
#include "params/ParamControls.h"
#include <QDoubleSpinBox>
#include <QFormLayout>
#include <QLabel>
#include <QLineEdit>
#include <QVBoxLayout>
namespace gui {
EntityPropertySection::EntityPropertySection(QWidget* parent)
: PropertySectionWidget(parent) {
auto* lay = new QVBoxLayout(this);
lay->setContentsMargins(0, 0, 0, 0);
lay->setSpacing(6);
auto* form = new QFormLayout();
form->setContentsMargins(0, 0, 0, 0);
form->setSpacing(6);
m_name = new QLineEdit(this);
m_name->setPlaceholderText(QStringLiteral("显示名称"));
m_name->setToolTip(QStringLiteral("仅显示用;内部 id 不变"));
form->addRow(QStringLiteral("名称"), m_name);
m_depth = new QLabel(QStringLiteral("-"), this);
m_distScale = new QLabel(QStringLiteral("-"), this);
for (QLabel* lab : {m_depth, m_distScale}) {
lab->setTextInteractionFlags(Qt::TextSelectableByMouse);
}
form->addRow(QStringLiteral("深度"), m_depth);
form->addRow(QStringLiteral("距离缩放"), m_distScale);
m_pivot = new Vec2ParamControl(this);
m_pivot->setToolTip(QStringLiteral("枢轴在世界坐标中的位置(限制在轮廓包络内),用于重定位局部原点"));
form->addRow(QStringLiteral("中心坐标"), m_pivot);
m_centroid = new Vec2ParamControl(this);
m_centroid->setToolTip(QStringLiteral("实体几何质心的世界坐标;修改将整体平移实体"));
form->addRow(QStringLiteral("位置"), m_centroid);
m_userScale = new QDoubleSpinBox(this);
m_userScale->setRange(0.05, 20.0);
m_userScale->setDecimals(3);
m_userScale->setSingleStep(0.05);
m_userScale->setValue(1.0);
m_userScale->setToolTip(QStringLiteral("人为整体缩放,与深度距离缩放相乘"));
form->addRow(QStringLiteral("整体缩放"), m_userScale);
lay->addLayout(form);
lay->addStretch(1);
connect(m_name, &QLineEdit::editingFinished, this, [this]() {
if (m_name) {
emit displayNameCommitted(m_name->text());
}
});
connect(m_pivot, &Vec2ParamControl::valueChanged, this, &EntityPropertySection::pivotEdited);
connect(m_centroid, &Vec2ParamControl::valueChanged, this, &EntityPropertySection::centroidEdited);
connect(m_userScale, qOverload<double>(&QDoubleSpinBox::valueChanged), this, &EntityPropertySection::userScaleEdited);
}
void EntityPropertySection::clearDisconnected() {
setEditingEnabled(false);
if (m_name) {
m_name->blockSignals(true);
m_name->clear();
m_name->blockSignals(false);
}
if (m_depth) m_depth->setText(QStringLiteral("-"));
if (m_distScale) m_distScale->setText(QStringLiteral("-"));
if (m_pivot) m_pivot->setValue(0.0, 0.0);
if (m_centroid) m_centroid->setValue(0.0, 0.0);
if (m_userScale) {
m_userScale->blockSignals(true);
m_userScale->setValue(1.0);
m_userScale->blockSignals(false);
}
}
void EntityPropertySection::applyState(const EntityPropertyUiState& s) {
setEditingEnabled(true);
if (m_name) {
m_name->blockSignals(true);
m_name->setText(s.displayName);
m_name->blockSignals(false);
}
if (m_depth) m_depth->setText(QString::number(s.depthZ));
if (m_distScale) m_distScale->setText(s.distanceScaleText);
if (m_pivot) m_pivot->setValue(s.pivot.x(), s.pivot.y());
if (m_centroid) m_centroid->setValue(s.centroid.x(), s.centroid.y());
if (m_userScale) {
m_userScale->blockSignals(true);
m_userScale->setValue(s.userScale);
m_userScale->blockSignals(false);
}
}
void EntityPropertySection::setEditingEnabled(bool on) {
if (m_name) m_name->setEnabled(on);
if (m_pivot) m_pivot->setEnabled(on);
if (m_centroid) m_centroid->setEnabled(on);
if (m_userScale) m_userScale->setEnabled(on);
}
} // namespace gui

View File

@@ -0,0 +1,52 @@
#pragma once
#include "props/PropertySectionWidget.h"
#include <QPointF>
#include <QString>
class QLabel;
class QLineEdit;
class QDoubleSpinBox;
namespace gui {
class Vec2ParamControl;
}
namespace gui {
struct EntityPropertyUiState {
QString displayName;
int depthZ = 0;
QString distanceScaleText;
QPointF pivot;
QPointF centroid;
double userScale = 1.0;
};
// 实体相关属性(可嵌入 QStackedWidget 的一页)
class EntityPropertySection final : public PropertySectionWidget {
Q_OBJECT
public:
explicit EntityPropertySection(QWidget* parent = nullptr);
void clearDisconnected();
void applyState(const EntityPropertyUiState& s);
void setEditingEnabled(bool on);
signals:
void displayNameCommitted(const QString& text);
void pivotEdited(double x, double y);
void centroidEdited(double x, double y);
void userScaleEdited(double value);
private:
QLineEdit* m_name = nullptr;
QLabel* m_depth = nullptr;
QLabel* m_distScale = nullptr;
Vec2ParamControl* m_pivot = nullptr;
Vec2ParamControl* m_centroid = nullptr;
QDoubleSpinBox* m_userScale = nullptr;
};
} // namespace gui

View File

@@ -0,0 +1,13 @@
#pragma once
#include <QWidget>
namespace gui {
// 属性 dock 中可切换的「一节」的公共基类:便于以后扩展更多对象类型(灯光、相机等)
class PropertySectionWidget : public QWidget {
public:
explicit PropertySectionWidget(QWidget* parent = nullptr) : QWidget(parent) {}
};
} // namespace gui

View File

@@ -0,0 +1,310 @@
#include "timeline/TimelineWidget.h"
#include <algorithm>
#include <QMouseEvent>
#include <QPainter>
#include <QWheelEvent>
namespace {
int clampFrame(int f, int a, int b) {
if (a > b) std::swap(a, b);
return std::clamp(f, a, b);
}
} // namespace
TimelineWidget::TimelineWidget(QWidget* parent)
: QWidget(parent) {
setMouseTracking(true);
setMinimumHeight(28);
setFocusPolicy(Qt::StrongFocus);
}
void TimelineWidget::setFrameRange(int start, int end) {
if (m_start == start && m_end == end) {
return;
}
m_start = start;
m_end = end;
m_currentFrame = clampFrame(m_currentFrame, m_start, m_end);
update();
}
void TimelineWidget::setCurrentFrame(int frame) {
setFrameInternal(frame, false);
}
void TimelineWidget::setSelectionRange(int start, int end) {
if (start < 0 || end < 0) {
m_selStart = -1;
m_selEnd = -1;
update();
return;
}
m_selStart = clampFrame(std::min(start, end), m_start, m_end);
m_selEnd = clampFrame(std::max(start, end), m_start, m_end);
update();
}
void TimelineWidget::setKeyframeTracks(const core::Project::Entity* e) {
m_locFrames.clear();
m_scaleFrames.clear();
m_imgFrames.clear();
if (!e) {
update();
return;
}
m_locFrames.reserve(e->locationKeys.size());
for (const auto& k : e->locationKeys) m_locFrames.push_back(k.frame);
m_scaleFrames.reserve(e->userScaleKeys.size());
for (const auto& k : e->userScaleKeys) m_scaleFrames.push_back(k.frame);
m_imgFrames.reserve(e->imageFrames.size());
for (const auto& k : e->imageFrames) m_imgFrames.push_back(k.frame);
auto uniqSort = [](QVector<int>& v) {
std::sort(v.begin(), v.end());
v.erase(std::unique(v.begin(), v.end()), v.end());
};
uniqSort(m_locFrames);
uniqSort(m_scaleFrames);
uniqSort(m_imgFrames);
// 轨道变了:若当前选中的关键帧不再存在,则清除
auto contains = [](const QVector<int>& v, int f) {
return std::binary_search(v.begin(), v.end(), f);
};
bool ok = true;
if (m_selKeyKind == KeyKind::Location) ok = contains(m_locFrames, m_selKeyFrame);
if (m_selKeyKind == KeyKind::UserScale) ok = contains(m_scaleFrames, m_selKeyFrame);
if (m_selKeyKind == KeyKind::Image) ok = contains(m_imgFrames, m_selKeyFrame);
if (!ok) {
m_selKeyKind = KeyKind::None;
m_selKeyFrame = -1;
emit keyframeSelectionChanged(m_selKeyKind, m_selKeyFrame);
}
update();
}
QRect TimelineWidget::trackRect() const {
const int pad = 8;
const int h = height();
return QRect(pad, 0, std::max(1, width() - pad * 2), h);
}
int TimelineWidget::xToFrame(int x) const {
const QRect r = trackRect();
if (r.width() <= 1) return m_start;
const double t = std::clamp((x - r.left()) / double(r.width() - 1), 0.0, 1.0);
const int span = std::max(1, m_end - m_start);
const int f = m_start + int(std::round(t * span));
return clampFrame(f, m_start, m_end);
}
int TimelineWidget::frameToX(int frame) const {
const QRect r = trackRect();
if (r.width() <= 1) return r.left();
const int f = clampFrame(frame, m_start, m_end);
const int span = std::max(1, m_end - m_start);
const double t = double(f - m_start) / double(span);
return r.left() + int(std::round(t * (r.width() - 1)));
}
void TimelineWidget::setFrameInternal(int frame, bool commit) {
const int f = clampFrame(frame, m_start, m_end);
if (m_currentFrame == f && !commit) {
return;
}
m_currentFrame = f;
update();
emit frameScrubbed(f);
if (commit) {
emit frameCommitted(f);
}
}
void TimelineWidget::paintEvent(QPaintEvent*) {
QPainter p(this);
p.setRenderHint(QPainter::Antialiasing, true);
const QRect r = rect();
p.fillRect(r, palette().base());
const QRect tr = trackRect().adjusted(0, 8, 0, -8);
const QColor rail = palette().mid().color();
p.setPen(Qt::NoPen);
p.setBrush(rail);
p.drawRoundedRect(tr, 6, 6);
// selection range
if (m_selStart >= 0 && m_selEnd >= 0) {
const int x0 = frameToX(m_selStart);
const int x1 = frameToX(m_selEnd);
QRect sel(QPoint(std::min(x0, x1), tr.top()), QPoint(std::max(x0, x1), tr.bottom()));
sel = sel.adjusted(0, 2, 0, -2);
QColor c = palette().highlight().color();
c.setAlpha(50);
p.setBrush(c);
p.drawRoundedRect(sel, 4, 4);
}
auto drawDots = [&](const QVector<int>& frames, const QColor& c, int y) {
p.setBrush(c);
p.setPen(Qt::NoPen);
for (int f : frames) {
if (f < m_start || f > m_end) continue;
const int x = frameToX(f);
const bool sel =
(m_selKeyFrame == f)
&& ((m_selKeyKind == KeyKind::Image && &frames == &m_imgFrames)
|| (m_selKeyKind == KeyKind::Location && &frames == &m_locFrames)
|| (m_selKeyKind == KeyKind::UserScale && &frames == &m_scaleFrames));
if (sel) {
p.setPen(QPen(palette().highlight().color(), 2.0));
p.setBrush(c);
p.drawEllipse(QPointF(x, y), 4.4, 4.4);
p.setPen(Qt::NoPen);
} else {
p.drawEllipse(QPointF(x, y), 2.6, 2.6);
}
}
};
const int yMid = tr.center().y();
drawDots(m_imgFrames, QColor(80, 160, 255, 230), yMid - 6);
drawDots(m_locFrames, QColor(255, 120, 0, 230), yMid);
drawDots(m_scaleFrames, QColor(140, 220, 140, 230), yMid + 6);
// current frame caret
const int cx = frameToX(m_currentFrame);
p.setPen(QPen(palette().highlight().color(), 2.0));
p.drawLine(QPoint(cx, tr.top() - 6), QPoint(cx, tr.bottom() + 6));
}
static bool hitDot(const QPoint& pos, int dotX, int dotY, int radiusPx) {
const int dx = pos.x() - dotX;
const int dy = pos.y() - dotY;
return (dx * dx + dy * dy) <= (radiusPx * radiusPx);
}
static int findNearestFrameInTrack(const QVector<int>& frames, int frame) {
if (frames.isEmpty()) return -1;
const auto it = std::lower_bound(frames.begin(), frames.end(), frame);
if (it == frames.begin()) return *it;
if (it == frames.end()) return frames.back();
const int a = *(it - 1);
const int b = *it;
return (std::abs(frame - a) <= std::abs(b - frame)) ? a : b;
}
static void findIntervalAround(const QVector<int>& allFrames, int frame, int& outA, int& outB) {
outA = -1;
outB = -1;
if (allFrames.size() < 2) return;
const auto it = std::upper_bound(allFrames.begin(), allFrames.end(), frame);
if (it == allFrames.begin() || it == allFrames.end()) return;
outA = *(it - 1);
outB = *it;
}
void TimelineWidget::mousePressEvent(QMouseEvent* e) {
if (e->button() == Qt::RightButton) {
emit contextMenuRequested(mapToGlobal(e->pos()), xToFrame(e->pos().x()));
return;
}
if (e->button() == Qt::LeftButton) {
m_pressPos = e->pos();
m_moved = false;
m_dragging = true;
setFrameInternal(xToFrame(e->pos().x()), false);
e->accept();
return;
}
QWidget::mousePressEvent(e);
}
void TimelineWidget::mouseMoveEvent(QMouseEvent* e) {
if (m_dragging) {
if ((e->pos() - m_pressPos).manhattanLength() > 3) {
m_moved = true;
}
setFrameInternal(xToFrame(e->pos().x()), false);
e->accept();
return;
}
QWidget::mouseMoveEvent(e);
}
void TimelineWidget::mouseReleaseEvent(QMouseEvent* e) {
if (m_dragging && e->button() == Qt::LeftButton) {
m_dragging = false;
const int f = xToFrame(e->pos().x());
setFrameInternal(f, true);
// 点击(非拖拽)时做选中:关键帧或区间
if (!m_moved) {
const QRect tr = trackRect().adjusted(0, 8, 0, -8);
const int yMid = tr.center().y();
const int yImg = yMid - 6;
const int yLoc = yMid;
const int ySc = yMid + 6;
const int rad = 7;
auto trySelectKey = [&](KeyKind kind, const QVector<int>& frames, int laneY) -> bool {
const int nearest = findNearestFrameInTrack(frames, f);
if (nearest < 0) return false;
const int x = frameToX(nearest);
if (hitDot(e->pos(), x, laneY, rad)) {
m_selKeyKind = kind;
m_selKeyFrame = nearest;
emit keyframeSelectionChanged(m_selKeyKind, m_selKeyFrame);
update();
return true;
}
return false;
};
// 先尝试命中关键帧(按 lane 优先)
if (trySelectKey(KeyKind::Image, m_imgFrames, yImg)
|| trySelectKey(KeyKind::Location, m_locFrames, yLoc)
|| trySelectKey(KeyKind::UserScale, m_scaleFrames, ySc)) {
// 选中关键帧时清掉区间
if (m_selStart >= 0 && m_selEnd >= 0) {
m_selStart = -1;
m_selEnd = -1;
emit intervalSelectionChanged(m_selStart, m_selEnd);
}
} else {
// 未命中关键帧:尝试选中由关键帧切分出的区间(使用三轨道的并集)
QVector<int> all = m_locFrames;
all += m_scaleFrames;
all += m_imgFrames;
std::sort(all.begin(), all.end());
all.erase(std::unique(all.begin(), all.end()), all.end());
int a = -1, b = -1;
findIntervalAround(all, f, a, b);
if (a >= 0 && b >= 0) {
setSelectionRange(a, b);
emit intervalSelectionChanged(m_selStart, m_selEnd);
// 选中区间时清掉关键帧选中
if (m_selKeyKind != KeyKind::None) {
m_selKeyKind = KeyKind::None;
m_selKeyFrame = -1;
emit keyframeSelectionChanged(m_selKeyKind, m_selKeyFrame);
}
}
}
}
e->accept();
return;
}
QWidget::mouseReleaseEvent(e);
}
void TimelineWidget::wheelEvent(QWheelEvent* e) {
const int delta = (e->angleDelta().y() > 0) ? 1 : -1;
setFrameInternal(m_currentFrame + delta, true);
e->accept();
}

View File

@@ -0,0 +1,69 @@
#pragma once
#include "core/domain/Project.h"
#include <QWidget>
class TimelineWidget final : public QWidget {
Q_OBJECT
public:
explicit TimelineWidget(QWidget* parent = nullptr);
void setFrameRange(int start, int end);
void setCurrentFrame(int frame);
int currentFrame() const { return m_currentFrame; }
void setSelectionRange(int start, int end); // -1,-1 清除
int selectionStart() const { return m_selStart; }
int selectionEnd() const { return m_selEnd; }
// 只显示“当前选中实体”的关键帧标记
void setKeyframeTracks(const core::Project::Entity* entityOrNull);
enum class KeyKind { None, Location, UserScale, Image };
KeyKind selectedKeyKind() const { return m_selKeyKind; }
int selectedKeyFrame() const { return m_selKeyFrame; }
bool hasSelectedKeyframe() const { return m_selKeyKind != KeyKind::None && m_selKeyFrame >= 0; }
signals:
void frameScrubbed(int frame); // 拖动中实时触发(用于实时预览)
void frameCommitted(int frame); // 松手/点击确认(用于较重的刷新)
void contextMenuRequested(const QPoint& globalPos, int frame);
void keyframeSelectionChanged(KeyKind kind, int frame);
void intervalSelectionChanged(int start, int end);
protected:
void paintEvent(QPaintEvent*) override;
void mousePressEvent(QMouseEvent*) override;
void mouseMoveEvent(QMouseEvent*) override;
void mouseReleaseEvent(QMouseEvent*) override;
void wheelEvent(QWheelEvent*) override;
private:
int xToFrame(int x) const;
int frameToX(int frame) const;
QRect trackRect() const;
void setFrameInternal(int frame, bool commit);
private:
int m_start = 0;
int m_end = 600;
int m_currentFrame = 0;
int m_selStart = -1;
int m_selEnd = -1;
bool m_dragging = false;
QPoint m_pressPos;
bool m_moved = false;
// snapshot避免频繁遍历 workspace
QVector<int> m_locFrames;
QVector<int> m_scaleFrames;
QVector<int> m_imgFrames;
KeyKind m_selKeyKind = KeyKind::None;
int m_selKeyFrame = -1;
};

347
doc/editor-workflow.md Normal file
View File

@@ -0,0 +1,347 @@
# 编辑界面功能流程说明
本文档用于定义编辑界面的核心模块、用户操作流程和各模块之间的数据流,作为实现与联调依据。
## 1. 功能模块
系统包含以下功能模块:
1. 编辑界面
2. 深度估计
3. 分层与遮罩
4. 补全与纹理延展
5. 漫游渲染与预览
6. 热点与叙事运行
其中,`热点与叙事运行` 用于内部部分场景生成多帧动态化效果,作为静态场景的局部增强能力。
## 2. 总体流程(用户视角)
用户在编辑界面的标准操作流程如下:
1. 打开程序后选择图片。
2. 对图片进行裁剪,确定编辑画布范围。
3. 对裁剪后的全部图像执行深度估计。
4. 执行深度估计。
5. 估计完成后展示深度叠加图,并支持开关切换(显示/隐藏叠加)。
6. 进入圈选与点选阶段,系统先基于深度图自动选出深度变化较明显的候选区域,再由人工圈画/点选补充与修正,确定前景目标与背景区域。
7. 将结果输入分层程序,按深度关系分离目标物体,未分离区域作为底层背景。
8. 对被分层出的区域,支持右键触发补全,输入提示词进行内容补全与纹理延展。
9. 对被分层出的区域,支持右键输入提示词进行叙事帧生成。
10. 在漫游渲染与预览中检查效果并迭代调整。
## 3. 模块详细要求
## 3.1 编辑界面
- 支持图片加载与初始化展示。
- 提供裁剪工具,输出裁剪后的工作区域。
- 深度估计默认作用于裁剪后的全部图像。
- 深度估计完成后支持深度叠加图开关。
- 支持圈选与点选两类交互,用于前景目标确认与细化。
- 右键菜单需至少包含:`补全``叙事帧生成`
### 3.1.1 GUI 总体信息架构(窗口与面板)
建议采用“单主窗口 + 多 Dock + 少量模态对话框”的结构,保证复杂流程可见且可回退。
1. 主窗口(`MainWindow`
- 顶部:菜单栏、主工具栏、阶段切换按钮。
- 中央:主画布(编辑/预览共用)。
- 左侧:工程树 + 图层树(可 tab
- 右侧:流程控制、参数设置、属性面板(可 tab
- 底部:任务状态栏与日志条。
2. 模态/半模态对话框
- 裁剪对话框:用于确认裁剪框、比例锁定。
- 提示词对话框:用于补全/叙事输入 Prompt。
- 任务详情对话框:查看失败原因与重试入口。
- 偏好设置对话框:模型端点、默认参数、快捷键。
3. 关键 Dock建议
- `流程控制 Dock`:按步骤触发(裁剪 -> 深度 -> 候选区域 -> 分层 -> 补全/叙事)。
- `图层 Dock`:背景层/对象层/补全层/叙事层管理。
- `属性 Dock`:当前选中区域或图层的参数。
- `工程树 Dock`:场景、热点、叙事节点。
- `预览 Dock`:漫游参数、播放控制、帧速率信息。
### 3.1.2 主窗口布局与阶段条(详细)
主窗口建议增加明确的“阶段条Step Bar放在工具栏下方
1. `S1 导入与裁剪`
2. `S2 深度估计(整图)`
3. `S3 候选区域与遮罩修正`
4. `S4 分层`
5. `S5 补全/纹理延展`
6. `S6 热点与叙事帧`
7. `S7 漫游预览`
每个阶段显示 4 种状态:`未开始``进行中``完成``失败`
仅允许“当前阶段 + 已完成阶段”可编辑,避免跨阶段误操作。
### 3.1.3 画布交互设计(核心)
画布(建议继续以 `CanvasWidget` 为核心)需支持以下显示层与交互模式:
1. 显示层(可开关)
- 原图层
- 深度叠加层(支持透明度 0~100
- 自动候选区域层(描边 + 半透明填充)
- 人工遮罩层(新增/擦除轨迹)
- 分层结果层(前景、中景、背景)
- 叙事帧预览层(时间轴预览时启用)
2. 交互模式
- 浏览模式:平移、缩放、查看。
- 裁剪模式:拖拽裁剪框、锁定比例、确认/取消。
- 圈选模式:套索或画笔新增区域。
- 点选模式:点击候选区域进行“收录/排除”。
- 擦除模式:从最终遮罩中删除误选区域。
- 热点编辑模式:绘制热点框并绑定叙事节点。
3. 右键菜单(在画布选区上触发)
- `补全...`(打开 Prompt 对话框)
- `叙事帧生成...`(打开 Prompt 对话框 + 帧数参数)
- `加入前景层` / `加入背景层`
- `从遮罩中移除`
- `复制遮罩到新图层`
### 3.1.4 流程控制 Dock按钮与状态
流程控制 Dock 建议替代“单个开始处理”按钮,拆为分步触发:
1. `开始裁剪` / `应用裁剪`
2. `执行深度估计(整图)`
3. `生成自动候选区域`
4. `确认遮罩并分层`
5. `对选中区域补全`
6. `对选中区域生成叙事帧`
7. `进入漫游预览`
每个按钮旁应有状态灯(灰/蓝/绿/红)与耗时。
失败时在同一行提供 `重试``查看详情`
### 3.1.5 图层 Dock建议新增
图层 Dock 至少包含以下节点:
- `Base`(裁剪后的底图)
- `DepthOverlay`(仅可视化)
- `AutoCandidates`(自动候选区域)
- `MaskFinal`(自动 + 人工修正结果)
- `Foreground_i`(可多个)
- `Background`
- `Inpaint_i`
- `NarrativeFrames_i`(关联某一热点/区域)
每个图层支持:可见性、锁定、重命名、透明度、删除(受阶段约束)。
### 3.1.5A 工程树对象模型(本次需求)
工程树中的所有节点统一称为“对象Object按层级表达空间远近关系与遮挡关系。
1. 对象类型定义
- 背景对象Background Object导入图像默认生成的对象作为底层背景。
- 实体Entity由分层直接得到的对象或由补全后得到的对象。
- 活动实体Active Entity由实体派生已生成叙事动画帧的对象。
2. 继承/派生关系
- `背景对象` 不派生自其他对象。
- `实体` 可由背景对象或其他实体拆分得到。
- `活动实体` 必须派生自实体(实体 -> 活动实体)。
3. 层级与遮挡规则
- 工程树越往下,层级越低,表示距离越远。
- 距离近的对象会遮挡距离远的对象。
- 渲染建议采用“先远后近”的顺序(深层节点先绘制,浅层节点后绘制)。
4. 动画语义
- 实体默认是静态对象。
- 活动实体支持循环播放叙事帧(自然动画)。
- 活动实体同时支持触发动画(例如点击热点、时间线事件触发)。
5. 工程树最低能力
- 支持新增多个对象(背景对象、实体、活动实体)。
- 支持父子关系(实体下继续派生实体或活动实体)。
- 支持对象删除(删除父对象时一并删除子对象)。
- 支持在属性区查看对象类型与动画能力状态。
### 3.1.6 属性 Dock按对象动态切换
1. 选中“自动候选区域”时
- 显示候选评分、面积、平均深度、边界平滑系数。
2. 选中“遮罩”时
- 显示画笔大小、羽化、腐蚀/膨胀。
3. 选中“补全任务”时
- 显示 Prompt、负向提示词、步数、强度、随机种子。
4. 选中“叙事帧任务”时
- 显示 Prompt、目标帧数、帧率、运动幅度、循环方式。
5. 选中“热点”时
- 显示标题、描述、绑定叙事序列、触发方式。
### 3.1.7 任务栏与日志区(建议新增)
底部状态区建议包含:
- 当前阶段与子任务状态(例如“分层计算 67%”)。
- 最近一次接口调用耗时(深度/分层/补全/叙事)。
- 错误摘要(可点击展开完整日志)。
- 后台任务队列(允许补全与叙事并行排队)。
### 3.1.8 结合现有代码的改造清单(`client/gui`
以下为现有 GUI 与目标流程不匹配点,以及建议修改方向:
1. `MainWindow` 当前以“单次开始处理”为主,缺少分阶段流程控制
- 现状:`onProcessingStartRequested()` 串行触发,且 `onDepthEstimationFinished()` 中直接进入分层。
- 建议:拆为 `onCropConfirmed``onDepthRequested``onCandidatesRequested``onLayeringRequested` 等独立槽函数,并引入阶段状态机。
2. 缺少裁剪窗口/裁剪模式
- 现状:`onOpenImage()` 直接载入整图到画布。
- 建议:新增裁剪对话框(如 `CropDialog`)或在 `CanvasWidget` 增加 `Crop` 交互模式,确认后生成工作图并替换当前输入。
3. `CanvasWidget` 交互模式不足
- 现状:仅 `View``EditHotspot`
- 建议:扩展为 `Crop``MaskBrushAdd``MaskBrushErase``CandidatePick``PromptRegionSelect` 等模式,并增加对应信号(例如 `maskEdited``candidateToggled`)。
4. 缺少画布右键菜单(补全/叙事生成)
- 现状:右键菜单仅在工程树中提供“删除热点”。
- 建议:在画布选中区域上提供右键菜单,接入补全和叙事任务创建。
5. 缺少图层管理 Dock
- 现状:有工程树与属性面板,但没有分层图层树。
- 建议:新增 `LayerPanel`(可作为新 Dock统一管理前景/背景/补全/叙事层可见性与顺序。
6. 预处理 Dock 语义不完整
- 现状:`ProcessingPanel` 偏模型选择与“一键开始”。
- 建议:重构为“流程控制 + 模型参数”双区结构;模型参数保留,执行入口拆成分步按钮。
7. 深度叠加控制不足
- 现状:`MainWindow` 中仅有分层预览开关(`m_layerPreviewCheck`)。
- 建议:增加“深度叠加开关 + 透明度滑条 + 深度色图选择”,并在 `CanvasWidget` 绘制层中实现。
8. 任务完成后自动隐藏面板不利于调试
- 现状:`onInpaintFinished()` 使用定时器自动隐藏 `m_processingDock`
- 建议:改为默认不自动隐藏,仅在用户手动收起时隐藏。
9. 热点叙事仅文本节点,未覆盖多帧动态任务
- 现状:热点主要绑定 `NarrativeNode` 文本说明。
- 建议:为热点增加“叙事帧任务列表”与预览入口,支持每个热点关联多组动态帧。
10. 缺少撤销/重做与历史快照
- 现状:遮罩、分层、补全、叙事结果缺少统一历史机制。
- 建议引入编辑命令栈Command Pattern并在工具栏提供撤销/重做。
### 3.1.9 推荐新增/调整的 GUI 类
建议在 `client/gui` 增加或调整以下类(命名可按现有风格微调):
- `CropDialog`:裁剪确认窗口。
- `LayerPanel`:图层树与图层操作。
- `TaskPanel`:任务队列与执行状态。
- `PromptDialog`:补全/叙事 Prompt 与高级参数输入。
- `WorkflowController`(可先放 `MainWindow` 内部):统一阶段状态机与按钮可用性。
- `CanvasWidget`:扩展多交互模式与右键上下文菜单能力。
### 3.1.10 GUI 验收标准(编辑界面维度)
1. 用户在不离开主窗口的情况下可完整完成“裁剪 -> 深度 -> 候选修正 -> 分层 -> 补全/叙事 -> 预览”。
2. 每一阶段均有明确可视状态与失败重试入口。
3. 画布右键可直接触发补全与叙事,且默认绑定当前选区。
4. 图层可见性与锁定状态可稳定控制渲染结果。
5. 热点不仅可编辑文本,还可绑定并预览多帧叙事结果。
## 3.2 深度估计
- 输入:用户裁剪后的图像(整图)。
- 输出:深度图(与编辑画布对齐)以及可视化叠加图。
- 要求:结果可回传编辑界面用于后续圈选、点选和分层计算。
## 3.3 分层与遮罩
- 输入:深度图、系统自动候选区域、用户圈选/点选结果、人工修正遮罩。
- 输出:前景层(一个或多个对象层)与背景底层。
- 要求:
- 按深度关系优先分离目标物体。
- 保留可编辑遮罩,支持后续补全和叙事生成直接复用。
### 3.3.1 自动候选区域筛选规则(建议)
为保证“先自动、再人工补充”的效率,建议在深度图上执行以下候选区域提取流程:
1. 深度预处理
- 对深度图进行归一化到 `[0, 1]`
- 使用轻量平滑(如 3x3 中值滤波或双边滤波)抑制噪声,避免过碎片区域。
2. 深度变化检测
- 计算深度梯度幅值(可使用 Sobel
- 以阈值 `T_grad`(默认 0.12)筛选“深度变化较明显”像素,形成初始候选掩码。
3. 连通域与面积过滤
- 对候选掩码进行连通域分析。
- 过滤面积过小区域(默认最小面积为裁剪图总像素的 `0.2%`),减少噪声候选。
4. 形态学修正
- 执行一次闭运算填补小孔洞。
- 执行一次开运算移除细小毛刺,提升候选边界可用性。
5. 候选排序与展示
- 按“区域面积 + 平均梯度强度”综合评分排序。
- 默认展示 Top-K建议 K=5作为可点选候选区域。
6. 人工补充与修正
- 用户可通过圈选新增候选外区域。
- 用户可通过点选/擦除剔除误检区域。
- 最终输出为“自动候选 + 人工修正”的统一遮罩,进入分层流程。
参数建议支持在设置面板中可配:`T_grad`、最小面积比例、Top-K、平滑强度。
## 3.4 补全与纹理延展
- 触发方式:用户对分层区域右键选择 `补全`
- 输入:目标层或遮罩区域 + 文本提示词Prompt
- 输出:补全结果图层,用于修复空洞、扩展纹理或增强局部细节。
- 要求:
- 补全结果与原图层对齐。
- 支持重复执行与结果覆盖/回退。
## 3.5 漫游渲染与预览
- 将分层结果组织为可漫游场景进行实时预览。
- 支持查看层间深度关系导致的视差效果。
- 支持回到编辑阶段继续修改并再次预览(闭环迭代)。
## 3.6 热点与叙事运行(内部能力)
- 面向内部选定场景,不作为默认对外能力。
- 触发方式:用户右键选择目标区域并输入提示词。
- 输出:多帧叙事动态结果(局部区域动态化)。
- 目标:在静态场景中对关键区域生成连续帧,强化叙事表达。
- 说明:该模块与补全共享部分输入(区域与提示词),但输出为时间序列帧。
## 4. 关键数据流
1. 原图 -> 裁剪图
2. 裁剪图(整图) -> 深度图 + 深度叠加图
3. 深度图 + 圈选/点选 + 人工遮罩 -> 分层结果(前景层、背景层)
4. 分层区域 + Prompt -> 补全结果图层
5. 分层区域 + Prompt -> 叙事多帧结果
6. 分层与生成结果 -> 漫游渲染预览
## 5. 交互与状态建议
- 建议提供统一的图层面板,显示前景层、背景层、补全层、叙事层。
- 深度叠加图开关应为全局可见状态,便于在不同阶段快速核对。
- 右键操作应绑定当前选中区域,避免误触发到非目标层。
- 对补全和叙事生成增加任务状态:`待执行``执行中``完成``失败`
- 建议保留操作历史,支持撤销/重做,便于快速迭代。
## 6. 验收要点(首版)
- 可以完整跑通“加载 -> 裁剪 -> 深度估计 -> 分层 -> 补全/叙事 -> 预览”的链路。
- 深度叠加图可正常显示和关闭。
- 圈选与点选结果可正确作用于分层。
- 自动候选区域应可稳定生成,并支持人工补充与修正后进入分层。
- 右键补全和右键叙事生成均可输入提示词并产出结果。
- 热点与叙事运行可在至少一个内部场景产出多帧动态化效果。

0
doc/editor.md Normal file
View File

90
doc/models.md Normal file
View File

@@ -0,0 +1,90 @@
# 后端模型处理
当前后端主要围绕四类模型提供服务:深度估计、语义分割、图像补全和动画生成。
前端通过 GET /models 获取模型列表和参数配置,用来动态生成 UI推理接口分别为
POST /depth
POST /segment
POST /inpaint
POST /animate
## 一、深度估计
输入一张 RGB 图像,输出每个像素的相对深度,用于后续的分层和视差计算。
这一部分是整个伪3D效果的基础深度质量直接决定最终效果上限。
模型:
* ZoeDepthhttps://github.com/isl-org/ZoeDepth.git
* Depth Anything v2https://github.com/DepthAnything/Depth-Anything-V2.git
* MiDaShttps://github.com/isl-org/MiDaS.git
* DPThttps://github.com/isl-org/DPT.git
接口说明
HTTPPOST /depth
请求体DepthRequest
实现models_depth.py 中的 run_depth_inference
## 二、语义分割
对图像进行像素级分区,用于辅助分层(天空 / 山 / 地面 / 建筑等)。
在伪3D流程中这一步主要解决一个问题
哪里可以拆开,哪里必须保持整体
模型:
* Mask2Formerhttps://github.com/facebookresearch/Mask2Former.git
* SAMhttps://github.com/facebookresearch/segment-anything.git
接口说明
HTTPPOST /segment
请求体SegmentRequest
实现models_segmentation.py 中的 run_segmentation_inference
## 三、图像补全
在进行视差变换或分层后,图像中会出现“空洞区域”,需要通过生成模型进行补全。
这一部分主要影响最终画面的“真实感”。
模型:
* SDXL Inpaintinghttps://github.com/AyushUnleashed/sdxl-inpaint.git
* ControlNethttps://github.com/lllyasviel/ControlNet.git
接口说明
HTTPPOST /inpaint
请求体InpaintRequest
实现models_inpaint.py 中的 run_inpaint_inference
## 四、动画生成
通过文本提示词生成短动画GIF用于从静态描述快速预览动态镜头效果。
这部分当前接入 AnimateDiff并通过统一后端接口对外提供调用能力。
模型:
* AnimateDiffhttps://github.com/guoyww/animatediff.git
接口说明
HTTPPOST /animate
请求体AnimateRequest
实现:`python_server/model/Animation/animation_loader.py` + `python_server/server.py` 中的 `animate`

1
python_server/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
outputs/

View File

@@ -0,0 +1 @@

223
python_server/config.py Normal file
View File

@@ -0,0 +1,223 @@
"""
python_server 的统一配置文件。
特点:
- 使用 Python 而不是 YAML方便在代码中集中列举所有可用模型供前端读取。
- 后端加载模型时,也从这里读取默认值,保证单一信息源。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal, TypedDict, List
from model.Depth.zoe_loader import ZoeModelName
# -----------------------------
# 1. 深度模型枚举(给前端展示用)
# -----------------------------
class DepthModelInfo(TypedDict):
id: str # 唯一 ID如 "zoedepth_n"
family: str # 模型家族,如 "ZoeDepth"
name: str # 展示名,如 "ZoeD_N (NYU+KITT)"
description: str # 简短描述
backend: str # 后端类型,如 "zoedepth", "depth_anything_v2", "midas", "dpt"
DEPTH_MODELS: List[DepthModelInfo] = [
# ZoeDepth 系列
{
"id": "zoedepth_n",
"family": "ZoeDepth",
"name": "ZoeD_N",
"description": "ZoeDepth zero-shot 模型,适用室内/室外通用场景。",
"backend": "zoedepth",
},
{
"id": "zoedepth_k",
"family": "ZoeDepth",
"name": "ZoeD_K",
"description": "ZoeDepth Kitti 专用版本,针对户外驾驶场景优化。",
"backend": "zoedepth",
},
{
"id": "zoedepth_nk",
"family": "ZoeDepth",
"name": "ZoeD_NK",
"description": "ZoeDepth 双头版本NYU+KITTI综合室内/室外场景。",
"backend": "zoedepth",
},
# 预留Depth Anything v2
{
"id": "depth_anything_v2_s",
"family": "Depth Anything V2",
"name": "Depth Anything V2 Small",
"description": "轻量级 Depth Anything V2 小模型。",
"backend": "depth_anything_v2",
},
# 预留MiDaS
{
"id": "midas_dpt_large",
"family": "MiDaS",
"name": "MiDaS DPT Large",
"description": "MiDaS DPT-Large 高质量深度模型。",
"backend": "midas",
},
# 预留DPT
{
"id": "dpt_large",
"family": "DPT",
"name": "DPT Large",
"description": "DPT Large 单目深度估计模型。",
"backend": "dpt",
},
]
# -----------------------------
# 1.2 补全模型枚举(给前端展示用)
# -----------------------------
class InpaintModelInfo(TypedDict):
id: str
family: str
name: str
description: str
backend: str # "sdxl_inpaint" | "controlnet"
INPAINT_MODELS: List[InpaintModelInfo] = [
{
"id": "sdxl_inpaint",
"family": "SDXL",
"name": "SDXL Inpainting",
"description": "基于 diffusers 的 SDXL 补全管线(需要 prompt + mask",
"backend": "sdxl_inpaint",
},
{
"id": "controlnet",
"family": "ControlNet",
"name": "ControlNet (placeholder)",
"description": "ControlNet 补全/控制生成(当前统一封装暂未实现)。",
"backend": "controlnet",
},
]
# -----------------------------
# 1.3 动画模型枚举(给前端展示用)
# -----------------------------
class AnimationModelInfo(TypedDict):
id: str
family: str
name: str
description: str
backend: str # "animatediff"
ANIMATION_MODELS: List[AnimationModelInfo] = [
{
"id": "animatediff",
"family": "AnimateDiff",
"name": "AnimateDiff (Text-to-Video)",
"description": "基于 AnimateDiff 的文生动画能力,输出 GIF 动画。",
"backend": "animatediff",
},
]
# -----------------------------
# 2. 后端默认配置(给服务端用)
# -----------------------------
@dataclass
class DepthConfig:
# 深度后端选择:前端不参与选择;只允许在后端配置中切换
backend: Literal["zoedepth", "depth_anything_v2", "dpt", "midas"] = "zoedepth"
# ZoeDepth 家族默认选择
zoe_model: ZoeModelName = "ZoeD_N"
# Depth Anything V2 默认 encoder
da_v2_encoder: Literal["vits", "vitb", "vitl", "vitg"] = "vitl"
# DPT 默认模型类型
dpt_model_type: Literal["dpt_large", "dpt_hybrid"] = "dpt_large"
# MiDaS 默认模型类型
midas_model_type: Literal[
"dpt_beit_large_512",
"dpt_swin2_large_384",
"dpt_swin2_tiny_256",
"dpt_levit_224",
] = "dpt_beit_large_512"
# 统一的默认运行设备
device: str = "cuda"
@dataclass
class InpaintConfig:
# 统一补全默认后端
backend: Literal["sdxl_inpaint", "controlnet"] = "sdxl_inpaint"
# SDXL Inpaint 的基础模型(可写 HuggingFace model id 或本地目录)
sdxl_base_model: str = "stabilityai/stable-diffusion-xl-base-1.0"
# ControlNet Inpaint 基础模型与 controlnet 权重
controlnet_base_model: str = "runwayml/stable-diffusion-inpainting"
controlnet_model: str = "lllyasviel/control_v11p_sd15_inpaint"
device: str = "cuda"
@dataclass
class AnimationConfig:
# 统一动画默认后端
backend: Literal["animatediff"] = "animatediff"
# AnimateDiff 根目录(相对 python_server/ 或绝对路径)
animate_diff_root: str = "model/Animation/AnimateDiff"
# 文生图基础模型HuggingFace model id 或本地目录)
pretrained_model_path: str = "runwayml/stable-diffusion-v1-5"
# AnimateDiff 推理配置
inference_config: str = "configs/inference/inference-v3.yaml"
# 运动模块与个性化底模(为空则由脚本按默认处理)
motion_module: str = "v3_sd15_mm.ckpt"
dreambooth_model: str = "realisticVisionV60B1_v51VAE.safetensors"
lora_model: str = ""
lora_alpha: float = 0.8
# 部分环境 xformers 兼容性差,可手动关闭
without_xformers: bool = False
device: str = "cuda"
@dataclass
class AppConfig:
# 使用 default_factory 避免 dataclass 的可变默认值问题
depth: DepthConfig = field(default_factory=DepthConfig)
inpaint: InpaintConfig = field(default_factory=InpaintConfig)
animation: AnimationConfig = field(default_factory=AnimationConfig)
# 后端代码直接 import DEFAULT_CONFIG 即可
DEFAULT_CONFIG = AppConfig()
def list_depth_models() -> List[DepthModelInfo]:
"""
返回所有可用深度模型的元信息,方便前端通过 /models 等接口读取。
"""
return DEPTH_MODELS
def list_inpaint_models() -> List[InpaintModelInfo]:
"""
返回所有可用补全模型的元信息,方便前端通过 /models 等接口读取。
"""
return INPAINT_MODELS
def list_animation_models() -> List[AnimationModelInfo]:
"""
返回所有可用动画模型的元信息,方便前端通过 /models 等接口读取。
"""
return ANIMATION_MODELS

View File

@@ -0,0 +1,153 @@
"""
兼容层:从 Python 配置模块中构造 zoe_loader 需要的 ZoeConfig。
后端其它代码尽量只依赖这里的函数,而不直接依赖 config.py 的具体结构,
便于以后扩展。
"""
from model.Depth.zoe_loader import ZoeConfig
from model.Depth.depth_anything_v2_loader import DepthAnythingV2Config
from model.Depth.dpt_loader import DPTConfig
from model.Depth.midas_loader import MiDaSConfig
from config import AppConfig, DEFAULT_CONFIG
def load_app_config() -> AppConfig:
"""
当前直接返回 DEFAULT_CONFIG。
如未来需要支持多环境 / 覆盖配置,可以在这里加逻辑。
"""
return DEFAULT_CONFIG
def build_zoe_config_from_app(app_cfg: AppConfig | None = None) -> ZoeConfig:
"""
将 AppConfig.depth 映射为 ZoeConfig供 zoe_loader 使用。
如果未显式传入 app_cfg则使用全局 DEFAULT_CONFIG。
"""
if app_cfg is None:
app_cfg = load_app_config()
return ZoeConfig(
model=app_cfg.depth.zoe_model,
device=app_cfg.depth.device,
)
def build_depth_anything_v2_config_from_app(
app_cfg: AppConfig | None = None,
) -> DepthAnythingV2Config:
"""
将 AppConfig.depth 映射为 DepthAnythingV2Config。
"""
if app_cfg is None:
app_cfg = load_app_config()
return DepthAnythingV2Config(
encoder=app_cfg.depth.da_v2_encoder,
device=app_cfg.depth.device,
)
def build_dpt_config_from_app(app_cfg: AppConfig | None = None) -> DPTConfig:
if app_cfg is None:
app_cfg = load_app_config()
return DPTConfig(
model_type=app_cfg.depth.dpt_model_type,
device=app_cfg.depth.device,
)
def build_midas_config_from_app(app_cfg: AppConfig | None = None) -> MiDaSConfig:
if app_cfg is None:
app_cfg = load_app_config()
return MiDaSConfig(
model_type=app_cfg.depth.midas_model_type,
device=app_cfg.depth.device,
)
def get_depth_backend_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.depth.backend
def get_inpaint_backend_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.backend
def get_sdxl_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.sdxl_base_model
def get_controlnet_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.controlnet_base_model
def get_controlnet_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.inpaint.controlnet_model
def get_animation_backend_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.backend
def get_animatediff_root_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.animate_diff_root
def get_animatediff_pretrained_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.pretrained_model_path
def get_animatediff_inference_config_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.inference_config
def get_animatediff_motion_module_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.motion_module
def get_animatediff_dreambooth_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.dreambooth_model
def get_animatediff_lora_model_from_app(app_cfg: AppConfig | None = None) -> str:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.lora_model
def get_animatediff_lora_alpha_from_app(app_cfg: AppConfig | None = None) -> float:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.lora_alpha
def get_animatediff_without_xformers_from_app(app_cfg: AppConfig | None = None) -> bool:
if app_cfg is None:
app_cfg = load_app_config()
return app_cfg.animation.without_xformers

Submodule python_server/model/Animation/AnimateDiff added at e92bd5671b

View File

@@ -0,0 +1,12 @@
from .animation_loader import (
AnimationBackend,
UnifiedAnimationConfig,
build_animation_predictor,
)
__all__ = [
"AnimationBackend",
"UnifiedAnimationConfig",
"build_animation_predictor",
]

View File

@@ -0,0 +1,268 @@
from __future__ import annotations
"""
Unified animation model loading entry.
Current support:
- AnimateDiff (script-based invocation)
"""
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
import json
import os
import subprocess
import sys
import tempfile
from typing import Callable
from config_loader import (
load_app_config,
get_animatediff_root_from_app,
get_animatediff_pretrained_model_from_app,
get_animatediff_inference_config_from_app,
get_animatediff_motion_module_from_app,
get_animatediff_dreambooth_model_from_app,
get_animatediff_lora_model_from_app,
get_animatediff_lora_alpha_from_app,
get_animatediff_without_xformers_from_app,
)
class AnimationBackend(str, Enum):
ANIMATEDIFF = "animatediff"
@dataclass
class UnifiedAnimationConfig:
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
# Optional overrides. If None, values come from app config.
animate_diff_root: str | None = None
pretrained_model_path: str | None = None
inference_config: str | None = None
motion_module: str | None = None
dreambooth_model: str | None = None
lora_model: str | None = None
lora_alpha: float | None = None
without_xformers: bool | None = None
controlnet_path: str | None = None
controlnet_config: str | None = None
def _yaml_string(value: str) -> str:
return json.dumps(value, ensure_ascii=False)
def _resolve_root(root_cfg: str) -> Path:
root = Path(root_cfg)
if not root.is_absolute():
root = Path(__file__).resolve().parents[2] / root_cfg
return root.resolve()
def _make_animatediff_predictor(
cfg: UnifiedAnimationConfig,
) -> Callable[..., Path]:
app_cfg = load_app_config()
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
script_path = root / "scripts" / "animate.py"
samples_dir = root / "samples"
samples_dir.mkdir(parents=True, exist_ok=True)
if not script_path.is_file():
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
pretrained_model_path = (
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
)
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
dreambooth_model = cfg.dreambooth_model
if dreambooth_model is None:
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
lora_model = cfg.lora_model
if lora_model is None:
lora_model = get_animatediff_lora_model_from_app(app_cfg)
lora_alpha = cfg.lora_alpha
if lora_alpha is None:
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
without_xformers = cfg.without_xformers
if without_xformers is None:
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
def _predict(
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 25,
guidance_scale: float = 8.0,
width: int = 512,
height: int = 512,
video_length: int = 16,
seed: int = -1,
control_image_path: str | None = None,
output_format: str = "gif",
) -> Path:
if output_format not in {"gif", "png_sequence"}:
raise ValueError("output_format must be 'gif' or 'png_sequence'")
prompt_value = prompt.strip()
if not prompt_value:
raise ValueError("prompt must not be empty")
negative_prompt_value = negative_prompt or ""
motion_module_line = (
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
)
dreambooth_line = (
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
)
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
control_image_line = ""
if control_image_path:
control_image = Path(control_image_path).expanduser().resolve()
if not control_image.is_file():
raise FileNotFoundError(f"control_image_path not found: {control_image}")
control_image_line = (
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
" controlnet_images:\n"
f' - {_yaml_string(str(control_image))}\n'
" controlnet_image_indexs:\n"
" - 0\n"
)
config_text = (
"- prompt:\n"
f" - {_yaml_string(prompt_value)}\n"
" n_prompt:\n"
f" - {_yaml_string(negative_prompt_value)}\n"
f" steps: {int(num_inference_steps)}\n"
f" guidance_scale: {float(guidance_scale)}\n"
f" W: {int(width)}\n"
f" H: {int(height)}\n"
f" L: {int(video_length)}\n"
" seed:\n"
f" - {int(seed)}\n"
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
)
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
cfg_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",
prefix="animatediff_cfg_",
dir=str(root),
delete=False,
encoding="utf-8",
)
cfg_file_path = Path(cfg_file.name)
try:
cfg_file.write(config_text)
cfg_file.flush()
cfg_file.close()
cmd = [
sys.executable,
str(script_path),
"--pretrained-model-path",
pretrained_model_path,
"--inference-config",
inference_config,
"--config",
str(cfg_file_path),
"--L",
str(int(video_length)),
"--W",
str(int(width)),
"--H",
str(int(height)),
]
if without_xformers:
cmd.append("--without-xformers")
if output_format == "png_sequence":
cmd.append("--save-png-sequence")
env = dict(os.environ)
existing_pythonpath = env.get("PYTHONPATH", "")
root_pythonpath = str(root)
env["PYTHONPATH"] = (
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
)
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
return subprocess.run(
command,
cwd=str(root),
check=True,
capture_output=True,
text=True,
env=env,
)
try:
proc = _run_once(cmd)
except subprocess.CalledProcessError as first_error:
stderr_text = first_error.stderr or ""
should_retry_without_xformers = (
not without_xformers
and "--without-xformers" not in cmd
and (
"memory_efficient_attention" in stderr_text
or "AcceleratorError" in stderr_text
or "invalid configuration argument" in stderr_text
)
)
if not should_retry_without_xformers:
raise
retry_cmd = [*cmd, "--without-xformers"]
proc = _run_once(retry_cmd)
_ = proc
except subprocess.CalledProcessError as e:
raise RuntimeError(
"AnimateDiff inference failed.\n"
f"stdout:\n{e.stdout}\n"
f"stderr:\n{e.stderr}"
) from e
finally:
try:
cfg_file_path.unlink(missing_ok=True)
except Exception:
pass
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
if not candidates:
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
if not candidates:
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
if output_format == "png_sequence":
frames_root = latest / "sample_frames"
if not frames_root.is_dir():
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not frame_dirs:
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
return frame_dirs[0].resolve()
return (latest / "sample.gif").resolve()
return _predict
def build_animation_predictor(
cfg: UnifiedAnimationConfig | None = None,
) -> tuple[Callable[..., Path], AnimationBackend]:
cfg = cfg or UnifiedAnimationConfig()
if cfg.backend == AnimationBackend.ANIMATEDIFF:
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
raise ValueError(f"Unsupported animation backend: {cfg.backend}")

Submodule python_server/model/Depth/DPT added at cd3fe90bb4

Submodule python_server/model/Depth/Depth-Anything-V2 added at e5a2732d3e

Submodule python_server/model/Depth/MiDaS added at 454597711a

Submodule python_server/model/Depth/ZoeDepth added at d87f17b2f5

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,147 @@
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
_THIS_DIR = Path(__file__).resolve().parent
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
if _DA_REPO_ROOT.is_dir():
da_path = str(_DA_REPO_ROOT)
if da_path not in sys.path:
sys.path.insert(0, da_path)
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
@dataclass
class DepthAnythingV2Config:
"""
Depth Anything V2 模型选择配置。
encoder: "vits" | "vitb" | "vitl" | "vitg"
device: "cuda" | "cpu"
input_size: 推理时的输入分辨率(短边),参考官方 demo默认 518。
"""
encoder: EncoderName = "vitl"
device: str = "cuda"
input_size: int = 518
_MODEL_CONFIGS = {
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
}
_DA_V2_WEIGHTS_URLS = {
# 官方权重托管在 HuggingFace:
# - Small -> vits
# - Base -> vitb
# - Large -> vitl
# - Giant -> vitg
# 如需替换为国内镜像,可直接修改这些 URL。
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
}
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _DA_V2_WEIGHTS_URLS.get(encoder)
if not url:
raise FileNotFoundError(
f"找不到权重文件: {ckpt_path}\n"
f"且当前未为 encoder='{encoder}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\n权重下载完成。")
def load_depth_anything_v2_from_config(
cfg: DepthAnythingV2Config,
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
"""
根据配置加载 Depth Anything V2 模型与对应配置。
说明:
- 权重文件路径遵循官方命名约定:
checkpoints/depth_anything_v2_{encoder}.pth
例如depth_anything_v2_vitl.pth
- 请确保上述权重文件已下载到
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
"""
if cfg.encoder not in _MODEL_CONFIGS:
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
_download_if_missing(cfg.encoder, ckpt_path)
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
state_dict = torch.load(str(ckpt_path), map_location="cpu")
model.load_state_dict(state_dict)
if cfg.device.startswith("cuda") and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model = model.to(device).eval()
cfg = DepthAnythingV2Config(
encoder=cfg.encoder,
device=device,
input_size=cfg.input_size,
)
return model, cfg
def infer_depth_anything_v2(
model: DepthAnythingV2,
image_bgr: np.ndarray,
input_size: int,
) -> np.ndarray:
"""
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
"""
depth = model.infer_image(image_bgr, input_size)
depth = np.asarray(depth, dtype="float32")
return depth

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
"""
统一的深度模型加载入口。
当前支持:
- ZoeDepth三种ZoeD_N / ZoeD_K / ZoeD_NK
- Depth Anything V2四种 encodervits / vitb / vitl / vitg
未来如果要加 MiDaS / DPT只需要在这里再接一层即可。
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
build_zoe_config_from_app,
build_depth_anything_v2_config_from_app,
build_dpt_config_from_app,
build_midas_config_from_app,
)
from .zoe_loader import load_zoe_from_config
from .depth_anything_v2_loader import (
load_depth_anything_v2_from_config,
infer_depth_anything_v2,
)
from .dpt_loader import load_dpt_from_config, infer_dpt
from .midas_loader import load_midas_from_config, infer_midas
class DepthBackend(str, Enum):
"""统一的深度模型后端类型。"""
ZOEDEPTH = "zoedepth"
DEPTH_ANYTHING_V2 = "depth_anything_v2"
DPT = "dpt"
MIDAS = "midas"
@dataclass
class UnifiedDepthConfig:
"""
统一深度配置。
backend: 使用哪个后端
device: 强制设备(可选),不填则使用 config.py 中的设置
"""
backend: DepthBackend = DepthBackend.ZOEDEPTH
device: str | None = None
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
zoe_cfg = build_zoe_config_from_app(app_cfg)
if device_override is not None:
zoe_cfg.device = device_override
model, _ = load_zoe_from_config(zoe_cfg)
def _predict(img: Image.Image) -> np.ndarray:
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
return np.asarray(depth, dtype="float32").squeeze()
return _predict
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
if device_override is not None:
da_cfg.device = device_override
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
def _predict(img: Image.Image) -> np.ndarray:
# Depth Anything V2 的 infer_image 接收 BGR uint8
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
bgr = rgb[:, :, ::-1]
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
return depth.astype("float32").squeeze()
return _predict
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
dpt_cfg = build_dpt_config_from_app(app_cfg)
if device_override is not None:
dpt_cfg.device = device_override
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
def _predict(img: Image.Image) -> np.ndarray:
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
app_cfg = load_app_config()
midas_cfg = build_midas_config_from_app(app_cfg)
if device_override is not None:
midas_cfg.device = device_override
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
def _predict(img: Image.Image) -> np.ndarray:
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
return depth.astype("float32").squeeze()
return _predict
def build_depth_predictor(
cfg: UnifiedDepthConfig | None = None,
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
"""
统一构建深度预测函数。
返回:
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
- 实际使用的 backend 类型
"""
cfg = cfg or UnifiedDepthConfig()
if cfg.backend == DepthBackend.ZOEDEPTH:
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
if cfg.backend == DepthBackend.DPT:
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
if cfg.backend == DepthBackend.MIDAS:
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
raise ValueError(f"不支持的深度后端: {cfg.backend}")

View File

@@ -0,0 +1,156 @@
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
_THIS_DIR = Path(__file__).resolve().parent
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
if _DPT_REPO_ROOT.is_dir():
dpt_path = str(_DPT_REPO_ROOT)
if dpt_path not in sys.path:
sys.path.insert(0, dpt_path)
from dpt.models import DPTDepthModel # type: ignore[import]
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
from torchvision.transforms import Compose
import cv2
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
@dataclass
class DPTConfig:
model_type: DPTModelType = "dpt_large"
device: str = "cuda"
_DPT_WEIGHTS_URLS = {
# 官方 DPT 模型权重托管在:
# https://github.com/isl-org/DPT#models
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
}
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _DPT_WEIGHTS_URLS.get(model_type)
if not url:
raise FileNotFoundError(
f"找不到 DPT 权重文件: {ckpt_path}\n"
f"且当前未为 model_type='{model_type}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\nDPT 权重下载完成。")
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
"""
加载 DPT 模型与对应的预处理 transform。
"""
ckpt_name = {
"dpt_large": "dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
}[cfg.model_type]
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
_download_if_missing(cfg.model_type, ckpt_path)
if cfg.model_type == "dpt_large":
net_w = net_h = 384
model = DPTDepthModel(
path=str(ckpt_path),
backbone="vitl16_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
else:
net_w = net_h = 384
model = DPTDepthModel(
path=str(ckpt_path),
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
model.to(device).eval()
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
cfg = DPTConfig(model_type=cfg.model_type, device=device)
return model, cfg, transform
def infer_dpt(
model: DPTDepthModel,
transform: Compose,
image_bgr: np.ndarray,
device: str,
) -> np.ndarray:
"""
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
"""
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
sample = transform({"image": img})["image"]
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
with torch.no_grad():
prediction = model(input_batch)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
return prediction.astype("float32")

View File

@@ -0,0 +1,127 @@
from dataclasses import dataclass
from typing import Literal, Tuple
import sys
from pathlib import Path
import numpy as np
import torch
import requests
_THIS_DIR = Path(__file__).resolve().parent
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
if _MIDAS_REPO_ROOT.is_dir():
midas_path = str(_MIDAS_REPO_ROOT)
if midas_path not in sys.path:
sys.path.insert(0, midas_path)
from midas.model_loader import load_model, default_models # type: ignore[import]
import utils # from MiDaS repo
MiDaSModelType = Literal[
"dpt_beit_large_512",
"dpt_swin2_large_384",
"dpt_swin2_tiny_256",
"dpt_levit_224",
]
@dataclass
class MiDaSConfig:
model_type: MiDaSModelType = "dpt_beit_large_512"
device: str = "cuda"
_MIDAS_WEIGHTS_URLS = {
# 官方权重参见 MiDaS 仓库 README
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
}
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
if ckpt_path.is_file():
return
url = _MIDAS_WEIGHTS_URLS.get(model_type)
if not url:
raise FileNotFoundError(
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
f"且当前未为 model_type='{model_type}' 配置自动下载 URL请手动下载到该路径。"
)
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
print("\nMiDaS 权重下载完成。")
def load_midas_from_config(
cfg: MiDaSConfig,
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
"""
加载 MiDaS 模型与对应 transform。
返回: model, cfg, transform, net_w, net_h
"""
# default_models 中给了默认权重路径名
model_info = default_models[cfg.model_type]
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
_download_if_missing(cfg.model_type, ckpt_path)
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
model, transform, net_w, net_h = load_model(
device=torch.device(device),
model_path=str(ckpt_path),
model_type=cfg.model_type,
optimize=False,
height=None,
square=False,
)
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
return model, cfg, transform, net_w, net_h
def infer_midas(
model: torch.nn.Module,
transform: callable,
image_rgb: np.ndarray,
net_w: int,
net_h: int,
device: str,
) -> np.ndarray:
"""
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
"""
image = transform({"image": image_rgb})["image"]
prediction = utils.process(
torch.device(device),
model,
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
image=image,
input_size=(net_w, net_h),
target_size=image_rgb.shape[1::-1],
optimize=False,
use_camera=False,
)
return np.asarray(prediction, dtype="float32").squeeze()

View File

@@ -0,0 +1,74 @@
from dataclasses import dataclass
from typing import Literal
import sys
from pathlib import Path
import torch
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
# 这样其内部的 `import zoedepth...` 才能正常工作。
_THIS_DIR = Path(__file__).resolve().parent
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
if _ZOE_REPO_ROOT.is_dir():
zoe_path = str(_ZOE_REPO_ROOT)
if zoe_path not in sys.path:
sys.path.insert(0, zoe_path)
from zoedepth.models.builder import build_model
from zoedepth.utils.config import get_config
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
@dataclass
class ZoeConfig:
"""
ZoeDepth 模型选择配置。
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
device: "cuda" | "cpu"
"""
model: ZoeModelName = "ZoeD_N"
device: str = "cuda"
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
"""
手动加载 ZoeDepth 三种模型之一:
- "ZoeD_N"
- "ZoeD_K"
- "ZoeD_NK"
"""
if name == "ZoeD_N":
conf = get_config("zoedepth", "infer")
elif name == "ZoeD_K":
conf = get_config("zoedepth", "infer", config_version="kitti")
elif name == "ZoeD_NK":
conf = get_config("zoedepth_nk", "infer")
else:
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
model = build_model(conf)
if device.startswith("cuda") and torch.cuda.is_available():
model = model.to("cuda")
else:
model = model.to("cpu")
model.eval()
return model, conf
def load_zoe_from_config(config: ZoeConfig):
"""
根据 ZoeConfig 加载模型。
示例:
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
model, conf = load_zoe_from_config(cfg)
"""
return load_zoe_from_name(config.model, config.device)

Submodule python_server/model/Inpaint/ControlNet added at ed85cd1e25

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,413 @@
from __future__ import annotations
"""
统一的补全Inpaint模型加载入口。
当前支持:
- SDXL Inpaintdiffusers AutoPipelineForInpainting
- ControlNet占位暂未统一封装
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import numpy as np
from PIL import Image
from config_loader import (
load_app_config,
get_sdxl_base_model_from_app,
get_controlnet_base_model_from_app,
get_controlnet_model_from_app,
)
class InpaintBackend(str, Enum):
SDXL_INPAINT = "sdxl_inpaint"
CONTROLNET = "controlnet"
@dataclass
class UnifiedInpaintConfig:
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
device: str | None = None
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
sdxl_base_model: str | None = None
@dataclass
class UnifiedDrawConfig:
"""
统一绘图配置:
- 纯文生图image=None
- 图生图模仿输入图image=某张参考图
"""
device: str | None = None
sdxl_base_model: str | None = None
def _resolve_device_and_dtype(device: str | None):
import torch
app_cfg = load_app_config()
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
return device, torch_dtype
def _enable_memory_opts(pipe, device: str) -> None:
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
run_w, run_h = orig_w, orig_h
if max_side > 0 and max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
return run_w, run_h
def _make_sdxl_inpaint_predictor(
cfg: UnifiedInpaintConfig,
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
"""
返回补全函数:
- 输入image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
- 输出PIL RGB 结果图
"""
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
app_cfg = load_app_config()
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
device = cfg.device
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
base_model,
torch_dtype=torch_dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
).to(device)
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
# 省显存设置(尽量不改变输出语义)
# 注意CPU offload 会明显变慢,但能显著降低显存占用。
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
def _predict(
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.8,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
max_side: int = 1024,
) -> Image.Image:
image = image.convert("RGB")
# diffusers 要求 mask 为单通道,白色区域为需要重绘
mask = mask.convert("L")
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM
# 推理时将图像按比例缩放到不超过 max_side默认 1024并对齐到 8 的倍数。
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
orig_w, orig_h = image.size
run_w, run_h = orig_w, orig_h
if max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
if run_w <= 0:
run_w = 8
if run_h <= 0:
run_h = 8
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
else:
image_run = image
mask_run = mask
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
mask_image=mask_run,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=run_w,
height=run_h,
).images[0]
out = out.convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _predict
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
app_cfg = load_app_config()
device = _.device
if device is None:
device = app_cfg.inpaint.device
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
base_model = get_controlnet_base_model_from_app(app_cfg)
controlnet_id = get_controlnet_model_from_app(app_cfg)
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None,
)
if device == "cuda":
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
try:
pipe.enable_model_cpu_offload()
except Exception:
pipe.to(device)
else:
pipe.to(device)
def _predict(
image: Image.Image,
mask: Image.Image,
prompt: str,
negative_prompt: str = "",
strength: float = 0.8,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
controlnet_conditioning_scale: float = 1.0,
max_side: int = 768,
) -> Image.Image:
import cv2
import numpy as np
image = image.convert("RGB")
mask = mask.convert("L")
orig_w, orig_h = image.size
run_w, run_h = orig_w, orig_h
if max(orig_w, orig_h) > max_side:
scale = max_side / float(max(orig_w, orig_h))
run_w = int(round(orig_w * scale))
run_h = int(round(orig_h * scale))
run_w = max(8, run_w - (run_w % 8))
run_h = max(8, run_h - (run_h % 8))
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
else:
image_run = image
mask_run = mask
# control image使用 canny 边缘作为约束(最通用)
rgb = np.array(image_run, dtype=np.uint8)
edges = cv2.Canny(rgb, 100, 200)
edges3 = np.stack([edges, edges, edges], axis=-1)
control_image = Image.fromarray(edges3)
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
mask_image=mask_run,
control_image=control_image,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
width=run_w,
height=run_h,
).images[0]
out = out.convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _predict
def build_inpaint_predictor(
cfg: UnifiedInpaintConfig | None = None,
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
"""
统一构建补全预测函数。
"""
cfg = cfg or UnifiedInpaintConfig()
if cfg.backend == InpaintBackend.SDXL_INPAINT:
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
if cfg.backend == InpaintBackend.CONTROLNET:
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
raise ValueError(f"不支持的补全后端: {cfg.backend}")
def build_draw_predictor(
cfg: UnifiedDrawConfig | None = None,
) -> Callable[..., Image.Image]:
"""
构建统一绘图函数:
- 文生图draw(prompt, image=None, ...)
- 图生图draw(prompt, image=ref_image, strength=0.55, ...)
"""
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
cfg = cfg or UnifiedDrawConfig()
app_cfg = load_app_config()
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
base_model,
torch_dtype=torch_dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
).to(device)
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
_enable_memory_opts(pipe_t2i, device)
_enable_memory_opts(pipe_i2i, device)
def _draw(
prompt: str,
image: Image.Image | None = None,
negative_prompt: str = "",
strength: float = 0.55,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
width: int = 1024,
height: int = 1024,
max_side: int = 1024,
) -> Image.Image:
prompt = prompt or ""
negative_prompt = negative_prompt or ""
if device == "cuda":
try:
torch.cuda.empty_cache()
except Exception:
pass
if image is None:
run_w, run_h = _align_size(width, height, max_side=max_side)
out = pipe_t2i(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=run_w,
height=run_h,
).images[0]
return out.convert("RGB")
image = image.convert("RGB")
orig_w, orig_h = image.size
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
if (run_w, run_h) != (orig_w, orig_h):
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
else:
image_run = image
out = pipe_i2i(
prompt=prompt,
negative_prompt=negative_prompt,
image=image_run,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
).images[0].convert("RGB")
if out.size != (orig_w, orig_h):
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
return out
return _draw

Submodule python_server/model/Inpaint/sdxl-inpaint added at 29867f540b

Submodule python_server/model/Seg/Mask2Former added at 9b0651c6c1

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Literal, Tuple
import numpy as np
from PIL import Image
@dataclass
class Mask2FormerHFConfig:
"""
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
model_id: HuggingFace 模型 id默认 ADE20K semantic
device: "cuda" | "cpu"
"""
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
device: str = "cuda"
def build_mask2former_hf_predictor(
cfg: Mask2FormerHFConfig | None = None,
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
"""
返回 predictor(image_rgb_uint8) -> label_map(int32)。
"""
cfg = cfg or Mask2FormerHFConfig()
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
model.to(device).eval()
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
@torch.no_grad()
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
pil = Image.fromarray(image_rgb_u8, mode="RGB")
inputs = processor(images=pil, return_tensors="pt").to(device)
outputs = model(**inputs)
# post-process to original size
target_sizes = [(pil.height, pil.width)]
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
return seg.detach().to("cpu").numpy().astype("int32")
return _predict, cfg

View File

@@ -0,0 +1,168 @@
from __future__ import annotations
"""
统一的分割模型加载入口。
当前支持:
- SAM (segment-anything)
- Mask2Former使用 HuggingFace transformers 的语义分割实现)
"""
from dataclasses import dataclass
from enum import Enum
from typing import Callable
import sys
from pathlib import Path
import numpy as np
_THIS_DIR = Path(__file__).resolve().parent
class SegBackend(str, Enum):
SAM = "sam"
MASK2FORMER = "mask2former"
@dataclass
class UnifiedSegConfig:
backend: SegBackend = SegBackend.SAM
# -----------------------------
# SAM (Segment Anything)
# -----------------------------
def _ensure_sam_on_path() -> Path:
sam_root = _THIS_DIR / "segment-anything"
if not sam_root.is_dir():
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
sam_path = str(sam_root)
if sam_path not in sys.path:
sys.path.insert(0, sam_path)
return sam_root
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
import requests
ckpt_dir = sam_root / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
if ckpt_path.is_file():
return ckpt_path
url = (
"https://dl.fbaipublicfiles.com/segment_anything/"
"sam_vit_h_4b8939.pth"
)
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
resp = requests.get(url, stream=True)
resp.raise_for_status()
total = int(resp.headers.get("content-length", "0") or "0")
downloaded = 0
chunk_size = 1024 * 1024
with ckpt_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=chunk_size):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if total > 0:
done = int(50 * downloaded / total)
print(
"\r[{}{}] {:.1f}%".format(
"#" * done,
"." * (50 - done),
downloaded * 100 / total,
),
end="",
)
print("\nSAM 权重下载完成。")
return ckpt_path
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
"""
返回一个分割函数:
- 输入RGB uint8 图像 (H, W, 3)
- 输出:语义标签图 (H, W),每个目标一个 int id从 1 开始)
"""
sam_root = _ensure_sam_on_path()
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](
checkpoint=str(ckpt_path),
).to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
def _predict(image_rgb: np.ndarray) -> np.ndarray:
if image_rgb.dtype != np.uint8:
image_rgb_u8 = image_rgb.astype("uint8")
else:
image_rgb_u8 = image_rgb
masks = mask_generator.generate(image_rgb_u8)
h, w, _ = image_rgb_u8.shape
label_map = np.zeros((h, w), dtype="int32")
for idx, m in enumerate(masks, start=1):
seg = m.get("segmentation")
if seg is None:
continue
label_map[seg.astype(bool)] = idx
return label_map
return _predict
# -----------------------------
# Mask2Former (占位)
# -----------------------------
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
from .mask2former_loader import build_mask2former_hf_predictor
predictor, _ = build_mask2former_hf_predictor()
return predictor
# -----------------------------
# 统一构建函数
# -----------------------------
def build_seg_predictor(
cfg: UnifiedSegConfig | None = None,
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
"""
统一构建分割预测函数。
返回:
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
- 实际使用的 backend
"""
cfg = cfg or UnifiedSegConfig()
if cfg.backend == SegBackend.SAM:
return _make_sam_predictor(), SegBackend.SAM
if cfg.backend == SegBackend.MASK2FORMER:
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
raise ValueError(f"不支持的分割后端: {cfg.backend}")

Submodule python_server/model/Seg/segment-anything added at dca509fe79

View File

@@ -0,0 +1 @@

407
python_server/server.py Normal file
View File

@@ -0,0 +1,407 @@
from __future__ import annotations
import base64
import datetime
import io
import os
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel, Field
from PIL import Image, ImageDraw
from config_loader import load_app_config, get_depth_backend_from_app
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
from model.Seg.seg_loader import UnifiedSegConfig, SegBackend, build_seg_predictor
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
from model.Animation.animation_loader import (
UnifiedAnimationConfig,
AnimationBackend,
build_animation_predictor,
)
APP_ROOT = Path(__file__).resolve().parent
OUTPUT_DIR = APP_ROOT / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title="HFUT Model Server", version="0.1.0")
class ImageInput(BaseModel):
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64不含 data: 前缀)")
model_name: Optional[str] = Field(None, description="模型 key来自 /models")
class DepthRequest(ImageInput):
pass
class SegmentRequest(ImageInput):
pass
class InpaintRequest(ImageInput):
prompt: Optional[str] = Field("", description="补全 prompt")
strength: float = Field(0.8, ge=0.0, le=1.0)
negative_prompt: Optional[str] = Field("", description="负向 prompt")
# 可选 mask白色区域为重绘
mask_b64: Optional[str] = Field(None, description="mask PNG base64可选")
# 推理缩放上限(避免 OOM
max_side: int = Field(1024, ge=128, le=2048)
class AnimateRequest(BaseModel):
model_name: Optional[str] = Field(None, description="模型 key来自 /models")
prompt: str = Field(..., description="文本提示词")
negative_prompt: Optional[str] = Field("", description="负向提示词")
num_inference_steps: int = Field(25, ge=1, le=200)
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
width: int = Field(512, ge=128, le=2048)
height: int = Field(512, ge=128, le=2048)
video_length: int = Field(16, ge=1, le=128)
seed: int = Field(-1, description="-1 表示随机种子")
def _b64_to_pil_image(b64: str) -> Image.Image:
raw = base64.b64decode(b64)
return Image.open(io.BytesIO(raw))
def _pil_image_to_png_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def _depth_to_png16_b64(depth: np.ndarray) -> str:
depth = np.asarray(depth, dtype=np.float32)
dmin = float(depth.min())
dmax = float(depth.max())
if dmax > dmin:
norm = (depth - dmin) / (dmax - dmin)
else:
norm = np.zeros_like(depth, dtype=np.float32)
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
img = Image.fromarray(u16, mode="I;16")
return _pil_image_to_png_b64(img)
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
depth = np.asarray(depth, dtype=np.float32)
dmin = float(depth.min())
dmax = float(depth.max())
if dmax > dmin:
norm = (depth - dmin) / (dmax - dmin)
else:
norm = np.zeros_like(depth, dtype=np.float32)
# 前后端约定:最远=255最近=08-bit
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
img = Image.fromarray(u8, mode="L")
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _default_half_mask(img: Image.Image) -> Image.Image:
w, h = img.size
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([w // 2, 0, w, h], fill=255)
return mask
# -----------------------------
# /models给前端/GUI 使用)
# -----------------------------
@app.get("/models")
def get_models() -> Dict[str, Any]:
"""
返回一个兼容 Qt 前端的 schema
{
"models": {
"depth": { "key": { "name": "..."} ... },
"segment": { ... },
"inpaint": { "key": { "name": "...", "params": [...] } ... }
}
}
"""
return {
"models": {
"depth": {
# 兼容旧配置默认值
"midas": {"name": "MiDaS (default)"},
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
"dpt_large": {"name": "DPT (large)"},
"dpt_hybrid": {"name": "DPT (hybrid)"},
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
},
"segment": {
"sam": {"name": "SAM (vit_h)"},
# 兼容旧配置默认值
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
"mask2former": {"name": "Mask2Former (not implemented)"},
},
"inpaint": {
# 兼容旧配置默认值copy 表示不做补全
"copy": {"name": "Copy (no-op)", "params": []},
"sdxl_inpaint": {
"name": "SDXL Inpaint",
"params": [
{"id": "prompt", "label": "提示词", "optional": True},
],
},
"controlnet": {
"name": "ControlNet Inpaint (canny)",
"params": [
{"id": "prompt", "label": "提示词", "optional": True},
],
},
},
"animation": {
"animatediff": {
"name": "AnimateDiff (Text-to-Video)",
"params": [
{"id": "prompt", "label": "提示词", "optional": False},
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
{"id": "width", "label": "宽度", "optional": True},
{"id": "height", "label": "高度", "optional": True},
{"id": "video_length", "label": "帧数", "optional": True},
{"id": "seed", "label": "随机种子", "optional": True},
],
},
},
}
}
# -----------------------------
# Depth
# -----------------------------
_depth_predictor = None
_depth_backend: DepthBackend | None = None
def _ensure_depth_predictor() -> None:
global _depth_predictor, _depth_backend
if _depth_predictor is not None and _depth_backend is not None:
return
app_cfg = load_app_config()
backend_str = get_depth_backend_from_app(app_cfg)
try:
backend = DepthBackend(backend_str)
except Exception as e:
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
@app.post("/depth")
def depth(req: DepthRequest):
"""
计算深度并直接返回二进制 PNG16-bit 灰度)。
约束:
- 前端不传/不选模型;模型选择写死在后端 config.py
- 成功HTTP 200 + Content-Type: image/png
- 失败HTTP 500detail 为错误信息
"""
try:
_ensure_depth_predictor()
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
depth_arr = _depth_predictor(pil) # type: ignore[misc]
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
return Response(content=png_bytes, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# -----------------------------
# Segment
# -----------------------------
_seg_cache: Dict[str, Any] = {}
def _get_seg_predictor(model_name: str):
if model_name in _seg_cache:
return _seg_cache[model_name]
# 兼容旧默认 key
if model_name == "mask2former_debug":
model_name = "sam"
if model_name == "sam":
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
_seg_cache[model_name] = pred
return pred
if model_name == "mask2former":
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
_seg_cache[model_name] = pred
return pred
raise ValueError(f"未知 segment model_name: {model_name}")
@app.post("/segment")
def segment(req: SegmentRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "sam"
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
rgb = np.array(pil, dtype=np.uint8)
predictor = _get_seg_predictor(model_name)
label_map = predictor(rgb).astype(np.int32)
out_dir = OUTPUT_DIR / "segment"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_label.png"
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
return {"success": True, "label_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
# -----------------------------
# Inpaint
# -----------------------------
_inpaint_cache: Dict[str, Any] = {}
def _get_inpaint_predictor(model_name: str):
if model_name in _inpaint_cache:
return _inpaint_cache[model_name]
if model_name == "copy":
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
return image.convert("RGB")
_inpaint_cache[model_name] = _copy
return _copy
if model_name == "sdxl_inpaint":
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
_inpaint_cache[model_name] = pred
return pred
if model_name == "controlnet":
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
_inpaint_cache[model_name] = pred
return pred
raise ValueError(f"未知 inpaint model_name: {model_name}")
@app.post("/inpaint")
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "sdxl_inpaint"
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
if req.mask_b64:
mask = _b64_to_pil_image(req.mask_b64).convert("L")
else:
mask = _default_half_mask(pil)
predictor = _get_inpaint_predictor(model_name)
out = predictor(
pil,
mask,
req.prompt or "",
req.negative_prompt or "",
strength=req.strength,
max_side=req.max_side,
)
out_dir = OUTPUT_DIR / "inpaint"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{model_name}_inpaint.png"
out.save(out_path)
return {"success": True, "output_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
_animation_cache: Dict[str, Any] = {}
def _get_animation_predictor(model_name: str):
if model_name in _animation_cache:
return _animation_cache[model_name]
if model_name == "animatediff":
pred, _ = build_animation_predictor(
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
)
_animation_cache[model_name] = pred
return pred
raise ValueError(f"未知 animation model_name: {model_name}")
@app.post("/animate")
def animate(req: AnimateRequest) -> Dict[str, Any]:
try:
model_name = req.model_name or "animatediff"
predictor = _get_animation_predictor(model_name)
result_path = predictor(
prompt=req.prompt,
negative_prompt=req.negative_prompt or "",
num_inference_steps=req.num_inference_steps,
guidance_scale=req.guidance_scale,
width=req.width,
height=req.height,
video_length=req.video_length,
seed=req.seed,
)
out_dir = OUTPUT_DIR / "animation"
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"{model_name}_{ts}.gif"
shutil.copy2(result_path, out_path)
return {"success": True, "output_path": str(out_path)}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from pathlib import Path
import shutil
from model.Animation.animation_loader import (
build_animation_predictor,
UnifiedAnimationConfig,
AnimationBackend,
)
# -----------------------------
# 配置区(按需修改)
# -----------------------------
OUTPUT_DIR = "outputs/test_animation"
ANIMATION_BACKEND = AnimationBackend.ANIMATEDIFF
OUTPUT_FORMAT = "png_sequence" # "gif" | "png_sequence"
PROMPT = "a cinematic mountain landscape, camera slowly pans left"
NEGATIVE_PROMPT = "blurry, low quality"
NUM_INFERENCE_STEPS = 25
GUIDANCE_SCALE = 8.0
WIDTH = 512
HEIGHT = 512
VIDEO_LENGTH = 16
SEED = -1
CONTROL_IMAGE_PATH = "path/to/your_image.png"
def main() -> None:
base_dir = Path(__file__).resolve().parent
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
predictor, used_backend = build_animation_predictor(
UnifiedAnimationConfig(backend=ANIMATION_BACKEND)
)
if CONTROL_IMAGE_PATH.strip() in {"", "path/to/your_image.png"}:
raise ValueError("请先设置 CONTROL_IMAGE_PATH 为你的输入图片路径png/jpg")
control_image = (base_dir / CONTROL_IMAGE_PATH).resolve()
if not control_image.is_file():
raise FileNotFoundError(f"control image not found: {control_image}")
result_path = predictor(
prompt=PROMPT,
negative_prompt=NEGATIVE_PROMPT,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
width=WIDTH,
height=HEIGHT,
video_length=VIDEO_LENGTH,
seed=SEED,
control_image_path=str(control_image),
output_format=OUTPUT_FORMAT,
)
source = Path(result_path)
if OUTPUT_FORMAT == "png_sequence":
out_seq_dir = out_dir / f"{used_backend.value}_frames"
if out_seq_dir.exists():
shutil.rmtree(out_seq_dir)
shutil.copytree(source, out_seq_dir)
print(f"[Animation] backend={used_backend.value}, saved={out_seq_dir}")
return
out_path = out_dir / f"{used_backend.value}.gif"
out_path.write_bytes(source.read_bytes())
print(f"[Animation] backend={used_backend.value}, saved={out_path}")
if __name__ == "__main__":
main()

101
python_server/test_depth.py Normal file
View File

@@ -0,0 +1,101 @@
from __future__ import annotations
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
# 解除大图限制
Image.MAX_IMAGE_PIXELS = None
from model.Depth.depth_loader import build_depth_predictor, UnifiedDepthConfig, DepthBackend
# ================= 配置区 =================
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
OUTPUT_DIR = "outputs/test_depth_v4"
DEPTH_BACKEND = DepthBackend.DEPTH_ANYTHING_V2
# 边缘捕捉参数
# 增大这个值会让边缘更细,减小会让边缘更粗(捕捉更多微弱信息)
EDGE_TOP_PERCENTILE = 96.0 # 选取梯度最强的前 7.0% 的像素
# 局部增强的灵敏度,建议在 2.0 - 5.0 之间
CLAHE_CLIP_LIMIT = 3.0
# 形态学核大小
MORPH_SIZE = 5
# =========================================
def _extract_robust_edges(depth_norm: np.ndarray) -> np.ndarray:
"""
通过局部增强和 Sobel 梯度提取闭合边缘
"""
# 1. 转换为 8 位灰度
depth_u8 = (depth_norm * 255).astype(np.uint8)
# 2. 【核心步骤】CLAHE 局部自适应对比度增强
# 这会强行放大古画中细微的建筑/人物深度差
clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=(16, 16))
enhanced_depth = clahe.apply(depth_u8)
# 3. 高斯模糊:减少数字化噪声
blurred = cv2.GaussianBlur(enhanced_depth, (5, 5), 0)
# 4. Sobel 算子计算梯度强度
grad_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)
grad_mag = np.sqrt(grad_x**2 + grad_y**2)
# 5. 【统计学阈值】不再死守固定数值,而是选 Top X%
threshold = np.percentile(grad_mag, EDGE_TOP_PERCENTILE)
binary_edges = (grad_mag >= threshold).astype(np.uint8) * 255
# 6. 形态学闭合:桥接裂缝,让线条连起来
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (MORPH_SIZE, MORPH_SIZE))
closed = cv2.morphologyEx(binary_edges, cv2.MORPH_CLOSE, kernel)
# 7. 再次轻微膨胀,为 SAM 提供更好的引导范围
final_mask = cv2.dilate(closed, kernel, iterations=1)
return final_mask
def main() -> None:
base_dir = Path(__file__).resolve().parent
img_path = Path(INPUT_IMAGE)
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
# 1. 初始化深度模型
predictor, used_backend = build_depth_predictor(UnifiedDepthConfig(backend=DEPTH_BACKEND))
# 2. 加载图像
print(f"[Loading] 正在处理: {img_path.name}")
img_pil = Image.open(img_path).convert("RGB")
w_orig, h_orig = img_pil.size
# 3. 深度预测
print(f"[Depth] 正在进行深度估计 (Large Image)...")
depth = np.asarray(predictor(img_pil), dtype=np.float32).squeeze()
# 4. 归一化与保存
dmin, dmax = depth.min(), depth.max()
depth_norm = (depth - dmin) / (dmax - dmin + 1e-8)
depth_u16 = (depth_norm * 65535.0).astype(np.uint16)
Image.fromarray(depth_u16).save(out_dir / f"{img_path.stem}.depth.png")
# 5. 提取强鲁棒性边缘
print(f"[Edge] 正在应用 CLAHE + Sobel 增强算法提取边缘...")
edge_mask = _extract_robust_edges(depth_norm)
# 6. 导出
mask_path = out_dir / f"{img_path.stem}.edge_mask_robust.png"
Image.fromarray(edge_mask).save(mask_path)
edge_ratio = float((edge_mask > 0).sum()) / float(edge_mask.size)
print("-" * 30)
print(f"提取完成!")
print(f"边缘密度: {edge_ratio:.2%} (目标通常应在 1% ~ 8% 之间)")
print(f"如果 Mask 依然太黑,请调低 EDGE_TOP_PERCENTILE (如 90.0)")
print(f"如果 Mask 太乱,请调高 EDGE_TOP_PERCENTILE (如 96.0)")
print("-" * 30)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,284 @@
from __future__ import annotations
from datetime import datetime
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from model.Inpaint.inpaint_loader import (
build_inpaint_predictor,
UnifiedInpaintConfig,
InpaintBackend,
build_draw_predictor,
UnifiedDrawConfig,
)
# -----------------------------
# 配置区(按需修改)
# -----------------------------
# 任务模式:
# - "inpaint": 补全
# - "draw": 绘图(文生图 / 图生图)
TASK_MODE = "draw"
# 指向你的输入图像,例如 image_0.png
INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png"
INPUT_MASK = "" # 为空表示不使用手工 mask
OUTPUT_DIR = "outputs/test_inpaint_v2"
MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK)
USE_AUTO_MASK = True # True: 自动从图像推断补全区域
AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时接近黑色像素阈值
AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖
FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask
# 透明输出控制
PRESERVE_TRANSPARENCY = True # 保持输出透明
EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明
NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景
INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果
# -----------------------------
# 关键 Prompt 修改:只生成树
# -----------------------------
# 细化 PROMPT强调树的种类、叶子和风格使其与原图融合
PROMPT = (
"A highly detailed traditional Chinese ink brush painting. "
"Restore and complete the existing trees naturally. "
"Extend the complex tree branches and dense, varied green and teal leaves. "
"Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. "
"Focus solely on vegetation. "
"Maintain the light beige background color and texture."
)
# 使用 NEGATIVE_PROMPT 显式排除不想要的内容
NEGATIVE_PROMPT = (
"buildings, architecture, houses, pavilions, temples, windows, doors, "
"people, figures, persons, characters, figures, clothing, faces, hands, "
"text, writing, characters, words, letters, signatures, seals, stamps, "
"calligraphy, objects, artifacts, boxes, baskets, tools, "
"extra branches crossing unnaturally, bad composition, watermark, signature"
)
STRENGTH = 0.8 # 保持较高强度以进行生成
GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt
NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节
MAX_SIDE = 1024
CONTROLNET_SCALE = 1.0
# -----------------------------
# 绘图draw参数
# -----------------------------
# DRAW_INPUT_IMAGE 为空时:文生图
# DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘)
DRAW_INPUT_IMAGE = ""
DRAW_PROMPT = """
Chinese ink wash painting, vast snowy river under a pale sky,
a small lonely boat at the horizon where water meets sky,
an old fisherman wearing a straw hat sits at the stern, fishing quietly,
gentle snowfall, misty atmosphere, distant mountains barely visible,
minimalist composition, large empty space, soft brush strokes,
calm, cold, and silent mood, poetic and serene
"""
DRAW_NEGATIVE_PROMPT = """
blurry, low quality, many people, bright colors, modern elements, crowded, noisy
"""
DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘
DRAW_GUIDANCE_SCALE = 9
DRAW_STEPS = 64
DRAW_WIDTH = 2560
DRAW_HEIGHT = 1440
DRAW_MAX_SIDE = 2560
def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray:
out = mask.astype(bool)
for _ in range(max(0, iterations)):
p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False)
out = (
p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:]
| p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:]
| p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:]
)
return out
def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image:
w, h = img_size
x1, y1, x2, y2 = MASK_RECT
x1 = max(0, min(x1, w - 1))
x2 = max(0, min(x2, w - 1))
y1 = max(0, min(y1, h - 1))
y2 = max(0, min(y2, h - 1))
if x2 < x1:
x1, x2 = x2, x1
if y2 < y1:
y1, y2 = y2, y1
mask = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x1, y1, x2, y2], fill=255)
return mask
def _auto_mask_from_image(img_path: Path) -> Image.Image:
"""
自动推断缺失区域:
1) 若输入带 alpha透明区域作为 mask
2) 否则将“接近黑色”区域作为候选缺失区域
"""
raw = Image.open(img_path)
arr = np.asarray(raw)
if arr.ndim == 3 and arr.shape[2] == 4:
alpha = arr[:, :, 3]
mask_bool = alpha < 250
else:
rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8)
# 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失
dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1)
mask_bool = dark
mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER)
return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L")
def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray:
"""
生成输出 alpha
- 输入若有 alpha优先沿用
- 输入若无 alpha则把接近黑色区域视作透明背景
"""
raw = Image.open(img_path)
arr = np.asarray(raw)
if arr.ndim == 3 and arr.shape[2] == 4:
return arr[:, :, 3].astype(np.uint8)
rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8)
non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1)
return (non_black.astype(np.uint8) * 255)
def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image:
if INPUT_MASK:
raw_mask_path = Path(INPUT_MASK)
mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path)
if mask_path.is_file():
return Image.open(mask_path).convert("L")
if USE_AUTO_MASK:
auto = _auto_mask_from_image(img_path)
auto_arr = np.asarray(auto, dtype=np.uint8)
if (auto_arr > 0).any():
return auto
if not FALLBACK_TO_RECT_MASK:
raise ValueError("自动mask为空请检查输入图像是否存在透明/黑色缺失区域。")
return _make_rect_mask(img_rgb.size)
def _resolve_path(base_dir: Path, p: str) -> Path:
raw = Path(p)
return raw if raw.is_absolute() else (base_dir / raw)
def run_inpaint_test(base_dir: Path, out_dir: Path) -> None:
img_path = _resolve_path(base_dir, INPUT_IMAGE)
if not img_path.is_file():
raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}")
predictor, used_backend = build_inpaint_predictor(
UnifiedInpaintConfig(backend=INPAINT_BACKEND)
)
img = Image.open(img_path).convert("RGB")
mask = _load_or_make_mask(base_dir, img_path, img)
mask_out = out_dir / f"{img_path.stem}.mask_used.png"
mask.save(mask_out)
kwargs = dict(
strength=STRENGTH,
guidance_scale=GUIDANCE_SCALE,
num_inference_steps=NUM_INFERENCE_STEPS,
max_side=MAX_SIDE,
)
if used_backend == InpaintBackend.CONTROLNET:
kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE
out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs)
out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png"
if PRESERVE_TRANSPARENCY:
alpha = _build_alpha_from_input(img_path, img)
mask_u8 = np.asarray(mask, dtype=np.uint8)
if EXPAND_ALPHA_WITH_MASK:
alpha = np.maximum(alpha, mask_u8)
out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8)
out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1)
Image.fromarray(out_rgba, mode="RGBA").save(out_path)
else:
out.save(out_path)
ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1])
print(f"[Inpaint] backend={used_backend.value}, saved={out_path}")
print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}")
def run_draw_test(base_dir: Path, out_dir: Path) -> None:
"""
绘图测试:
- 文生图DRAW_INPUT_IMAGE=""
- 图生图DRAW_INPUT_IMAGE 指向参考图
"""
draw_predictor = build_draw_predictor(UnifiedDrawConfig())
ref_image: Image.Image | None = None
mode = "text2img"
if DRAW_INPUT_IMAGE:
ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE)
if not ref_path.is_file():
raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}")
ref_image = Image.open(ref_path).convert("RGB")
mode = "img2img"
out = draw_predictor(
prompt=DRAW_PROMPT,
image=ref_image,
negative_prompt=DRAW_NEGATIVE_PROMPT,
strength=DRAW_STRENGTH,
guidance_scale=DRAW_GUIDANCE_SCALE,
num_inference_steps=DRAW_STEPS,
width=DRAW_WIDTH,
height=DRAW_HEIGHT,
max_side=DRAW_MAX_SIDE,
)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"draw_{mode}_{ts}.png"
out.save(out_path)
print(f"[Draw] mode={mode}, saved={out_path}")
def main() -> None:
base_dir = Path(__file__).resolve().parent
out_dir = base_dir / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
if TASK_MODE == "inpaint":
run_inpaint_test(base_dir, out_dir)
return
if TASK_MODE == "draw":
run_draw_test(base_dir, out_dir)
return
raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw")
if __name__ == "__main__":
main()

163
python_server/test_seg.py Normal file
View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageDraw
from scipy.ndimage import label as nd_label
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
# ================= 配置区 =================
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
OUTPUT_DIR = "outputs/test_seg_v2"
SEG_BACKEND = SegBackend.SAM
# 目标筛选参数
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
# 视觉效果
MASK_ALPHA = 0.4
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
SAVE_OBJECT_PNG = True
# =========================================
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
h, w = arr.shape[:2]
if max_side <= 0 or max(h, w) <= max_side:
return arr, (h, w), (h, w)
scale = float(max_side) / float(max(h, w))
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
resample = Image.NEAREST if is_mask else Image.BICUBIC
pil = Image.fromarray(arr)
out = pil.resize((run_w, run_h), resample=resample)
return np.asarray(out), (h, w), (run_h, run_w)
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
"""
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
"""
# 确保是布尔类型
mask_bool = edge_mask > 127
# 连通域标记
labeled_array, num_features = nd_label(mask_bool)
prompts = []
component_info = []
for i in range(1, num_features + 1):
coords = np.argwhere(labeled_array == i)
area = len(coords)
if area < 100: continue # 过滤噪声
# 取几何中心作为引导点
center_y, center_x = np.median(coords, axis=0).astype(int)
component_info.append(((center_x, center_y), area))
# 按面积排序,优先处理大面积线条覆盖的物体
component_info.sort(key=lambda x: x[1], reverse=True)
for pt, _ in component_info[:max_components]:
prompts.append({
"point_coords": np.array([pt], dtype=np.float32),
"point_labels": np.array([1], dtype=np.int32)
})
return prompts
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
obj_dir = out_dir / f"{stem}.objects"
obj_dir.mkdir(parents=True, exist_ok=True)
for idx, lb in enumerate(targets, start=1):
mask = (label_map == lb)
ys, xs = np.where(mask)
if len(ys) == 0: continue
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
rgb_crop = rgb[y1:y2+1, x1:x2+1]
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
def main():
# 1. 加载资源
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
out_dir = Path(__file__).parent / OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
# 2. 准备推理尺寸 (缩放以节省显存)
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
# 3. 初始化 SAM直接使用 SamPredictor避免 wrapper 接口不支持 point/bbox prompt
import torch
_ensure_sam_on_path()
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
predictor = SamPredictor(sam)
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
# 4. 提取引导点
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
# 5. 执行推理
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
predictor.set_image(rgb_run)
for idx, p in enumerate(prompts, start=1):
# multimask_output=True 可以获得更稳定的结果
masks, scores, _ = predictor.predict(
point_coords=p["point_coords"],
point_labels=p["point_labels"],
multimask_output=True
)
# 挑选得分最高的 mask
best_mask = masks[np.argmax(scores)]
# 只有面积足够才保留
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
final_label_map_run[best_mask > 0] = idx
# 6. 后处理与映射回原图
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
# 7. 导出与可视化
unique_labels = [l for l in np.unique(label_map) if l > 0]
# 绘制可视化图
marked_img = rgb_orig.copy()
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
# 混合颜色显示
overlay = rgb_orig.astype(np.float32)
for lb in unique_labels:
m = (label_map == lb)
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
# 简单边缘处理
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
if SAVE_OBJECT_PNG:
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
if __name__ == "__main__":
main()