initial commit
This commit is contained in:
4
.clang-tidy
Normal file
4
.clang-tidy
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
Checks: >
|
||||||
|
clang-diagnostic-unused-variable,
|
||||||
|
clang-diagnostic-unused-parameter,
|
||||||
|
clang-diagnostic-unused-lambda-capture
|
||||||
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
/build
|
||||||
|
/.venv
|
||||||
|
__pycache__/
|
||||||
20
.vscode/c_cpp_properties.json
vendored
Normal file
20
.vscode/c_cpp_properties.json
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Linux",
|
||||||
|
"includePath": [
|
||||||
|
"${workspaceFolder}/**",
|
||||||
|
"/usr/include/x86_64-linux-gnu/qt6/QtGui",
|
||||||
|
"/usr/include/x86_64-linux-gnu/qt6",
|
||||||
|
"/usr/include/x86_64-linux-gnu/qt6/QtCore",
|
||||||
|
"/usr/include/x86_64-linux-gnu/qt6/QtWidgets"
|
||||||
|
],
|
||||||
|
"defines": [],
|
||||||
|
"compilerPath": "/usr/bin/clang++",
|
||||||
|
"cStandard": "c23",
|
||||||
|
"cppStandard": "c++26",
|
||||||
|
"intelliSenseMode": "linux-clang-x64"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"version": 4
|
||||||
|
}
|
||||||
7
.vscode/launch.json
vendored
Normal file
7
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
// 使用 IntelliSense 了解相关属性。
|
||||||
|
// 悬停以查看现有属性的描述。
|
||||||
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": []
|
||||||
|
}
|
||||||
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"C_Cpp.codeAnalysis.clangTidy.enabled": true,
|
||||||
|
"C_Cpp.codeAnalysis.clangTidy.useBuildPath": true,
|
||||||
|
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
||||||
|
}
|
||||||
46
CMakeLists.txt
Normal file
46
CMakeLists.txt
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.16)
|
||||||
|
|
||||||
|
project(LandscapeInteractiveTool
|
||||||
|
VERSION 0.1.0
|
||||||
|
LANGUAGES CXX
|
||||||
|
)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 26)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
# 开启较严格的编译告警(包含未使用变量等)
|
||||||
|
if (MSVC)
|
||||||
|
add_compile_options(/W4 /permissive-)
|
||||||
|
else()
|
||||||
|
add_compile_options(
|
||||||
|
-Wall
|
||||||
|
-Wextra
|
||||||
|
-Wpedantic
|
||||||
|
-Wconversion
|
||||||
|
-Wsign-conversion
|
||||||
|
-Wunused
|
||||||
|
-Wunused-const-variable=2
|
||||||
|
-Wunused-parameter
|
||||||
|
-Wshadow
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# 尝试优先使用 Qt6,其次 Qt5
|
||||||
|
set(QT_REQUIRED_COMPONENTS Widgets Gui Core Network)
|
||||||
|
|
||||||
|
find_package(Qt6 COMPONENTS ${QT_REQUIRED_COMPONENTS} QUIET)
|
||||||
|
if(Qt6_FOUND)
|
||||||
|
message(STATUS "Configuring with Qt6")
|
||||||
|
set(QT_PACKAGE Qt6)
|
||||||
|
else()
|
||||||
|
find_package(Qt5 COMPONENTS ${QT_REQUIRED_COMPONENTS} REQUIRED)
|
||||||
|
message(STATUS "Configuring with Qt5")
|
||||||
|
set(QT_PACKAGE Qt5)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_AUTOMOC ON)
|
||||||
|
set(CMAKE_AUTOUIC ON)
|
||||||
|
set(CMAKE_AUTORCC ON)
|
||||||
|
|
||||||
|
add_subdirectory(client)
|
||||||
|
|
||||||
10
README.md
Normal file
10
README.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
```
|
||||||
|
cmake -S . -B build
|
||||||
|
cmake --build build -j
|
||||||
|
./build/client/gui/landscape_tool
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
cmake --build build --target update_translations
|
||||||
|
```
|
||||||
4
client/CMakeLists.txt
Normal file
4
client/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
set(SRC_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
|
add_subdirectory(core)
|
||||||
|
add_subdirectory(gui)
|
||||||
24
client/README.md
Normal file
24
client/README.md
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Client(Qt 桌面端)
|
||||||
|
|
||||||
|
## 目录结构(按模块)
|
||||||
|
|
||||||
|
**`core/`**(静态库 `core`,include 根为 `client/core`)
|
||||||
|
|
||||||
|
- `domain/` — 领域模型:`Project`、`Entity` 等
|
||||||
|
- `persistence/` — `PersistentBinaryObject`(统一二进制头与原子写)、`EntityPayloadBinary`(`.hfe` / 旧 `.anim`)
|
||||||
|
- `workspace/` — 项目目录、索引 JSON、撤销栈:`ProjectWorkspace`
|
||||||
|
- `depth/` — 假深度图生成:`DepthService`
|
||||||
|
- `animation/` — 关键帧采样(Hold / 线性插值):`AnimationSampling`
|
||||||
|
|
||||||
|
**`gui/`**(可执行程序 `landscape_tool`,额外 include 根为 `client/gui`)
|
||||||
|
|
||||||
|
- `app/` — 入口 `main.cpp`
|
||||||
|
- `main_window/` — 主窗口与时间轴等:`MainWindow`
|
||||||
|
- `editor/` — 编辑画布:`EditorCanvas`
|
||||||
|
- `dialogs/` — `AboutWindow`、`ImageCropDialog`
|
||||||
|
|
||||||
|
引用方式示例:`#include "core/workspace/ProjectWorkspace.h"`(以 `client/` 为根)、`#include "editor/EditorCanvas.h"`(以 `client/gui/` 为根)。
|
||||||
|
|
||||||
|
## 界面语言
|
||||||
|
|
||||||
|
界面文案为中文(无运行时语言切换)。
|
||||||
39
client/core/CMakeLists.txt
Normal file
39
client/core/CMakeLists.txt
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# 模块:domain、persistence、workspace、depth、animation(时间采样)
|
||||||
|
set(CORE_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
|
set(CORE_SOURCES
|
||||||
|
${CORE_ROOT}/domain/Project.cpp
|
||||||
|
${CORE_ROOT}/workspace/ProjectWorkspace.cpp
|
||||||
|
${CORE_ROOT}/persistence/PersistentBinaryObject.cpp
|
||||||
|
${CORE_ROOT}/persistence/EntityPayloadBinary.cpp
|
||||||
|
${CORE_ROOT}/animation/AnimationSampling.cpp
|
||||||
|
${CORE_ROOT}/depth/DepthService.cpp
|
||||||
|
${CORE_ROOT}/net/ModelServerClient.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
set(CORE_HEADERS
|
||||||
|
${CORE_ROOT}/domain/Project.h
|
||||||
|
${CORE_ROOT}/workspace/ProjectWorkspace.h
|
||||||
|
${CORE_ROOT}/persistence/PersistentBinaryObject.h
|
||||||
|
${CORE_ROOT}/persistence/EntityPayloadBinary.h
|
||||||
|
${CORE_ROOT}/animation/AnimationSampling.h
|
||||||
|
${CORE_ROOT}/depth/DepthService.h
|
||||||
|
${CORE_ROOT}/net/ModelServerClient.h
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(core STATIC
|
||||||
|
${CORE_SOURCES}
|
||||||
|
${CORE_HEADERS}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(core
|
||||||
|
PUBLIC
|
||||||
|
${CORE_ROOT}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(core
|
||||||
|
PUBLIC
|
||||||
|
${QT_PACKAGE}::Core
|
||||||
|
${QT_PACKAGE}::Gui
|
||||||
|
${QT_PACKAGE}::Network
|
||||||
|
)
|
||||||
191
client/core/animation/AnimationSampling.cpp
Normal file
191
client/core/animation/AnimationSampling.cpp
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
#include "animation/AnimationSampling.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename KeyT, typename FrameGetter>
|
||||||
|
void sortKeysByFrame(QVector<KeyT>& keys, FrameGetter getFrame) {
|
||||||
|
std::sort(keys.begin(), keys.end(), [&](const KeyT& a, const KeyT& b) { return getFrame(a) < getFrame(b); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
QPointF sampleLocation(const QVector<Project::Entity::KeyframeVec2>& keys,
|
||||||
|
int frame,
|
||||||
|
const QPointF& fallbackOrigin,
|
||||||
|
KeyInterpolation mode) {
|
||||||
|
QVector<Project::Entity::KeyframeVec2> sorted = keys;
|
||||||
|
sortKeysByFrame(sorted, [](const Project::Entity::KeyframeVec2& k) { return k.frame; });
|
||||||
|
|
||||||
|
if (sorted.isEmpty()) {
|
||||||
|
return fallbackOrigin;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode == KeyInterpolation::Hold) {
|
||||||
|
QPointF out = fallbackOrigin;
|
||||||
|
int best = -1;
|
||||||
|
for (const auto& k : sorted) {
|
||||||
|
if (k.frame <= frame && k.frame >= best) {
|
||||||
|
best = k.frame;
|
||||||
|
out = k.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Linear:区间外夹持到端点;中间在相邻关键帧间线性插值(对 x、y 分别 lerp)
|
||||||
|
const auto& first = sorted.front();
|
||||||
|
const auto& last = sorted.back();
|
||||||
|
if (frame <= first.frame) {
|
||||||
|
return first.value;
|
||||||
|
}
|
||||||
|
if (frame >= last.frame) {
|
||||||
|
return last.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i + 1 < sorted.size(); ++i) {
|
||||||
|
const int f0 = sorted[i].frame;
|
||||||
|
const int f1 = sorted[i + 1].frame;
|
||||||
|
if (frame < f0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (frame <= f1) {
|
||||||
|
if (f1 == f0 || frame == f0) {
|
||||||
|
return sorted[i].value;
|
||||||
|
}
|
||||||
|
const double t = static_cast<double>(frame - f0) / static_cast<double>(f1 - f0);
|
||||||
|
const QPointF& a = sorted[i].value;
|
||||||
|
const QPointF& b = sorted[i + 1].value;
|
||||||
|
return QPointF(a.x() + (b.x() - a.x()) * t, a.y() + (b.y() - a.y()) * t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return last.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
double sampleDepthScale01(const QVector<Project::Entity::KeyframeFloat01>& keys,
|
||||||
|
int frame,
|
||||||
|
double fallback01,
|
||||||
|
KeyInterpolation mode) {
|
||||||
|
QVector<Project::Entity::KeyframeFloat01> sorted = keys;
|
||||||
|
sortKeysByFrame(sorted, [](const Project::Entity::KeyframeFloat01& k) { return k.frame; });
|
||||||
|
|
||||||
|
const double fb = std::clamp(fallback01, 0.0, 1.0);
|
||||||
|
|
||||||
|
if (sorted.isEmpty()) {
|
||||||
|
return fb;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode == KeyInterpolation::Hold) {
|
||||||
|
double out = fb;
|
||||||
|
int best = -1;
|
||||||
|
for (const auto& k : sorted) {
|
||||||
|
if (k.frame <= frame && k.frame >= best) {
|
||||||
|
best = k.frame;
|
||||||
|
out = k.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::clamp(out, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& first = sorted.front();
|
||||||
|
const auto& last = sorted.back();
|
||||||
|
if (frame <= first.frame) {
|
||||||
|
return std::clamp(first.value, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
if (frame >= last.frame) {
|
||||||
|
return std::clamp(last.value, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i + 1 < sorted.size(); ++i) {
|
||||||
|
const int f0 = sorted[i].frame;
|
||||||
|
const int f1 = sorted[i + 1].frame;
|
||||||
|
if (frame < f0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (frame <= f1) {
|
||||||
|
if (f1 == f0 || frame == f0) {
|
||||||
|
return std::clamp(sorted[i].value, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
const double t = static_cast<double>(frame - f0) / static_cast<double>(f1 - f0);
|
||||||
|
const double a = sorted[i].value;
|
||||||
|
const double b = sorted[i + 1].value;
|
||||||
|
return std::clamp(a + (b - a) * t, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::clamp(last.value, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
double sampleUserScale(const QVector<Project::Entity::KeyframeDouble>& keys,
|
||||||
|
int frame,
|
||||||
|
double fallback,
|
||||||
|
KeyInterpolation mode) {
|
||||||
|
QVector<Project::Entity::KeyframeDouble> sorted = keys;
|
||||||
|
sortKeysByFrame(sorted, [](const Project::Entity::KeyframeDouble& k) { return k.frame; });
|
||||||
|
|
||||||
|
const double fb = std::max(fallback, 1e-6);
|
||||||
|
|
||||||
|
if (sorted.isEmpty()) {
|
||||||
|
return fb;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode == KeyInterpolation::Hold) {
|
||||||
|
double out = fb;
|
||||||
|
int best = -1;
|
||||||
|
for (const auto& k : sorted) {
|
||||||
|
if (k.frame <= frame && k.frame >= best) {
|
||||||
|
best = k.frame;
|
||||||
|
out = k.value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::max(out, 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& first = sorted.front();
|
||||||
|
const auto& last = sorted.back();
|
||||||
|
if (frame <= first.frame) {
|
||||||
|
return std::max(first.value, 1e-6);
|
||||||
|
}
|
||||||
|
if (frame >= last.frame) {
|
||||||
|
return std::max(last.value, 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i + 1 < sorted.size(); ++i) {
|
||||||
|
const int f0 = sorted[i].frame;
|
||||||
|
const int f1 = sorted[i + 1].frame;
|
||||||
|
if (frame < f0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (frame <= f1) {
|
||||||
|
if (f1 == f0 || frame == f0) {
|
||||||
|
return std::max(sorted[i].value, 1e-6);
|
||||||
|
}
|
||||||
|
const double t = static_cast<double>(frame - f0) / static_cast<double>(f1 - f0);
|
||||||
|
const double a = sorted[i].value;
|
||||||
|
const double b = sorted[i + 1].value;
|
||||||
|
return std::max(a + (b - a) * t, 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::max(last.value, 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
QString sampleImagePath(const QVector<Project::Entity::ImageFrame>& frames,
|
||||||
|
int frame,
|
||||||
|
const QString& fallbackPath) {
|
||||||
|
QVector<Project::Entity::ImageFrame> sorted = frames;
|
||||||
|
sortKeysByFrame(sorted, [](const Project::Entity::ImageFrame& k) { return k.frame; });
|
||||||
|
|
||||||
|
QString out = fallbackPath;
|
||||||
|
int best = -1;
|
||||||
|
for (const auto& k : sorted) {
|
||||||
|
if (k.frame <= frame && k.frame >= best && !k.imagePath.isEmpty()) {
|
||||||
|
best = k.frame;
|
||||||
|
out = k.imagePath;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
33
client/core/animation/AnimationSampling.h
Normal file
33
client/core/animation/AnimationSampling.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "domain/Project.h"
|
||||||
|
|
||||||
|
#include <QPointF>
|
||||||
|
#include <QString>
|
||||||
|
#include <QVector>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
enum class KeyInterpolation { Hold, Linear };
|
||||||
|
|
||||||
|
// 关键帧按 frame 排序后使用;内部会对副本排序以保证稳健。
|
||||||
|
[[nodiscard]] QPointF sampleLocation(const QVector<Project::Entity::KeyframeVec2>& keys,
|
||||||
|
int frame,
|
||||||
|
const QPointF& fallbackOrigin,
|
||||||
|
KeyInterpolation mode);
|
||||||
|
|
||||||
|
[[nodiscard]] double sampleDepthScale01(const QVector<Project::Entity::KeyframeFloat01>& keys,
|
||||||
|
int frame,
|
||||||
|
double fallback01,
|
||||||
|
KeyInterpolation mode);
|
||||||
|
|
||||||
|
[[nodiscard]] double sampleUserScale(const QVector<Project::Entity::KeyframeDouble>& keys,
|
||||||
|
int frame,
|
||||||
|
double fallback,
|
||||||
|
KeyInterpolation mode);
|
||||||
|
|
||||||
|
[[nodiscard]] QString sampleImagePath(const QVector<Project::Entity::ImageFrame>& frames,
|
||||||
|
int frame,
|
||||||
|
const QString& fallbackPath);
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
58
client/core/depth/DepthService.cpp
Normal file
58
client/core/depth/DepthService.cpp
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
#include "depth/DepthService.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
QImage DepthService::computeFakeDepth(const QSize& size) {
|
||||||
|
if (size.isEmpty() || size.width() <= 0 || size.height() <= 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
QImage depth(size, QImage::Format_Grayscale8);
|
||||||
|
if (depth.isNull()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
depth.fill(0);
|
||||||
|
return depth;
|
||||||
|
}
|
||||||
|
|
||||||
|
QImage DepthService::computeFakeDepthFromBackground(const QImage& background) {
|
||||||
|
if (background.isNull()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return computeFakeDepth(background.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
QImage DepthService::depthToColormapOverlay(const QImage& depth8, int alpha) {
|
||||||
|
if (depth8.isNull()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const QImage src = (depth8.format() == QImage::Format_Grayscale8) ? depth8 : depth8.convertToFormat(QImage::Format_Grayscale8);
|
||||||
|
if (src.isNull()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const int a = std::clamp(alpha, 0, 255);
|
||||||
|
QImage out(src.size(), QImage::Format_ARGB32_Premultiplied);
|
||||||
|
if (out.isNull()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int y = 0; y < src.height(); ++y) {
|
||||||
|
const uchar* row = src.constScanLine(y);
|
||||||
|
QRgb* dst = reinterpret_cast<QRgb*>(out.scanLine(y));
|
||||||
|
for (int x = 0; x < src.width(); ++x) {
|
||||||
|
const int d = static_cast<int>(row[x]); // 0..255
|
||||||
|
// depth=0(远)-> 蓝;depth=255(近)-> 红
|
||||||
|
const int r = d;
|
||||||
|
const int g = 0;
|
||||||
|
const int b = 255 - d;
|
||||||
|
dst[x] = qRgba(r, g, b, a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
|
||||||
20
client/core/depth/DepthService.h
Normal file
20
client/core/depth/DepthService.h
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QImage>
|
||||||
|
#include <QSize>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
class DepthService final {
|
||||||
|
public:
|
||||||
|
// 生成 8-bit 深度图:0 最远,255 最近。当前实现为全 0(假深度)。
|
||||||
|
static QImage computeFakeDepth(const QSize& size);
|
||||||
|
static QImage computeFakeDepthFromBackground(const QImage& background);
|
||||||
|
|
||||||
|
// 把 8-bit 深度(Grayscale8)映射为伪彩色 ARGB32(带 alpha),用于叠加显示。
|
||||||
|
// 约定:depth=0(最远)-> 蓝,depth=255(最近)-> 红(线性插值)。
|
||||||
|
static QImage depthToColormapOverlay(const QImage& depth8, int alpha /*0-255*/);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
|
||||||
5
client/core/domain/Project.cpp
Normal file
5
client/core/domain/Project.cpp
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#include "domain/Project.h"
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
96
client/core/domain/Project.h
Normal file
96
client/core/domain/Project.h
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QString>
|
||||||
|
#include <QPointF>
|
||||||
|
#include <QVector>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
class Project {
|
||||||
|
public:
|
||||||
|
void setName(const QString& name) { m_name = name; }
|
||||||
|
const QString& name() const { return m_name; }
|
||||||
|
|
||||||
|
// 背景图在项目目录内的相对路径,例如 "assets/background.png"
|
||||||
|
void setBackgroundImagePath(const QString& relativePath) { m_backgroundImagePath = relativePath; }
|
||||||
|
const QString& backgroundImagePath() const { return m_backgroundImagePath; }
|
||||||
|
|
||||||
|
// 背景在视口/预览中的显隐(默认显示)
|
||||||
|
void setBackgroundVisible(bool on) { m_backgroundVisible = on; }
|
||||||
|
bool backgroundVisible() const { return m_backgroundVisible; }
|
||||||
|
|
||||||
|
void setDepthComputed(bool on) { m_depthComputed = on; }
|
||||||
|
bool depthComputed() const { return m_depthComputed; }
|
||||||
|
|
||||||
|
// 深度图在项目目录内的相对路径,例如 "assets/depth.png"
|
||||||
|
void setDepthMapPath(const QString& relativePath) { m_depthMapPath = relativePath; }
|
||||||
|
const QString& depthMapPath() const { return m_depthMapPath; }
|
||||||
|
|
||||||
|
void setFrameStart(int f) { m_frameStart = f; }
|
||||||
|
int frameStart() const { return m_frameStart; }
|
||||||
|
void setFrameEnd(int f) { m_frameEnd = f; }
|
||||||
|
int frameEnd() const { return m_frameEnd; }
|
||||||
|
void setFps(int fps) { m_fps = std::max(1, fps); }
|
||||||
|
int fps() const { return m_fps; }
|
||||||
|
|
||||||
|
struct Entity {
|
||||||
|
QString id;
|
||||||
|
QString displayName; // 显示名(空则界面用 id)
|
||||||
|
bool visible = true; // Outliner 眼睛:默认显示
|
||||||
|
// 可移动实体形状:存为局部坐标(相对 originWorld)
|
||||||
|
QVector<QPointF> polygonLocal;
|
||||||
|
// 从背景中抠洞的位置:固定在创建时的 world 坐标,不随实体移动
|
||||||
|
QVector<QPointF> cutoutPolygonWorld;
|
||||||
|
QPointF originWorld;
|
||||||
|
int depth = 0; // 0..255
|
||||||
|
QString imagePath; // 相对路径,例如 "assets/entities/entity-1.png"
|
||||||
|
QPointF imageTopLeftWorld; // 贴图左上角 world 坐标
|
||||||
|
// 人为整体缩放,与深度驱动的距离缩放相乘(画布中 visualScale = distanceScale * userScale)
|
||||||
|
double userScale = 1.0;
|
||||||
|
|
||||||
|
struct KeyframeVec2 {
|
||||||
|
int frame = 0;
|
||||||
|
QPointF value;
|
||||||
|
};
|
||||||
|
struct KeyframeFloat01 {
|
||||||
|
int frame = 0;
|
||||||
|
double value = 0.5; // 0..1,默认 0.5 -> scale=1.0(0.5..1.5 映射)
|
||||||
|
};
|
||||||
|
struct KeyframeDouble {
|
||||||
|
int frame = 0;
|
||||||
|
double value = 1.0;
|
||||||
|
};
|
||||||
|
struct ImageFrame {
|
||||||
|
int frame = 0;
|
||||||
|
QString imagePath; // 相对路径
|
||||||
|
};
|
||||||
|
|
||||||
|
// v2:project.json 仅存 id + payload,几何与动画在 entityPayloadPath(.hfe)中。
|
||||||
|
QString entityPayloadPath; // 例如 "assets/entities/entity-1.hfe"
|
||||||
|
// 仅打开 v1 项目时由 JSON 的 animationBundle 填入,用于合并旧 .anim;保存 v2 前应为空。
|
||||||
|
QString legacyAnimSidecarPath;
|
||||||
|
|
||||||
|
QVector<KeyframeVec2> locationKeys;
|
||||||
|
QVector<KeyframeFloat01> depthScaleKeys;
|
||||||
|
QVector<KeyframeDouble> userScaleKeys;
|
||||||
|
QVector<ImageFrame> imageFrames;
|
||||||
|
};
|
||||||
|
|
||||||
|
void setEntities(const QVector<Entity>& entities) { m_entities = entities; }
|
||||||
|
const QVector<Entity>& entities() const { return m_entities; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
QString m_name;
|
||||||
|
QString m_backgroundImagePath;
|
||||||
|
bool m_backgroundVisible = true;
|
||||||
|
bool m_depthComputed = false;
|
||||||
|
QString m_depthMapPath;
|
||||||
|
int m_frameStart = 0;
|
||||||
|
int m_frameEnd = 600;
|
||||||
|
int m_fps = 60;
|
||||||
|
QVector<Entity> m_entities;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
139
client/core/net/ModelServerClient.cpp
Normal file
139
client/core/net/ModelServerClient.cpp
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
#include "net/ModelServerClient.h"
|
||||||
|
|
||||||
|
#include <QEventLoop>
|
||||||
|
#include <QJsonDocument>
|
||||||
|
#include <QJsonObject>
|
||||||
|
#include <QNetworkAccessManager>
|
||||||
|
#include <QNetworkReply>
|
||||||
|
#include <QNetworkRequest>
|
||||||
|
#include <QTimer>
|
||||||
|
#include <QUrl>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
ModelServerClient::ModelServerClient(QObject* parent)
|
||||||
|
: QObject(parent)
|
||||||
|
, m_nam(new QNetworkAccessManager(this)) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelServerClient::setBaseUrl(const QUrl& baseUrl) {
|
||||||
|
m_baseUrl = baseUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
QUrl ModelServerClient::baseUrl() const {
|
||||||
|
return m_baseUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
QNetworkReply* ModelServerClient::computeDepthPng8Async(const QByteArray& imageBytes, QString* outImmediateError) {
|
||||||
|
if (outImmediateError) {
|
||||||
|
outImmediateError->clear();
|
||||||
|
}
|
||||||
|
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
|
||||||
|
if (outImmediateError) *outImmediateError = QStringLiteral("后端地址无效。");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (imageBytes.isEmpty()) {
|
||||||
|
if (outImmediateError) *outImmediateError = QStringLiteral("输入图像为空。");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/depth")));
|
||||||
|
QNetworkRequest req(url);
|
||||||
|
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||||||
|
|
||||||
|
const QByteArray imageB64 = imageBytes.toBase64();
|
||||||
|
const QJsonObject payload{
|
||||||
|
{QStringLiteral("image_b64"), QString::fromLatin1(imageB64)},
|
||||||
|
};
|
||||||
|
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
|
||||||
|
return m_nam->post(req, body);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ModelServerClient::computeDepthPng8(
|
||||||
|
const QByteArray& imageBytes,
|
||||||
|
QByteArray& outPngBytes,
|
||||||
|
QString& outError,
|
||||||
|
int timeoutMs
|
||||||
|
) {
|
||||||
|
outPngBytes.clear();
|
||||||
|
outError.clear();
|
||||||
|
|
||||||
|
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
|
||||||
|
outError = QStringLiteral("后端地址无效。");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (imageBytes.isEmpty()) {
|
||||||
|
outError = QStringLiteral("输入图像为空。");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/depth")));
|
||||||
|
|
||||||
|
QNetworkRequest req(url);
|
||||||
|
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||||||
|
|
||||||
|
const QByteArray imageB64 = imageBytes.toBase64();
|
||||||
|
const QJsonObject payload{
|
||||||
|
{QStringLiteral("image_b64"), QString::fromLatin1(imageB64)},
|
||||||
|
};
|
||||||
|
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
|
||||||
|
|
||||||
|
QNetworkReply* reply = m_nam->post(req, body);
|
||||||
|
if (!reply) {
|
||||||
|
outError = QStringLiteral("创建网络请求失败。");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
QEventLoop loop;
|
||||||
|
QTimer timer;
|
||||||
|
timer.setSingleShot(true);
|
||||||
|
const int t = (timeoutMs <= 0) ? 30000 : timeoutMs;
|
||||||
|
|
||||||
|
QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
|
||||||
|
QObject::connect(&timer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||||||
|
timer.start(t);
|
||||||
|
loop.exec();
|
||||||
|
|
||||||
|
if (timer.isActive() == false && reply->isFinished() == false) {
|
||||||
|
reply->abort();
|
||||||
|
reply->deleteLater();
|
||||||
|
outError = QStringLiteral("请求超时(%1ms)。").arg(t);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int httpStatus = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
|
||||||
|
const QByteArray raw = reply->readAll();
|
||||||
|
const auto netErr = reply->error();
|
||||||
|
const QString netErrStr = reply->errorString();
|
||||||
|
reply->deleteLater();
|
||||||
|
|
||||||
|
if (netErr != QNetworkReply::NoError) {
|
||||||
|
outError = QStringLiteral("网络错误:%1").arg(netErrStr);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (httpStatus != 200) {
|
||||||
|
// FastAPI HTTPException 默认返回 {"detail": "..."}
|
||||||
|
QString detail;
|
||||||
|
const QJsonDocument jd = QJsonDocument::fromJson(raw);
|
||||||
|
if (jd.isObject()) {
|
||||||
|
const auto obj = jd.object();
|
||||||
|
detail = obj.value(QStringLiteral("detail")).toString();
|
||||||
|
}
|
||||||
|
outError = detail.isEmpty()
|
||||||
|
? QStringLiteral("后端返回HTTP %1。").arg(httpStatus)
|
||||||
|
: QStringLiteral("后端错误(HTTP %1):%2").arg(httpStatus).arg(detail);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (raw.isEmpty()) {
|
||||||
|
outError = QStringLiteral("后端返回空数据。");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
outPngBytes = raw;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
|
||||||
36
client/core/net/ModelServerClient.h
Normal file
36
client/core/net/ModelServerClient.h
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QByteArray>
|
||||||
|
#include <QObject>
|
||||||
|
#include <QString>
|
||||||
|
#include <QUrl>
|
||||||
|
#include <QNetworkReply>
|
||||||
|
class QNetworkAccessManager;
|
||||||
|
class QUrl;
|
||||||
|
class QNetworkReply;
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
class ModelServerClient final : public QObject {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit ModelServerClient(QObject* parent = nullptr);
|
||||||
|
|
||||||
|
void setBaseUrl(const QUrl& baseUrl);
|
||||||
|
QUrl baseUrl() const;
|
||||||
|
|
||||||
|
// 同步调用:向后端 POST /depth 发送背景图,成功返回 PNG(8-bit 灰度) 的二进制数据。
|
||||||
|
// timeoutMs<=0 表示使用默认超时(30s)。
|
||||||
|
bool computeDepthPng8(const QByteArray& imageBytes, QByteArray& outPngBytes, QString& outError, int timeoutMs = 30000);
|
||||||
|
|
||||||
|
// 异步调用:发起 POST /depth,返回 reply(由 Qt 管理生命周期;调用方负责连接 finished/错误处理)。
|
||||||
|
// 返回 nullptr 表示参数/URL 非法导致无法发起。
|
||||||
|
QNetworkReply* computeDepthPng8Async(const QByteArray& imageBytes, QString* outImmediateError = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
QNetworkAccessManager* m_nam = nullptr;
|
||||||
|
QUrl m_baseUrl;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
|
||||||
324
client/core/persistence/EntityPayloadBinary.cpp
Normal file
324
client/core/persistence/EntityPayloadBinary.cpp
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
#include "persistence/EntityPayloadBinary.h"
|
||||||
|
|
||||||
|
#include "persistence/PersistentBinaryObject.h"
|
||||||
|
|
||||||
|
#include "domain/Project.h"
|
||||||
|
|
||||||
|
#include <QDataStream>
|
||||||
|
#include <QFile>
|
||||||
|
#include <QtGlobal>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void sortByFrame(QVector<Project::Entity::KeyframeVec2>& v) {
|
||||||
|
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void sortByFrame(QVector<Project::Entity::KeyframeFloat01>& v) {
|
||||||
|
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void sortByFrame(QVector<Project::Entity::KeyframeDouble>& v) {
|
||||||
|
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void sortByFrame(QVector<Project::Entity::ImageFrame>& v) {
|
||||||
|
std::sort(v.begin(), v.end(), [](const auto& a, const auto& b) { return a.frame < b.frame; });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readAnimationBlock(QDataStream& ds, Project::Entity& out, bool hasUserScaleKeys) {
|
||||||
|
out.locationKeys.clear();
|
||||||
|
out.depthScaleKeys.clear();
|
||||||
|
out.userScaleKeys.clear();
|
||||||
|
out.imageFrames.clear();
|
||||||
|
|
||||||
|
qint32 nLoc = 0;
|
||||||
|
ds >> nLoc;
|
||||||
|
if (ds.status() != QDataStream::Ok || nLoc < 0 || nLoc > 1000000) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.locationKeys.reserve(nLoc);
|
||||||
|
for (qint32 i = 0; i < nLoc; ++i) {
|
||||||
|
qint32 frame = 0;
|
||||||
|
double x = 0.0;
|
||||||
|
double y = 0.0;
|
||||||
|
ds >> frame >> x >> y;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.locationKeys.push_back(Project::Entity::KeyframeVec2{frame, QPointF(x, y)});
|
||||||
|
}
|
||||||
|
|
||||||
|
qint32 nDepth = 0;
|
||||||
|
ds >> nDepth;
|
||||||
|
if (ds.status() != QDataStream::Ok || nDepth < 0 || nDepth > 1000000) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.depthScaleKeys.reserve(nDepth);
|
||||||
|
for (qint32 i = 0; i < nDepth; ++i) {
|
||||||
|
qint32 frame = 0;
|
||||||
|
double v = 0.5;
|
||||||
|
ds >> frame >> v;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.depthScaleKeys.push_back(Project::Entity::KeyframeFloat01{frame, v});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasUserScaleKeys) {
|
||||||
|
qint32 nUser = 0;
|
||||||
|
ds >> nUser;
|
||||||
|
if (ds.status() != QDataStream::Ok || nUser < 0 || nUser > 1000000) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.userScaleKeys.reserve(nUser);
|
||||||
|
for (qint32 i = 0; i < nUser; ++i) {
|
||||||
|
qint32 frame = 0;
|
||||||
|
double v = 1.0;
|
||||||
|
ds >> frame >> v;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.userScaleKeys.push_back(Project::Entity::KeyframeDouble{frame, v});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
qint32 nImg = 0;
|
||||||
|
ds >> nImg;
|
||||||
|
if (ds.status() != QDataStream::Ok || nImg < 0 || nImg > 1000000) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.imageFrames.reserve(nImg);
|
||||||
|
for (qint32 i = 0; i < nImg; ++i) {
|
||||||
|
qint32 frame = 0;
|
||||||
|
QString path;
|
||||||
|
ds >> frame >> path;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!path.isEmpty()) {
|
||||||
|
out.imageFrames.push_back(Project::Entity::ImageFrame{frame, path});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sortByFrame(out.locationKeys);
|
||||||
|
sortByFrame(out.depthScaleKeys);
|
||||||
|
sortByFrame(out.userScaleKeys);
|
||||||
|
sortByFrame(out.imageFrames);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeAnimationBlock(QDataStream& ds, const Project::Entity& entity, bool writeUserScaleKeys) {
|
||||||
|
ds << qint32(entity.locationKeys.size());
|
||||||
|
for (const auto& k : entity.locationKeys) {
|
||||||
|
ds << qint32(k.frame) << double(k.value.x()) << double(k.value.y());
|
||||||
|
}
|
||||||
|
|
||||||
|
ds << qint32(entity.depthScaleKeys.size());
|
||||||
|
for (const auto& k : entity.depthScaleKeys) {
|
||||||
|
ds << qint32(k.frame) << double(k.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (writeUserScaleKeys) {
|
||||||
|
ds << qint32(entity.userScaleKeys.size());
|
||||||
|
for (const auto& k : entity.userScaleKeys) {
|
||||||
|
ds << qint32(k.frame) << double(k.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ds << qint32(entity.imageFrames.size());
|
||||||
|
for (const auto& k : entity.imageFrames) {
|
||||||
|
ds << qint32(k.frame) << k.imagePath;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readEntityPayloadV1(QDataStream& ds, Project::Entity& tmp, bool hasUserScaleKeys) {
|
||||||
|
ds >> tmp.id;
|
||||||
|
qint32 depth = 0;
|
||||||
|
ds >> depth;
|
||||||
|
tmp.depth = static_cast<int>(depth);
|
||||||
|
ds >> tmp.imagePath;
|
||||||
|
double ox = 0.0;
|
||||||
|
double oy = 0.0;
|
||||||
|
double itlx = 0.0;
|
||||||
|
double itly = 0.0;
|
||||||
|
ds >> ox >> oy >> itlx >> itly;
|
||||||
|
tmp.originWorld = QPointF(ox, oy);
|
||||||
|
tmp.imageTopLeftWorld = QPointF(itlx, itly);
|
||||||
|
|
||||||
|
qint32 nLocal = 0;
|
||||||
|
ds >> nLocal;
|
||||||
|
if (ds.status() != QDataStream::Ok || nLocal < 0 || nLocal > 1000000) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp.polygonLocal.reserve(nLocal);
|
||||||
|
for (qint32 i = 0; i < nLocal; ++i) {
|
||||||
|
double x = 0.0;
|
||||||
|
double y = 0.0;
|
||||||
|
ds >> x >> y;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp.polygonLocal.push_back(QPointF(x, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
qint32 nCut = 0;
|
||||||
|
ds >> nCut;
|
||||||
|
if (ds.status() != QDataStream::Ok || nCut < 0 || nCut > 1000000) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp.cutoutPolygonWorld.reserve(nCut);
|
||||||
|
for (qint32 i = 0; i < nCut; ++i) {
|
||||||
|
double x = 0.0;
|
||||||
|
double y = 0.0;
|
||||||
|
ds >> x >> y;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp.cutoutPolygonWorld.push_back(QPointF(x, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!readAnimationBlock(ds, tmp, hasUserScaleKeys)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tmp.id.isEmpty() || tmp.polygonLocal.isEmpty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
class EntityBinaryRecord final : public PersistentBinaryObject {
|
||||||
|
public:
|
||||||
|
explicit EntityBinaryRecord(const Project::Entity& e) : m_src(&e), m_dst(nullptr) {}
|
||||||
|
explicit EntityBinaryRecord(Project::Entity& e) : m_src(nullptr), m_dst(&e) {}
|
||||||
|
|
||||||
|
quint32 recordMagic() const override { return EntityPayloadBinary::kMagicPayload; }
|
||||||
|
quint32 recordFormatVersion() const override { return EntityPayloadBinary::kPayloadVersion; }
|
||||||
|
|
||||||
|
void writeBody(QDataStream& ds) const override {
|
||||||
|
Q_ASSERT(m_src != nullptr);
|
||||||
|
const Project::Entity& entity = *m_src;
|
||||||
|
ds << entity.id;
|
||||||
|
ds << qint32(entity.depth);
|
||||||
|
ds << entity.imagePath;
|
||||||
|
ds << double(entity.originWorld.x()) << double(entity.originWorld.y());
|
||||||
|
ds << double(entity.imageTopLeftWorld.x()) << double(entity.imageTopLeftWorld.y());
|
||||||
|
|
||||||
|
ds << qint32(entity.polygonLocal.size());
|
||||||
|
for (const auto& pt : entity.polygonLocal) {
|
||||||
|
ds << double(pt.x()) << double(pt.y());
|
||||||
|
}
|
||||||
|
|
||||||
|
ds << qint32(entity.cutoutPolygonWorld.size());
|
||||||
|
for (const auto& pt : entity.cutoutPolygonWorld) {
|
||||||
|
ds << double(pt.x()) << double(pt.y());
|
||||||
|
}
|
||||||
|
|
||||||
|
writeAnimationBlock(ds, entity, true);
|
||||||
|
ds << entity.displayName << double(entity.userScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readBody(QDataStream& ds) override {
|
||||||
|
Q_ASSERT(m_dst != nullptr);
|
||||||
|
Project::Entity tmp;
|
||||||
|
if (!readEntityPayloadV1(ds, tmp, true)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
QString dn;
|
||||||
|
double us = 1.0;
|
||||||
|
ds >> dn >> us;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp.displayName = dn;
|
||||||
|
tmp.userScale = std::clamp(us, 1e-3, 1e3);
|
||||||
|
*m_dst = std::move(tmp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Project::Entity* m_src;
|
||||||
|
Project::Entity* m_dst;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LegacyAnimSidecarRecord final : public PersistentBinaryObject {
|
||||||
|
public:
|
||||||
|
explicit LegacyAnimSidecarRecord(Project::Entity& e) : m_entity(&e) {}
|
||||||
|
|
||||||
|
quint32 recordMagic() const override { return EntityPayloadBinary::kMagicLegacyAnim; }
|
||||||
|
quint32 recordFormatVersion() const override { return EntityPayloadBinary::kLegacyAnimVersion; }
|
||||||
|
|
||||||
|
void writeBody(QDataStream& ds) const override { Q_UNUSED(ds); }
|
||||||
|
|
||||||
|
bool readBody(QDataStream& ds) override {
|
||||||
|
Project::Entity tmp = *m_entity;
|
||||||
|
if (!readAnimationBlock(ds, tmp, false)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
m_entity->locationKeys = std::move(tmp.locationKeys);
|
||||||
|
m_entity->depthScaleKeys = std::move(tmp.depthScaleKeys);
|
||||||
|
m_entity->userScaleKeys = std::move(tmp.userScaleKeys);
|
||||||
|
m_entity->imageFrames = std::move(tmp.imageFrames);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Project::Entity* m_entity;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool EntityPayloadBinary::save(const QString& absolutePath, const Project::Entity& entity) {
|
||||||
|
if (absolutePath.isEmpty() || entity.id.isEmpty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return EntityBinaryRecord(entity).saveToFile(absolutePath);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EntityPayloadBinary::load(const QString& absolutePath, Project::Entity& entity) {
|
||||||
|
QFile f(absolutePath);
|
||||||
|
if (!f.open(QIODevice::ReadOnly)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
QDataStream ds(&f);
|
||||||
|
ds.setVersion(QDataStream::Qt_5_15);
|
||||||
|
quint32 magic = 0;
|
||||||
|
quint32 ver = 0;
|
||||||
|
ds >> magic >> ver;
|
||||||
|
if (ds.status() != QDataStream::Ok || magic != kMagicPayload) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (ver != 1 && ver != 2 && ver != 3) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Project::Entity tmp;
|
||||||
|
if (!readEntityPayloadV1(ds, tmp, ver >= 3)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (ver >= 2) {
|
||||||
|
QString dn;
|
||||||
|
double us = 1.0;
|
||||||
|
ds >> dn >> us;
|
||||||
|
if (ds.status() != QDataStream::Ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp.displayName = dn;
|
||||||
|
tmp.userScale = std::clamp(us, 1e-3, 1e3);
|
||||||
|
} else {
|
||||||
|
tmp.displayName.clear();
|
||||||
|
tmp.userScale = 1.0;
|
||||||
|
}
|
||||||
|
entity = std::move(tmp);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EntityPayloadBinary::loadLegacyAnimFile(const QString& absolutePath, Project::Entity& entity) {
|
||||||
|
return LegacyAnimSidecarRecord(entity).loadFromFile(absolutePath);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
30
client/core/persistence/EntityPayloadBinary.h
Normal file
30
client/core/persistence/EntityPayloadBinary.h
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "domain/Project.h"
|
||||||
|
|
||||||
|
#include <QString>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
// 实体完整数据(几何 + 贴图路径 + 动画轨道)的二进制格式,与 project.json v2 的 payload 字段对应。
|
||||||
|
// 贴图 PNG 仍单独存放在 assets/entities/,本文件不嵌入像素。
|
||||||
|
// 具体读写通过继承 PersistentBinaryObject 的适配器类完成(见 EntityPayloadBinary.cpp)。
|
||||||
|
class EntityPayloadBinary {
|
||||||
|
public:
|
||||||
|
static constexpr quint32 kMagicPayload = 0x48464550; // 'HFEP'
|
||||||
|
static constexpr quint32 kPayloadVersion = 3; // v3:追加 userScaleKeys(动画轨道)
|
||||||
|
|
||||||
|
// 旧版独立动画文件(仍用于打开 v1 项目时合并)
|
||||||
|
static constexpr quint32 kMagicLegacyAnim = 0x48465441; // 'HFTA'
|
||||||
|
static constexpr quint32 kLegacyAnimVersion = 1;
|
||||||
|
|
||||||
|
static bool save(const QString& absolutePath, const Project::Entity& entity);
|
||||||
|
|
||||||
|
// 读入后覆盖 entity 中除调用方已校验外的字段;失败时尽量保持 entity 不变。
|
||||||
|
static bool load(const QString& absolutePath, Project::Entity& entity);
|
||||||
|
|
||||||
|
// 仅读取旧 .anim(HFTA),写入 entity 的三条动画轨道。
|
||||||
|
static bool loadLegacyAnimFile(const QString& absolutePath, Project::Entity& entity);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
57
client/core/persistence/PersistentBinaryObject.cpp
Normal file
57
client/core/persistence/PersistentBinaryObject.cpp
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
#include "persistence/PersistentBinaryObject.h"
|
||||||
|
|
||||||
|
#include <QDataStream>
|
||||||
|
#include <QDir>
|
||||||
|
#include <QFile>
|
||||||
|
#include <QFileInfo>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
bool PersistentBinaryObject::saveToFile(const QString& absolutePath) const {
|
||||||
|
if (absolutePath.isEmpty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto parent = QFileInfo(absolutePath).absolutePath();
|
||||||
|
if (!QFileInfo(parent).exists()) {
|
||||||
|
QDir().mkpath(parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
const QString tmpPath = absolutePath + QStringLiteral(".tmp");
|
||||||
|
QFile f(tmpPath);
|
||||||
|
if (!f.open(QIODevice::WriteOnly | QIODevice::Truncate)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
QDataStream ds(&f);
|
||||||
|
ds.setVersion(QDataStream::Qt_5_15);
|
||||||
|
ds << quint32(recordMagic()) << quint32(recordFormatVersion());
|
||||||
|
writeBody(ds);
|
||||||
|
|
||||||
|
f.close();
|
||||||
|
if (f.error() != QFile::NoError) {
|
||||||
|
QFile::remove(tmpPath);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
QFile::remove(absolutePath);
|
||||||
|
return QFile::rename(tmpPath, absolutePath);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PersistentBinaryObject::loadFromFile(const QString& absolutePath) {
|
||||||
|
QFile f(absolutePath);
|
||||||
|
if (!f.open(QIODevice::ReadOnly)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
QDataStream ds(&f);
|
||||||
|
ds.setVersion(QDataStream::Qt_5_15);
|
||||||
|
|
||||||
|
quint32 magic = 0;
|
||||||
|
quint32 version = 0;
|
||||||
|
ds >> magic >> version;
|
||||||
|
if (ds.status() != QDataStream::Ok || magic != recordMagic() || version != recordFormatVersion()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return readBody(ds);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
27
client/core/persistence/PersistentBinaryObject.h
Normal file
27
client/core/persistence/PersistentBinaryObject.h
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QString>
|
||||||
|
|
||||||
|
class QDataStream;
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
// 二进制记录的统一持久化基类:魔数/版本、QDataStream 版本、.tmp 原子替换、父目录创建。
|
||||||
|
//
|
||||||
|
// 领域类型(如 Project::Entity)应保持为可拷贝的值类型,不要继承本类;为每种存储格式写一个
|
||||||
|
// 小的适配器类(如 EntityBinaryRecord)继承本类并实现 writeBody/readBody 即可。
|
||||||
|
class PersistentBinaryObject {
|
||||||
|
public:
|
||||||
|
virtual ~PersistentBinaryObject() = default;
|
||||||
|
|
||||||
|
[[nodiscard]] bool saveToFile(const QString& absolutePath) const;
|
||||||
|
[[nodiscard]] bool loadFromFile(const QString& absolutePath);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual quint32 recordMagic() const = 0;
|
||||||
|
virtual quint32 recordFormatVersion() const = 0;
|
||||||
|
virtual void writeBody(QDataStream& ds) const = 0;
|
||||||
|
virtual bool readBody(QDataStream& ds) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
1629
client/core/workspace/ProjectWorkspace.cpp
Normal file
1629
client/core/workspace/ProjectWorkspace.cpp
Normal file
File diff suppressed because it is too large
Load Diff
148
client/core/workspace/ProjectWorkspace.h
Normal file
148
client/core/workspace/ProjectWorkspace.h
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "domain/Project.h"
|
||||||
|
|
||||||
|
#include <QImage>
|
||||||
|
#include <QJsonObject>
|
||||||
|
#include <QRect>
|
||||||
|
#include <QString>
|
||||||
|
#include <QStringList>
|
||||||
|
#include <QVector>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
|
||||||
|
class ProjectWorkspace {
|
||||||
|
public:
|
||||||
|
static constexpr const char* kProjectIndexFileName = "project.json";
|
||||||
|
static constexpr const char* kAssetsDirName = "assets";
|
||||||
|
// 写入 project.json 的 version 字段;仍可读 version 1(内嵌实体 + 可选 .anim)。
|
||||||
|
static constexpr int kProjectIndexFormatVersion = 2;
|
||||||
|
|
||||||
|
ProjectWorkspace() = default;
|
||||||
|
|
||||||
|
// 新建项目:
|
||||||
|
// - 传入的 parentDir 是“父目录”(你在文件对话框中选择的目录)
|
||||||
|
// - 会在 parentDir 下创建一个新的项目目录(默认使用项目名做文件夹名;若重名会自动加后缀)
|
||||||
|
// - 项目结构为(v2):
|
||||||
|
// <projectDir>/project.json (索引:背景/深度路径 + 实体 id 与 .hfe 路径)
|
||||||
|
// <projectDir>/assets/background.png
|
||||||
|
// <projectDir>/assets/entities/*.png / *.hfe
|
||||||
|
bool createNew(const QString& parentDir, const QString& name, const QString& backgroundImageSourcePath);
|
||||||
|
bool createNew(const QString& parentDir, const QString& name, const QString& backgroundImageSourcePath,
|
||||||
|
const QRect& cropRectInSourceImage);
|
||||||
|
bool openExisting(const QString& projectDir);
|
||||||
|
void close();
|
||||||
|
|
||||||
|
bool isOpen() const { return !m_projectDir.isEmpty(); }
|
||||||
|
const QString& projectDir() const { return m_projectDir; }
|
||||||
|
QString indexFilePath() const;
|
||||||
|
QString assetsDirPath() const;
|
||||||
|
bool hasBackground() const { return !m_project.backgroundImagePath().isEmpty(); }
|
||||||
|
QString backgroundAbsolutePath() const;
|
||||||
|
bool backgroundVisible() const { return m_project.backgroundVisible(); }
|
||||||
|
bool setBackgroundVisible(bool on);
|
||||||
|
|
||||||
|
bool hasDepth() const;
|
||||||
|
QString depthAbsolutePath() const;
|
||||||
|
|
||||||
|
// 写入 project.json 的 name 字段(可 undo)
|
||||||
|
bool setProjectTitle(const QString& title);
|
||||||
|
|
||||||
|
Project& project() { return m_project; }
|
||||||
|
const Project& project() const { return m_project; }
|
||||||
|
|
||||||
|
// 历史操作(最多 30 步),类似 Blender:维护 undo/redo 栈
|
||||||
|
bool canUndo() const;
|
||||||
|
bool canRedo() const;
|
||||||
|
bool undo();
|
||||||
|
bool redo();
|
||||||
|
QStringList historyLabelsNewestFirst() const;
|
||||||
|
|
||||||
|
// 追加一次“导入并设置背景图”操作:把图片拷贝进 assets/,并作为背景写入项目(会进入历史)。
|
||||||
|
bool importBackgroundImage(const QString& backgroundImageSourcePath);
|
||||||
|
bool importBackgroundImage(const QString& backgroundImageSourcePath, const QRect& cropRectInSourceImage);
|
||||||
|
|
||||||
|
// 计算并写入假深度图:assets/depth.png,同时更新 project.json(depthComputed/depthMapPath)。
|
||||||
|
bool computeFakeDepthForProject();
|
||||||
|
|
||||||
|
// 从后端计算深度并落盘:assets/depth.png,同时更新 project.json(depthComputed/depthMapPath)。
|
||||||
|
// - serverBaseUrl 为空时:优先读环境变量 MODEL_SERVER_URL,否则默认 http://127.0.0.1:8000
|
||||||
|
// - outError 可选:返回失败原因
|
||||||
|
bool computeDepthForProjectFromServer(const QString& serverBaseUrl, QString* outError = nullptr, int timeoutMs = 30000);
|
||||||
|
|
||||||
|
// 直接保存深度图(PNG bytes)到 assets/depth.png,并更新 project.json。
|
||||||
|
bool saveDepthMapPngBytes(const QByteArray& pngBytes, QString* outError = nullptr);
|
||||||
|
|
||||||
|
const QVector<Project::Entity>& entities() const { return m_project.entities(); }
|
||||||
|
bool addEntity(const Project::Entity& entity, const QImage& image);
|
||||||
|
bool setEntityVisible(const QString& id, bool on);
|
||||||
|
bool setEntityDisplayName(const QString& id, const QString& displayName);
|
||||||
|
bool setEntityUserScale(const QString& id, double userScale);
|
||||||
|
// 将多边形质心平移到 targetCentroidWorld(整体平移);sTotal 须与画布一致
|
||||||
|
bool moveEntityCentroidTo(const QString& id, int frame, const QPointF& targetCentroidWorld, double sTotal,
|
||||||
|
bool autoKeyLocation);
|
||||||
|
// 在保持外形不变的前提下移动枢轴点;sTotal 须与画布一致(距离缩放×整体缩放)
|
||||||
|
bool reanchorEntityPivot(const QString& id, int frame, const QPointF& newPivotWorld, double sTotal);
|
||||||
|
bool reorderEntitiesById(const QStringList& idsInOrder);
|
||||||
|
// currentFrame:自动关键帧时写入位置曲线;autoKeyLocation 为 false 时忽略。
|
||||||
|
bool moveEntityBy(const QString& id, const QPointF& delta, int currentFrame, bool autoKeyLocation);
|
||||||
|
bool setEntityLocationKey(const QString& id, int frame, const QPointF& originWorld);
|
||||||
|
bool setEntityDepthScaleKey(const QString& id, int frame, double value01);
|
||||||
|
bool setEntityUserScaleKey(const QString& id, int frame, double userScale);
|
||||||
|
bool setEntityImageFrame(const QString& id, int frame, const QImage& image, QString* outRelPath = nullptr);
|
||||||
|
bool removeEntityLocationKey(const QString& id, int frame);
|
||||||
|
bool removeEntityDepthScaleKey(const QString& id, int frame);
|
||||||
|
bool removeEntityUserScaleKey(const QString& id, int frame);
|
||||||
|
bool removeEntityImageFrame(const QString& id, int frame);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool writeIndexJson();
|
||||||
|
bool readIndexJson(const QString& indexPath);
|
||||||
|
|
||||||
|
bool syncEntityPayloadsToDisk();
|
||||||
|
bool hydrateEntityPayloadsFromDisk();
|
||||||
|
void loadV1LegacyAnimationSidecars();
|
||||||
|
|
||||||
|
static QJsonObject projectToJson(const Project& project);
|
||||||
|
static bool projectFromJson(const QJsonObject& root, Project& outProject, int* outFileVersion);
|
||||||
|
static QString asRelativeUnderProject(const QString& relativePath);
|
||||||
|
static QString fileSuffixWithDot(const QString& path);
|
||||||
|
static QString asOptionalRelativeUnderProject(const QString& relativePath);
|
||||||
|
static QJsonObject entityToJson(const Project::Entity& e);
|
||||||
|
static bool entityFromJsonV1(const QJsonObject& o, Project::Entity& out);
|
||||||
|
static bool entityStubFromJsonV2(const QJsonObject& o, Project::Entity& out);
|
||||||
|
|
||||||
|
struct Operation {
|
||||||
|
enum class Type { ImportBackground, SetEntities, SetProjectTitle };
|
||||||
|
Type type {Type::ImportBackground};
|
||||||
|
QString label;
|
||||||
|
QString beforeBackgroundPath;
|
||||||
|
QString afterBackgroundPath;
|
||||||
|
QVector<Project::Entity> beforeEntities;
|
||||||
|
QVector<Project::Entity> afterEntities;
|
||||||
|
QString beforeProjectTitle;
|
||||||
|
QString afterProjectTitle;
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr int kMaxHistorySteps = 30;
|
||||||
|
void pushOperation(const Operation& op);
|
||||||
|
|
||||||
|
bool applyBackgroundPath(const QString& relativePath, bool recordHistory, const QString& label);
|
||||||
|
bool applyEntities(const QVector<Project::Entity>& entities, bool recordHistory, const QString& label);
|
||||||
|
QString copyIntoAssetsAsBackground(const QString& sourceFilePath, const QRect& cropRectInSourceImage);
|
||||||
|
bool writeDepthMap(const QImage& depth8);
|
||||||
|
bool writeDepthMapBytes(const QByteArray& pngBytes);
|
||||||
|
QString ensureEntitiesDir() const;
|
||||||
|
bool writeEntityImage(const QString& entityId, const QImage& image, QString& outRelPath);
|
||||||
|
bool writeEntityFrameImage(const QString& entityId, int frame, const QImage& image, QString& outRelPath);
|
||||||
|
|
||||||
|
private:
|
||||||
|
QString m_projectDir;
|
||||||
|
Project m_project;
|
||||||
|
|
||||||
|
QVector<Operation> m_undoStack;
|
||||||
|
QVector<Operation> m_redoStack;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace core
|
||||||
|
|
||||||
62
client/gui/CMakeLists.txt
Normal file
62
client/gui/CMakeLists.txt
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# 模块:app(入口)、main_window(主窗口与时间轴等)、editor(画布)、dialogs(裁剪/关于)
|
||||||
|
set(GUI_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
|
set(GUI_SOURCES
|
||||||
|
${GUI_ROOT}/app/main.cpp
|
||||||
|
${GUI_ROOT}/main_window/MainWindow.cpp
|
||||||
|
${GUI_ROOT}/main_window/RecentProjectHistory.cpp
|
||||||
|
${GUI_ROOT}/dialogs/AboutWindow.cpp
|
||||||
|
${GUI_ROOT}/dialogs/ImageCropDialog.cpp
|
||||||
|
${GUI_ROOT}/dialogs/FrameAnimationDialog.cpp
|
||||||
|
${GUI_ROOT}/dialogs/CancelableTaskDialog.cpp
|
||||||
|
${GUI_ROOT}/editor/EditorCanvas.cpp
|
||||||
|
${GUI_ROOT}/params/ParamControls.cpp
|
||||||
|
${GUI_ROOT}/props/BackgroundPropertySection.cpp
|
||||||
|
${GUI_ROOT}/props/EntityPropertySection.cpp
|
||||||
|
${GUI_ROOT}/timeline/TimelineWidget.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
set(GUI_HEADERS
|
||||||
|
${GUI_ROOT}/main_window/MainWindow.h
|
||||||
|
${GUI_ROOT}/main_window/RecentProjectHistory.h
|
||||||
|
${GUI_ROOT}/dialogs/AboutWindow.h
|
||||||
|
${GUI_ROOT}/dialogs/ImageCropDialog.h
|
||||||
|
${GUI_ROOT}/dialogs/FrameAnimationDialog.h
|
||||||
|
${GUI_ROOT}/dialogs/CancelableTaskDialog.h
|
||||||
|
${GUI_ROOT}/editor/EditorCanvas.h
|
||||||
|
${GUI_ROOT}/params/ParamControls.h
|
||||||
|
${GUI_ROOT}/props/BackgroundPropertySection.h
|
||||||
|
${GUI_ROOT}/props/EntityPropertySection.h
|
||||||
|
${GUI_ROOT}/props/PropertySectionWidget.h
|
||||||
|
${GUI_ROOT}/timeline/TimelineWidget.h
|
||||||
|
)
|
||||||
|
|
||||||
|
if(QT_PACKAGE STREQUAL "Qt6")
|
||||||
|
qt_add_executable(LandscapeInteractiveToolApp
|
||||||
|
${GUI_SOURCES}
|
||||||
|
${GUI_HEADERS}
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
add_executable(LandscapeInteractiveToolApp
|
||||||
|
${GUI_SOURCES}
|
||||||
|
${GUI_HEADERS}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_include_directories(LandscapeInteractiveToolApp
|
||||||
|
PRIVATE
|
||||||
|
${SRC_ROOT}
|
||||||
|
${GUI_ROOT}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(LandscapeInteractiveToolApp
|
||||||
|
PRIVATE
|
||||||
|
${QT_PACKAGE}::Core
|
||||||
|
${QT_PACKAGE}::Gui
|
||||||
|
${QT_PACKAGE}::Widgets
|
||||||
|
core
|
||||||
|
)
|
||||||
|
|
||||||
|
set_target_properties(LandscapeInteractiveToolApp PROPERTIES
|
||||||
|
OUTPUT_NAME "landscape_tool"
|
||||||
|
)
|
||||||
13
client/gui/app/main.cpp
Normal file
13
client/gui/app/main.cpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#include "main_window/MainWindow.h"
|
||||||
|
|
||||||
|
#include <QApplication>
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
QApplication app(argc, argv);
|
||||||
|
app.setApplicationName(QStringLiteral("landscape tool"));
|
||||||
|
|
||||||
|
MainWindow window;
|
||||||
|
window.show();
|
||||||
|
|
||||||
|
return app.exec();
|
||||||
|
}
|
||||||
66
client/gui/dialogs/AboutWindow.cpp
Normal file
66
client/gui/dialogs/AboutWindow.cpp
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
#include "dialogs/AboutWindow.h"
|
||||||
|
|
||||||
|
#include <QVBoxLayout>
|
||||||
|
#include <QLabel>
|
||||||
|
#include <QPushButton>
|
||||||
|
#include <QHBoxLayout>
|
||||||
|
#include <QDesktopServices>
|
||||||
|
#include <QUrl>
|
||||||
|
#include <QFont>
|
||||||
|
|
||||||
|
AboutWindow::AboutWindow(QWidget* parent)
|
||||||
|
: QDialog(parent)
|
||||||
|
{
|
||||||
|
setWindowTitle("About");
|
||||||
|
setFixedSize(400, 300);
|
||||||
|
|
||||||
|
// ===== 标题 =====
|
||||||
|
titleLabel = new QLabel("Landscape Interactive Tool");
|
||||||
|
QFont titleFont;
|
||||||
|
titleFont.setPointSize(16);
|
||||||
|
titleFont.setBold(true);
|
||||||
|
titleLabel->setFont(titleFont);
|
||||||
|
titleLabel->setAlignment(Qt::AlignCenter);
|
||||||
|
|
||||||
|
// ===== 版本 =====
|
||||||
|
versionLabel = new QLabel("Version: 1.0.0");
|
||||||
|
versionLabel->setAlignment(Qt::AlignCenter);
|
||||||
|
|
||||||
|
// ===== 作者 =====
|
||||||
|
authorLabel = new QLabel("Author: 丁伟豪");
|
||||||
|
authorLabel->setAlignment(Qt::AlignCenter);
|
||||||
|
|
||||||
|
// ===== 描述 =====
|
||||||
|
descLabel = new QLabel("An interactive tool for landscape visualization.\n"
|
||||||
|
"Built with Qt.");
|
||||||
|
descLabel->setAlignment(Qt::AlignCenter);
|
||||||
|
descLabel->setWordWrap(true);
|
||||||
|
|
||||||
|
// // ===== GitHub 按钮 =====
|
||||||
|
// githubButton = new QPushButton("GitHub");
|
||||||
|
// connect(githubButton, &QPushButton::clicked, []() {
|
||||||
|
// QDesktopServices::openUrl(QUrl("https://github.com/your_repo"));
|
||||||
|
// });
|
||||||
|
|
||||||
|
// ===== 关闭按钮 =====
|
||||||
|
closeButton = new QPushButton("Close");
|
||||||
|
connect(closeButton, &QPushButton::clicked, this, &QDialog::accept);
|
||||||
|
|
||||||
|
// ===== 按钮布局 =====
|
||||||
|
QHBoxLayout* buttonLayout = new QHBoxLayout;
|
||||||
|
buttonLayout->addStretch();
|
||||||
|
// buttonLayout->addWidget(githubButton);
|
||||||
|
buttonLayout->addWidget(closeButton);
|
||||||
|
|
||||||
|
// ===== 主布局 =====
|
||||||
|
QVBoxLayout* layout = new QVBoxLayout(this);
|
||||||
|
layout->addWidget(titleLabel);
|
||||||
|
layout->addWidget(versionLabel);
|
||||||
|
layout->addWidget(authorLabel);
|
||||||
|
layout->addSpacing(10);
|
||||||
|
layout->addWidget(descLabel);
|
||||||
|
layout->addStretch();
|
||||||
|
layout->addLayout(buttonLayout);
|
||||||
|
|
||||||
|
setLayout(layout);
|
||||||
|
}
|
||||||
20
client/gui/dialogs/AboutWindow.h
Normal file
20
client/gui/dialogs/AboutWindow.h
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <QDialog>
|
||||||
|
|
||||||
|
class QLabel;
|
||||||
|
class QPushButton;
|
||||||
|
|
||||||
|
class AboutWindow : public QDialog
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit AboutWindow(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
QLabel* titleLabel;
|
||||||
|
QLabel* versionLabel;
|
||||||
|
QLabel* authorLabel;
|
||||||
|
QLabel* descLabel;
|
||||||
|
// QPushButton* githubButton;
|
||||||
|
QPushButton* closeButton;
|
||||||
|
};
|
||||||
50
client/gui/dialogs/CancelableTaskDialog.cpp
Normal file
50
client/gui/dialogs/CancelableTaskDialog.cpp
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
#include "dialogs/CancelableTaskDialog.h"
|
||||||
|
|
||||||
|
#include <QBoxLayout>
|
||||||
|
#include <QLabel>
|
||||||
|
#include <QProgressBar>
|
||||||
|
#include <QPushButton>
|
||||||
|
|
||||||
|
CancelableTaskDialog::CancelableTaskDialog(const QString& title,
|
||||||
|
const QString& message,
|
||||||
|
QWidget* parent)
|
||||||
|
: QDialog(parent) {
|
||||||
|
setWindowTitle(title);
|
||||||
|
setModal(true);
|
||||||
|
setMinimumWidth(420);
|
||||||
|
|
||||||
|
auto* root = new QVBoxLayout(this);
|
||||||
|
root->setContentsMargins(14, 14, 14, 14);
|
||||||
|
root->setSpacing(10);
|
||||||
|
|
||||||
|
m_label = new QLabel(message, this);
|
||||||
|
m_label->setWordWrap(true);
|
||||||
|
root->addWidget(m_label);
|
||||||
|
|
||||||
|
m_bar = new QProgressBar(this);
|
||||||
|
m_bar->setRange(0, 0); // 不定进度
|
||||||
|
root->addWidget(m_bar);
|
||||||
|
|
||||||
|
auto* row = new QHBoxLayout();
|
||||||
|
row->addStretch(1);
|
||||||
|
m_btnCancel = new QPushButton(QStringLiteral("取消"), this);
|
||||||
|
row->addWidget(m_btnCancel);
|
||||||
|
root->addLayout(row);
|
||||||
|
|
||||||
|
connect(m_btnCancel, &QPushButton::clicked, this, &CancelableTaskDialog::onCancel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CancelableTaskDialog::setMessage(const QString& message) {
|
||||||
|
if (m_label) {
|
||||||
|
m_label->setText(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CancelableTaskDialog::onCancel() {
|
||||||
|
if (m_canceled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_canceled = true;
|
||||||
|
emit canceled();
|
||||||
|
}
|
||||||
|
|
||||||
35
client/gui/dialogs/CancelableTaskDialog.h
Normal file
35
client/gui/dialogs/CancelableTaskDialog.h
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QDialog>
|
||||||
|
#include <QString>
|
||||||
|
|
||||||
|
class QLabel;
|
||||||
|
class QProgressBar;
|
||||||
|
class QPushButton;
|
||||||
|
|
||||||
|
// 可复用的“长任务提示框”:显示提示文本 + 不定进度条 + 取消按钮。
|
||||||
|
// - 任务本身由调用方启动(例如网络请求/后台线程)
|
||||||
|
// - 调用方在取消时应中止任务,并调用 reject()/close()
|
||||||
|
class CancelableTaskDialog final : public QDialog {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit CancelableTaskDialog(const QString& title,
|
||||||
|
const QString& message,
|
||||||
|
QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void setMessage(const QString& message);
|
||||||
|
bool wasCanceled() const { return m_canceled; }
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void canceled();
|
||||||
|
|
||||||
|
private slots:
|
||||||
|
void onCancel();
|
||||||
|
|
||||||
|
private:
|
||||||
|
QLabel* m_label = nullptr;
|
||||||
|
QProgressBar* m_bar = nullptr;
|
||||||
|
QPushButton* m_btnCancel = nullptr;
|
||||||
|
bool m_canceled = false;
|
||||||
|
};
|
||||||
|
|
||||||
252
client/gui/dialogs/FrameAnimationDialog.cpp
Normal file
252
client/gui/dialogs/FrameAnimationDialog.cpp
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
#include "dialogs/FrameAnimationDialog.h"
|
||||||
|
|
||||||
|
#include "core/animation/AnimationSampling.h"
|
||||||
|
#include "core/workspace/ProjectWorkspace.h"
|
||||||
|
|
||||||
|
#include <QBoxLayout>
|
||||||
|
#include <QDir>
|
||||||
|
#include <QFileDialog>
|
||||||
|
#include <QFileInfo>
|
||||||
|
#include <QImage>
|
||||||
|
#include <QLabel>
|
||||||
|
#include <QListWidget>
|
||||||
|
#include <QMessageBox>
|
||||||
|
#include <QPixmap>
|
||||||
|
#include <QPushButton>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
QString resolvedImageAbsForFrame(const core::ProjectWorkspace& ws,
|
||||||
|
const core::Project::Entity& e,
|
||||||
|
int frame) {
|
||||||
|
const QString rel = core::sampleImagePath(e.imageFrames, frame, e.imagePath);
|
||||||
|
if (rel.isEmpty()) return {};
|
||||||
|
const QString abs = QDir(ws.projectDir()).filePath(rel);
|
||||||
|
return abs;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
FrameAnimationDialog::FrameAnimationDialog(core::ProjectWorkspace& workspace,
|
||||||
|
const QString& entityId,
|
||||||
|
int startFrame,
|
||||||
|
int endFrame,
|
||||||
|
QWidget* parent)
|
||||||
|
: QDialog(parent)
|
||||||
|
, m_workspace(workspace)
|
||||||
|
, m_entityId(entityId) {
|
||||||
|
setWindowTitle(QStringLiteral("区间动画帧"));
|
||||||
|
setModal(true);
|
||||||
|
setMinimumSize(720, 420);
|
||||||
|
|
||||||
|
m_start = std::min(startFrame, endFrame);
|
||||||
|
m_end = std::max(startFrame, endFrame);
|
||||||
|
|
||||||
|
auto* root = new QVBoxLayout(this);
|
||||||
|
root->setContentsMargins(12, 12, 12, 12);
|
||||||
|
root->setSpacing(10);
|
||||||
|
|
||||||
|
m_title = new QLabel(this);
|
||||||
|
m_title->setText(QStringLiteral("实体 %1 | 区间 [%2, %3]").arg(m_entityId).arg(m_start).arg(m_end));
|
||||||
|
root->addWidget(m_title);
|
||||||
|
|
||||||
|
auto* mid = new QHBoxLayout();
|
||||||
|
root->addLayout(mid, 1);
|
||||||
|
|
||||||
|
m_list = new QListWidget(this);
|
||||||
|
m_list->setMinimumWidth(240);
|
||||||
|
mid->addWidget(m_list, 0);
|
||||||
|
|
||||||
|
auto* right = new QVBoxLayout();
|
||||||
|
mid->addLayout(right, 1);
|
||||||
|
|
||||||
|
m_preview = new QLabel(this);
|
||||||
|
m_preview->setMinimumSize(320, 240);
|
||||||
|
m_preview->setFrameShape(QFrame::StyledPanel);
|
||||||
|
m_preview->setAlignment(Qt::AlignCenter);
|
||||||
|
m_preview->setText(QStringLiteral("选择一帧"));
|
||||||
|
right->addWidget(m_preview, 1);
|
||||||
|
|
||||||
|
auto* row = new QHBoxLayout();
|
||||||
|
right->addLayout(row);
|
||||||
|
m_btnReplace = new QPushButton(QStringLiteral("替换此帧…"), this);
|
||||||
|
m_btnClear = new QPushButton(QStringLiteral("清除此帧(恢复默认)"), this);
|
||||||
|
row->addWidget(m_btnReplace);
|
||||||
|
row->addWidget(m_btnClear);
|
||||||
|
|
||||||
|
auto* row2 = new QHBoxLayout();
|
||||||
|
right->addLayout(row2);
|
||||||
|
m_btnImportFiles = new QPushButton(QStringLiteral("批量导入(多选图片)…"), this);
|
||||||
|
m_btnImportFolder = new QPushButton(QStringLiteral("批量导入(文件夹)…"), this);
|
||||||
|
row2->addWidget(m_btnImportFiles);
|
||||||
|
row2->addWidget(m_btnImportFolder);
|
||||||
|
row2->addStretch(1);
|
||||||
|
|
||||||
|
auto* closeRow = new QHBoxLayout();
|
||||||
|
root->addLayout(closeRow);
|
||||||
|
closeRow->addStretch(1);
|
||||||
|
auto* btnClose = new QPushButton(QStringLiteral("关闭"), this);
|
||||||
|
closeRow->addWidget(btnClose);
|
||||||
|
|
||||||
|
connect(btnClose, &QPushButton::clicked, this, &QDialog::accept);
|
||||||
|
connect(m_list, &QListWidget::currentRowChanged, this, [this](int) { onSelectFrame(); });
|
||||||
|
connect(m_btnReplace, &QPushButton::clicked, this, &FrameAnimationDialog::onReplaceCurrentFrame);
|
||||||
|
connect(m_btnClear, &QPushButton::clicked, this, &FrameAnimationDialog::onClearCurrentFrame);
|
||||||
|
connect(m_btnImportFiles, &QPushButton::clicked, this, &FrameAnimationDialog::onBatchImportFiles);
|
||||||
|
connect(m_btnImportFolder, &QPushButton::clicked, this, &FrameAnimationDialog::onBatchImportFolder);
|
||||||
|
|
||||||
|
rebuildFrameList();
|
||||||
|
if (m_list->count() > 0) {
|
||||||
|
m_list->setCurrentRow(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::rebuildFrameList() {
|
||||||
|
m_list->clear();
|
||||||
|
if (!m_workspace.isOpen()) return;
|
||||||
|
|
||||||
|
const auto& ents = m_workspace.entities();
|
||||||
|
const core::Project::Entity* hit = nullptr;
|
||||||
|
for (const auto& e : ents) {
|
||||||
|
if (e.id == m_entityId) {
|
||||||
|
hit = &e;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!hit) return;
|
||||||
|
|
||||||
|
// 默认贴图(用于 UI 提示)
|
||||||
|
m_defaultImageAbs.clear();
|
||||||
|
if (!hit->imagePath.isEmpty()) {
|
||||||
|
const QString abs = QDir(m_workspace.projectDir()).filePath(hit->imagePath);
|
||||||
|
if (QFileInfo::exists(abs)) {
|
||||||
|
m_defaultImageAbs = abs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int f = m_start; f <= m_end; ++f) {
|
||||||
|
bool hasCustom = false;
|
||||||
|
for (const auto& k : hit->imageFrames) {
|
||||||
|
if (k.frame == f) {
|
||||||
|
hasCustom = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto* it = new QListWidgetItem(QStringLiteral("%1%2").arg(f).arg(hasCustom ? QStringLiteral(" *") : QString()));
|
||||||
|
it->setData(Qt::UserRole, f);
|
||||||
|
m_list->addItem(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::onSelectFrame() {
|
||||||
|
auto* it = m_list->currentItem();
|
||||||
|
if (!it) return;
|
||||||
|
const int f = it->data(Qt::UserRole).toInt();
|
||||||
|
updatePreviewForFrame(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::updatePreviewForFrame(int frame) {
|
||||||
|
if (!m_workspace.isOpen()) return;
|
||||||
|
const auto& ents = m_workspace.entities();
|
||||||
|
const core::Project::Entity* hit = nullptr;
|
||||||
|
for (const auto& e : ents) {
|
||||||
|
if (e.id == m_entityId) {
|
||||||
|
hit = &e;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!hit) return;
|
||||||
|
|
||||||
|
const QString abs = resolvedImageAbsForFrame(m_workspace, *hit, frame);
|
||||||
|
if (abs.isEmpty() || !QFileInfo::exists(abs)) {
|
||||||
|
m_preview->setText(QStringLiteral("无图像"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
QPixmap pm(abs);
|
||||||
|
if (pm.isNull()) {
|
||||||
|
m_preview->setText(QStringLiteral("加载失败"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_preview->setPixmap(pm.scaled(m_preview->size(), Qt::KeepAspectRatio, Qt::SmoothTransformation));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FrameAnimationDialog::applyImageToFrame(int frame, const QString& absImagePath) {
|
||||||
|
QImage img(absImagePath);
|
||||||
|
if (img.isNull()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (img.format() != QImage::Format_ARGB32_Premultiplied) {
|
||||||
|
img = img.convertToFormat(QImage::Format_ARGB32_Premultiplied);
|
||||||
|
}
|
||||||
|
return m_workspace.setEntityImageFrame(m_entityId, frame, img);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::onReplaceCurrentFrame() {
|
||||||
|
auto* it = m_list->currentItem();
|
||||||
|
if (!it) return;
|
||||||
|
const int f = it->data(Qt::UserRole).toInt();
|
||||||
|
const QString path = QFileDialog::getOpenFileName(
|
||||||
|
this,
|
||||||
|
QStringLiteral("选择该帧图像"),
|
||||||
|
QString(),
|
||||||
|
QStringLiteral("Images (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)"));
|
||||||
|
if (path.isEmpty()) return;
|
||||||
|
if (!applyImageToFrame(f, path)) {
|
||||||
|
QMessageBox::warning(this, QStringLiteral("动画帧"), QStringLiteral("写入该帧失败。"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rebuildFrameList();
|
||||||
|
updatePreviewForFrame(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::onClearCurrentFrame() {
|
||||||
|
auto* it = m_list->currentItem();
|
||||||
|
if (!it) return;
|
||||||
|
const int f = it->data(Qt::UserRole).toInt();
|
||||||
|
if (!m_workspace.removeEntityImageFrame(m_entityId, f)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rebuildFrameList();
|
||||||
|
updatePreviewForFrame(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::onBatchImportFiles() {
|
||||||
|
const QStringList paths = QFileDialog::getOpenFileNames(
|
||||||
|
this,
|
||||||
|
QStringLiteral("选择逐帧动画图片(按文件名排序)"),
|
||||||
|
QString(),
|
||||||
|
QStringLiteral("Images (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)"));
|
||||||
|
if (paths.isEmpty()) return;
|
||||||
|
QStringList sorted = paths;
|
||||||
|
sorted.sort(Qt::CaseInsensitive);
|
||||||
|
const int need = m_end - m_start + 1;
|
||||||
|
const int count = std::min(need, static_cast<int>(sorted.size()));
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
applyImageToFrame(m_start + i, sorted[i]);
|
||||||
|
}
|
||||||
|
rebuildFrameList();
|
||||||
|
onSelectFrame();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FrameAnimationDialog::onBatchImportFolder() {
|
||||||
|
const QString dir = QFileDialog::getExistingDirectory(this, QStringLiteral("选择逐帧动画图片文件夹"));
|
||||||
|
if (dir.isEmpty()) return;
|
||||||
|
QDir d(dir);
|
||||||
|
d.setFilter(QDir::Files | QDir::Readable);
|
||||||
|
d.setSorting(QDir::Name);
|
||||||
|
const QStringList filters = {QStringLiteral("*.png"),
|
||||||
|
QStringLiteral("*.jpg"),
|
||||||
|
QStringLiteral("*.jpeg"),
|
||||||
|
QStringLiteral("*.bmp"),
|
||||||
|
QStringLiteral("*.webp")};
|
||||||
|
const QStringList files = d.entryList(filters, QDir::Files, QDir::Name);
|
||||||
|
if (files.isEmpty()) return;
|
||||||
|
const int need = m_end - m_start + 1;
|
||||||
|
const int count = std::min(need, static_cast<int>(files.size()));
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
applyImageToFrame(m_start + i, d.filePath(files[i]));
|
||||||
|
}
|
||||||
|
rebuildFrameList();
|
||||||
|
onSelectFrame();
|
||||||
|
}
|
||||||
|
|
||||||
52
client/gui/dialogs/FrameAnimationDialog.h
Normal file
52
client/gui/dialogs/FrameAnimationDialog.h
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QDialog>
|
||||||
|
#include <QString>
|
||||||
|
#include <QStringList>
|
||||||
|
|
||||||
|
namespace core {
|
||||||
|
class ProjectWorkspace;
|
||||||
|
}
|
||||||
|
|
||||||
|
class QLabel;
|
||||||
|
class QListWidget;
|
||||||
|
class QPushButton;
|
||||||
|
|
||||||
|
class FrameAnimationDialog final : public QDialog {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
FrameAnimationDialog(core::ProjectWorkspace& workspace,
|
||||||
|
const QString& entityId,
|
||||||
|
int startFrame,
|
||||||
|
int endFrame,
|
||||||
|
QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
private slots:
|
||||||
|
void onSelectFrame();
|
||||||
|
void onReplaceCurrentFrame();
|
||||||
|
void onClearCurrentFrame();
|
||||||
|
void onBatchImportFiles();
|
||||||
|
void onBatchImportFolder();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void rebuildFrameList();
|
||||||
|
void updatePreviewForFrame(int frame);
|
||||||
|
bool applyImageToFrame(int frame, const QString& absImagePath);
|
||||||
|
|
||||||
|
private:
|
||||||
|
core::ProjectWorkspace& m_workspace;
|
||||||
|
QString m_entityId;
|
||||||
|
int m_start = 0;
|
||||||
|
int m_end = 0;
|
||||||
|
|
||||||
|
QLabel* m_title = nullptr;
|
||||||
|
QListWidget* m_list = nullptr;
|
||||||
|
QLabel* m_preview = nullptr;
|
||||||
|
QPushButton* m_btnReplace = nullptr;
|
||||||
|
QPushButton* m_btnClear = nullptr;
|
||||||
|
QPushButton* m_btnImportFiles = nullptr;
|
||||||
|
QPushButton* m_btnImportFolder = nullptr;
|
||||||
|
|
||||||
|
QString m_defaultImageAbs;
|
||||||
|
};
|
||||||
|
|
||||||
209
client/gui/dialogs/ImageCropDialog.cpp
Normal file
209
client/gui/dialogs/ImageCropDialog.cpp
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
#include "dialogs/ImageCropDialog.h"
|
||||||
|
|
||||||
|
#include <QBoxLayout>
|
||||||
|
#include <QDialogButtonBox>
|
||||||
|
#include <QLabel>
|
||||||
|
#include <QMouseEvent>
|
||||||
|
#include <QPainter>
|
||||||
|
#include <QPushButton>
|
||||||
|
#include <QtMath>
|
||||||
|
|
||||||
|
class ImageCropDialog::CropView final : public QWidget {
|
||||||
|
public:
|
||||||
|
explicit CropView(QWidget* parent = nullptr)
|
||||||
|
: QWidget(parent) {
|
||||||
|
setMouseTracking(true);
|
||||||
|
setMinimumSize(480, 320);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setImage(const QImage& img) {
|
||||||
|
m_image = img;
|
||||||
|
m_selection = {};
|
||||||
|
updateGeometry();
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasSelection() const { return !m_selection.isNull() && m_selection.width() > 0 && m_selection.height() > 0; }
|
||||||
|
|
||||||
|
QRect selectionInImagePixels() const {
|
||||||
|
if (m_image.isNull() || !hasSelection()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const auto map = viewToImageTransform();
|
||||||
|
// selection 是 view 坐标;映射到 image 像素坐标
|
||||||
|
const QRectF selF = QRectF(m_selection).normalized();
|
||||||
|
bool invertible = false;
|
||||||
|
const QTransform inv = map.inverted(&invertible);
|
||||||
|
if (!invertible) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const QPointF topLeftImg = inv.map(selF.topLeft());
|
||||||
|
const QPointF bottomRightImg = inv.map(selF.bottomRight());
|
||||||
|
|
||||||
|
// 使用 floor/ceil,避免因为取整导致宽高变 0
|
||||||
|
const int left = qFloor(std::min(topLeftImg.x(), bottomRightImg.x()));
|
||||||
|
const int top = qFloor(std::min(topLeftImg.y(), bottomRightImg.y()));
|
||||||
|
const int right = qCeil(std::max(topLeftImg.x(), bottomRightImg.x()));
|
||||||
|
const int bottom = qCeil(std::max(topLeftImg.y(), bottomRightImg.y()));
|
||||||
|
|
||||||
|
QRect r(QPoint(left, top), QPoint(right, bottom));
|
||||||
|
r = r.normalized().intersected(QRect(0, 0, m_image.width(), m_image.height()));
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void resetSelection() {
|
||||||
|
m_selection = {};
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void paintEvent(QPaintEvent*) override {
|
||||||
|
QPainter p(this);
|
||||||
|
p.fillRect(rect(), palette().window());
|
||||||
|
|
||||||
|
if (m_image.isNull()) {
|
||||||
|
p.setPen(palette().text().color());
|
||||||
|
p.drawText(rect(), Qt::AlignCenter, QStringLiteral("无法加载图片"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto map = viewToImageTransform();
|
||||||
|
p.setRenderHint(QPainter::SmoothPixmapTransform, true);
|
||||||
|
p.setTransform(map);
|
||||||
|
p.drawImage(QPoint(0, 0), m_image);
|
||||||
|
p.resetTransform();
|
||||||
|
|
||||||
|
if (hasSelection()) {
|
||||||
|
// 避免 CompositionMode_Clear 在某些平台/样式下表现异常:
|
||||||
|
// 用“围绕选区画四块遮罩”的方式实现高亮裁剪区域。
|
||||||
|
const QRect sel = m_selection.normalized().intersected(rect());
|
||||||
|
const QColor shade(0, 0, 0, 120);
|
||||||
|
|
||||||
|
// 上
|
||||||
|
p.fillRect(QRect(0, 0, width(), sel.top()), shade);
|
||||||
|
// 下
|
||||||
|
p.fillRect(QRect(0, sel.bottom(), width(), height() - sel.bottom()), shade);
|
||||||
|
// 左
|
||||||
|
p.fillRect(QRect(0, sel.top(), sel.left(), sel.height()), shade);
|
||||||
|
// 右
|
||||||
|
p.fillRect(QRect(sel.right(), sel.top(), width() - sel.right(), sel.height()), shade);
|
||||||
|
|
||||||
|
p.setPen(QPen(QColor(255, 255, 255, 220), 2));
|
||||||
|
p.drawRect(sel);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mousePressEvent(QMouseEvent* e) override {
|
||||||
|
if (m_image.isNull() || e->button() != Qt::LeftButton) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_dragging = true;
|
||||||
|
m_anchor = e->position().toPoint();
|
||||||
|
m_selection = QRect(m_anchor, m_anchor);
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mouseMoveEvent(QMouseEvent* e) override {
|
||||||
|
if (!m_dragging) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const QPoint cur = e->position().toPoint();
|
||||||
|
m_selection = QRect(m_anchor, cur).normalized();
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mouseReleaseEvent(QMouseEvent* e) override {
|
||||||
|
if (e->button() != Qt::LeftButton) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_dragging = false;
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
QTransform viewToImageTransform() const {
|
||||||
|
// 让图片按比例 fit 到 view 中居中显示
|
||||||
|
const QSizeF viewSize = size();
|
||||||
|
const QSizeF imgSize = m_image.size();
|
||||||
|
const qreal sx = viewSize.width() / imgSize.width();
|
||||||
|
const qreal sy = viewSize.height() / imgSize.height();
|
||||||
|
const qreal s = std::min(sx, sy);
|
||||||
|
|
||||||
|
const qreal drawW = imgSize.width() * s;
|
||||||
|
const qreal drawH = imgSize.height() * s;
|
||||||
|
const qreal offsetX = (viewSize.width() - drawW) / 2.0;
|
||||||
|
const qreal offsetY = (viewSize.height() - drawH) / 2.0;
|
||||||
|
|
||||||
|
QTransform t;
|
||||||
|
t.translate(offsetX, offsetY);
|
||||||
|
t.scale(s, s);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
QImage m_image;
|
||||||
|
bool m_dragging = false;
|
||||||
|
QPoint m_anchor;
|
||||||
|
QRect m_selection;
|
||||||
|
};
|
||||||
|
|
||||||
|
ImageCropDialog::ImageCropDialog(const QString& imagePath, QWidget* parent)
|
||||||
|
: QDialog(parent),
|
||||||
|
m_imagePath(imagePath) {
|
||||||
|
setWindowTitle(QStringLiteral("裁剪图片"));
|
||||||
|
setModal(true);
|
||||||
|
resize(900, 600);
|
||||||
|
loadImageOrClose();
|
||||||
|
rebuildUi();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImageCropDialog::loadImageOrClose() {
|
||||||
|
m_image = QImage(m_imagePath);
|
||||||
|
if (m_image.isNull()) {
|
||||||
|
reject();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImageCropDialog::rebuildUi() {
|
||||||
|
auto* root = new QVBoxLayout(this);
|
||||||
|
|
||||||
|
auto* hint = new QLabel(QStringLiteral("拖拽选择裁剪区域(不选则使用整张图)。"), this);
|
||||||
|
root->addWidget(hint);
|
||||||
|
|
||||||
|
m_view = new CropView(this);
|
||||||
|
m_view->setImage(m_image);
|
||||||
|
root->addWidget(m_view, 1);
|
||||||
|
|
||||||
|
auto* buttons = new QDialogButtonBox(QDialogButtonBox::Ok | QDialogButtonBox::Cancel, this);
|
||||||
|
m_okButton = buttons->button(QDialogButtonBox::Ok);
|
||||||
|
auto* resetBtn = new QPushButton(QStringLiteral("重置选择"), this);
|
||||||
|
buttons->addButton(resetBtn, QDialogButtonBox::ActionRole);
|
||||||
|
|
||||||
|
connect(resetBtn, &QPushButton::clicked, this, &ImageCropDialog::onReset);
|
||||||
|
connect(buttons, &QDialogButtonBox::accepted, this, &ImageCropDialog::onOk);
|
||||||
|
connect(buttons, &QDialogButtonBox::rejected, this, &ImageCropDialog::reject);
|
||||||
|
root->addWidget(buttons);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ImageCropDialog::hasValidSelection() const {
|
||||||
|
return m_view && m_view->hasSelection();
|
||||||
|
}
|
||||||
|
|
||||||
|
QRect ImageCropDialog::selectedRectInImagePixels() const {
|
||||||
|
if (!m_view) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return m_view->selectionInImagePixels();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImageCropDialog::onReset() {
|
||||||
|
if (m_view) {
|
||||||
|
m_view->resetSelection();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImageCropDialog::onOk() {
|
||||||
|
accept();
|
||||||
|
}
|
||||||
|
|
||||||
34
client/gui/dialogs/ImageCropDialog.h
Normal file
34
client/gui/dialogs/ImageCropDialog.h
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QDialog>
|
||||||
|
#include <QImage>
|
||||||
|
#include <QRect>
|
||||||
|
|
||||||
|
class QLabel;
|
||||||
|
class QPushButton;
|
||||||
|
|
||||||
|
class ImageCropDialog final : public QDialog {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit ImageCropDialog(const QString& imagePath, QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
bool hasValidSelection() const;
|
||||||
|
QRect selectedRectInImagePixels() const;
|
||||||
|
|
||||||
|
private slots:
|
||||||
|
void onReset();
|
||||||
|
void onOk();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void loadImageOrClose();
|
||||||
|
void rebuildUi();
|
||||||
|
|
||||||
|
private:
|
||||||
|
class CropView;
|
||||||
|
CropView* m_view = nullptr;
|
||||||
|
QPushButton* m_okButton = nullptr;
|
||||||
|
|
||||||
|
QString m_imagePath;
|
||||||
|
QImage m_image;
|
||||||
|
};
|
||||||
|
|
||||||
1327
client/gui/editor/EditorCanvas.cpp
Normal file
1327
client/gui/editor/EditorCanvas.cpp
Normal file
File diff suppressed because it is too large
Load Diff
179
client/gui/editor/EditorCanvas.h
Normal file
179
client/gui/editor/EditorCanvas.h
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/domain/Project.h"
|
||||||
|
|
||||||
|
#include <QPixmap>
|
||||||
|
#include <QPointF>
|
||||||
|
#include <QImage>
|
||||||
|
#include <QPainterPath>
|
||||||
|
#include <QVector>
|
||||||
|
#include <QWidget>
|
||||||
|
#include <QElapsedTimer>
|
||||||
|
|
||||||
|
class EditorCanvas final : public QWidget {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
enum class Tool { Move, Zoom, CreateEntity };
|
||||||
|
Q_ENUM(Tool)
|
||||||
|
|
||||||
|
explicit EditorCanvas(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void setBackgroundImagePath(const QString& absolutePath);
|
||||||
|
QString backgroundImagePath() const { return m_bgAbsPath; }
|
||||||
|
void setBackgroundVisible(bool on);
|
||||||
|
bool backgroundVisible() const { return m_backgroundVisible; }
|
||||||
|
|
||||||
|
void setDepthMapPath(const QString& absolutePath);
|
||||||
|
void setDepthOverlayEnabled(bool on);
|
||||||
|
bool depthOverlayEnabled() const { return m_depthOverlayEnabled; }
|
||||||
|
|
||||||
|
void setTool(Tool tool);
|
||||||
|
Tool tool() const { return m_tool; }
|
||||||
|
|
||||||
|
void resetView();
|
||||||
|
void zoomToFit();
|
||||||
|
|
||||||
|
void setWorldAxesVisible(bool on);
|
||||||
|
bool worldAxesVisible() const { return m_worldAxesVisible; }
|
||||||
|
|
||||||
|
void setAxisLabelsVisible(bool on);
|
||||||
|
bool axisLabelsVisible() const { return m_axisLabelsVisible; }
|
||||||
|
|
||||||
|
void setGizmoLabelsVisible(bool on);
|
||||||
|
bool gizmoLabelsVisible() const { return m_gizmoLabelsVisible; }
|
||||||
|
|
||||||
|
void setGridVisible(bool on);
|
||||||
|
bool gridVisible() const { return m_gridVisible; }
|
||||||
|
|
||||||
|
void setCheckerboardVisible(bool on);
|
||||||
|
bool checkerboardVisible() const { return m_checkerboardVisible; }
|
||||||
|
|
||||||
|
// 预览呈现:完整背景 + 全部实体(忽略显隐开关),隐藏编辑辅助元素,仅可平移/缩放查看
|
||||||
|
void setPresentationPreviewMode(bool on);
|
||||||
|
bool presentationPreviewMode() const { return m_presentationPreviewMode; }
|
||||||
|
|
||||||
|
void setEntities(const QVector<core::Project::Entity>& entities, const QString& projectDirAbs);
|
||||||
|
void setCurrentFrame(int frame);
|
||||||
|
int currentFrame() const { return m_currentFrame; }
|
||||||
|
|
||||||
|
bool isDraggingEntity() const { return m_draggingEntity; }
|
||||||
|
|
||||||
|
void selectEntityById(const QString& id);
|
||||||
|
void clearEntitySelection();
|
||||||
|
|
||||||
|
// 与动画求值一致的原点/缩放(用于 K 帧与自动关键帧)
|
||||||
|
QPointF selectedAnimatedOriginWorld() const;
|
||||||
|
double selectedDepthScale01() const;
|
||||||
|
QPointF selectedEntityCentroidWorld() const;
|
||||||
|
double selectedDistanceScaleMultiplier() const;
|
||||||
|
double selectedUserScale() const;
|
||||||
|
double selectedCombinedScale() const;
|
||||||
|
|
||||||
|
enum class DragMode { None, Free, AxisX, AxisY };
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void hoveredWorldPosChanged(const QPointF& worldPos);
|
||||||
|
void hoveredWorldPosDepthChanged(const QPointF& worldPos, int depthZ);
|
||||||
|
void selectedEntityChanged(bool hasSelection, const QString& id, int depth, const QPointF& originWorld);
|
||||||
|
void requestAddEntity(const core::Project::Entity& entity, const QImage& image);
|
||||||
|
void requestMoveEntity(const QString& id, const QPointF& delta);
|
||||||
|
void entityDragActiveChanged(bool on);
|
||||||
|
void selectedEntityPreviewChanged(const QString& id, int depth, const QPointF& originWorld);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void paintEvent(QPaintEvent* e) override;
|
||||||
|
void resizeEvent(QResizeEvent* e) override;
|
||||||
|
void mousePressEvent(QMouseEvent* e) override;
|
||||||
|
void mouseMoveEvent(QMouseEvent* e) override;
|
||||||
|
void mouseReleaseEvent(QMouseEvent* e) override;
|
||||||
|
void wheelEvent(QWheelEvent* e) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ensurePixmapLoaded() const;
|
||||||
|
void invalidatePixmap();
|
||||||
|
void updateCursor();
|
||||||
|
|
||||||
|
QPointF viewToWorld(const QPointF& v) const;
|
||||||
|
QPointF worldToView(const QPointF& w) const;
|
||||||
|
QRectF worldRectOfBackground() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Entity {
|
||||||
|
QString id;
|
||||||
|
QRectF rect; // world 坐标(用于拖拽与约束)
|
||||||
|
QVector<QPointF> polygonWorld; // 非空则使用 polygon
|
||||||
|
QPainterPath pathWorld; // polygonWorld 对应的 world 路径(缓存,避免每帧重建)
|
||||||
|
QVector<QPointF> cutoutPolygonWorld;
|
||||||
|
QColor color;
|
||||||
|
|
||||||
|
// 实体独立信息:
|
||||||
|
int depth = 0; // 0..255,来自划分区域平均深度
|
||||||
|
QImage image; // 抠图后的实体图像(带透明)
|
||||||
|
QPointF imageTopLeft; // image 对应的 world 左上角
|
||||||
|
double visualScale = 1.0; // 实体在 world 坐标下的缩放(用于贴图绘制)
|
||||||
|
double userScale = 1.0; // 与深度距离缩放相乘
|
||||||
|
QPointF animatedOriginWorld;
|
||||||
|
double animatedDepthScale01 = 0.5;
|
||||||
|
// 编辑模式下实体被设为隐藏时:不响应点选且不绘制,除非当前选中(便于树选隐藏实体)
|
||||||
|
bool hiddenInEditMode = false;
|
||||||
|
};
|
||||||
|
int hitTestEntity(const QPointF& worldPos) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
QString m_bgAbsPath;
|
||||||
|
bool m_backgroundVisible = true;
|
||||||
|
mutable QPixmap m_bgPixmap;
|
||||||
|
mutable bool m_pixmapDirty = true;
|
||||||
|
mutable QImage m_bgImage; // 原背景(用于抠图/填充)
|
||||||
|
mutable QImage m_bgImageCutout; // 抠图后的背景(实体区域填黑)
|
||||||
|
mutable bool m_bgImageDirty = true;
|
||||||
|
mutable bool m_bgCutoutDirty = true;
|
||||||
|
|
||||||
|
QString m_depthAbsPath;
|
||||||
|
mutable QImage m_depthImage8;
|
||||||
|
mutable bool m_depthDirty = true;
|
||||||
|
bool m_depthOverlayEnabled = false;
|
||||||
|
int m_depthOverlayAlpha = 110;
|
||||||
|
bool m_worldAxesVisible = true;
|
||||||
|
bool m_axisLabelsVisible = true;
|
||||||
|
bool m_gizmoLabelsVisible = true;
|
||||||
|
bool m_gridVisible = true;
|
||||||
|
bool m_checkerboardVisible = true;
|
||||||
|
bool m_presentationPreviewMode = false;
|
||||||
|
|
||||||
|
Tool m_tool = Tool::Move;
|
||||||
|
qreal m_scale = 1.0;
|
||||||
|
QPointF m_pan; // world 原点对应的 view 坐标偏移(view = world*scale + pan)
|
||||||
|
|
||||||
|
bool m_dragging = false;
|
||||||
|
bool m_draggingEntity = false;
|
||||||
|
bool m_drawingEntity = false;
|
||||||
|
QPointF m_lastMouseView;
|
||||||
|
// 拖动以“实体原点 animatedOriginWorld”为基准,避免因缩放导致 rect/topLeft 抖动
|
||||||
|
QPointF m_entityDragOffsetOriginWorld;
|
||||||
|
QPointF m_entityDragStartAnimatedOrigin;
|
||||||
|
// 拖动性能优化:拖动过程中不逐点修改 polygonWorld,而是保留基准形状+增量参数,在 paint 时做变换预览
|
||||||
|
bool m_dragPreviewActive = false;
|
||||||
|
QVector<QPointF> m_dragPolyBase;
|
||||||
|
QPainterPath m_dragPathBase;
|
||||||
|
QPointF m_dragImageTopLeftBase;
|
||||||
|
QRectF m_dragRectBase;
|
||||||
|
QPointF m_dragOriginBase;
|
||||||
|
QPointF m_dragDelta; // 纯平移
|
||||||
|
QPointF m_dragCentroidBase;
|
||||||
|
double m_dragScaleBase = 1.0; // 拖动开始时的 visualScale
|
||||||
|
double m_dragScaleRatio = 1.0; // 相对 m_dragScaleBase 的缩放比(由深度重算驱动)
|
||||||
|
QElapsedTimer m_previewEmitTimer;
|
||||||
|
qint64 m_lastPreviewEmitMs = 0;
|
||||||
|
qint64 m_lastDepthScaleRecalcMs = 0;
|
||||||
|
int m_selectedEntity = -1;
|
||||||
|
|
||||||
|
DragMode m_dragMode = DragMode::None;
|
||||||
|
QPointF m_dragStartMouseWorld;
|
||||||
|
|
||||||
|
QVector<Entity> m_entities;
|
||||||
|
QVector<QPointF> m_strokeWorld;
|
||||||
|
|
||||||
|
int m_currentFrame = 0;
|
||||||
|
};
|
||||||
|
|
||||||
2097
client/gui/main_window/MainWindow.cpp
Normal file
2097
client/gui/main_window/MainWindow.cpp
Normal file
File diff suppressed because it is too large
Load Diff
177
client/gui/main_window/MainWindow.h
Normal file
177
client/gui/main_window/MainWindow.h
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/workspace/ProjectWorkspace.h"
|
||||||
|
#include "main_window/RecentProjectHistory.h"
|
||||||
|
|
||||||
|
#include <QMainWindow>
|
||||||
|
#include <QPointF>
|
||||||
|
#include <QFrame>
|
||||||
|
#include <QIcon>
|
||||||
|
#include <QTimer>
|
||||||
|
|
||||||
|
class QAction;
|
||||||
|
class QCheckBox;
|
||||||
|
class QComboBox;
|
||||||
|
class QDockWidget;
|
||||||
|
class QFormLayout;
|
||||||
|
class QLabel;
|
||||||
|
class QMenu;
|
||||||
|
class QFrame;
|
||||||
|
class QIcon;
|
||||||
|
class QPushButton;
|
||||||
|
class QSlider;
|
||||||
|
class QStackedWidget;
|
||||||
|
class QToolButton;
|
||||||
|
class QTreeWidget;
|
||||||
|
class QTreeWidgetItem;
|
||||||
|
class QWidget;
|
||||||
|
class EditorCanvas;
|
||||||
|
class TimelineWidget;
|
||||||
|
namespace gui {
|
||||||
|
class BackgroundPropertySection;
|
||||||
|
class EntityPropertySection;
|
||||||
|
}
|
||||||
|
|
||||||
|
class MainWindow : public QMainWindow {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit MainWindow(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool eventFilter(QObject* watched, QEvent* event) override;
|
||||||
|
|
||||||
|
private slots:
|
||||||
|
|
||||||
|
// 文件菜单槽函数
|
||||||
|
void onNewProject();
|
||||||
|
void onOpenProject();
|
||||||
|
void onSaveProject();
|
||||||
|
void onCloseProject();
|
||||||
|
|
||||||
|
// 编辑菜单槽函数
|
||||||
|
void onUndo();
|
||||||
|
void onRedo();
|
||||||
|
void onCopyObject();
|
||||||
|
void onPasteObject();
|
||||||
|
|
||||||
|
// 帮助菜单槽函数
|
||||||
|
void onAbout();
|
||||||
|
void onComputeDepth();
|
||||||
|
void onTogglePlay(bool on);
|
||||||
|
void onInsertCombinedKey(); // 位置 + userScale
|
||||||
|
|
||||||
|
void onProjectTreeItemClicked(QTreeWidgetItem* item, int column);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void computeDepthAsync();
|
||||||
|
// UI 状态分三种:
|
||||||
|
// - Welcome:未打开项目。只显示欢迎页,其它 dock 一律隐藏,视图开关禁用。
|
||||||
|
// - Editor:已打开项目。显示编辑页,按默认规则显示 dock,同时允许用户通过“视图”菜单控制。
|
||||||
|
// - Preview:预览展示。用于全流程完成后的展示(要求:项目已打开且背景不为空)。
|
||||||
|
enum class UiMode { Welcome, Editor, Preview };
|
||||||
|
|
||||||
|
void createMenus(); // 菜单和工具栏
|
||||||
|
void createFileMenu(); // 文件菜单
|
||||||
|
void createEditMenu(); // 编辑菜单
|
||||||
|
void createHelpMenu(); // 帮助菜单
|
||||||
|
void createViewMenu(); // 视图菜单
|
||||||
|
void createProjectTreeDock();
|
||||||
|
void createTimelineDock();
|
||||||
|
void refreshProjectTree();
|
||||||
|
void updateUiEnabledState(); // 更新“可用性/勾选/默认显隐”,不要做业务逻辑
|
||||||
|
void applyUiMode(UiMode mode); // 统一控制 welcome/editor 两态的显隐策略
|
||||||
|
UiMode currentUiMode() const; // 根据 workspace 状态推导
|
||||||
|
void syncCanvasViewMenuFromState();
|
||||||
|
|
||||||
|
void showProjectRootContextMenu(const QPoint& globalPos);
|
||||||
|
void showBackgroundContextMenu(const QPoint& globalPos);
|
||||||
|
void rebuildCentralPages();
|
||||||
|
void showWelcomePage();
|
||||||
|
void showEditorPage();
|
||||||
|
void showPreviewPage();
|
||||||
|
void refreshWelcomeRecentList();
|
||||||
|
void openProjectFromPath(const QString& dir);
|
||||||
|
void refreshPreviewPage();
|
||||||
|
void refreshEditorPage();
|
||||||
|
void applyTimelineFromProject();
|
||||||
|
void refreshDopeSheet();
|
||||||
|
void setPreviewRequested(bool preview);
|
||||||
|
|
||||||
|
QStackedWidget* m_centerStack = nullptr;
|
||||||
|
QWidget* m_pageWelcome = nullptr;
|
||||||
|
QTreeWidget* m_welcomeRecentTree = nullptr;
|
||||||
|
QLabel* m_welcomeRecentEmptyLabel = nullptr;
|
||||||
|
QWidget* m_pageEditor = nullptr;
|
||||||
|
QWidget* m_canvasHost = nullptr;
|
||||||
|
QFrame* m_floatingModeDock = nullptr;
|
||||||
|
QFrame* m_floatingToolDock = nullptr;
|
||||||
|
QComboBox* m_modeSelector = nullptr;
|
||||||
|
QStackedWidget* m_propertyStack = nullptr;
|
||||||
|
gui::BackgroundPropertySection* m_bgPropertySection = nullptr;
|
||||||
|
gui::EntityPropertySection* m_entityPropertySection = nullptr;
|
||||||
|
QToolButton* m_btnCreateEntity = nullptr;
|
||||||
|
QToolButton* m_btnToggleDepthOverlay = nullptr;
|
||||||
|
|
||||||
|
EditorCanvas* m_editorCanvas = nullptr;
|
||||||
|
|
||||||
|
QTreeWidget* m_projectTree = nullptr;
|
||||||
|
QDockWidget* m_dockProjectTree = nullptr;
|
||||||
|
QDockWidget* m_dockProperties = nullptr;
|
||||||
|
QDockWidget* m_dockTimeline = nullptr;
|
||||||
|
QTreeWidgetItem* m_itemBackground = nullptr;
|
||||||
|
|
||||||
|
QAction* m_actionUndo = nullptr;
|
||||||
|
QAction* m_actionRedo = nullptr;
|
||||||
|
QAction* m_actionCopy = nullptr;
|
||||||
|
QAction* m_actionPaste = nullptr;
|
||||||
|
QAction* m_actionToggleProjectTree = nullptr;
|
||||||
|
QAction* m_actionToggleProperties = nullptr;
|
||||||
|
QAction* m_actionToggleTimeline = nullptr;
|
||||||
|
QAction* m_actionEnterPreview = nullptr;
|
||||||
|
QAction* m_actionBackToEditor = nullptr;
|
||||||
|
QAction* m_actionCanvasWorldAxes = nullptr;
|
||||||
|
QAction* m_actionCanvasAxisValues = nullptr;
|
||||||
|
QAction* m_actionCanvasGrid = nullptr;
|
||||||
|
QAction* m_actionCanvasCheckerboard = nullptr;
|
||||||
|
QAction* m_actionCanvasDepthOverlay = nullptr;
|
||||||
|
QAction* m_actionCanvasGizmoLabels = nullptr;
|
||||||
|
|
||||||
|
core::ProjectWorkspace m_workspace;
|
||||||
|
RecentProjectHistory m_recentHistory;
|
||||||
|
bool m_previewRequested = false;
|
||||||
|
/// 因右侧栏过窄自动收起;用户通过视图菜单再次打开时清除
|
||||||
|
bool m_rightDocksNarrowHidden = false;
|
||||||
|
|
||||||
|
QPointF m_lastWorldPos;
|
||||||
|
int m_lastWorldZ = -1;
|
||||||
|
bool m_hasSelectedEntity = false;
|
||||||
|
bool m_syncingTreeSelection = false;
|
||||||
|
int m_selectedEntityDepth = 0;
|
||||||
|
QPointF m_selectedEntityOrigin;
|
||||||
|
QString m_selectedEntityId;
|
||||||
|
QString m_selectedEntityDisplayNameCache;
|
||||||
|
QString m_bgAbsCache;
|
||||||
|
QString m_bgSizeTextCache;
|
||||||
|
void updateStatusBarText();
|
||||||
|
void refreshPropertyPanel();
|
||||||
|
void refreshEntityPropertyPanelFast();
|
||||||
|
void syncProjectTreeFromCanvasSelection();
|
||||||
|
|
||||||
|
bool m_timelineScrubbing = false;
|
||||||
|
bool m_entityDragging = false;
|
||||||
|
QTimer* m_propertySyncTimer = nullptr;
|
||||||
|
|
||||||
|
int m_currentFrame = 0;
|
||||||
|
bool m_playing = false;
|
||||||
|
QTimer* m_playTimer = nullptr;
|
||||||
|
TimelineWidget* m_timeline = nullptr;
|
||||||
|
QToolButton* m_btnPlay = nullptr;
|
||||||
|
QLabel* m_frameLabel = nullptr;
|
||||||
|
// 时间轴区间选择(用于逐帧贴图动画)
|
||||||
|
int m_timelineRangeStart = -1;
|
||||||
|
int m_timelineRangeEnd = -1;
|
||||||
|
QCheckBox* m_chkAutoKeyframe = nullptr;
|
||||||
|
// 旧版 DopeSheet 已移除,这里保留占位便于后续扩展区间 UI(如自定义小部件)
|
||||||
|
QTreeWidget* m_dopeTree = nullptr;
|
||||||
|
QPushButton* m_btnDopeDeleteKey = nullptr;
|
||||||
|
};
|
||||||
100
client/gui/main_window/RecentProjectHistory.cpp
Normal file
100
client/gui/main_window/RecentProjectHistory.cpp
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
#include "main_window/RecentProjectHistory.h"
|
||||||
|
|
||||||
|
#include <QDir>
|
||||||
|
#include <QFile>
|
||||||
|
#include <QFileInfo>
|
||||||
|
#include <QJsonArray>
|
||||||
|
#include <QJsonDocument>
|
||||||
|
#include <QJsonValue>
|
||||||
|
#include <QDebug>
|
||||||
|
#include <QStandardPaths>
|
||||||
|
|
||||||
|
QString RecentProjectHistory::cacheFilePath() {
|
||||||
|
const QString base = QStandardPaths::writableLocation(QStandardPaths::GenericCacheLocation);
|
||||||
|
return QDir(base).filePath(QStringLiteral("landscape_tool/recent_projects.cache"));
|
||||||
|
}
|
||||||
|
|
||||||
|
QString RecentProjectHistory::normalizePath(const QString& path) {
|
||||||
|
if (path.isEmpty()) {
|
||||||
|
return QString();
|
||||||
|
}
|
||||||
|
const QFileInfo fi(path);
|
||||||
|
const QString c = fi.canonicalFilePath();
|
||||||
|
return c.isEmpty() ? QDir::cleanPath(fi.absoluteFilePath()) : c;
|
||||||
|
}
|
||||||
|
|
||||||
|
QStringList RecentProjectHistory::dedupeNewestFirst(const QStringList& paths) {
|
||||||
|
QStringList out;
|
||||||
|
out.reserve(paths.size());
|
||||||
|
for (const QString& p : paths) {
|
||||||
|
const QString n = normalizePath(p);
|
||||||
|
if (n.isEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (out.contains(n)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out.append(n);
|
||||||
|
if (out.size() >= kMaxEntries) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
QStringList RecentProjectHistory::load() const {
|
||||||
|
const QString filePath = cacheFilePath();
|
||||||
|
QFile f(filePath);
|
||||||
|
if (!f.open(QIODevice::ReadOnly)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const QJsonDocument doc = QJsonDocument::fromJson(f.readAll());
|
||||||
|
if (!doc.isArray()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
QStringList paths;
|
||||||
|
for (const QJsonValue& v : doc.array()) {
|
||||||
|
if (v.isString()) {
|
||||||
|
paths.append(v.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dedupeNewestFirst(paths);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RecentProjectHistory::save(const QStringList& paths) const {
|
||||||
|
const QString filePath = cacheFilePath();
|
||||||
|
const QFileInfo fi(filePath);
|
||||||
|
QDir().mkpath(fi.absolutePath());
|
||||||
|
|
||||||
|
QJsonArray arr;
|
||||||
|
for (const QString& p : dedupeNewestFirst(paths)) {
|
||||||
|
arr.append(p);
|
||||||
|
}
|
||||||
|
const QJsonDocument doc(arr);
|
||||||
|
|
||||||
|
QFile f(filePath);
|
||||||
|
if (!f.open(QIODevice::WriteOnly | QIODevice::Truncate)) {
|
||||||
|
qWarning() << "RecentProjectHistory: cannot write" << filePath;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
f.write(doc.toJson(QJsonDocument::Compact));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecentProjectHistory::addAndSave(const QString& projectDir) {
|
||||||
|
const QString n = normalizePath(projectDir);
|
||||||
|
if (n.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
QStringList paths = load();
|
||||||
|
paths.removeAll(n);
|
||||||
|
paths.prepend(n);
|
||||||
|
save(paths);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecentProjectHistory::removeAndSave(const QString& projectDir) {
|
||||||
|
const QString n = normalizePath(projectDir);
|
||||||
|
QStringList paths = load();
|
||||||
|
paths.removeAll(n);
|
||||||
|
save(paths);
|
||||||
|
}
|
||||||
21
client/gui/main_window/RecentProjectHistory.h
Normal file
21
client/gui/main_window/RecentProjectHistory.h
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QString>
|
||||||
|
#include <QStringList>
|
||||||
|
|
||||||
|
class RecentProjectHistory {
|
||||||
|
public:
|
||||||
|
static constexpr int kMaxEntries = 15;
|
||||||
|
|
||||||
|
static QString cacheFilePath();
|
||||||
|
|
||||||
|
QStringList load() const;
|
||||||
|
bool save(const QStringList& paths) const;
|
||||||
|
void addAndSave(const QString& projectDir);
|
||||||
|
void removeAndSave(const QString& projectDir);
|
||||||
|
|
||||||
|
static QString normalizePath(const QString& path);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static QStringList dedupeNewestFirst(const QStringList& paths);
|
||||||
|
};
|
||||||
127
client/gui/params/ParamControls.cpp
Normal file
127
client/gui/params/ParamControls.cpp
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#include "params/ParamControls.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <QDoubleSpinBox>
|
||||||
|
#include <QHBoxLayout>
|
||||||
|
#include <QSlider>
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
Float01ParamControl::Float01ParamControl(QWidget* parent)
|
||||||
|
: QWidget(parent) {
|
||||||
|
auto* row = new QHBoxLayout(this);
|
||||||
|
row->setContentsMargins(0, 0, 0, 0);
|
||||||
|
row->setSpacing(8);
|
||||||
|
|
||||||
|
m_slider = new QSlider(Qt::Horizontal, this);
|
||||||
|
m_slider->setRange(0, 1000);
|
||||||
|
m_slider->setSingleStep(1);
|
||||||
|
m_slider->setPageStep(10);
|
||||||
|
row->addWidget(m_slider, 1);
|
||||||
|
|
||||||
|
m_spin = new QDoubleSpinBox(this);
|
||||||
|
m_spin->setRange(0.0, 1.0);
|
||||||
|
m_spin->setDecimals(3);
|
||||||
|
m_spin->setSingleStep(0.01);
|
||||||
|
m_spin->setMinimumWidth(84);
|
||||||
|
row->addWidget(m_spin);
|
||||||
|
|
||||||
|
connect(m_slider, &QSlider::valueChanged, this, [this]() { syncFromSlider(); });
|
||||||
|
connect(m_spin, qOverload<double>(&QDoubleSpinBox::valueChanged), this, [this]() { syncFromSpin(); });
|
||||||
|
|
||||||
|
setValue01(0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Float01ParamControl::setEnabled(bool on) {
|
||||||
|
QWidget::setEnabled(on);
|
||||||
|
if (m_slider) m_slider->setEnabled(on);
|
||||||
|
if (m_spin) m_spin->setEnabled(on);
|
||||||
|
}
|
||||||
|
|
||||||
|
double Float01ParamControl::value01() const {
|
||||||
|
return m_spin ? m_spin->value() : 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Float01ParamControl::setValue01(double v) {
|
||||||
|
const double clamped = std::clamp(v, 0.0, 1.0);
|
||||||
|
m_block = true;
|
||||||
|
if (m_spin) m_spin->setValue(clamped);
|
||||||
|
if (m_slider) m_slider->setValue(static_cast<int>(std::lround(clamped * 1000.0)));
|
||||||
|
m_block = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Float01ParamControl::syncFromSlider() {
|
||||||
|
if (m_block || !m_slider || !m_spin) return;
|
||||||
|
m_block = true;
|
||||||
|
const double v = static_cast<double>(m_slider->value()) / 1000.0;
|
||||||
|
m_spin->setValue(v);
|
||||||
|
m_block = false;
|
||||||
|
emit valueChanged01(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Float01ParamControl::syncFromSpin() {
|
||||||
|
if (m_block || !m_slider || !m_spin) return;
|
||||||
|
m_block = true;
|
||||||
|
const double v = m_spin->value();
|
||||||
|
m_slider->setValue(static_cast<int>(std::lround(v * 1000.0)));
|
||||||
|
m_block = false;
|
||||||
|
emit valueChanged01(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vec2ParamControl::Vec2ParamControl(QWidget* parent)
|
||||||
|
: QWidget(parent) {
|
||||||
|
auto* row = new QHBoxLayout(this);
|
||||||
|
row->setContentsMargins(0, 0, 0, 0);
|
||||||
|
row->setSpacing(8);
|
||||||
|
|
||||||
|
m_x = new QDoubleSpinBox(this);
|
||||||
|
m_x->setRange(-1e9, 1e9);
|
||||||
|
m_x->setDecimals(2);
|
||||||
|
m_x->setSingleStep(1.0);
|
||||||
|
m_x->setMinimumWidth(88);
|
||||||
|
row->addWidget(m_x, 1);
|
||||||
|
|
||||||
|
m_y = new QDoubleSpinBox(this);
|
||||||
|
m_y->setRange(-1e9, 1e9);
|
||||||
|
m_y->setDecimals(2);
|
||||||
|
m_y->setSingleStep(1.0);
|
||||||
|
m_y->setMinimumWidth(88);
|
||||||
|
row->addWidget(m_y, 1);
|
||||||
|
|
||||||
|
connect(m_x, qOverload<double>(&QDoubleSpinBox::valueChanged), this, [this]() { emitIfChanged(); });
|
||||||
|
connect(m_y, qOverload<double>(&QDoubleSpinBox::valueChanged), this, [this]() { emitIfChanged(); });
|
||||||
|
|
||||||
|
setValue(0.0, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Vec2ParamControl::setEnabled(bool on) {
|
||||||
|
QWidget::setEnabled(on);
|
||||||
|
if (m_x) m_x->setEnabled(on);
|
||||||
|
if (m_y) m_y->setEnabled(on);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Vec2ParamControl::setValue(double x, double y) {
|
||||||
|
m_block = true;
|
||||||
|
if (m_x) m_x->setValue(x);
|
||||||
|
if (m_y) m_y->setValue(y);
|
||||||
|
m_lastX = x;
|
||||||
|
m_lastY = y;
|
||||||
|
m_block = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
double Vec2ParamControl::x() const { return m_x ? m_x->value() : 0.0; }
|
||||||
|
double Vec2ParamControl::y() const { return m_y ? m_y->value() : 0.0; }
|
||||||
|
|
||||||
|
void Vec2ParamControl::emitIfChanged() {
|
||||||
|
if (m_block || !m_x || !m_y) return;
|
||||||
|
const double nx = m_x->value();
|
||||||
|
const double ny = m_y->value();
|
||||||
|
if (nx == m_lastX && ny == m_lastY) return;
|
||||||
|
m_lastX = nx;
|
||||||
|
m_lastY = ny;
|
||||||
|
emit valueChanged(nx, ny);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
|
|
||||||
60
client/gui/params/ParamControls.h
Normal file
60
client/gui/params/ParamControls.h
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QWidget>
|
||||||
|
|
||||||
|
class QDoubleSpinBox;
|
||||||
|
class QSlider;
|
||||||
|
class QLabel;
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
// 0..1 浮点参数:Slider + DoubleSpinBox(可复用)
|
||||||
|
class Float01ParamControl final : public QWidget {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit Float01ParamControl(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void setValue01(double v);
|
||||||
|
double value01() const;
|
||||||
|
|
||||||
|
void setEnabled(bool on);
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void valueChanged01(double v);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void syncFromSlider();
|
||||||
|
void syncFromSpin();
|
||||||
|
|
||||||
|
QSlider* m_slider = nullptr;
|
||||||
|
QDoubleSpinBox* m_spin = nullptr;
|
||||||
|
bool m_block = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Vec2 参数:两个 DoubleSpinBox(可复用)
|
||||||
|
class Vec2ParamControl final : public QWidget {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit Vec2ParamControl(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void setValue(double x, double y);
|
||||||
|
double x() const;
|
||||||
|
double y() const;
|
||||||
|
|
||||||
|
void setEnabled(bool on);
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void valueChanged(double x, double y);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void emitIfChanged();
|
||||||
|
|
||||||
|
QDoubleSpinBox* m_x = nullptr;
|
||||||
|
QDoubleSpinBox* m_y = nullptr;
|
||||||
|
bool m_block = false;
|
||||||
|
double m_lastX = 0.0;
|
||||||
|
double m_lastY = 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
|
|
||||||
77
client/gui/props/BackgroundPropertySection.cpp
Normal file
77
client/gui/props/BackgroundPropertySection.cpp
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
#include "props/BackgroundPropertySection.h"
|
||||||
|
|
||||||
|
#include <QCheckBox>
|
||||||
|
#include <QFormLayout>
|
||||||
|
#include <QLabel>
|
||||||
|
#include <QVBoxLayout>
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
BackgroundPropertySection::BackgroundPropertySection(QWidget* parent)
|
||||||
|
: PropertySectionWidget(parent) {
|
||||||
|
auto* lay = new QVBoxLayout(this);
|
||||||
|
lay->setContentsMargins(0, 0, 0, 0);
|
||||||
|
lay->setSpacing(6);
|
||||||
|
|
||||||
|
auto* form = new QFormLayout();
|
||||||
|
form->setContentsMargins(0, 0, 0, 0);
|
||||||
|
form->setSpacing(6);
|
||||||
|
|
||||||
|
m_sizeLabel = new QLabel(QStringLiteral("-"), this);
|
||||||
|
m_sizeLabel->setTextInteractionFlags(Qt::TextSelectableByMouse);
|
||||||
|
form->addRow(QStringLiteral("背景尺寸"), m_sizeLabel);
|
||||||
|
|
||||||
|
m_showBackground = new QCheckBox(QStringLiteral("显示背景"), this);
|
||||||
|
m_showBackground->setToolTip(QStringLiteral("是否绘制背景图"));
|
||||||
|
form->addRow(QString(), m_showBackground);
|
||||||
|
|
||||||
|
m_depthOverlay = new QCheckBox(QStringLiteral("叠加深度"), this);
|
||||||
|
m_depthOverlay->setToolTip(QStringLiteral("在背景上叠加深度伪彩图"));
|
||||||
|
form->addRow(QString(), m_depthOverlay);
|
||||||
|
|
||||||
|
lay->addLayout(form);
|
||||||
|
lay->addStretch(1);
|
||||||
|
|
||||||
|
connect(m_showBackground, &QCheckBox::toggled, this, &BackgroundPropertySection::backgroundVisibleToggled);
|
||||||
|
connect(m_depthOverlay, &QCheckBox::toggled, this, &BackgroundPropertySection::depthOverlayToggled);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackgroundPropertySection::setBackgroundSizeText(const QString& text) {
|
||||||
|
if (m_sizeLabel) {
|
||||||
|
m_sizeLabel->setText(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackgroundPropertySection::syncBackgroundVisible(bool visible, bool controlsEnabled) {
|
||||||
|
if (!m_showBackground) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_showBackground->blockSignals(true);
|
||||||
|
m_showBackground->setChecked(visible);
|
||||||
|
m_showBackground->setEnabled(controlsEnabled);
|
||||||
|
m_showBackground->blockSignals(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackgroundPropertySection::syncDepthOverlayChecked(bool on) {
|
||||||
|
if (!m_depthOverlay) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_depthOverlay->blockSignals(true);
|
||||||
|
m_depthOverlay->setChecked(on);
|
||||||
|
m_depthOverlay->blockSignals(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackgroundPropertySection::setDepthOverlayCheckEnabled(bool on) {
|
||||||
|
if (m_depthOverlay) {
|
||||||
|
m_depthOverlay->setEnabled(on);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackgroundPropertySection::setProjectClosedAppearance() {
|
||||||
|
setBackgroundSizeText(QStringLiteral("-"));
|
||||||
|
syncBackgroundVisible(true, false);
|
||||||
|
syncDepthOverlayChecked(false);
|
||||||
|
setDepthOverlayCheckEnabled(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
32
client/gui/props/BackgroundPropertySection.h
Normal file
32
client/gui/props/BackgroundPropertySection.h
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "props/PropertySectionWidget.h"
|
||||||
|
|
||||||
|
class QLabel;
|
||||||
|
class QCheckBox;
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
// 背景相关属性:尺寸、显隐、深度叠加(可嵌入 QStackedWidget 的一页)
|
||||||
|
class BackgroundPropertySection final : public PropertySectionWidget {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit BackgroundPropertySection(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void setBackgroundSizeText(const QString& text);
|
||||||
|
void syncBackgroundVisible(bool visible, bool controlsEnabled);
|
||||||
|
void syncDepthOverlayChecked(bool on);
|
||||||
|
void setDepthOverlayCheckEnabled(bool on);
|
||||||
|
void setProjectClosedAppearance();
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void backgroundVisibleToggled(bool on);
|
||||||
|
void depthOverlayToggled(bool on);
|
||||||
|
|
||||||
|
private:
|
||||||
|
QLabel* m_sizeLabel = nullptr;
|
||||||
|
QCheckBox* m_showBackground = nullptr;
|
||||||
|
QCheckBox* m_depthOverlay = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
108
client/gui/props/EntityPropertySection.cpp
Normal file
108
client/gui/props/EntityPropertySection.cpp
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#include "props/EntityPropertySection.h"
|
||||||
|
|
||||||
|
#include "params/ParamControls.h"
|
||||||
|
|
||||||
|
#include <QDoubleSpinBox>
|
||||||
|
#include <QFormLayout>
|
||||||
|
#include <QLabel>
|
||||||
|
#include <QLineEdit>
|
||||||
|
#include <QVBoxLayout>
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
EntityPropertySection::EntityPropertySection(QWidget* parent)
|
||||||
|
: PropertySectionWidget(parent) {
|
||||||
|
auto* lay = new QVBoxLayout(this);
|
||||||
|
lay->setContentsMargins(0, 0, 0, 0);
|
||||||
|
lay->setSpacing(6);
|
||||||
|
|
||||||
|
auto* form = new QFormLayout();
|
||||||
|
form->setContentsMargins(0, 0, 0, 0);
|
||||||
|
form->setSpacing(6);
|
||||||
|
|
||||||
|
m_name = new QLineEdit(this);
|
||||||
|
m_name->setPlaceholderText(QStringLiteral("显示名称"));
|
||||||
|
m_name->setToolTip(QStringLiteral("仅显示用;内部 id 不变"));
|
||||||
|
form->addRow(QStringLiteral("名称"), m_name);
|
||||||
|
|
||||||
|
m_depth = new QLabel(QStringLiteral("-"), this);
|
||||||
|
m_distScale = new QLabel(QStringLiteral("-"), this);
|
||||||
|
for (QLabel* lab : {m_depth, m_distScale}) {
|
||||||
|
lab->setTextInteractionFlags(Qt::TextSelectableByMouse);
|
||||||
|
}
|
||||||
|
form->addRow(QStringLiteral("深度"), m_depth);
|
||||||
|
form->addRow(QStringLiteral("距离缩放"), m_distScale);
|
||||||
|
|
||||||
|
m_pivot = new Vec2ParamControl(this);
|
||||||
|
m_pivot->setToolTip(QStringLiteral("枢轴在世界坐标中的位置(限制在轮廓包络内),用于重定位局部原点"));
|
||||||
|
form->addRow(QStringLiteral("中心坐标"), m_pivot);
|
||||||
|
|
||||||
|
m_centroid = new Vec2ParamControl(this);
|
||||||
|
m_centroid->setToolTip(QStringLiteral("实体几何质心的世界坐标;修改将整体平移实体"));
|
||||||
|
form->addRow(QStringLiteral("位置"), m_centroid);
|
||||||
|
|
||||||
|
m_userScale = new QDoubleSpinBox(this);
|
||||||
|
m_userScale->setRange(0.05, 20.0);
|
||||||
|
m_userScale->setDecimals(3);
|
||||||
|
m_userScale->setSingleStep(0.05);
|
||||||
|
m_userScale->setValue(1.0);
|
||||||
|
m_userScale->setToolTip(QStringLiteral("人为整体缩放,与深度距离缩放相乘"));
|
||||||
|
form->addRow(QStringLiteral("整体缩放"), m_userScale);
|
||||||
|
|
||||||
|
lay->addLayout(form);
|
||||||
|
lay->addStretch(1);
|
||||||
|
|
||||||
|
connect(m_name, &QLineEdit::editingFinished, this, [this]() {
|
||||||
|
if (m_name) {
|
||||||
|
emit displayNameCommitted(m_name->text());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
connect(m_pivot, &Vec2ParamControl::valueChanged, this, &EntityPropertySection::pivotEdited);
|
||||||
|
connect(m_centroid, &Vec2ParamControl::valueChanged, this, &EntityPropertySection::centroidEdited);
|
||||||
|
connect(m_userScale, qOverload<double>(&QDoubleSpinBox::valueChanged), this, &EntityPropertySection::userScaleEdited);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EntityPropertySection::clearDisconnected() {
|
||||||
|
setEditingEnabled(false);
|
||||||
|
if (m_name) {
|
||||||
|
m_name->blockSignals(true);
|
||||||
|
m_name->clear();
|
||||||
|
m_name->blockSignals(false);
|
||||||
|
}
|
||||||
|
if (m_depth) m_depth->setText(QStringLiteral("-"));
|
||||||
|
if (m_distScale) m_distScale->setText(QStringLiteral("-"));
|
||||||
|
if (m_pivot) m_pivot->setValue(0.0, 0.0);
|
||||||
|
if (m_centroid) m_centroid->setValue(0.0, 0.0);
|
||||||
|
if (m_userScale) {
|
||||||
|
m_userScale->blockSignals(true);
|
||||||
|
m_userScale->setValue(1.0);
|
||||||
|
m_userScale->blockSignals(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void EntityPropertySection::applyState(const EntityPropertyUiState& s) {
|
||||||
|
setEditingEnabled(true);
|
||||||
|
if (m_name) {
|
||||||
|
m_name->blockSignals(true);
|
||||||
|
m_name->setText(s.displayName);
|
||||||
|
m_name->blockSignals(false);
|
||||||
|
}
|
||||||
|
if (m_depth) m_depth->setText(QString::number(s.depthZ));
|
||||||
|
if (m_distScale) m_distScale->setText(s.distanceScaleText);
|
||||||
|
if (m_pivot) m_pivot->setValue(s.pivot.x(), s.pivot.y());
|
||||||
|
if (m_centroid) m_centroid->setValue(s.centroid.x(), s.centroid.y());
|
||||||
|
if (m_userScale) {
|
||||||
|
m_userScale->blockSignals(true);
|
||||||
|
m_userScale->setValue(s.userScale);
|
||||||
|
m_userScale->blockSignals(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void EntityPropertySection::setEditingEnabled(bool on) {
|
||||||
|
if (m_name) m_name->setEnabled(on);
|
||||||
|
if (m_pivot) m_pivot->setEnabled(on);
|
||||||
|
if (m_centroid) m_centroid->setEnabled(on);
|
||||||
|
if (m_userScale) m_userScale->setEnabled(on);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
52
client/gui/props/EntityPropertySection.h
Normal file
52
client/gui/props/EntityPropertySection.h
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "props/PropertySectionWidget.h"
|
||||||
|
|
||||||
|
#include <QPointF>
|
||||||
|
#include <QString>
|
||||||
|
|
||||||
|
class QLabel;
|
||||||
|
class QLineEdit;
|
||||||
|
class QDoubleSpinBox;
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
class Vec2ParamControl;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
struct EntityPropertyUiState {
|
||||||
|
QString displayName;
|
||||||
|
int depthZ = 0;
|
||||||
|
QString distanceScaleText;
|
||||||
|
QPointF pivot;
|
||||||
|
QPointF centroid;
|
||||||
|
double userScale = 1.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 实体相关属性(可嵌入 QStackedWidget 的一页)
|
||||||
|
class EntityPropertySection final : public PropertySectionWidget {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit EntityPropertySection(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void clearDisconnected();
|
||||||
|
void applyState(const EntityPropertyUiState& s);
|
||||||
|
void setEditingEnabled(bool on);
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void displayNameCommitted(const QString& text);
|
||||||
|
void pivotEdited(double x, double y);
|
||||||
|
void centroidEdited(double x, double y);
|
||||||
|
void userScaleEdited(double value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
QLineEdit* m_name = nullptr;
|
||||||
|
QLabel* m_depth = nullptr;
|
||||||
|
QLabel* m_distScale = nullptr;
|
||||||
|
Vec2ParamControl* m_pivot = nullptr;
|
||||||
|
Vec2ParamControl* m_centroid = nullptr;
|
||||||
|
QDoubleSpinBox* m_userScale = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
13
client/gui/props/PropertySectionWidget.h
Normal file
13
client/gui/props/PropertySectionWidget.h
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QWidget>
|
||||||
|
|
||||||
|
namespace gui {
|
||||||
|
|
||||||
|
// 属性 dock 中可切换的「一节」的公共基类:便于以后扩展更多对象类型(灯光、相机等)
|
||||||
|
class PropertySectionWidget : public QWidget {
|
||||||
|
public:
|
||||||
|
explicit PropertySectionWidget(QWidget* parent = nullptr) : QWidget(parent) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gui
|
||||||
310
client/gui/timeline/TimelineWidget.cpp
Normal file
310
client/gui/timeline/TimelineWidget.cpp
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
#include "timeline/TimelineWidget.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include <QMouseEvent>
|
||||||
|
#include <QPainter>
|
||||||
|
#include <QWheelEvent>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
int clampFrame(int f, int a, int b) {
|
||||||
|
if (a > b) std::swap(a, b);
|
||||||
|
return std::clamp(f, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TimelineWidget::TimelineWidget(QWidget* parent)
|
||||||
|
: QWidget(parent) {
|
||||||
|
setMouseTracking(true);
|
||||||
|
setMinimumHeight(28);
|
||||||
|
setFocusPolicy(Qt::StrongFocus);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::setFrameRange(int start, int end) {
|
||||||
|
if (m_start == start && m_end == end) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_start = start;
|
||||||
|
m_end = end;
|
||||||
|
m_currentFrame = clampFrame(m_currentFrame, m_start, m_end);
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::setCurrentFrame(int frame) {
|
||||||
|
setFrameInternal(frame, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::setSelectionRange(int start, int end) {
|
||||||
|
if (start < 0 || end < 0) {
|
||||||
|
m_selStart = -1;
|
||||||
|
m_selEnd = -1;
|
||||||
|
update();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_selStart = clampFrame(std::min(start, end), m_start, m_end);
|
||||||
|
m_selEnd = clampFrame(std::max(start, end), m_start, m_end);
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::setKeyframeTracks(const core::Project::Entity* e) {
|
||||||
|
m_locFrames.clear();
|
||||||
|
m_scaleFrames.clear();
|
||||||
|
m_imgFrames.clear();
|
||||||
|
if (!e) {
|
||||||
|
update();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_locFrames.reserve(e->locationKeys.size());
|
||||||
|
for (const auto& k : e->locationKeys) m_locFrames.push_back(k.frame);
|
||||||
|
m_scaleFrames.reserve(e->userScaleKeys.size());
|
||||||
|
for (const auto& k : e->userScaleKeys) m_scaleFrames.push_back(k.frame);
|
||||||
|
m_imgFrames.reserve(e->imageFrames.size());
|
||||||
|
for (const auto& k : e->imageFrames) m_imgFrames.push_back(k.frame);
|
||||||
|
|
||||||
|
auto uniqSort = [](QVector<int>& v) {
|
||||||
|
std::sort(v.begin(), v.end());
|
||||||
|
v.erase(std::unique(v.begin(), v.end()), v.end());
|
||||||
|
};
|
||||||
|
uniqSort(m_locFrames);
|
||||||
|
uniqSort(m_scaleFrames);
|
||||||
|
uniqSort(m_imgFrames);
|
||||||
|
// 轨道变了:若当前选中的关键帧不再存在,则清除
|
||||||
|
auto contains = [](const QVector<int>& v, int f) {
|
||||||
|
return std::binary_search(v.begin(), v.end(), f);
|
||||||
|
};
|
||||||
|
bool ok = true;
|
||||||
|
if (m_selKeyKind == KeyKind::Location) ok = contains(m_locFrames, m_selKeyFrame);
|
||||||
|
if (m_selKeyKind == KeyKind::UserScale) ok = contains(m_scaleFrames, m_selKeyFrame);
|
||||||
|
if (m_selKeyKind == KeyKind::Image) ok = contains(m_imgFrames, m_selKeyFrame);
|
||||||
|
if (!ok) {
|
||||||
|
m_selKeyKind = KeyKind::None;
|
||||||
|
m_selKeyFrame = -1;
|
||||||
|
emit keyframeSelectionChanged(m_selKeyKind, m_selKeyFrame);
|
||||||
|
}
|
||||||
|
update();
|
||||||
|
}
|
||||||
|
|
||||||
|
QRect TimelineWidget::trackRect() const {
|
||||||
|
const int pad = 8;
|
||||||
|
const int h = height();
|
||||||
|
return QRect(pad, 0, std::max(1, width() - pad * 2), h);
|
||||||
|
}
|
||||||
|
|
||||||
|
int TimelineWidget::xToFrame(int x) const {
|
||||||
|
const QRect r = trackRect();
|
||||||
|
if (r.width() <= 1) return m_start;
|
||||||
|
const double t = std::clamp((x - r.left()) / double(r.width() - 1), 0.0, 1.0);
|
||||||
|
const int span = std::max(1, m_end - m_start);
|
||||||
|
const int f = m_start + int(std::round(t * span));
|
||||||
|
return clampFrame(f, m_start, m_end);
|
||||||
|
}
|
||||||
|
|
||||||
|
int TimelineWidget::frameToX(int frame) const {
|
||||||
|
const QRect r = trackRect();
|
||||||
|
if (r.width() <= 1) return r.left();
|
||||||
|
const int f = clampFrame(frame, m_start, m_end);
|
||||||
|
const int span = std::max(1, m_end - m_start);
|
||||||
|
const double t = double(f - m_start) / double(span);
|
||||||
|
return r.left() + int(std::round(t * (r.width() - 1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::setFrameInternal(int frame, bool commit) {
|
||||||
|
const int f = clampFrame(frame, m_start, m_end);
|
||||||
|
if (m_currentFrame == f && !commit) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m_currentFrame = f;
|
||||||
|
update();
|
||||||
|
emit frameScrubbed(f);
|
||||||
|
if (commit) {
|
||||||
|
emit frameCommitted(f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::paintEvent(QPaintEvent*) {
|
||||||
|
QPainter p(this);
|
||||||
|
p.setRenderHint(QPainter::Antialiasing, true);
|
||||||
|
|
||||||
|
const QRect r = rect();
|
||||||
|
p.fillRect(r, palette().base());
|
||||||
|
|
||||||
|
const QRect tr = trackRect().adjusted(0, 8, 0, -8);
|
||||||
|
const QColor rail = palette().mid().color();
|
||||||
|
p.setPen(Qt::NoPen);
|
||||||
|
p.setBrush(rail);
|
||||||
|
p.drawRoundedRect(tr, 6, 6);
|
||||||
|
|
||||||
|
// selection range
|
||||||
|
if (m_selStart >= 0 && m_selEnd >= 0) {
|
||||||
|
const int x0 = frameToX(m_selStart);
|
||||||
|
const int x1 = frameToX(m_selEnd);
|
||||||
|
QRect sel(QPoint(std::min(x0, x1), tr.top()), QPoint(std::max(x0, x1), tr.bottom()));
|
||||||
|
sel = sel.adjusted(0, 2, 0, -2);
|
||||||
|
QColor c = palette().highlight().color();
|
||||||
|
c.setAlpha(50);
|
||||||
|
p.setBrush(c);
|
||||||
|
p.drawRoundedRect(sel, 4, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto drawDots = [&](const QVector<int>& frames, const QColor& c, int y) {
|
||||||
|
p.setBrush(c);
|
||||||
|
p.setPen(Qt::NoPen);
|
||||||
|
for (int f : frames) {
|
||||||
|
if (f < m_start || f > m_end) continue;
|
||||||
|
const int x = frameToX(f);
|
||||||
|
const bool sel =
|
||||||
|
(m_selKeyFrame == f)
|
||||||
|
&& ((m_selKeyKind == KeyKind::Image && &frames == &m_imgFrames)
|
||||||
|
|| (m_selKeyKind == KeyKind::Location && &frames == &m_locFrames)
|
||||||
|
|| (m_selKeyKind == KeyKind::UserScale && &frames == &m_scaleFrames));
|
||||||
|
if (sel) {
|
||||||
|
p.setPen(QPen(palette().highlight().color(), 2.0));
|
||||||
|
p.setBrush(c);
|
||||||
|
p.drawEllipse(QPointF(x, y), 4.4, 4.4);
|
||||||
|
p.setPen(Qt::NoPen);
|
||||||
|
} else {
|
||||||
|
p.drawEllipse(QPointF(x, y), 2.6, 2.6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const int yMid = tr.center().y();
|
||||||
|
drawDots(m_imgFrames, QColor(80, 160, 255, 230), yMid - 6);
|
||||||
|
drawDots(m_locFrames, QColor(255, 120, 0, 230), yMid);
|
||||||
|
drawDots(m_scaleFrames, QColor(140, 220, 140, 230), yMid + 6);
|
||||||
|
|
||||||
|
// current frame caret
|
||||||
|
const int cx = frameToX(m_currentFrame);
|
||||||
|
p.setPen(QPen(palette().highlight().color(), 2.0));
|
||||||
|
p.drawLine(QPoint(cx, tr.top() - 6), QPoint(cx, tr.bottom() + 6));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hitDot(const QPoint& pos, int dotX, int dotY, int radiusPx) {
|
||||||
|
const int dx = pos.x() - dotX;
|
||||||
|
const int dy = pos.y() - dotY;
|
||||||
|
return (dx * dx + dy * dy) <= (radiusPx * radiusPx);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int findNearestFrameInTrack(const QVector<int>& frames, int frame) {
|
||||||
|
if (frames.isEmpty()) return -1;
|
||||||
|
const auto it = std::lower_bound(frames.begin(), frames.end(), frame);
|
||||||
|
if (it == frames.begin()) return *it;
|
||||||
|
if (it == frames.end()) return frames.back();
|
||||||
|
const int a = *(it - 1);
|
||||||
|
const int b = *it;
|
||||||
|
return (std::abs(frame - a) <= std::abs(b - frame)) ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void findIntervalAround(const QVector<int>& allFrames, int frame, int& outA, int& outB) {
|
||||||
|
outA = -1;
|
||||||
|
outB = -1;
|
||||||
|
if (allFrames.size() < 2) return;
|
||||||
|
const auto it = std::upper_bound(allFrames.begin(), allFrames.end(), frame);
|
||||||
|
if (it == allFrames.begin() || it == allFrames.end()) return;
|
||||||
|
outA = *(it - 1);
|
||||||
|
outB = *it;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::mousePressEvent(QMouseEvent* e) {
|
||||||
|
if (e->button() == Qt::RightButton) {
|
||||||
|
emit contextMenuRequested(mapToGlobal(e->pos()), xToFrame(e->pos().x()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (e->button() == Qt::LeftButton) {
|
||||||
|
m_pressPos = e->pos();
|
||||||
|
m_moved = false;
|
||||||
|
m_dragging = true;
|
||||||
|
setFrameInternal(xToFrame(e->pos().x()), false);
|
||||||
|
e->accept();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
QWidget::mousePressEvent(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::mouseMoveEvent(QMouseEvent* e) {
|
||||||
|
if (m_dragging) {
|
||||||
|
if ((e->pos() - m_pressPos).manhattanLength() > 3) {
|
||||||
|
m_moved = true;
|
||||||
|
}
|
||||||
|
setFrameInternal(xToFrame(e->pos().x()), false);
|
||||||
|
e->accept();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
QWidget::mouseMoveEvent(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::mouseReleaseEvent(QMouseEvent* e) {
|
||||||
|
if (m_dragging && e->button() == Qt::LeftButton) {
|
||||||
|
m_dragging = false;
|
||||||
|
const int f = xToFrame(e->pos().x());
|
||||||
|
setFrameInternal(f, true);
|
||||||
|
|
||||||
|
// 点击(非拖拽)时做选中:关键帧或区间
|
||||||
|
if (!m_moved) {
|
||||||
|
const QRect tr = trackRect().adjusted(0, 8, 0, -8);
|
||||||
|
const int yMid = tr.center().y();
|
||||||
|
const int yImg = yMid - 6;
|
||||||
|
const int yLoc = yMid;
|
||||||
|
const int ySc = yMid + 6;
|
||||||
|
const int rad = 7;
|
||||||
|
|
||||||
|
auto trySelectKey = [&](KeyKind kind, const QVector<int>& frames, int laneY) -> bool {
|
||||||
|
const int nearest = findNearestFrameInTrack(frames, f);
|
||||||
|
if (nearest < 0) return false;
|
||||||
|
const int x = frameToX(nearest);
|
||||||
|
if (hitDot(e->pos(), x, laneY, rad)) {
|
||||||
|
m_selKeyKind = kind;
|
||||||
|
m_selKeyFrame = nearest;
|
||||||
|
emit keyframeSelectionChanged(m_selKeyKind, m_selKeyFrame);
|
||||||
|
update();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 先尝试命中关键帧(按 lane 优先)
|
||||||
|
if (trySelectKey(KeyKind::Image, m_imgFrames, yImg)
|
||||||
|
|| trySelectKey(KeyKind::Location, m_locFrames, yLoc)
|
||||||
|
|| trySelectKey(KeyKind::UserScale, m_scaleFrames, ySc)) {
|
||||||
|
// 选中关键帧时清掉区间
|
||||||
|
if (m_selStart >= 0 && m_selEnd >= 0) {
|
||||||
|
m_selStart = -1;
|
||||||
|
m_selEnd = -1;
|
||||||
|
emit intervalSelectionChanged(m_selStart, m_selEnd);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 未命中关键帧:尝试选中由关键帧切分出的区间(使用三轨道的并集)
|
||||||
|
QVector<int> all = m_locFrames;
|
||||||
|
all += m_scaleFrames;
|
||||||
|
all += m_imgFrames;
|
||||||
|
std::sort(all.begin(), all.end());
|
||||||
|
all.erase(std::unique(all.begin(), all.end()), all.end());
|
||||||
|
int a = -1, b = -1;
|
||||||
|
findIntervalAround(all, f, a, b);
|
||||||
|
if (a >= 0 && b >= 0) {
|
||||||
|
setSelectionRange(a, b);
|
||||||
|
emit intervalSelectionChanged(m_selStart, m_selEnd);
|
||||||
|
// 选中区间时清掉关键帧选中
|
||||||
|
if (m_selKeyKind != KeyKind::None) {
|
||||||
|
m_selKeyKind = KeyKind::None;
|
||||||
|
m_selKeyFrame = -1;
|
||||||
|
emit keyframeSelectionChanged(m_selKeyKind, m_selKeyFrame);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e->accept();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
QWidget::mouseReleaseEvent(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TimelineWidget::wheelEvent(QWheelEvent* e) {
|
||||||
|
const int delta = (e->angleDelta().y() > 0) ? 1 : -1;
|
||||||
|
setFrameInternal(m_currentFrame + delta, true);
|
||||||
|
e->accept();
|
||||||
|
}
|
||||||
|
|
||||||
69
client/gui/timeline/TimelineWidget.h
Normal file
69
client/gui/timeline/TimelineWidget.h
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "core/domain/Project.h"
|
||||||
|
|
||||||
|
#include <QWidget>
|
||||||
|
|
||||||
|
class TimelineWidget final : public QWidget {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
explicit TimelineWidget(QWidget* parent = nullptr);
|
||||||
|
|
||||||
|
void setFrameRange(int start, int end);
|
||||||
|
void setCurrentFrame(int frame);
|
||||||
|
int currentFrame() const { return m_currentFrame; }
|
||||||
|
|
||||||
|
void setSelectionRange(int start, int end); // -1,-1 清除
|
||||||
|
int selectionStart() const { return m_selStart; }
|
||||||
|
int selectionEnd() const { return m_selEnd; }
|
||||||
|
|
||||||
|
// 只显示“当前选中实体”的关键帧标记
|
||||||
|
void setKeyframeTracks(const core::Project::Entity* entityOrNull);
|
||||||
|
|
||||||
|
enum class KeyKind { None, Location, UserScale, Image };
|
||||||
|
KeyKind selectedKeyKind() const { return m_selKeyKind; }
|
||||||
|
int selectedKeyFrame() const { return m_selKeyFrame; }
|
||||||
|
bool hasSelectedKeyframe() const { return m_selKeyKind != KeyKind::None && m_selKeyFrame >= 0; }
|
||||||
|
|
||||||
|
signals:
|
||||||
|
void frameScrubbed(int frame); // 拖动中实时触发(用于实时预览)
|
||||||
|
void frameCommitted(int frame); // 松手/点击确认(用于较重的刷新)
|
||||||
|
void contextMenuRequested(const QPoint& globalPos, int frame);
|
||||||
|
void keyframeSelectionChanged(KeyKind kind, int frame);
|
||||||
|
void intervalSelectionChanged(int start, int end);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void paintEvent(QPaintEvent*) override;
|
||||||
|
void mousePressEvent(QMouseEvent*) override;
|
||||||
|
void mouseMoveEvent(QMouseEvent*) override;
|
||||||
|
void mouseReleaseEvent(QMouseEvent*) override;
|
||||||
|
void wheelEvent(QWheelEvent*) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int xToFrame(int x) const;
|
||||||
|
int frameToX(int frame) const;
|
||||||
|
QRect trackRect() const;
|
||||||
|
|
||||||
|
void setFrameInternal(int frame, bool commit);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int m_start = 0;
|
||||||
|
int m_end = 600;
|
||||||
|
int m_currentFrame = 0;
|
||||||
|
|
||||||
|
int m_selStart = -1;
|
||||||
|
int m_selEnd = -1;
|
||||||
|
|
||||||
|
bool m_dragging = false;
|
||||||
|
QPoint m_pressPos;
|
||||||
|
bool m_moved = false;
|
||||||
|
|
||||||
|
// snapshot(避免频繁遍历 workspace)
|
||||||
|
QVector<int> m_locFrames;
|
||||||
|
QVector<int> m_scaleFrames;
|
||||||
|
QVector<int> m_imgFrames;
|
||||||
|
|
||||||
|
KeyKind m_selKeyKind = KeyKind::None;
|
||||||
|
int m_selKeyFrame = -1;
|
||||||
|
};
|
||||||
|
|
||||||
347
doc/editor-workflow.md
Normal file
347
doc/editor-workflow.md
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
# 编辑界面功能流程说明
|
||||||
|
|
||||||
|
本文档用于定义编辑界面的核心模块、用户操作流程和各模块之间的数据流,作为实现与联调依据。
|
||||||
|
|
||||||
|
## 1. 功能模块
|
||||||
|
|
||||||
|
系统包含以下功能模块:
|
||||||
|
|
||||||
|
1. 编辑界面
|
||||||
|
2. 深度估计
|
||||||
|
3. 分层与遮罩
|
||||||
|
4. 补全与纹理延展
|
||||||
|
5. 漫游渲染与预览
|
||||||
|
6. 热点与叙事运行
|
||||||
|
|
||||||
|
其中,`热点与叙事运行` 用于内部部分场景生成多帧动态化效果,作为静态场景的局部增强能力。
|
||||||
|
|
||||||
|
## 2. 总体流程(用户视角)
|
||||||
|
|
||||||
|
用户在编辑界面的标准操作流程如下:
|
||||||
|
|
||||||
|
1. 打开程序后选择图片。
|
||||||
|
2. 对图片进行裁剪,确定编辑画布范围。
|
||||||
|
3. 对裁剪后的全部图像执行深度估计。
|
||||||
|
4. 执行深度估计。
|
||||||
|
5. 估计完成后展示深度叠加图,并支持开关切换(显示/隐藏叠加)。
|
||||||
|
6. 进入圈选与点选阶段,系统先基于深度图自动选出深度变化较明显的候选区域,再由人工圈画/点选补充与修正,确定前景目标与背景区域。
|
||||||
|
7. 将结果输入分层程序,按深度关系分离目标物体,未分离区域作为底层背景。
|
||||||
|
8. 对被分层出的区域,支持右键触发补全,输入提示词进行内容补全与纹理延展。
|
||||||
|
9. 对被分层出的区域,支持右键输入提示词进行叙事帧生成。
|
||||||
|
10. 在漫游渲染与预览中检查效果并迭代调整。
|
||||||
|
|
||||||
|
## 3. 模块详细要求
|
||||||
|
|
||||||
|
## 3.1 编辑界面
|
||||||
|
|
||||||
|
- 支持图片加载与初始化展示。
|
||||||
|
- 提供裁剪工具,输出裁剪后的工作区域。
|
||||||
|
- 深度估计默认作用于裁剪后的全部图像。
|
||||||
|
- 深度估计完成后支持深度叠加图开关。
|
||||||
|
- 支持圈选与点选两类交互,用于前景目标确认与细化。
|
||||||
|
- 右键菜单需至少包含:`补全`、`叙事帧生成`。
|
||||||
|
|
||||||
|
### 3.1.1 GUI 总体信息架构(窗口与面板)
|
||||||
|
|
||||||
|
建议采用“单主窗口 + 多 Dock + 少量模态对话框”的结构,保证复杂流程可见且可回退。
|
||||||
|
|
||||||
|
1. 主窗口(`MainWindow`)
|
||||||
|
- 顶部:菜单栏、主工具栏、阶段切换按钮。
|
||||||
|
- 中央:主画布(编辑/预览共用)。
|
||||||
|
- 左侧:工程树 + 图层树(可 tab)。
|
||||||
|
- 右侧:流程控制、参数设置、属性面板(可 tab)。
|
||||||
|
- 底部:任务状态栏与日志条。
|
||||||
|
|
||||||
|
2. 模态/半模态对话框
|
||||||
|
- 裁剪对话框:用于确认裁剪框、比例锁定。
|
||||||
|
- 提示词对话框:用于补全/叙事输入 Prompt。
|
||||||
|
- 任务详情对话框:查看失败原因与重试入口。
|
||||||
|
- 偏好设置对话框:模型端点、默认参数、快捷键。
|
||||||
|
|
||||||
|
3. 关键 Dock(建议)
|
||||||
|
- `流程控制 Dock`:按步骤触发(裁剪 -> 深度 -> 候选区域 -> 分层 -> 补全/叙事)。
|
||||||
|
- `图层 Dock`:背景层/对象层/补全层/叙事层管理。
|
||||||
|
- `属性 Dock`:当前选中区域或图层的参数。
|
||||||
|
- `工程树 Dock`:场景、热点、叙事节点。
|
||||||
|
- `预览 Dock`:漫游参数、播放控制、帧速率信息。
|
||||||
|
|
||||||
|
### 3.1.2 主窗口布局与阶段条(详细)
|
||||||
|
|
||||||
|
主窗口建议增加明确的“阶段条(Step Bar)”,放在工具栏下方:
|
||||||
|
|
||||||
|
1. `S1 导入与裁剪`
|
||||||
|
2. `S2 深度估计(整图)`
|
||||||
|
3. `S3 候选区域与遮罩修正`
|
||||||
|
4. `S4 分层`
|
||||||
|
5. `S5 补全/纹理延展`
|
||||||
|
6. `S6 热点与叙事帧`
|
||||||
|
7. `S7 漫游预览`
|
||||||
|
|
||||||
|
每个阶段显示 4 种状态:`未开始`、`进行中`、`完成`、`失败`。
|
||||||
|
仅允许“当前阶段 + 已完成阶段”可编辑,避免跨阶段误操作。
|
||||||
|
|
||||||
|
### 3.1.3 画布交互设计(核心)
|
||||||
|
|
||||||
|
画布(建议继续以 `CanvasWidget` 为核心)需支持以下显示层与交互模式:
|
||||||
|
|
||||||
|
1. 显示层(可开关)
|
||||||
|
- 原图层
|
||||||
|
- 深度叠加层(支持透明度 0~100)
|
||||||
|
- 自动候选区域层(描边 + 半透明填充)
|
||||||
|
- 人工遮罩层(新增/擦除轨迹)
|
||||||
|
- 分层结果层(前景、中景、背景)
|
||||||
|
- 叙事帧预览层(时间轴预览时启用)
|
||||||
|
|
||||||
|
2. 交互模式
|
||||||
|
- 浏览模式:平移、缩放、查看。
|
||||||
|
- 裁剪模式:拖拽裁剪框、锁定比例、确认/取消。
|
||||||
|
- 圈选模式:套索或画笔新增区域。
|
||||||
|
- 点选模式:点击候选区域进行“收录/排除”。
|
||||||
|
- 擦除模式:从最终遮罩中删除误选区域。
|
||||||
|
- 热点编辑模式:绘制热点框并绑定叙事节点。
|
||||||
|
|
||||||
|
3. 右键菜单(在画布选区上触发)
|
||||||
|
- `补全...`(打开 Prompt 对话框)
|
||||||
|
- `叙事帧生成...`(打开 Prompt 对话框 + 帧数参数)
|
||||||
|
- `加入前景层` / `加入背景层`
|
||||||
|
- `从遮罩中移除`
|
||||||
|
- `复制遮罩到新图层`
|
||||||
|
|
||||||
|
### 3.1.4 流程控制 Dock(按钮与状态)
|
||||||
|
|
||||||
|
流程控制 Dock 建议替代“单个开始处理”按钮,拆为分步触发:
|
||||||
|
|
||||||
|
1. `开始裁剪` / `应用裁剪`
|
||||||
|
2. `执行深度估计(整图)`
|
||||||
|
3. `生成自动候选区域`
|
||||||
|
4. `确认遮罩并分层`
|
||||||
|
5. `对选中区域补全`
|
||||||
|
6. `对选中区域生成叙事帧`
|
||||||
|
7. `进入漫游预览`
|
||||||
|
|
||||||
|
每个按钮旁应有状态灯(灰/蓝/绿/红)与耗时。
|
||||||
|
失败时在同一行提供 `重试` 与 `查看详情`。
|
||||||
|
|
||||||
|
### 3.1.5 图层 Dock(建议新增)
|
||||||
|
|
||||||
|
图层 Dock 至少包含以下节点:
|
||||||
|
|
||||||
|
- `Base`(裁剪后的底图)
|
||||||
|
- `DepthOverlay`(仅可视化)
|
||||||
|
- `AutoCandidates`(自动候选区域)
|
||||||
|
- `MaskFinal`(自动 + 人工修正结果)
|
||||||
|
- `Foreground_i`(可多个)
|
||||||
|
- `Background`
|
||||||
|
- `Inpaint_i`
|
||||||
|
- `NarrativeFrames_i`(关联某一热点/区域)
|
||||||
|
|
||||||
|
每个图层支持:可见性、锁定、重命名、透明度、删除(受阶段约束)。
|
||||||
|
|
||||||
|
### 3.1.5A 工程树对象模型(本次需求)
|
||||||
|
|
||||||
|
工程树中的所有节点统一称为“对象(Object)”,按层级表达空间远近关系与遮挡关系。
|
||||||
|
|
||||||
|
1. 对象类型定义
|
||||||
|
- 背景对象(Background Object):导入图像默认生成的对象,作为底层背景。
|
||||||
|
- 实体(Entity):由分层直接得到的对象,或由补全后得到的对象。
|
||||||
|
- 活动实体(Active Entity):由实体派生,已生成叙事动画帧的对象。
|
||||||
|
|
||||||
|
2. 继承/派生关系
|
||||||
|
- `背景对象` 不派生自其他对象。
|
||||||
|
- `实体` 可由背景对象或其他实体拆分得到。
|
||||||
|
- `活动实体` 必须派生自实体(实体 -> 活动实体)。
|
||||||
|
|
||||||
|
3. 层级与遮挡规则
|
||||||
|
- 工程树越往下,层级越低,表示距离越远。
|
||||||
|
- 距离近的对象会遮挡距离远的对象。
|
||||||
|
- 渲染建议采用“先远后近”的顺序(深层节点先绘制,浅层节点后绘制)。
|
||||||
|
|
||||||
|
4. 动画语义
|
||||||
|
- 实体默认是静态对象。
|
||||||
|
- 活动实体支持循环播放叙事帧(自然动画)。
|
||||||
|
- 活动实体同时支持触发动画(例如点击热点、时间线事件触发)。
|
||||||
|
|
||||||
|
5. 工程树最低能力
|
||||||
|
- 支持新增多个对象(背景对象、实体、活动实体)。
|
||||||
|
- 支持父子关系(实体下继续派生实体或活动实体)。
|
||||||
|
- 支持对象删除(删除父对象时一并删除子对象)。
|
||||||
|
- 支持在属性区查看对象类型与动画能力状态。
|
||||||
|
|
||||||
|
### 3.1.6 属性 Dock(按对象动态切换)
|
||||||
|
|
||||||
|
1. 选中“自动候选区域”时
|
||||||
|
- 显示候选评分、面积、平均深度、边界平滑系数。
|
||||||
|
2. 选中“遮罩”时
|
||||||
|
- 显示画笔大小、羽化、腐蚀/膨胀。
|
||||||
|
3. 选中“补全任务”时
|
||||||
|
- 显示 Prompt、负向提示词、步数、强度、随机种子。
|
||||||
|
4. 选中“叙事帧任务”时
|
||||||
|
- 显示 Prompt、目标帧数、帧率、运动幅度、循环方式。
|
||||||
|
5. 选中“热点”时
|
||||||
|
- 显示标题、描述、绑定叙事序列、触发方式。
|
||||||
|
|
||||||
|
### 3.1.7 任务栏与日志区(建议新增)
|
||||||
|
|
||||||
|
底部状态区建议包含:
|
||||||
|
|
||||||
|
- 当前阶段与子任务状态(例如“分层计算 67%”)。
|
||||||
|
- 最近一次接口调用耗时(深度/分层/补全/叙事)。
|
||||||
|
- 错误摘要(可点击展开完整日志)。
|
||||||
|
- 后台任务队列(允许补全与叙事并行排队)。
|
||||||
|
|
||||||
|
### 3.1.8 结合现有代码的改造清单(`client/gui`)
|
||||||
|
|
||||||
|
以下为现有 GUI 与目标流程不匹配点,以及建议修改方向:
|
||||||
|
|
||||||
|
1. `MainWindow` 当前以“单次开始处理”为主,缺少分阶段流程控制
|
||||||
|
- 现状:`onProcessingStartRequested()` 串行触发,且 `onDepthEstimationFinished()` 中直接进入分层。
|
||||||
|
- 建议:拆为 `onCropConfirmed`、`onDepthRequested`、`onCandidatesRequested`、`onLayeringRequested` 等独立槽函数,并引入阶段状态机。
|
||||||
|
|
||||||
|
2. 缺少裁剪窗口/裁剪模式
|
||||||
|
- 现状:`onOpenImage()` 直接载入整图到画布。
|
||||||
|
- 建议:新增裁剪对话框(如 `CropDialog`)或在 `CanvasWidget` 增加 `Crop` 交互模式,确认后生成工作图并替换当前输入。
|
||||||
|
|
||||||
|
3. `CanvasWidget` 交互模式不足
|
||||||
|
- 现状:仅 `View` 与 `EditHotspot`。
|
||||||
|
- 建议:扩展为 `Crop`、`MaskBrushAdd`、`MaskBrushErase`、`CandidatePick`、`PromptRegionSelect` 等模式,并增加对应信号(例如 `maskEdited`、`candidateToggled`)。
|
||||||
|
|
||||||
|
4. 缺少画布右键菜单(补全/叙事生成)
|
||||||
|
- 现状:右键菜单仅在工程树中提供“删除热点”。
|
||||||
|
- 建议:在画布选中区域上提供右键菜单,接入补全和叙事任务创建。
|
||||||
|
|
||||||
|
5. 缺少图层管理 Dock
|
||||||
|
- 现状:有工程树与属性面板,但没有分层图层树。
|
||||||
|
- 建议:新增 `LayerPanel`(可作为新 Dock),统一管理前景/背景/补全/叙事层可见性与顺序。
|
||||||
|
|
||||||
|
6. 预处理 Dock 语义不完整
|
||||||
|
- 现状:`ProcessingPanel` 偏模型选择与“一键开始”。
|
||||||
|
- 建议:重构为“流程控制 + 模型参数”双区结构;模型参数保留,执行入口拆成分步按钮。
|
||||||
|
|
||||||
|
7. 深度叠加控制不足
|
||||||
|
- 现状:`MainWindow` 中仅有分层预览开关(`m_layerPreviewCheck`)。
|
||||||
|
- 建议:增加“深度叠加开关 + 透明度滑条 + 深度色图选择”,并在 `CanvasWidget` 绘制层中实现。
|
||||||
|
|
||||||
|
8. 任务完成后自动隐藏面板不利于调试
|
||||||
|
- 现状:`onInpaintFinished()` 使用定时器自动隐藏 `m_processingDock`。
|
||||||
|
- 建议:改为默认不自动隐藏,仅在用户手动收起时隐藏。
|
||||||
|
|
||||||
|
9. 热点叙事仅文本节点,未覆盖多帧动态任务
|
||||||
|
- 现状:热点主要绑定 `NarrativeNode` 文本说明。
|
||||||
|
- 建议:为热点增加“叙事帧任务列表”与预览入口,支持每个热点关联多组动态帧。
|
||||||
|
|
||||||
|
10. 缺少撤销/重做与历史快照
|
||||||
|
- 现状:遮罩、分层、补全、叙事结果缺少统一历史机制。
|
||||||
|
- 建议:引入编辑命令栈(Command Pattern)并在工具栏提供撤销/重做。
|
||||||
|
|
||||||
|
### 3.1.9 推荐新增/调整的 GUI 类
|
||||||
|
|
||||||
|
建议在 `client/gui` 增加或调整以下类(命名可按现有风格微调):
|
||||||
|
|
||||||
|
- `CropDialog`:裁剪确认窗口。
|
||||||
|
- `LayerPanel`:图层树与图层操作。
|
||||||
|
- `TaskPanel`:任务队列与执行状态。
|
||||||
|
- `PromptDialog`:补全/叙事 Prompt 与高级参数输入。
|
||||||
|
- `WorkflowController`(可先放 `MainWindow` 内部):统一阶段状态机与按钮可用性。
|
||||||
|
- `CanvasWidget`:扩展多交互模式与右键上下文菜单能力。
|
||||||
|
|
||||||
|
### 3.1.10 GUI 验收标准(编辑界面维度)
|
||||||
|
|
||||||
|
1. 用户在不离开主窗口的情况下可完整完成“裁剪 -> 深度 -> 候选修正 -> 分层 -> 补全/叙事 -> 预览”。
|
||||||
|
2. 每一阶段均有明确可视状态与失败重试入口。
|
||||||
|
3. 画布右键可直接触发补全与叙事,且默认绑定当前选区。
|
||||||
|
4. 图层可见性与锁定状态可稳定控制渲染结果。
|
||||||
|
5. 热点不仅可编辑文本,还可绑定并预览多帧叙事结果。
|
||||||
|
|
||||||
|
## 3.2 深度估计
|
||||||
|
|
||||||
|
- 输入:用户裁剪后的图像(整图)。
|
||||||
|
- 输出:深度图(与编辑画布对齐)以及可视化叠加图。
|
||||||
|
- 要求:结果可回传编辑界面用于后续圈选、点选和分层计算。
|
||||||
|
|
||||||
|
## 3.3 分层与遮罩
|
||||||
|
|
||||||
|
- 输入:深度图、系统自动候选区域、用户圈选/点选结果、人工修正遮罩。
|
||||||
|
- 输出:前景层(一个或多个对象层)与背景底层。
|
||||||
|
- 要求:
|
||||||
|
- 按深度关系优先分离目标物体。
|
||||||
|
- 保留可编辑遮罩,支持后续补全和叙事生成直接复用。
|
||||||
|
|
||||||
|
### 3.3.1 自动候选区域筛选规则(建议)
|
||||||
|
|
||||||
|
为保证“先自动、再人工补充”的效率,建议在深度图上执行以下候选区域提取流程:
|
||||||
|
|
||||||
|
1. 深度预处理
|
||||||
|
- 对深度图进行归一化到 `[0, 1]`。
|
||||||
|
- 使用轻量平滑(如 3x3 中值滤波或双边滤波)抑制噪声,避免过碎片区域。
|
||||||
|
|
||||||
|
2. 深度变化检测
|
||||||
|
- 计算深度梯度幅值(可使用 Sobel)。
|
||||||
|
- 以阈值 `T_grad`(默认 0.12)筛选“深度变化较明显”像素,形成初始候选掩码。
|
||||||
|
|
||||||
|
3. 连通域与面积过滤
|
||||||
|
- 对候选掩码进行连通域分析。
|
||||||
|
- 过滤面积过小区域(默认最小面积为裁剪图总像素的 `0.2%`),减少噪声候选。
|
||||||
|
|
||||||
|
4. 形态学修正
|
||||||
|
- 执行一次闭运算填补小孔洞。
|
||||||
|
- 执行一次开运算移除细小毛刺,提升候选边界可用性。
|
||||||
|
|
||||||
|
5. 候选排序与展示
|
||||||
|
- 按“区域面积 + 平均梯度强度”综合评分排序。
|
||||||
|
- 默认展示 Top-K(建议 K=5)作为可点选候选区域。
|
||||||
|
|
||||||
|
6. 人工补充与修正
|
||||||
|
- 用户可通过圈选新增候选外区域。
|
||||||
|
- 用户可通过点选/擦除剔除误检区域。
|
||||||
|
- 最终输出为“自动候选 + 人工修正”的统一遮罩,进入分层流程。
|
||||||
|
|
||||||
|
参数建议支持在设置面板中可配:`T_grad`、最小面积比例、Top-K、平滑强度。
|
||||||
|
|
||||||
|
## 3.4 补全与纹理延展
|
||||||
|
|
||||||
|
- 触发方式:用户对分层区域右键选择 `补全`。
|
||||||
|
- 输入:目标层或遮罩区域 + 文本提示词(Prompt)。
|
||||||
|
- 输出:补全结果图层,用于修复空洞、扩展纹理或增强局部细节。
|
||||||
|
- 要求:
|
||||||
|
- 补全结果与原图层对齐。
|
||||||
|
- 支持重复执行与结果覆盖/回退。
|
||||||
|
|
||||||
|
## 3.5 漫游渲染与预览
|
||||||
|
|
||||||
|
- 将分层结果组织为可漫游场景进行实时预览。
|
||||||
|
- 支持查看层间深度关系导致的视差效果。
|
||||||
|
- 支持回到编辑阶段继续修改并再次预览(闭环迭代)。
|
||||||
|
|
||||||
|
## 3.6 热点与叙事运行(内部能力)
|
||||||
|
|
||||||
|
- 面向内部选定场景,不作为默认对外能力。
|
||||||
|
- 触发方式:用户右键选择目标区域并输入提示词。
|
||||||
|
- 输出:多帧叙事动态结果(局部区域动态化)。
|
||||||
|
- 目标:在静态场景中对关键区域生成连续帧,强化叙事表达。
|
||||||
|
- 说明:该模块与补全共享部分输入(区域与提示词),但输出为时间序列帧。
|
||||||
|
|
||||||
|
## 4. 关键数据流
|
||||||
|
|
||||||
|
1. 原图 -> 裁剪图
|
||||||
|
2. 裁剪图(整图) -> 深度图 + 深度叠加图
|
||||||
|
3. 深度图 + 圈选/点选 + 人工遮罩 -> 分层结果(前景层、背景层)
|
||||||
|
4. 分层区域 + Prompt -> 补全结果图层
|
||||||
|
5. 分层区域 + Prompt -> 叙事多帧结果
|
||||||
|
6. 分层与生成结果 -> 漫游渲染预览
|
||||||
|
|
||||||
|
## 5. 交互与状态建议
|
||||||
|
|
||||||
|
- 建议提供统一的图层面板,显示前景层、背景层、补全层、叙事层。
|
||||||
|
- 深度叠加图开关应为全局可见状态,便于在不同阶段快速核对。
|
||||||
|
- 右键操作应绑定当前选中区域,避免误触发到非目标层。
|
||||||
|
- 对补全和叙事生成增加任务状态:`待执行`、`执行中`、`完成`、`失败`。
|
||||||
|
- 建议保留操作历史,支持撤销/重做,便于快速迭代。
|
||||||
|
|
||||||
|
## 6. 验收要点(首版)
|
||||||
|
|
||||||
|
- 可以完整跑通“加载 -> 裁剪 -> 深度估计 -> 分层 -> 补全/叙事 -> 预览”的链路。
|
||||||
|
- 深度叠加图可正常显示和关闭。
|
||||||
|
- 圈选与点选结果可正确作用于分层。
|
||||||
|
- 自动候选区域应可稳定生成,并支持人工补充与修正后进入分层。
|
||||||
|
- 右键补全和右键叙事生成均可输入提示词并产出结果。
|
||||||
|
- 热点与叙事运行可在至少一个内部场景产出多帧动态化效果。
|
||||||
0
doc/editor.md
Normal file
0
doc/editor.md
Normal file
90
doc/models.md
Normal file
90
doc/models.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# 后端模型处理
|
||||||
|
|
||||||
|
当前后端主要围绕四类模型提供服务:深度估计、语义分割、图像补全和动画生成。
|
||||||
|
|
||||||
|
前端通过 GET /models 获取模型列表和参数配置,用来动态生成 UI;推理接口分别为:
|
||||||
|
|
||||||
|
POST /depth
|
||||||
|
|
||||||
|
POST /segment
|
||||||
|
|
||||||
|
POST /inpaint
|
||||||
|
|
||||||
|
POST /animate
|
||||||
|
|
||||||
|
## 一、深度估计
|
||||||
|
|
||||||
|
输入一张 RGB 图像,输出每个像素的相对深度,用于后续的分层和视差计算。
|
||||||
|
|
||||||
|
这一部分是整个伪3D效果的基础,深度质量直接决定最终效果上限。
|
||||||
|
|
||||||
|
模型:
|
||||||
|
|
||||||
|
* ZoeDepth:https://github.com/isl-org/ZoeDepth.git
|
||||||
|
* Depth Anything v2:https://github.com/DepthAnything/Depth-Anything-V2.git
|
||||||
|
* MiDaS:https://github.com/isl-org/MiDaS.git
|
||||||
|
* DPT:https://github.com/isl-org/DPT.git
|
||||||
|
|
||||||
|
接口说明
|
||||||
|
|
||||||
|
HTTP:POST /depth
|
||||||
|
|
||||||
|
请求体:DepthRequest
|
||||||
|
|
||||||
|
实现:models_depth.py 中的 run_depth_inference
|
||||||
|
|
||||||
|
|
||||||
|
## 二、语义分割
|
||||||
|
|
||||||
|
对图像进行像素级分区,用于辅助分层(天空 / 山 / 地面 / 建筑等)。
|
||||||
|
|
||||||
|
在伪3D流程中,这一步主要解决一个问题:
|
||||||
|
|
||||||
|
哪里可以拆开,哪里必须保持整体
|
||||||
|
|
||||||
|
模型:
|
||||||
|
* Mask2Former:https://github.com/facebookresearch/Mask2Former.git
|
||||||
|
* SAM:https://github.com/facebookresearch/segment-anything.git
|
||||||
|
|
||||||
|
接口说明
|
||||||
|
|
||||||
|
HTTP:POST /segment
|
||||||
|
|
||||||
|
请求体:SegmentRequest
|
||||||
|
|
||||||
|
实现:models_segmentation.py 中的 run_segmentation_inference
|
||||||
|
|
||||||
|
## 三、图像补全
|
||||||
|
|
||||||
|
在进行视差变换或分层后,图像中会出现“空洞区域”,需要通过生成模型进行补全。
|
||||||
|
|
||||||
|
这一部分主要影响最终画面的“真实感”。
|
||||||
|
|
||||||
|
模型:
|
||||||
|
* SDXL Inpainting:https://github.com/AyushUnleashed/sdxl-inpaint.git
|
||||||
|
* ControlNet:https://github.com/lllyasviel/ControlNet.git
|
||||||
|
|
||||||
|
接口说明
|
||||||
|
|
||||||
|
HTTP:POST /inpaint
|
||||||
|
|
||||||
|
请求体:InpaintRequest
|
||||||
|
|
||||||
|
实现:models_inpaint.py 中的 run_inpaint_inference
|
||||||
|
|
||||||
|
## 四、动画生成
|
||||||
|
|
||||||
|
通过文本提示词生成短动画(GIF),用于从静态描述快速预览动态镜头效果。
|
||||||
|
|
||||||
|
这部分当前接入 AnimateDiff,并通过统一后端接口对外提供调用能力。
|
||||||
|
|
||||||
|
模型:
|
||||||
|
* AnimateDiff:https://github.com/guoyww/animatediff.git
|
||||||
|
|
||||||
|
接口说明
|
||||||
|
|
||||||
|
HTTP:POST /animate
|
||||||
|
|
||||||
|
请求体:AnimateRequest
|
||||||
|
|
||||||
|
实现:`python_server/model/Animation/animation_loader.py` + `python_server/server.py` 中的 `animate`
|
||||||
1
python_server/.gitignore
vendored
Normal file
1
python_server/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
outputs/
|
||||||
1
python_server/__init__.py
Normal file
1
python_server/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
223
python_server/config.py
Normal file
223
python_server/config.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""
|
||||||
|
python_server 的统一配置文件。
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 使用 Python 而不是 YAML,方便在代码中集中列举所有可用模型,供前端读取。
|
||||||
|
- 后端加载模型时,也从这里读取默认值,保证单一信息源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, TypedDict, List
|
||||||
|
|
||||||
|
from model.Depth.zoe_loader import ZoeModelName
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 1. 深度模型枚举(给前端展示用)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DepthModelInfo(TypedDict):
|
||||||
|
id: str # 唯一 ID,如 "zoedepth_n"
|
||||||
|
family: str # 模型家族,如 "ZoeDepth"
|
||||||
|
name: str # 展示名,如 "ZoeD_N (NYU+KITT)"
|
||||||
|
description: str # 简短描述
|
||||||
|
backend: str # 后端类型,如 "zoedepth", "depth_anything_v2", "midas", "dpt"
|
||||||
|
|
||||||
|
|
||||||
|
DEPTH_MODELS: List[DepthModelInfo] = [
|
||||||
|
# ZoeDepth 系列
|
||||||
|
{
|
||||||
|
"id": "zoedepth_n",
|
||||||
|
"family": "ZoeDepth",
|
||||||
|
"name": "ZoeD_N",
|
||||||
|
"description": "ZoeDepth zero-shot 模型,适用室内/室外通用场景。",
|
||||||
|
"backend": "zoedepth",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "zoedepth_k",
|
||||||
|
"family": "ZoeDepth",
|
||||||
|
"name": "ZoeD_K",
|
||||||
|
"description": "ZoeDepth Kitti 专用版本,针对户外驾驶场景优化。",
|
||||||
|
"backend": "zoedepth",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "zoedepth_nk",
|
||||||
|
"family": "ZoeDepth",
|
||||||
|
"name": "ZoeD_NK",
|
||||||
|
"description": "ZoeDepth 双头版本(NYU+KITTI),综合室内/室外场景。",
|
||||||
|
"backend": "zoedepth",
|
||||||
|
},
|
||||||
|
# 预留:Depth Anything v2
|
||||||
|
{
|
||||||
|
"id": "depth_anything_v2_s",
|
||||||
|
"family": "Depth Anything V2",
|
||||||
|
"name": "Depth Anything V2 Small",
|
||||||
|
"description": "轻量级 Depth Anything V2 小模型。",
|
||||||
|
"backend": "depth_anything_v2",
|
||||||
|
},
|
||||||
|
# 预留:MiDaS
|
||||||
|
{
|
||||||
|
"id": "midas_dpt_large",
|
||||||
|
"family": "MiDaS",
|
||||||
|
"name": "MiDaS DPT Large",
|
||||||
|
"description": "MiDaS DPT-Large 高质量深度模型。",
|
||||||
|
"backend": "midas",
|
||||||
|
},
|
||||||
|
# 预留:DPT
|
||||||
|
{
|
||||||
|
"id": "dpt_large",
|
||||||
|
"family": "DPT",
|
||||||
|
"name": "DPT Large",
|
||||||
|
"description": "DPT Large 单目深度估计模型。",
|
||||||
|
"backend": "dpt",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 1.2 补全模型枚举(给前端展示用)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintModelInfo(TypedDict):
|
||||||
|
id: str
|
||||||
|
family: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
backend: str # "sdxl_inpaint" | "controlnet"
|
||||||
|
|
||||||
|
|
||||||
|
INPAINT_MODELS: List[InpaintModelInfo] = [
|
||||||
|
{
|
||||||
|
"id": "sdxl_inpaint",
|
||||||
|
"family": "SDXL",
|
||||||
|
"name": "SDXL Inpainting",
|
||||||
|
"description": "基于 diffusers 的 SDXL 补全管线(需要 prompt + mask)。",
|
||||||
|
"backend": "sdxl_inpaint",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "controlnet",
|
||||||
|
"family": "ControlNet",
|
||||||
|
"name": "ControlNet (placeholder)",
|
||||||
|
"description": "ControlNet 补全/控制生成(当前统一封装暂未实现)。",
|
||||||
|
"backend": "controlnet",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 1.3 动画模型枚举(给前端展示用)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AnimationModelInfo(TypedDict):
|
||||||
|
id: str
|
||||||
|
family: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
backend: str # "animatediff"
|
||||||
|
|
||||||
|
|
||||||
|
ANIMATION_MODELS: List[AnimationModelInfo] = [
|
||||||
|
{
|
||||||
|
"id": "animatediff",
|
||||||
|
"family": "AnimateDiff",
|
||||||
|
"name": "AnimateDiff (Text-to-Video)",
|
||||||
|
"description": "基于 AnimateDiff 的文生动画能力,输出 GIF 动画。",
|
||||||
|
"backend": "animatediff",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 2. 后端默认配置(给服务端用)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DepthConfig:
|
||||||
|
# 深度后端选择:前端不参与选择;只允许在后端配置中切换
|
||||||
|
backend: Literal["zoedepth", "depth_anything_v2", "dpt", "midas"] = "zoedepth"
|
||||||
|
# ZoeDepth 家族默认选择
|
||||||
|
zoe_model: ZoeModelName = "ZoeD_N"
|
||||||
|
# Depth Anything V2 默认 encoder
|
||||||
|
da_v2_encoder: Literal["vits", "vitb", "vitl", "vitg"] = "vitl"
|
||||||
|
# DPT 默认模型类型
|
||||||
|
dpt_model_type: Literal["dpt_large", "dpt_hybrid"] = "dpt_large"
|
||||||
|
# MiDaS 默认模型类型
|
||||||
|
midas_model_type: Literal[
|
||||||
|
"dpt_beit_large_512",
|
||||||
|
"dpt_swin2_large_384",
|
||||||
|
"dpt_swin2_tiny_256",
|
||||||
|
"dpt_levit_224",
|
||||||
|
] = "dpt_beit_large_512"
|
||||||
|
# 统一的默认运行设备
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InpaintConfig:
|
||||||
|
# 统一补全默认后端
|
||||||
|
backend: Literal["sdxl_inpaint", "controlnet"] = "sdxl_inpaint"
|
||||||
|
# SDXL Inpaint 的基础模型(可写 HuggingFace model id 或本地目录)
|
||||||
|
sdxl_base_model: str = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
|
# ControlNet Inpaint 基础模型与 controlnet 权重
|
||||||
|
controlnet_base_model: str = "runwayml/stable-diffusion-inpainting"
|
||||||
|
controlnet_model: str = "lllyasviel/control_v11p_sd15_inpaint"
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AnimationConfig:
|
||||||
|
# 统一动画默认后端
|
||||||
|
backend: Literal["animatediff"] = "animatediff"
|
||||||
|
# AnimateDiff 根目录(相对 python_server/ 或绝对路径)
|
||||||
|
animate_diff_root: str = "model/Animation/AnimateDiff"
|
||||||
|
# 文生图基础模型(HuggingFace model id 或本地目录)
|
||||||
|
pretrained_model_path: str = "runwayml/stable-diffusion-v1-5"
|
||||||
|
# AnimateDiff 推理配置
|
||||||
|
inference_config: str = "configs/inference/inference-v3.yaml"
|
||||||
|
# 运动模块与个性化底模(为空则由脚本按默认处理)
|
||||||
|
motion_module: str = "v3_sd15_mm.ckpt"
|
||||||
|
dreambooth_model: str = "realisticVisionV60B1_v51VAE.safetensors"
|
||||||
|
lora_model: str = ""
|
||||||
|
lora_alpha: float = 0.8
|
||||||
|
# 部分环境 xformers 兼容性差,可手动关闭
|
||||||
|
without_xformers: bool = False
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppConfig:
|
||||||
|
# 使用 default_factory 避免 dataclass 的可变默认值问题
|
||||||
|
depth: DepthConfig = field(default_factory=DepthConfig)
|
||||||
|
inpaint: InpaintConfig = field(default_factory=InpaintConfig)
|
||||||
|
animation: AnimationConfig = field(default_factory=AnimationConfig)
|
||||||
|
|
||||||
|
|
||||||
|
# 后端代码直接 import DEFAULT_CONFIG 即可
|
||||||
|
DEFAULT_CONFIG = AppConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def list_depth_models() -> List[DepthModelInfo]:
|
||||||
|
"""
|
||||||
|
返回所有可用深度模型的元信息,方便前端通过 /models 等接口读取。
|
||||||
|
"""
|
||||||
|
return DEPTH_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def list_inpaint_models() -> List[InpaintModelInfo]:
|
||||||
|
"""
|
||||||
|
返回所有可用补全模型的元信息,方便前端通过 /models 等接口读取。
|
||||||
|
"""
|
||||||
|
return INPAINT_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def list_animation_models() -> List[AnimationModelInfo]:
|
||||||
|
"""
|
||||||
|
返回所有可用动画模型的元信息,方便前端通过 /models 等接口读取。
|
||||||
|
"""
|
||||||
|
return ANIMATION_MODELS
|
||||||
|
|
||||||
153
python_server/config_loader.py
Normal file
153
python_server/config_loader.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
兼容层:从 Python 配置模块中构造 zoe_loader 需要的 ZoeConfig。
|
||||||
|
|
||||||
|
后端其它代码尽量只依赖这里的函数,而不直接依赖 config.py 的具体结构,
|
||||||
|
便于以后扩展。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from model.Depth.zoe_loader import ZoeConfig
|
||||||
|
from model.Depth.depth_anything_v2_loader import DepthAnythingV2Config
|
||||||
|
from model.Depth.dpt_loader import DPTConfig
|
||||||
|
from model.Depth.midas_loader import MiDaSConfig
|
||||||
|
from config import AppConfig, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def load_app_config() -> AppConfig:
|
||||||
|
"""
|
||||||
|
当前直接返回 DEFAULT_CONFIG。
|
||||||
|
如未来需要支持多环境 / 覆盖配置,可以在这里加逻辑。
|
||||||
|
"""
|
||||||
|
return DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def build_zoe_config_from_app(app_cfg: AppConfig | None = None) -> ZoeConfig:
|
||||||
|
"""
|
||||||
|
将 AppConfig.depth 映射为 ZoeConfig,供 zoe_loader 使用。
|
||||||
|
如果未显式传入 app_cfg,则使用全局 DEFAULT_CONFIG。
|
||||||
|
"""
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
|
||||||
|
return ZoeConfig(
|
||||||
|
model=app_cfg.depth.zoe_model,
|
||||||
|
device=app_cfg.depth.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_depth_anything_v2_config_from_app(
|
||||||
|
app_cfg: AppConfig | None = None,
|
||||||
|
) -> DepthAnythingV2Config:
|
||||||
|
"""
|
||||||
|
将 AppConfig.depth 映射为 DepthAnythingV2Config。
|
||||||
|
"""
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
|
||||||
|
return DepthAnythingV2Config(
|
||||||
|
encoder=app_cfg.depth.da_v2_encoder,
|
||||||
|
device=app_cfg.depth.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dpt_config_from_app(app_cfg: AppConfig | None = None) -> DPTConfig:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return DPTConfig(
|
||||||
|
model_type=app_cfg.depth.dpt_model_type,
|
||||||
|
device=app_cfg.depth.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_midas_config_from_app(app_cfg: AppConfig | None = None) -> MiDaSConfig:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return MiDaSConfig(
|
||||||
|
model_type=app_cfg.depth.midas_model_type,
|
||||||
|
device=app_cfg.depth.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_depth_backend_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.depth.backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_inpaint_backend_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.inpaint.backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_sdxl_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.inpaint.sdxl_base_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_controlnet_base_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.inpaint.controlnet_base_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_controlnet_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.inpaint.controlnet_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_animation_backend_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.backend
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_root_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.animate_diff_root
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_pretrained_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.pretrained_model_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_inference_config_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.inference_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_motion_module_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.motion_module
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_dreambooth_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.dreambooth_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_lora_model_from_app(app_cfg: AppConfig | None = None) -> str:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.lora_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_lora_alpha_from_app(app_cfg: AppConfig | None = None) -> float:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.lora_alpha
|
||||||
|
|
||||||
|
|
||||||
|
def get_animatediff_without_xformers_from_app(app_cfg: AppConfig | None = None) -> bool:
|
||||||
|
if app_cfg is None:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
return app_cfg.animation.without_xformers
|
||||||
|
|
||||||
|
|
||||||
1
python_server/model/Animation/AnimateDiff
Submodule
1
python_server/model/Animation/AnimateDiff
Submodule
Submodule python_server/model/Animation/AnimateDiff added at e92bd5671b
12
python_server/model/Animation/__init__.py
Normal file
12
python_server/model/Animation/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from .animation_loader import (
|
||||||
|
AnimationBackend,
|
||||||
|
UnifiedAnimationConfig,
|
||||||
|
build_animation_predictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnimationBackend",
|
||||||
|
"UnifiedAnimationConfig",
|
||||||
|
"build_animation_predictor",
|
||||||
|
]
|
||||||
|
|
||||||
268
python_server/model/Animation/animation_loader.py
Normal file
268
python_server/model/Animation/animation_loader.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unified animation model loading entry.
|
||||||
|
|
||||||
|
Current support:
|
||||||
|
- AnimateDiff (script-based invocation)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from config_loader import (
|
||||||
|
load_app_config,
|
||||||
|
get_animatediff_root_from_app,
|
||||||
|
get_animatediff_pretrained_model_from_app,
|
||||||
|
get_animatediff_inference_config_from_app,
|
||||||
|
get_animatediff_motion_module_from_app,
|
||||||
|
get_animatediff_dreambooth_model_from_app,
|
||||||
|
get_animatediff_lora_model_from_app,
|
||||||
|
get_animatediff_lora_alpha_from_app,
|
||||||
|
get_animatediff_without_xformers_from_app,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimationBackend(str, Enum):
|
||||||
|
ANIMATEDIFF = "animatediff"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnifiedAnimationConfig:
|
||||||
|
backend: AnimationBackend = AnimationBackend.ANIMATEDIFF
|
||||||
|
# Optional overrides. If None, values come from app config.
|
||||||
|
animate_diff_root: str | None = None
|
||||||
|
pretrained_model_path: str | None = None
|
||||||
|
inference_config: str | None = None
|
||||||
|
motion_module: str | None = None
|
||||||
|
dreambooth_model: str | None = None
|
||||||
|
lora_model: str | None = None
|
||||||
|
lora_alpha: float | None = None
|
||||||
|
without_xformers: bool | None = None
|
||||||
|
controlnet_path: str | None = None
|
||||||
|
controlnet_config: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _yaml_string(value: str) -> str:
|
||||||
|
return json.dumps(value, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_root(root_cfg: str) -> Path:
|
||||||
|
root = Path(root_cfg)
|
||||||
|
if not root.is_absolute():
|
||||||
|
root = Path(__file__).resolve().parents[2] / root_cfg
|
||||||
|
return root.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_animatediff_predictor(
|
||||||
|
cfg: UnifiedAnimationConfig,
|
||||||
|
) -> Callable[..., Path]:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
|
||||||
|
root = _resolve_root(cfg.animate_diff_root or get_animatediff_root_from_app(app_cfg))
|
||||||
|
script_path = root / "scripts" / "animate.py"
|
||||||
|
samples_dir = root / "samples"
|
||||||
|
samples_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if not script_path.is_file():
|
||||||
|
raise FileNotFoundError(f"AnimateDiff script not found: {script_path}")
|
||||||
|
|
||||||
|
pretrained_model_path = (
|
||||||
|
cfg.pretrained_model_path or get_animatediff_pretrained_model_from_app(app_cfg)
|
||||||
|
)
|
||||||
|
inference_config = cfg.inference_config or get_animatediff_inference_config_from_app(app_cfg)
|
||||||
|
motion_module = cfg.motion_module or get_animatediff_motion_module_from_app(app_cfg)
|
||||||
|
dreambooth_model = cfg.dreambooth_model
|
||||||
|
if dreambooth_model is None:
|
||||||
|
dreambooth_model = get_animatediff_dreambooth_model_from_app(app_cfg)
|
||||||
|
lora_model = cfg.lora_model
|
||||||
|
if lora_model is None:
|
||||||
|
lora_model = get_animatediff_lora_model_from_app(app_cfg)
|
||||||
|
lora_alpha = cfg.lora_alpha
|
||||||
|
if lora_alpha is None:
|
||||||
|
lora_alpha = get_animatediff_lora_alpha_from_app(app_cfg)
|
||||||
|
without_xformers = cfg.without_xformers
|
||||||
|
if without_xformers is None:
|
||||||
|
without_xformers = get_animatediff_without_xformers_from_app(app_cfg)
|
||||||
|
|
||||||
|
def _predict(
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
num_inference_steps: int = 25,
|
||||||
|
guidance_scale: float = 8.0,
|
||||||
|
width: int = 512,
|
||||||
|
height: int = 512,
|
||||||
|
video_length: int = 16,
|
||||||
|
seed: int = -1,
|
||||||
|
control_image_path: str | None = None,
|
||||||
|
output_format: str = "gif",
|
||||||
|
) -> Path:
|
||||||
|
if output_format not in {"gif", "png_sequence"}:
|
||||||
|
raise ValueError("output_format must be 'gif' or 'png_sequence'")
|
||||||
|
prompt_value = prompt.strip()
|
||||||
|
if not prompt_value:
|
||||||
|
raise ValueError("prompt must not be empty")
|
||||||
|
|
||||||
|
negative_prompt_value = negative_prompt or ""
|
||||||
|
|
||||||
|
motion_module_line = (
|
||||||
|
f' motion_module: {_yaml_string(motion_module)}\n' if motion_module else ""
|
||||||
|
)
|
||||||
|
dreambooth_line = (
|
||||||
|
f' dreambooth_path: {_yaml_string(dreambooth_model)}\n' if dreambooth_model else ""
|
||||||
|
)
|
||||||
|
lora_path_line = f' lora_model_path: {_yaml_string(lora_model)}\n' if lora_model else ""
|
||||||
|
lora_alpha_line = f" lora_alpha: {float(lora_alpha)}\n" if lora_model else ""
|
||||||
|
controlnet_path_value = cfg.controlnet_path or "v3_sd15_sparsectrl_rgb.ckpt"
|
||||||
|
controlnet_config_value = cfg.controlnet_config or "configs/inference/sparsectrl/image_condition.yaml"
|
||||||
|
control_image_line = ""
|
||||||
|
if control_image_path:
|
||||||
|
control_image = Path(control_image_path).expanduser().resolve()
|
||||||
|
if not control_image.is_file():
|
||||||
|
raise FileNotFoundError(f"control_image_path not found: {control_image}")
|
||||||
|
control_image_line = (
|
||||||
|
f' controlnet_path: {_yaml_string(controlnet_path_value)}\n'
|
||||||
|
f' controlnet_config: {_yaml_string(controlnet_config_value)}\n'
|
||||||
|
" controlnet_images:\n"
|
||||||
|
f' - {_yaml_string(str(control_image))}\n'
|
||||||
|
" controlnet_image_indexs:\n"
|
||||||
|
" - 0\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_text = (
|
||||||
|
"- prompt:\n"
|
||||||
|
f" - {_yaml_string(prompt_value)}\n"
|
||||||
|
" n_prompt:\n"
|
||||||
|
f" - {_yaml_string(negative_prompt_value)}\n"
|
||||||
|
f" steps: {int(num_inference_steps)}\n"
|
||||||
|
f" guidance_scale: {float(guidance_scale)}\n"
|
||||||
|
f" W: {int(width)}\n"
|
||||||
|
f" H: {int(height)}\n"
|
||||||
|
f" L: {int(video_length)}\n"
|
||||||
|
" seed:\n"
|
||||||
|
f" - {int(seed)}\n"
|
||||||
|
f"{motion_module_line}{dreambooth_line}{lora_path_line}{lora_alpha_line}{control_image_line}"
|
||||||
|
)
|
||||||
|
|
||||||
|
before_dirs = {p for p in samples_dir.iterdir() if p.is_dir()}
|
||||||
|
cfg_file = tempfile.NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
suffix=".yaml",
|
||||||
|
prefix="animatediff_cfg_",
|
||||||
|
dir=str(root),
|
||||||
|
delete=False,
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
cfg_file_path = Path(cfg_file.name)
|
||||||
|
try:
|
||||||
|
cfg_file.write(config_text)
|
||||||
|
cfg_file.flush()
|
||||||
|
cfg_file.close()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(script_path),
|
||||||
|
"--pretrained-model-path",
|
||||||
|
pretrained_model_path,
|
||||||
|
"--inference-config",
|
||||||
|
inference_config,
|
||||||
|
"--config",
|
||||||
|
str(cfg_file_path),
|
||||||
|
"--L",
|
||||||
|
str(int(video_length)),
|
||||||
|
"--W",
|
||||||
|
str(int(width)),
|
||||||
|
"--H",
|
||||||
|
str(int(height)),
|
||||||
|
]
|
||||||
|
if without_xformers:
|
||||||
|
cmd.append("--without-xformers")
|
||||||
|
if output_format == "png_sequence":
|
||||||
|
cmd.append("--save-png-sequence")
|
||||||
|
|
||||||
|
env = dict(os.environ)
|
||||||
|
existing_pythonpath = env.get("PYTHONPATH", "")
|
||||||
|
root_pythonpath = str(root)
|
||||||
|
env["PYTHONPATH"] = (
|
||||||
|
f"{root_pythonpath}:{existing_pythonpath}" if existing_pythonpath else root_pythonpath
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_once(command: list[str]) -> subprocess.CompletedProcess[str]:
|
||||||
|
return subprocess.run(
|
||||||
|
command,
|
||||||
|
cwd=str(root),
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = _run_once(cmd)
|
||||||
|
except subprocess.CalledProcessError as first_error:
|
||||||
|
stderr_text = first_error.stderr or ""
|
||||||
|
should_retry_without_xformers = (
|
||||||
|
not without_xformers
|
||||||
|
and "--without-xformers" not in cmd
|
||||||
|
and (
|
||||||
|
"memory_efficient_attention" in stderr_text
|
||||||
|
or "AcceleratorError" in stderr_text
|
||||||
|
or "invalid configuration argument" in stderr_text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not should_retry_without_xformers:
|
||||||
|
raise
|
||||||
|
|
||||||
|
retry_cmd = [*cmd, "--without-xformers"]
|
||||||
|
proc = _run_once(retry_cmd)
|
||||||
|
_ = proc
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"AnimateDiff inference failed.\n"
|
||||||
|
f"stdout:\n{e.stdout}\n"
|
||||||
|
f"stderr:\n{e.stderr}"
|
||||||
|
) from e
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
cfg_file_path.unlink(missing_ok=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
after_dirs = [p for p in samples_dir.iterdir() if p.is_dir() and p not in before_dirs]
|
||||||
|
candidates = [p for p in after_dirs if (p / "sample.gif").is_file()]
|
||||||
|
if not candidates:
|
||||||
|
candidates = [p for p in samples_dir.iterdir() if p.is_dir() and (p / "sample.gif").is_file()]
|
||||||
|
if not candidates:
|
||||||
|
raise FileNotFoundError("AnimateDiff finished but sample.gif was not found in samples/")
|
||||||
|
|
||||||
|
latest = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)[0]
|
||||||
|
if output_format == "png_sequence":
|
||||||
|
frames_root = latest / "sample_frames"
|
||||||
|
if not frames_root.is_dir():
|
||||||
|
raise FileNotFoundError("AnimateDiff finished but sample_frames/ was not found in samples/")
|
||||||
|
frame_dirs = sorted([p for p in frames_root.iterdir() if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
if not frame_dirs:
|
||||||
|
raise FileNotFoundError("AnimateDiff finished but no PNG sequence directory was found in sample_frames/")
|
||||||
|
return frame_dirs[0].resolve()
|
||||||
|
return (latest / "sample.gif").resolve()
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def build_animation_predictor(
|
||||||
|
cfg: UnifiedAnimationConfig | None = None,
|
||||||
|
) -> tuple[Callable[..., Path], AnimationBackend]:
|
||||||
|
cfg = cfg or UnifiedAnimationConfig()
|
||||||
|
|
||||||
|
if cfg.backend == AnimationBackend.ANIMATEDIFF:
|
||||||
|
return _make_animatediff_predictor(cfg), AnimationBackend.ANIMATEDIFF
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported animation backend: {cfg.backend}")
|
||||||
|
|
||||||
1
python_server/model/Depth/DPT
Submodule
1
python_server/model/Depth/DPT
Submodule
Submodule python_server/model/Depth/DPT added at cd3fe90bb4
1
python_server/model/Depth/Depth-Anything-V2
Submodule
1
python_server/model/Depth/Depth-Anything-V2
Submodule
Submodule python_server/model/Depth/Depth-Anything-V2 added at e5a2732d3e
1
python_server/model/Depth/MiDaS
Submodule
1
python_server/model/Depth/MiDaS
Submodule
Submodule python_server/model/Depth/MiDaS added at 454597711a
1
python_server/model/Depth/ZoeDepth
Submodule
1
python_server/model/Depth/ZoeDepth
Submodule
Submodule python_server/model/Depth/ZoeDepth added at d87f17b2f5
1
python_server/model/Depth/__init__.py
Normal file
1
python_server/model/Depth/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
147
python_server/model/Depth/depth_anything_v2_loader.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# 确保本地克隆的 Depth Anything V2 仓库在 sys.path 中,
|
||||||
|
# 这样其内部的 `from depth_anything_v2...` 导入才能正常工作。
|
||||||
|
_THIS_DIR = Path(__file__).resolve().parent
|
||||||
|
_DA_REPO_ROOT = _THIS_DIR / "Depth-Anything-V2"
|
||||||
|
if _DA_REPO_ROOT.is_dir():
|
||||||
|
da_path = str(_DA_REPO_ROOT)
|
||||||
|
if da_path not in sys.path:
|
||||||
|
sys.path.insert(0, da_path)
|
||||||
|
|
||||||
|
from depth_anything_v2.dpt import DepthAnythingV2 # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
|
EncoderName = Literal["vits", "vitb", "vitl", "vitg"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DepthAnythingV2Config:
|
||||||
|
"""
|
||||||
|
Depth Anything V2 模型选择配置。
|
||||||
|
|
||||||
|
encoder: "vits" | "vitb" | "vitl" | "vitg"
|
||||||
|
device: "cuda" | "cpu"
|
||||||
|
input_size: 推理时的输入分辨率(短边),参考官方 demo,默认 518。
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder: EncoderName = "vitl"
|
||||||
|
device: str = "cuda"
|
||||||
|
input_size: int = 518
|
||||||
|
|
||||||
|
|
||||||
|
_MODEL_CONFIGS = {
|
||||||
|
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
||||||
|
"vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
|
||||||
|
"vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
|
||||||
|
"vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_DA_V2_WEIGHTS_URLS = {
|
||||||
|
# 官方权重托管在 HuggingFace:
|
||||||
|
# - Small -> vits
|
||||||
|
# - Base -> vitb
|
||||||
|
# - Large -> vitl
|
||||||
|
# - Giant -> vitg
|
||||||
|
# 如需替换为国内镜像,可直接修改这些 URL。
|
||||||
|
"vits": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth",
|
||||||
|
"vitb": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth",
|
||||||
|
"vitl": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth",
|
||||||
|
"vitg": "https://huggingface.co/depth-anything/Depth-Anything-V2-Giant/resolve/main/depth_anything_v2_vitg.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download_if_missing(encoder: str, ckpt_path: Path) -> None:
|
||||||
|
if ckpt_path.is_file():
|
||||||
|
return
|
||||||
|
|
||||||
|
url = _DA_V2_WEIGHTS_URLS.get(encoder)
|
||||||
|
if not url:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"找不到权重文件: {ckpt_path}\n"
|
||||||
|
f"且当前未为 encoder='{encoder}' 配置自动下载 URL,请手动下载到该路径。"
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"自动下载 Depth Anything V2 权重 ({encoder}):\n {url}\n -> {ckpt_path}")
|
||||||
|
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
total = int(resp.headers.get("content-length", "0") or "0")
|
||||||
|
downloaded = 0
|
||||||
|
chunk_size = 1024 * 1024
|
||||||
|
|
||||||
|
with ckpt_path.open("wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
if total > 0:
|
||||||
|
done = int(50 * downloaded / total)
|
||||||
|
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||||
|
|
||||||
|
print("\n权重下载完成。")
|
||||||
|
|
||||||
|
|
||||||
|
def load_depth_anything_v2_from_config(
|
||||||
|
cfg: DepthAnythingV2Config,
|
||||||
|
) -> Tuple[DepthAnythingV2, DepthAnythingV2Config]:
|
||||||
|
"""
|
||||||
|
根据配置加载 Depth Anything V2 模型与对应配置。
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- 权重文件路径遵循官方命名约定:
|
||||||
|
checkpoints/depth_anything_v2_{encoder}.pth
|
||||||
|
例如:depth_anything_v2_vitl.pth
|
||||||
|
- 请确保上述权重文件已下载到
|
||||||
|
python_server/model/Depth/Depth-Anything-V2/checkpoints 下。
|
||||||
|
"""
|
||||||
|
if cfg.encoder not in _MODEL_CONFIGS:
|
||||||
|
raise ValueError(f"不支持的 encoder: {cfg.encoder}")
|
||||||
|
|
||||||
|
ckpt_path = _DA_REPO_ROOT / "checkpoints" / f"depth_anything_v2_{cfg.encoder}.pth"
|
||||||
|
_download_if_missing(cfg.encoder, ckpt_path)
|
||||||
|
|
||||||
|
model = DepthAnythingV2(**_MODEL_CONFIGS[cfg.encoder])
|
||||||
|
|
||||||
|
state_dict = torch.load(str(ckpt_path), map_location="cpu")
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
if cfg.device.startswith("cuda") and torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
model = model.to(device).eval()
|
||||||
|
cfg = DepthAnythingV2Config(
|
||||||
|
encoder=cfg.encoder,
|
||||||
|
device=device,
|
||||||
|
input_size=cfg.input_size,
|
||||||
|
)
|
||||||
|
return model, cfg
|
||||||
|
|
||||||
|
|
||||||
|
def infer_depth_anything_v2(
|
||||||
|
model: DepthAnythingV2,
|
||||||
|
image_bgr: np.ndarray,
|
||||||
|
input_size: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||||
|
image_bgr: OpenCV 读取的 BGR 图像 (H, W, 3), uint8
|
||||||
|
"""
|
||||||
|
depth = model.infer_image(image_bgr, input_size)
|
||||||
|
depth = np.asarray(depth, dtype="float32")
|
||||||
|
return depth
|
||||||
|
|
||||||
148
python_server/model/Depth/depth_loader.py
Normal file
148
python_server/model/Depth/depth_loader.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""
|
||||||
|
统一的深度模型加载入口。
|
||||||
|
|
||||||
|
当前支持:
|
||||||
|
- ZoeDepth(三种:ZoeD_N / ZoeD_K / ZoeD_NK)
|
||||||
|
- Depth Anything V2(四种 encoder:vits / vitb / vitl / vitg)
|
||||||
|
|
||||||
|
未来如果要加 MiDaS / DPT,只需要在这里再接一层即可。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from config_loader import (
|
||||||
|
load_app_config,
|
||||||
|
build_zoe_config_from_app,
|
||||||
|
build_depth_anything_v2_config_from_app,
|
||||||
|
build_dpt_config_from_app,
|
||||||
|
build_midas_config_from_app,
|
||||||
|
)
|
||||||
|
from .zoe_loader import load_zoe_from_config
|
||||||
|
from .depth_anything_v2_loader import (
|
||||||
|
load_depth_anything_v2_from_config,
|
||||||
|
infer_depth_anything_v2,
|
||||||
|
)
|
||||||
|
from .dpt_loader import load_dpt_from_config, infer_dpt
|
||||||
|
from .midas_loader import load_midas_from_config, infer_midas
|
||||||
|
|
||||||
|
|
||||||
|
class DepthBackend(str, Enum):
|
||||||
|
"""统一的深度模型后端类型。"""
|
||||||
|
|
||||||
|
ZOEDEPTH = "zoedepth"
|
||||||
|
DEPTH_ANYTHING_V2 = "depth_anything_v2"
|
||||||
|
DPT = "dpt"
|
||||||
|
MIDAS = "midas"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnifiedDepthConfig:
|
||||||
|
"""
|
||||||
|
统一深度配置。
|
||||||
|
|
||||||
|
backend: 使用哪个后端
|
||||||
|
device: 强制设备(可选),不填则使用 config.py 中的设置
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend: DepthBackend = DepthBackend.ZOEDEPTH
|
||||||
|
device: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _make_zoe_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
zoe_cfg = build_zoe_config_from_app(app_cfg)
|
||||||
|
if device_override is not None:
|
||||||
|
zoe_cfg.device = device_override
|
||||||
|
|
||||||
|
model, _ = load_zoe_from_config(zoe_cfg)
|
||||||
|
|
||||||
|
def _predict(img: Image.Image) -> np.ndarray:
|
||||||
|
depth = model.infer_pil(img.convert("RGB"), output_type="numpy")
|
||||||
|
return np.asarray(depth, dtype="float32").squeeze()
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def _make_da_v2_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
da_cfg = build_depth_anything_v2_config_from_app(app_cfg)
|
||||||
|
if device_override is not None:
|
||||||
|
da_cfg.device = device_override
|
||||||
|
|
||||||
|
model, da_cfg = load_depth_anything_v2_from_config(da_cfg)
|
||||||
|
|
||||||
|
def _predict(img: Image.Image) -> np.ndarray:
|
||||||
|
# Depth Anything V2 的 infer_image 接收 BGR uint8
|
||||||
|
rgb = np.array(img.convert("RGB"), dtype=np.uint8)
|
||||||
|
bgr = rgb[:, :, ::-1]
|
||||||
|
depth = infer_depth_anything_v2(model, bgr, da_cfg.input_size)
|
||||||
|
return depth.astype("float32").squeeze()
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dpt_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
dpt_cfg = build_dpt_config_from_app(app_cfg)
|
||||||
|
if device_override is not None:
|
||||||
|
dpt_cfg.device = device_override
|
||||||
|
|
||||||
|
model, dpt_cfg, transform = load_dpt_from_config(dpt_cfg)
|
||||||
|
|
||||||
|
def _predict(img: Image.Image) -> np.ndarray:
|
||||||
|
bgr = cv2.cvtColor(np.array(img.convert("RGB"), dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||||
|
depth = infer_dpt(model, transform, bgr, dpt_cfg.device)
|
||||||
|
return depth.astype("float32").squeeze()
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def _make_midas_predictor(device_override: str | None = None) -> Callable[[Image.Image], np.ndarray]:
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
midas_cfg = build_midas_config_from_app(app_cfg)
|
||||||
|
if device_override is not None:
|
||||||
|
midas_cfg.device = device_override
|
||||||
|
|
||||||
|
model, midas_cfg, transform, net_w, net_h = load_midas_from_config(midas_cfg)
|
||||||
|
|
||||||
|
def _predict(img: Image.Image) -> np.ndarray:
|
||||||
|
rgb = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
||||||
|
depth = infer_midas(model, transform, rgb, net_w, net_h, midas_cfg.device)
|
||||||
|
return depth.astype("float32").squeeze()
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def build_depth_predictor(
|
||||||
|
cfg: UnifiedDepthConfig | None = None,
|
||||||
|
) -> tuple[Callable[[Image.Image], np.ndarray], DepthBackend]:
|
||||||
|
"""
|
||||||
|
统一构建深度预测函数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- predictor(image: PIL.Image) -> np.ndarray[H, W], float32
|
||||||
|
- 实际使用的 backend 类型
|
||||||
|
"""
|
||||||
|
cfg = cfg or UnifiedDepthConfig()
|
||||||
|
|
||||||
|
if cfg.backend == DepthBackend.ZOEDEPTH:
|
||||||
|
return _make_zoe_predictor(cfg.device), DepthBackend.ZOEDEPTH
|
||||||
|
|
||||||
|
if cfg.backend == DepthBackend.DEPTH_ANYTHING_V2:
|
||||||
|
return _make_da_v2_predictor(cfg.device), DepthBackend.DEPTH_ANYTHING_V2
|
||||||
|
|
||||||
|
if cfg.backend == DepthBackend.DPT:
|
||||||
|
return _make_dpt_predictor(cfg.device), DepthBackend.DPT
|
||||||
|
|
||||||
|
if cfg.backend == DepthBackend.MIDAS:
|
||||||
|
return _make_midas_predictor(cfg.device), DepthBackend.MIDAS
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的深度后端: {cfg.backend}")
|
||||||
|
|
||||||
156
python_server/model/Depth/dpt_loader.py
Normal file
156
python_server/model/Depth/dpt_loader.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
|
||||||
|
_THIS_DIR = Path(__file__).resolve().parent
|
||||||
|
_DPT_REPO_ROOT = _THIS_DIR / "DPT"
|
||||||
|
if _DPT_REPO_ROOT.is_dir():
|
||||||
|
dpt_path = str(_DPT_REPO_ROOT)
|
||||||
|
if dpt_path not in sys.path:
|
||||||
|
sys.path.insert(0, dpt_path)
|
||||||
|
|
||||||
|
from dpt.models import DPTDepthModel # type: ignore[import]
|
||||||
|
from dpt.transforms import Resize, NormalizeImage, PrepareForNet # type: ignore[import]
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
DPTModelType = Literal["dpt_large", "dpt_hybrid"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DPTConfig:
|
||||||
|
model_type: DPTModelType = "dpt_large"
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
_DPT_WEIGHTS_URLS = {
|
||||||
|
# 官方 DPT 模型权重托管在:
|
||||||
|
# https://github.com/isl-org/DPT#models
|
||||||
|
"dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||||
|
"dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||||
|
if ckpt_path.is_file():
|
||||||
|
return
|
||||||
|
|
||||||
|
url = _DPT_WEIGHTS_URLS.get(model_type)
|
||||||
|
if not url:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"找不到 DPT 权重文件: {ckpt_path}\n"
|
||||||
|
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"自动下载 DPT 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||||
|
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
total = int(resp.headers.get("content-length", "0") or "0")
|
||||||
|
downloaded = 0
|
||||||
|
chunk_size = 1024 * 1024
|
||||||
|
|
||||||
|
with ckpt_path.open("wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
if total > 0:
|
||||||
|
done = int(50 * downloaded / total)
|
||||||
|
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||||
|
print("\nDPT 权重下载完成。")
|
||||||
|
|
||||||
|
|
||||||
|
def load_dpt_from_config(cfg: DPTConfig) -> Tuple[DPTDepthModel, DPTConfig, Compose]:
|
||||||
|
"""
|
||||||
|
加载 DPT 模型与对应的预处理 transform。
|
||||||
|
"""
|
||||||
|
ckpt_name = {
|
||||||
|
"dpt_large": "dpt_large-midas-2f21e586.pt",
|
||||||
|
"dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
||||||
|
}[cfg.model_type]
|
||||||
|
|
||||||
|
ckpt_path = _DPT_REPO_ROOT / "weights" / ckpt_name
|
||||||
|
_download_if_missing(cfg.model_type, ckpt_path)
|
||||||
|
|
||||||
|
if cfg.model_type == "dpt_large":
|
||||||
|
net_w = net_h = 384
|
||||||
|
model = DPTDepthModel(
|
||||||
|
path=str(ckpt_path),
|
||||||
|
backbone="vitl16_384",
|
||||||
|
non_negative=True,
|
||||||
|
enable_attention_hooks=False,
|
||||||
|
)
|
||||||
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
else:
|
||||||
|
net_w = net_h = 384
|
||||||
|
model = DPTDepthModel(
|
||||||
|
path=str(ckpt_path),
|
||||||
|
backbone="vitb_rn50_384",
|
||||||
|
non_negative=True,
|
||||||
|
enable_attention_hooks=False,
|
||||||
|
)
|
||||||
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||||
|
model.to(device).eval()
|
||||||
|
|
||||||
|
transform = Compose(
|
||||||
|
[
|
||||||
|
Resize(
|
||||||
|
net_w,
|
||||||
|
net_h,
|
||||||
|
resize_target=None,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
ensure_multiple_of=32,
|
||||||
|
resize_method="minimal",
|
||||||
|
image_interpolation_method=cv2.INTER_CUBIC,
|
||||||
|
),
|
||||||
|
normalization,
|
||||||
|
PrepareForNet(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DPTConfig(model_type=cfg.model_type, device=device)
|
||||||
|
return model, cfg, transform
|
||||||
|
|
||||||
|
|
||||||
|
def infer_dpt(
|
||||||
|
model: DPTDepthModel,
|
||||||
|
transform: Compose,
|
||||||
|
image_bgr: np.ndarray,
|
||||||
|
device: str,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
对单张 BGR 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||||
|
"""
|
||||||
|
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
sample = transform({"image": img})["image"]
|
||||||
|
|
||||||
|
input_batch = torch.from_numpy(sample).to(device).unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
prediction = model(input_batch)
|
||||||
|
prediction = (
|
||||||
|
torch.nn.functional.interpolate(
|
||||||
|
prediction.unsqueeze(1),
|
||||||
|
size=img.shape[:2],
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
.squeeze()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
return prediction.astype("float32")
|
||||||
|
|
||||||
127
python_server/model/Depth/midas_loader.py
Normal file
127
python_server/model/Depth/midas_loader.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
|
||||||
|
_THIS_DIR = Path(__file__).resolve().parent
|
||||||
|
_MIDAS_REPO_ROOT = _THIS_DIR / "MiDaS"
|
||||||
|
if _MIDAS_REPO_ROOT.is_dir():
|
||||||
|
midas_path = str(_MIDAS_REPO_ROOT)
|
||||||
|
if midas_path not in sys.path:
|
||||||
|
sys.path.insert(0, midas_path)
|
||||||
|
|
||||||
|
from midas.model_loader import load_model, default_models # type: ignore[import]
|
||||||
|
import utils # from MiDaS repo
|
||||||
|
|
||||||
|
|
||||||
|
MiDaSModelType = Literal[
|
||||||
|
"dpt_beit_large_512",
|
||||||
|
"dpt_swin2_large_384",
|
||||||
|
"dpt_swin2_tiny_256",
|
||||||
|
"dpt_levit_224",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MiDaSConfig:
|
||||||
|
model_type: MiDaSModelType = "dpt_beit_large_512"
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
_MIDAS_WEIGHTS_URLS = {
|
||||||
|
# 官方权重参见 MiDaS 仓库 README
|
||||||
|
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
|
||||||
|
"dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
|
||||||
|
"dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
|
||||||
|
"dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download_if_missing(model_type: str, ckpt_path: Path) -> None:
|
||||||
|
if ckpt_path.is_file():
|
||||||
|
return
|
||||||
|
|
||||||
|
url = _MIDAS_WEIGHTS_URLS.get(model_type)
|
||||||
|
if not url:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"找不到 MiDaS 权重文件: {ckpt_path}\n"
|
||||||
|
f"且当前未为 model_type='{model_type}' 配置自动下载 URL,请手动下载到该路径。"
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"自动下载 MiDaS 权重 ({model_type}):\n {url}\n -> {ckpt_path}")
|
||||||
|
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
total = int(resp.headers.get("content-length", "0") or "0")
|
||||||
|
downloaded = 0
|
||||||
|
chunk_size = 1024 * 1024
|
||||||
|
|
||||||
|
with ckpt_path.open("wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
if total > 0:
|
||||||
|
done = int(50 * downloaded / total)
|
||||||
|
print("\r[{}{}] {:.1f}%".format("#" * done, "." * (50 - done), downloaded * 100 / total), end="")
|
||||||
|
print("\nMiDaS 权重下载完成。")
|
||||||
|
|
||||||
|
|
||||||
|
def load_midas_from_config(
|
||||||
|
cfg: MiDaSConfig,
|
||||||
|
) -> Tuple[torch.nn.Module, MiDaSConfig, callable, int, int]:
|
||||||
|
"""
|
||||||
|
加载 MiDaS 模型与对应 transform。
|
||||||
|
返回: model, cfg, transform, net_w, net_h
|
||||||
|
"""
|
||||||
|
# default_models 中给了默认权重路径名
|
||||||
|
model_info = default_models[cfg.model_type]
|
||||||
|
ckpt_path = _MIDAS_REPO_ROOT / model_info.path
|
||||||
|
_download_if_missing(cfg.model_type, ckpt_path)
|
||||||
|
|
||||||
|
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||||
|
model, transform, net_w, net_h = load_model(
|
||||||
|
device=torch.device(device),
|
||||||
|
model_path=str(ckpt_path),
|
||||||
|
model_type=cfg.model_type,
|
||||||
|
optimize=False,
|
||||||
|
height=None,
|
||||||
|
square=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = MiDaSConfig(model_type=cfg.model_type, device=device)
|
||||||
|
return model, cfg, transform, net_w, net_h
|
||||||
|
|
||||||
|
|
||||||
|
def infer_midas(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
transform: callable,
|
||||||
|
image_rgb: np.ndarray,
|
||||||
|
net_w: int,
|
||||||
|
net_h: int,
|
||||||
|
device: str,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
对单张 RGB 图像做深度推理,返回 float32 深度图(未归一化)。
|
||||||
|
"""
|
||||||
|
image = transform({"image": image_rgb})["image"]
|
||||||
|
prediction = utils.process(
|
||||||
|
torch.device(device),
|
||||||
|
model,
|
||||||
|
model_type="dpt", # 这里具体字符串对 utils.process 的逻辑影响不大,只要不包含 "openvino"
|
||||||
|
image=image,
|
||||||
|
input_size=(net_w, net_h),
|
||||||
|
target_size=image_rgb.shape[1::-1],
|
||||||
|
optimize=False,
|
||||||
|
use_camera=False,
|
||||||
|
)
|
||||||
|
return np.asarray(prediction, dtype="float32").squeeze()
|
||||||
|
|
||||||
74
python_server/model/Depth/zoe_loader.py
Normal file
74
python_server/model/Depth/zoe_loader.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 确保本地克隆的 ZoeDepth 仓库在 sys.path 中,
|
||||||
|
# 这样其内部的 `import zoedepth...` 才能正常工作。
|
||||||
|
_THIS_DIR = Path(__file__).resolve().parent
|
||||||
|
_ZOE_REPO_ROOT = _THIS_DIR / "ZoeDepth"
|
||||||
|
if _ZOE_REPO_ROOT.is_dir():
|
||||||
|
zoe_path = str(_ZOE_REPO_ROOT)
|
||||||
|
if zoe_path not in sys.path:
|
||||||
|
sys.path.insert(0, zoe_path)
|
||||||
|
|
||||||
|
from zoedepth.models.builder import build_model
|
||||||
|
from zoedepth.utils.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
ZoeModelName = Literal["ZoeD_N", "ZoeD_K", "ZoeD_NK"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ZoeConfig:
|
||||||
|
"""
|
||||||
|
ZoeDepth 模型选择配置。
|
||||||
|
|
||||||
|
model: "ZoeD_N" | "ZoeD_K" | "ZoeD_NK"
|
||||||
|
device: "cuda" | "cpu"
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: ZoeModelName = "ZoeD_N"
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def load_zoe_from_name(name: ZoeModelName, device: str = "cuda"):
|
||||||
|
"""
|
||||||
|
手动加载 ZoeDepth 三种模型之一:
|
||||||
|
- "ZoeD_N"
|
||||||
|
- "ZoeD_K"
|
||||||
|
- "ZoeD_NK"
|
||||||
|
"""
|
||||||
|
if name == "ZoeD_N":
|
||||||
|
conf = get_config("zoedepth", "infer")
|
||||||
|
elif name == "ZoeD_K":
|
||||||
|
conf = get_config("zoedepth", "infer", config_version="kitti")
|
||||||
|
elif name == "ZoeD_NK":
|
||||||
|
conf = get_config("zoedepth_nk", "infer")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 ZoeDepth 模型名称: {name}")
|
||||||
|
|
||||||
|
model = build_model(conf)
|
||||||
|
|
||||||
|
if device.startswith("cuda") and torch.cuda.is_available():
|
||||||
|
model = model.to("cuda")
|
||||||
|
else:
|
||||||
|
model = model.to("cpu")
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return model, conf
|
||||||
|
|
||||||
|
|
||||||
|
def load_zoe_from_config(config: ZoeConfig):
|
||||||
|
"""
|
||||||
|
根据 ZoeConfig 加载模型。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
cfg = ZoeConfig(model="ZoeD_NK", device="cuda")
|
||||||
|
model, conf = load_zoe_from_config(cfg)
|
||||||
|
"""
|
||||||
|
return load_zoe_from_name(config.model, config.device)
|
||||||
|
|
||||||
1
python_server/model/Inpaint/ControlNet
Submodule
1
python_server/model/Inpaint/ControlNet
Submodule
Submodule python_server/model/Inpaint/ControlNet added at ed85cd1e25
1
python_server/model/Inpaint/__init__.py
Normal file
1
python_server/model/Inpaint/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
413
python_server/model/Inpaint/inpaint_loader.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""
|
||||||
|
统一的补全(Inpaint)模型加载入口。
|
||||||
|
|
||||||
|
当前支持:
|
||||||
|
- SDXL Inpaint(diffusers AutoPipelineForInpainting)
|
||||||
|
- ControlNet(占位,暂未统一封装)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from config_loader import (
|
||||||
|
load_app_config,
|
||||||
|
get_sdxl_base_model_from_app,
|
||||||
|
get_controlnet_base_model_from_app,
|
||||||
|
get_controlnet_model_from_app,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintBackend(str, Enum):
|
||||||
|
SDXL_INPAINT = "sdxl_inpaint"
|
||||||
|
CONTROLNET = "controlnet"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnifiedInpaintConfig:
|
||||||
|
backend: InpaintBackend = InpaintBackend.SDXL_INPAINT
|
||||||
|
device: str | None = None
|
||||||
|
# SDXL base model (HF id 或本地目录),不填则用 config.py 的默认值
|
||||||
|
sdxl_base_model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnifiedDrawConfig:
|
||||||
|
"""
|
||||||
|
统一绘图配置:
|
||||||
|
- 纯文生图:image=None
|
||||||
|
- 图生图(模仿输入图):image=某张参考图
|
||||||
|
"""
|
||||||
|
device: str | None = None
|
||||||
|
sdxl_base_model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_device_and_dtype(device: str | None):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
if device is None:
|
||||||
|
device = app_cfg.inpaint.device
|
||||||
|
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||||
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||||
|
return device, torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_memory_opts(pipe, device: str) -> None:
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_vae_tiling()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _align_size(orig_w: int, orig_h: int, max_side: int) -> tuple[int, int]:
|
||||||
|
run_w, run_h = orig_w, orig_h
|
||||||
|
if max_side > 0 and max(orig_w, orig_h) > max_side:
|
||||||
|
scale = max_side / float(max(orig_w, orig_h))
|
||||||
|
run_w = int(round(orig_w * scale))
|
||||||
|
run_h = int(round(orig_h * scale))
|
||||||
|
run_w = max(8, run_w - (run_w % 8))
|
||||||
|
run_h = max(8, run_h - (run_h % 8))
|
||||||
|
return run_w, run_h
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sdxl_inpaint_predictor(
|
||||||
|
cfg: UnifiedInpaintConfig,
|
||||||
|
) -> Callable[[Image.Image, Image.Image, str, str], Image.Image]:
|
||||||
|
"""
|
||||||
|
返回补全函数:
|
||||||
|
- 输入:image(PIL RGB), mask(PIL L/1), prompt, negative_prompt
|
||||||
|
- 输出:PIL RGB 结果图
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
|
||||||
|
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||||
|
|
||||||
|
device = cfg.device
|
||||||
|
if device is None:
|
||||||
|
device = app_cfg.inpaint.device
|
||||||
|
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||||
|
|
||||||
|
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
variant="fp16" if device == "cuda" else None,
|
||||||
|
use_safetensors=True,
|
||||||
|
).to(device)
|
||||||
|
pipe = AutoPipelineForInpainting.from_pipe(pipe_t2i).to(device)
|
||||||
|
|
||||||
|
# 省显存设置(尽量不改变输出语义)
|
||||||
|
# 注意:CPU offload 会明显变慢,但能显著降低显存占用。
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_vae_tiling()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _predict(
|
||||||
|
image: Image.Image,
|
||||||
|
mask: Image.Image,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
strength: float = 0.8,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
max_side: int = 1024,
|
||||||
|
) -> Image.Image:
|
||||||
|
image = image.convert("RGB")
|
||||||
|
# diffusers 要求 mask 为单通道,白色区域为需要重绘
|
||||||
|
mask = mask.convert("L")
|
||||||
|
|
||||||
|
# SDXL / diffusers 通常要求宽高为 8 的倍数;同时为了避免 OOM,
|
||||||
|
# 推理时将图像按比例缩放到不超过 max_side(默认 1024)并对齐到 8 的倍数。
|
||||||
|
# 推理后再 resize 回原始尺寸,保证输出与原图分辨率一致。
|
||||||
|
orig_w, orig_h = image.size
|
||||||
|
run_w, run_h = orig_w, orig_h
|
||||||
|
|
||||||
|
if max(orig_w, orig_h) > max_side:
|
||||||
|
scale = max_side / float(max(orig_w, orig_h))
|
||||||
|
run_w = int(round(orig_w * scale))
|
||||||
|
run_h = int(round(orig_h * scale))
|
||||||
|
|
||||||
|
run_w = max(8, run_w - (run_w % 8))
|
||||||
|
run_h = max(8, run_h - (run_h % 8))
|
||||||
|
if run_w <= 0:
|
||||||
|
run_w = 8
|
||||||
|
if run_h <= 0:
|
||||||
|
run_h = 8
|
||||||
|
|
||||||
|
if (run_w, run_h) != (orig_w, orig_h):
|
||||||
|
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||||
|
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||||
|
else:
|
||||||
|
image_run = image
|
||||||
|
mask_run = mask
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
out = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
image=image_run,
|
||||||
|
mask_image=mask_run,
|
||||||
|
strength=strength,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
width=run_w,
|
||||||
|
height=run_h,
|
||||||
|
).images[0]
|
||||||
|
out = out.convert("RGB")
|
||||||
|
if out.size != (orig_w, orig_h):
|
||||||
|
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def _make_controlnet_predictor(_: UnifiedInpaintConfig):
|
||||||
|
import torch
|
||||||
|
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||||
|
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
|
||||||
|
device = _.device
|
||||||
|
if device is None:
|
||||||
|
device = app_cfg.inpaint.device
|
||||||
|
device = "cuda" if device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||||
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
||||||
|
|
||||||
|
base_model = get_controlnet_base_model_from_app(app_cfg)
|
||||||
|
controlnet_id = get_controlnet_model_from_app(app_cfg)
|
||||||
|
|
||||||
|
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch_dtype)
|
||||||
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
controlnet=controlnet,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
safety_checker=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_vae_tiling()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
except Exception:
|
||||||
|
pipe.to(device)
|
||||||
|
else:
|
||||||
|
pipe.to(device)
|
||||||
|
|
||||||
|
def _predict(
|
||||||
|
image: Image.Image,
|
||||||
|
mask: Image.Image,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
strength: float = 0.8,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
controlnet_conditioning_scale: float = 1.0,
|
||||||
|
max_side: int = 768,
|
||||||
|
) -> Image.Image:
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
image = image.convert("RGB")
|
||||||
|
mask = mask.convert("L")
|
||||||
|
|
||||||
|
orig_w, orig_h = image.size
|
||||||
|
run_w, run_h = orig_w, orig_h
|
||||||
|
if max(orig_w, orig_h) > max_side:
|
||||||
|
scale = max_side / float(max(orig_w, orig_h))
|
||||||
|
run_w = int(round(orig_w * scale))
|
||||||
|
run_h = int(round(orig_h * scale))
|
||||||
|
run_w = max(8, run_w - (run_w % 8))
|
||||||
|
run_h = max(8, run_h - (run_h % 8))
|
||||||
|
|
||||||
|
if (run_w, run_h) != (orig_w, orig_h):
|
||||||
|
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||||
|
mask_run = mask.resize((run_w, run_h), resample=Image.NEAREST)
|
||||||
|
else:
|
||||||
|
image_run = image
|
||||||
|
mask_run = mask
|
||||||
|
|
||||||
|
# control image:使用 canny 边缘作为约束(最通用)
|
||||||
|
rgb = np.array(image_run, dtype=np.uint8)
|
||||||
|
edges = cv2.Canny(rgb, 100, 200)
|
||||||
|
edges3 = np.stack([edges, edges, edges], axis=-1)
|
||||||
|
control_image = Image.fromarray(edges3)
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
out = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
image=image_run,
|
||||||
|
mask_image=mask_run,
|
||||||
|
control_image=control_image,
|
||||||
|
strength=strength,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||||
|
width=run_w,
|
||||||
|
height=run_h,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
out = out.convert("RGB")
|
||||||
|
if out.size != (orig_w, orig_h):
|
||||||
|
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
def build_inpaint_predictor(
|
||||||
|
cfg: UnifiedInpaintConfig | None = None,
|
||||||
|
) -> tuple[Callable[..., Image.Image], InpaintBackend]:
|
||||||
|
"""
|
||||||
|
统一构建补全预测函数。
|
||||||
|
"""
|
||||||
|
cfg = cfg or UnifiedInpaintConfig()
|
||||||
|
|
||||||
|
if cfg.backend == InpaintBackend.SDXL_INPAINT:
|
||||||
|
return _make_sdxl_inpaint_predictor(cfg), InpaintBackend.SDXL_INPAINT
|
||||||
|
|
||||||
|
if cfg.backend == InpaintBackend.CONTROLNET:
|
||||||
|
return _make_controlnet_predictor(cfg), InpaintBackend.CONTROLNET
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的补全后端: {cfg.backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_draw_predictor(
|
||||||
|
cfg: UnifiedDrawConfig | None = None,
|
||||||
|
) -> Callable[..., Image.Image]:
|
||||||
|
"""
|
||||||
|
构建统一绘图函数:
|
||||||
|
- 文生图:draw(prompt, image=None, ...)
|
||||||
|
- 图生图:draw(prompt, image=ref_image, strength=0.55, ...)
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||||
|
|
||||||
|
cfg = cfg or UnifiedDrawConfig()
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
base_model = cfg.sdxl_base_model or get_sdxl_base_model_from_app(app_cfg)
|
||||||
|
device, torch_dtype = _resolve_device_and_dtype(cfg.device)
|
||||||
|
|
||||||
|
pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
variant="fp16" if device == "cuda" else None,
|
||||||
|
use_safetensors=True,
|
||||||
|
).to(device)
|
||||||
|
pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i).to(device)
|
||||||
|
|
||||||
|
_enable_memory_opts(pipe_t2i, device)
|
||||||
|
_enable_memory_opts(pipe_i2i, device)
|
||||||
|
|
||||||
|
def _draw(
|
||||||
|
prompt: str,
|
||||||
|
image: Image.Image | None = None,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
strength: float = 0.55,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
width: int = 1024,
|
||||||
|
height: int = 1024,
|
||||||
|
max_side: int = 1024,
|
||||||
|
) -> Image.Image:
|
||||||
|
prompt = prompt or ""
|
||||||
|
negative_prompt = negative_prompt or ""
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
run_w, run_h = _align_size(width, height, max_side=max_side)
|
||||||
|
out = pipe_t2i(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
width=run_w,
|
||||||
|
height=run_h,
|
||||||
|
).images[0]
|
||||||
|
return out.convert("RGB")
|
||||||
|
|
||||||
|
image = image.convert("RGB")
|
||||||
|
orig_w, orig_h = image.size
|
||||||
|
run_w, run_h = _align_size(orig_w, orig_h, max_side=max_side)
|
||||||
|
if (run_w, run_h) != (orig_w, orig_h):
|
||||||
|
image_run = image.resize((run_w, run_h), resample=Image.BICUBIC)
|
||||||
|
else:
|
||||||
|
image_run = image
|
||||||
|
|
||||||
|
out = pipe_i2i(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
image=image_run,
|
||||||
|
strength=strength,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
).images[0].convert("RGB")
|
||||||
|
|
||||||
|
if out.size != (orig_w, orig_h):
|
||||||
|
out = out.resize((orig_w, orig_h), resample=Image.BICUBIC)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return _draw
|
||||||
|
|
||||||
1
python_server/model/Inpaint/sdxl-inpaint
Submodule
1
python_server/model/Inpaint/sdxl-inpaint
Submodule
Submodule python_server/model/Inpaint/sdxl-inpaint added at 29867f540b
1
python_server/model/Seg/Mask2Former
Submodule
1
python_server/model/Seg/Mask2Former
Submodule
Submodule python_server/model/Seg/Mask2Former added at 9b0651c6c1
1
python_server/model/Seg/__init__.py
Normal file
1
python_server/model/Seg/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
59
python_server/model/Seg/mask2former_loader.py
Normal file
59
python_server/model/Seg/mask2former_loader.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Literal, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mask2FormerHFConfig:
|
||||||
|
"""
|
||||||
|
使用 HuggingFace transformers 版本的 Mask2Former 语义分割。
|
||||||
|
|
||||||
|
model_id: HuggingFace 模型 id(默认 ADE20K semantic)
|
||||||
|
device: "cuda" | "cpu"
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = "facebook/mask2former-swin-large-ade-semantic"
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask2former_hf_predictor(
|
||||||
|
cfg: Mask2FormerHFConfig | None = None,
|
||||||
|
) -> Tuple[Callable[[np.ndarray], np.ndarray], Mask2FormerHFConfig]:
|
||||||
|
"""
|
||||||
|
返回 predictor(image_rgb_uint8) -> label_map(int32)。
|
||||||
|
"""
|
||||||
|
cfg = cfg or Mask2FormerHFConfig()
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
||||||
|
|
||||||
|
device = "cuda" if cfg.device.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
processor = AutoImageProcessor.from_pretrained(cfg.model_id)
|
||||||
|
model = Mask2FormerForUniversalSegmentation.from_pretrained(cfg.model_id)
|
||||||
|
model.to(device).eval()
|
||||||
|
|
||||||
|
cfg = Mask2FormerHFConfig(model_id=cfg.model_id, device=device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||||
|
if image_rgb.dtype != np.uint8:
|
||||||
|
image_rgb_u8 = image_rgb.astype("uint8")
|
||||||
|
else:
|
||||||
|
image_rgb_u8 = image_rgb
|
||||||
|
|
||||||
|
pil = Image.fromarray(image_rgb_u8, mode="RGB")
|
||||||
|
inputs = processor(images=pil, return_tensors="pt").to(device)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# post-process to original size
|
||||||
|
target_sizes = [(pil.height, pil.width)]
|
||||||
|
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)[0]
|
||||||
|
return seg.detach().to("cpu").numpy().astype("int32")
|
||||||
|
|
||||||
|
return _predict, cfg
|
||||||
|
|
||||||
168
python_server/model/Seg/seg_loader.py
Normal file
168
python_server/model/Seg/seg_loader.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
"""
|
||||||
|
统一的分割模型加载入口。
|
||||||
|
|
||||||
|
当前支持:
|
||||||
|
- SAM (segment-anything)
|
||||||
|
- Mask2Former(使用 HuggingFace transformers 的语义分割实现)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_THIS_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
class SegBackend(str, Enum):
|
||||||
|
SAM = "sam"
|
||||||
|
MASK2FORMER = "mask2former"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnifiedSegConfig:
|
||||||
|
backend: SegBackend = SegBackend.SAM
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# SAM (Segment Anything)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_sam_on_path() -> Path:
|
||||||
|
sam_root = _THIS_DIR / "segment-anything"
|
||||||
|
if not sam_root.is_dir():
|
||||||
|
raise FileNotFoundError(f"未找到 segment-anything 仓库目录: {sam_root}")
|
||||||
|
sam_path = str(sam_root)
|
||||||
|
if sam_path not in sys.path:
|
||||||
|
sys.path.insert(0, sam_path)
|
||||||
|
return sam_root
|
||||||
|
|
||||||
|
|
||||||
|
def _download_sam_checkpoint_if_needed(sam_root: Path) -> Path:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
ckpt_dir = sam_root / "checkpoints"
|
||||||
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
ckpt_path = ckpt_dir / "sam_vit_h_4b8939.pth"
|
||||||
|
|
||||||
|
if ckpt_path.is_file():
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
|
url = (
|
||||||
|
"https://dl.fbaipublicfiles.com/segment_anything/"
|
||||||
|
"sam_vit_h_4b8939.pth"
|
||||||
|
)
|
||||||
|
print(f"自动下载 SAM 权重:\n {url}\n -> {ckpt_path}")
|
||||||
|
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
total = int(resp.headers.get("content-length", "0") or "0")
|
||||||
|
downloaded = 0
|
||||||
|
chunk_size = 1024 * 1024
|
||||||
|
|
||||||
|
with ckpt_path.open("wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
if total > 0:
|
||||||
|
done = int(50 * downloaded / total)
|
||||||
|
print(
|
||||||
|
"\r[{}{}] {:.1f}%".format(
|
||||||
|
"#" * done,
|
||||||
|
"." * (50 - done),
|
||||||
|
downloaded * 100 / total,
|
||||||
|
),
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
print("\nSAM 权重下载完成。")
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sam_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||||
|
"""
|
||||||
|
返回一个分割函数:
|
||||||
|
- 输入:RGB uint8 图像 (H, W, 3)
|
||||||
|
- 输出:语义标签图 (H, W),每个目标一个 int id(从 1 开始)
|
||||||
|
"""
|
||||||
|
sam_root = _ensure_sam_on_path()
|
||||||
|
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||||
|
|
||||||
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore[import]
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
sam = sam_model_registry["vit_h"](
|
||||||
|
checkpoint=str(ckpt_path),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||||
|
|
||||||
|
def _predict(image_rgb: np.ndarray) -> np.ndarray:
|
||||||
|
if image_rgb.dtype != np.uint8:
|
||||||
|
image_rgb_u8 = image_rgb.astype("uint8")
|
||||||
|
else:
|
||||||
|
image_rgb_u8 = image_rgb
|
||||||
|
|
||||||
|
masks = mask_generator.generate(image_rgb_u8)
|
||||||
|
h, w, _ = image_rgb_u8.shape
|
||||||
|
label_map = np.zeros((h, w), dtype="int32")
|
||||||
|
|
||||||
|
for idx, m in enumerate(masks, start=1):
|
||||||
|
seg = m.get("segmentation")
|
||||||
|
if seg is None:
|
||||||
|
continue
|
||||||
|
label_map[seg.astype(bool)] = idx
|
||||||
|
|
||||||
|
return label_map
|
||||||
|
|
||||||
|
return _predict
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Mask2Former (占位)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mask2former_predictor() -> Callable[[np.ndarray], np.ndarray]:
|
||||||
|
from .mask2former_loader import build_mask2former_hf_predictor
|
||||||
|
|
||||||
|
predictor, _ = build_mask2former_hf_predictor()
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 统一构建函数
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def build_seg_predictor(
|
||||||
|
cfg: UnifiedSegConfig | None = None,
|
||||||
|
) -> tuple[Callable[[np.ndarray], np.ndarray], SegBackend]:
|
||||||
|
"""
|
||||||
|
统一构建分割预测函数。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- predictor(image_rgb: np.ndarray[H, W, 3], uint8) -> np.ndarray[H, W], int32
|
||||||
|
- 实际使用的 backend
|
||||||
|
"""
|
||||||
|
cfg = cfg or UnifiedSegConfig()
|
||||||
|
|
||||||
|
if cfg.backend == SegBackend.SAM:
|
||||||
|
return _make_sam_predictor(), SegBackend.SAM
|
||||||
|
|
||||||
|
if cfg.backend == SegBackend.MASK2FORMER:
|
||||||
|
return _make_mask2former_predictor(), SegBackend.MASK2FORMER
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的分割后端: {cfg.backend}")
|
||||||
|
|
||||||
1
python_server/model/Seg/segment-anything
Submodule
1
python_server/model/Seg/segment-anything
Submodule
Submodule python_server/model/Seg/segment-anything added at dca509fe79
1
python_server/model/__init__.py
Normal file
1
python_server/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
407
python_server/server.py
Normal file
407
python_server/server.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from config_loader import load_app_config, get_depth_backend_from_app
|
||||||
|
from model.Depth.depth_loader import UnifiedDepthConfig, DepthBackend, build_depth_predictor
|
||||||
|
|
||||||
|
from model.Seg.seg_loader import UnifiedSegConfig, SegBackend, build_seg_predictor
|
||||||
|
from model.Inpaint.inpaint_loader import UnifiedInpaintConfig, InpaintBackend, build_inpaint_predictor
|
||||||
|
from model.Animation.animation_loader import (
|
||||||
|
UnifiedAnimationConfig,
|
||||||
|
AnimationBackend,
|
||||||
|
build_animation_predictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
APP_ROOT = Path(__file__).resolve().parent
|
||||||
|
OUTPUT_DIR = APP_ROOT / "outputs"
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
app = FastAPI(title="HFUT Model Server", version="0.1.0")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageInput(BaseModel):
|
||||||
|
image_b64: str = Field(..., description="PNG/JPG 编码后的 base64(不含 data: 前缀)")
|
||||||
|
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||||||
|
|
||||||
|
|
||||||
|
class DepthRequest(ImageInput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentRequest(ImageInput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintRequest(ImageInput):
|
||||||
|
prompt: Optional[str] = Field("", description="补全 prompt")
|
||||||
|
strength: float = Field(0.8, ge=0.0, le=1.0)
|
||||||
|
negative_prompt: Optional[str] = Field("", description="负向 prompt")
|
||||||
|
# 可选 mask(白色区域为重绘)
|
||||||
|
mask_b64: Optional[str] = Field(None, description="mask PNG base64(可选)")
|
||||||
|
# 推理缩放上限(避免 OOM)
|
||||||
|
max_side: int = Field(1024, ge=128, le=2048)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimateRequest(BaseModel):
|
||||||
|
model_name: Optional[str] = Field(None, description="模型 key(来自 /models)")
|
||||||
|
prompt: str = Field(..., description="文本提示词")
|
||||||
|
negative_prompt: Optional[str] = Field("", description="负向提示词")
|
||||||
|
num_inference_steps: int = Field(25, ge=1, le=200)
|
||||||
|
guidance_scale: float = Field(8.0, ge=0.0, le=30.0)
|
||||||
|
width: int = Field(512, ge=128, le=2048)
|
||||||
|
height: int = Field(512, ge=128, le=2048)
|
||||||
|
video_length: int = Field(16, ge=1, le=128)
|
||||||
|
seed: int = Field(-1, description="-1 表示随机种子")
|
||||||
|
|
||||||
|
|
||||||
|
def _b64_to_pil_image(b64: str) -> Image.Image:
|
||||||
|
raw = base64.b64decode(b64)
|
||||||
|
return Image.open(io.BytesIO(raw))
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_image_to_png_b64(img: Image.Image) -> str:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="PNG")
|
||||||
|
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def _depth_to_png16_b64(depth: np.ndarray) -> str:
|
||||||
|
depth = np.asarray(depth, dtype=np.float32)
|
||||||
|
dmin = float(depth.min())
|
||||||
|
dmax = float(depth.max())
|
||||||
|
if dmax > dmin:
|
||||||
|
norm = (depth - dmin) / (dmax - dmin)
|
||||||
|
else:
|
||||||
|
norm = np.zeros_like(depth, dtype=np.float32)
|
||||||
|
u16 = (norm * 65535.0).clip(0, 65535).astype(np.uint16)
|
||||||
|
img = Image.fromarray(u16, mode="I;16")
|
||||||
|
return _pil_image_to_png_b64(img)
|
||||||
|
|
||||||
|
def _depth_to_png16_bytes(depth: np.ndarray) -> bytes:
|
||||||
|
depth = np.asarray(depth, dtype=np.float32)
|
||||||
|
dmin = float(depth.min())
|
||||||
|
dmax = float(depth.max())
|
||||||
|
if dmax > dmin:
|
||||||
|
norm = (depth - dmin) / (dmax - dmin)
|
||||||
|
else:
|
||||||
|
norm = np.zeros_like(depth, dtype=np.float32)
|
||||||
|
# 前后端约定:最远=255,最近=0(8-bit)
|
||||||
|
u8 = ((1.0 - norm) * 255.0).clip(0, 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(u8, mode="L")
|
||||||
|
buf = io.BytesIO()
|
||||||
|
img.save(buf, format="PNG")
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def _default_half_mask(img: Image.Image) -> Image.Image:
|
||||||
|
w, h = img.size
|
||||||
|
mask = Image.new("L", (w, h), 0)
|
||||||
|
draw = ImageDraw.Draw(mask)
|
||||||
|
draw.rectangle([w // 2, 0, w, h], fill=255)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# /models(给前端/GUI 使用)
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/models")
|
||||||
|
def get_models() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
返回一个兼容 Qt 前端的 schema:
|
||||||
|
{
|
||||||
|
"models": {
|
||||||
|
"depth": { "key": { "name": "..."} ... },
|
||||||
|
"segment": { ... },
|
||||||
|
"inpaint": { "key": { "name": "...", "params": [...] } ... }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"models": {
|
||||||
|
"depth": {
|
||||||
|
# 兼容旧配置默认值
|
||||||
|
"midas": {"name": "MiDaS (default)"},
|
||||||
|
"zoedepth_n": {"name": "ZoeDepth (ZoeD_N)"},
|
||||||
|
"zoedepth_k": {"name": "ZoeDepth (ZoeD_K)"},
|
||||||
|
"zoedepth_nk": {"name": "ZoeDepth (ZoeD_NK)"},
|
||||||
|
"depth_anything_v2_vits": {"name": "Depth Anything V2 (vits)"},
|
||||||
|
"depth_anything_v2_vitb": {"name": "Depth Anything V2 (vitb)"},
|
||||||
|
"depth_anything_v2_vitl": {"name": "Depth Anything V2 (vitl)"},
|
||||||
|
"depth_anything_v2_vitg": {"name": "Depth Anything V2 (vitg)"},
|
||||||
|
"dpt_large": {"name": "DPT (large)"},
|
||||||
|
"dpt_hybrid": {"name": "DPT (hybrid)"},
|
||||||
|
"midas_dpt_beit_large_512": {"name": "MiDaS (dpt_beit_large_512)"},
|
||||||
|
"midas_dpt_swin2_large_384": {"name": "MiDaS (dpt_swin2_large_384)"},
|
||||||
|
"midas_dpt_swin2_tiny_256": {"name": "MiDaS (dpt_swin2_tiny_256)"},
|
||||||
|
"midas_dpt_levit_224": {"name": "MiDaS (dpt_levit_224)"},
|
||||||
|
},
|
||||||
|
"segment": {
|
||||||
|
"sam": {"name": "SAM (vit_h)"},
|
||||||
|
# 兼容旧配置默认值
|
||||||
|
"mask2former_debug": {"name": "SAM (compat mask2former_debug)"},
|
||||||
|
"mask2former": {"name": "Mask2Former (not implemented)"},
|
||||||
|
},
|
||||||
|
"inpaint": {
|
||||||
|
# 兼容旧配置默认值:copy 表示不做补全
|
||||||
|
"copy": {"name": "Copy (no-op)", "params": []},
|
||||||
|
"sdxl_inpaint": {
|
||||||
|
"name": "SDXL Inpaint",
|
||||||
|
"params": [
|
||||||
|
{"id": "prompt", "label": "提示词", "optional": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"controlnet": {
|
||||||
|
"name": "ControlNet Inpaint (canny)",
|
||||||
|
"params": [
|
||||||
|
{"id": "prompt", "label": "提示词", "optional": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"animation": {
|
||||||
|
"animatediff": {
|
||||||
|
"name": "AnimateDiff (Text-to-Video)",
|
||||||
|
"params": [
|
||||||
|
{"id": "prompt", "label": "提示词", "optional": False},
|
||||||
|
{"id": "negative_prompt", "label": "负向提示词", "optional": True},
|
||||||
|
{"id": "num_inference_steps", "label": "采样步数", "optional": True},
|
||||||
|
{"id": "guidance_scale", "label": "CFG Scale", "optional": True},
|
||||||
|
{"id": "width", "label": "宽度", "optional": True},
|
||||||
|
{"id": "height", "label": "高度", "optional": True},
|
||||||
|
{"id": "video_length", "label": "帧数", "optional": True},
|
||||||
|
{"id": "seed", "label": "随机种子", "optional": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Depth
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_depth_predictor = None
|
||||||
|
_depth_backend: DepthBackend | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_depth_predictor() -> None:
|
||||||
|
global _depth_predictor, _depth_backend
|
||||||
|
if _depth_predictor is not None and _depth_backend is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
app_cfg = load_app_config()
|
||||||
|
backend_str = get_depth_backend_from_app(app_cfg)
|
||||||
|
try:
|
||||||
|
backend = DepthBackend(backend_str)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"config.py 中 depth.backend 不合法: {backend_str}") from e
|
||||||
|
|
||||||
|
_depth_predictor, _depth_backend = build_depth_predictor(UnifiedDepthConfig(backend=backend))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/depth")
|
||||||
|
def depth(req: DepthRequest):
|
||||||
|
"""
|
||||||
|
计算深度并直接返回二进制 PNG(16-bit 灰度)。
|
||||||
|
|
||||||
|
约束:
|
||||||
|
- 前端不传/不选模型;模型选择写死在后端 config.py
|
||||||
|
- 成功:HTTP 200 + Content-Type: image/png
|
||||||
|
- 失败:HTTP 500,detail 为错误信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_ensure_depth_predictor()
|
||||||
|
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||||
|
depth_arr = _depth_predictor(pil) # type: ignore[misc]
|
||||||
|
png_bytes = _depth_to_png16_bytes(np.asarray(depth_arr))
|
||||||
|
return Response(content=png_bytes, media_type="image/png")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Segment
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_seg_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_seg_predictor(model_name: str):
|
||||||
|
if model_name in _seg_cache:
|
||||||
|
return _seg_cache[model_name]
|
||||||
|
|
||||||
|
# 兼容旧默认 key
|
||||||
|
if model_name == "mask2former_debug":
|
||||||
|
model_name = "sam"
|
||||||
|
|
||||||
|
if model_name == "sam":
|
||||||
|
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.SAM))
|
||||||
|
_seg_cache[model_name] = pred
|
||||||
|
return pred
|
||||||
|
|
||||||
|
if model_name == "mask2former":
|
||||||
|
pred, _ = build_seg_predictor(UnifiedSegConfig(backend=SegBackend.MASK2FORMER))
|
||||||
|
_seg_cache[model_name] = pred
|
||||||
|
return pred
|
||||||
|
|
||||||
|
raise ValueError(f"未知 segment model_name: {model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/segment")
|
||||||
|
def segment(req: SegmentRequest) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
model_name = req.model_name or "sam"
|
||||||
|
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||||
|
rgb = np.array(pil, dtype=np.uint8)
|
||||||
|
|
||||||
|
predictor = _get_seg_predictor(model_name)
|
||||||
|
label_map = predictor(rgb).astype(np.int32)
|
||||||
|
|
||||||
|
out_dir = OUTPUT_DIR / "segment"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = out_dir / f"{model_name}_label.png"
|
||||||
|
# 保存为 8-bit 灰度(若 label 超过 255 会截断;当前 SAM 通常不会太大)
|
||||||
|
Image.fromarray(np.clip(label_map, 0, 255).astype(np.uint8), mode="L").save(out_path)
|
||||||
|
|
||||||
|
return {"success": True, "label_path": str(out_path)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Inpaint
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_inpaint_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_inpaint_predictor(model_name: str):
|
||||||
|
if model_name in _inpaint_cache:
|
||||||
|
return _inpaint_cache[model_name]
|
||||||
|
|
||||||
|
if model_name == "copy":
|
||||||
|
def _copy(image: Image.Image, *_args, **_kwargs) -> Image.Image:
|
||||||
|
return image.convert("RGB")
|
||||||
|
_inpaint_cache[model_name] = _copy
|
||||||
|
return _copy
|
||||||
|
|
||||||
|
if model_name == "sdxl_inpaint":
|
||||||
|
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.SDXL_INPAINT))
|
||||||
|
_inpaint_cache[model_name] = pred
|
||||||
|
return pred
|
||||||
|
|
||||||
|
if model_name == "controlnet":
|
||||||
|
pred, _ = build_inpaint_predictor(UnifiedInpaintConfig(backend=InpaintBackend.CONTROLNET))
|
||||||
|
_inpaint_cache[model_name] = pred
|
||||||
|
return pred
|
||||||
|
|
||||||
|
raise ValueError(f"未知 inpaint model_name: {model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/inpaint")
|
||||||
|
def inpaint(req: InpaintRequest) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
model_name = req.model_name or "sdxl_inpaint"
|
||||||
|
pil = _b64_to_pil_image(req.image_b64).convert("RGB")
|
||||||
|
|
||||||
|
if req.mask_b64:
|
||||||
|
mask = _b64_to_pil_image(req.mask_b64).convert("L")
|
||||||
|
else:
|
||||||
|
mask = _default_half_mask(pil)
|
||||||
|
|
||||||
|
predictor = _get_inpaint_predictor(model_name)
|
||||||
|
out = predictor(
|
||||||
|
pil,
|
||||||
|
mask,
|
||||||
|
req.prompt or "",
|
||||||
|
req.negative_prompt or "",
|
||||||
|
strength=req.strength,
|
||||||
|
max_side=req.max_side,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = OUTPUT_DIR / "inpaint"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = out_dir / f"{model_name}_inpaint.png"
|
||||||
|
out.save(out_path)
|
||||||
|
|
||||||
|
return {"success": True, "output_path": str(out_path)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
_animation_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_animation_predictor(model_name: str):
|
||||||
|
if model_name in _animation_cache:
|
||||||
|
return _animation_cache[model_name]
|
||||||
|
|
||||||
|
if model_name == "animatediff":
|
||||||
|
pred, _ = build_animation_predictor(
|
||||||
|
UnifiedAnimationConfig(backend=AnimationBackend.ANIMATEDIFF)
|
||||||
|
)
|
||||||
|
_animation_cache[model_name] = pred
|
||||||
|
return pred
|
||||||
|
|
||||||
|
raise ValueError(f"未知 animation model_name: {model_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/animate")
|
||||||
|
def animate(req: AnimateRequest) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
model_name = req.model_name or "animatediff"
|
||||||
|
predictor = _get_animation_predictor(model_name)
|
||||||
|
|
||||||
|
result_path = predictor(
|
||||||
|
prompt=req.prompt,
|
||||||
|
negative_prompt=req.negative_prompt or "",
|
||||||
|
num_inference_steps=req.num_inference_steps,
|
||||||
|
guidance_scale=req.guidance_scale,
|
||||||
|
width=req.width,
|
||||||
|
height=req.height,
|
||||||
|
video_length=req.video_length,
|
||||||
|
seed=req.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = OUTPUT_DIR / "animation"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
out_path = out_dir / f"{model_name}_{ts}.gif"
|
||||||
|
shutil.copy2(result_path, out_path)
|
||||||
|
|
||||||
|
return {"success": True, "output_path": str(out_path)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health() -> Dict[str, str]:
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
port = int(os.environ.get("MODEL_SERVER_PORT", "8000"))
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||||
|
|
||||||
77
python_server/test_animation.py
Normal file
77
python_server/test_animation.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from model.Animation.animation_loader import (
|
||||||
|
build_animation_predictor,
|
||||||
|
UnifiedAnimationConfig,
|
||||||
|
AnimationBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 配置区(按需修改)
|
||||||
|
# -----------------------------
|
||||||
|
OUTPUT_DIR = "outputs/test_animation"
|
||||||
|
ANIMATION_BACKEND = AnimationBackend.ANIMATEDIFF
|
||||||
|
OUTPUT_FORMAT = "png_sequence" # "gif" | "png_sequence"
|
||||||
|
|
||||||
|
PROMPT = "a cinematic mountain landscape, camera slowly pans left"
|
||||||
|
NEGATIVE_PROMPT = "blurry, low quality"
|
||||||
|
NUM_INFERENCE_STEPS = 25
|
||||||
|
GUIDANCE_SCALE = 8.0
|
||||||
|
WIDTH = 512
|
||||||
|
HEIGHT = 512
|
||||||
|
VIDEO_LENGTH = 16
|
||||||
|
SEED = -1
|
||||||
|
CONTROL_IMAGE_PATH = "path/to/your_image.png"
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
out_dir = base_dir / OUTPUT_DIR
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
predictor, used_backend = build_animation_predictor(
|
||||||
|
UnifiedAnimationConfig(backend=ANIMATION_BACKEND)
|
||||||
|
)
|
||||||
|
|
||||||
|
if CONTROL_IMAGE_PATH.strip() in {"", "path/to/your_image.png"}:
|
||||||
|
raise ValueError("请先设置 CONTROL_IMAGE_PATH 为你的输入图片路径(png/jpg)。")
|
||||||
|
|
||||||
|
control_image = (base_dir / CONTROL_IMAGE_PATH).resolve()
|
||||||
|
if not control_image.is_file():
|
||||||
|
raise FileNotFoundError(f"control image not found: {control_image}")
|
||||||
|
|
||||||
|
result_path = predictor(
|
||||||
|
prompt=PROMPT,
|
||||||
|
negative_prompt=NEGATIVE_PROMPT,
|
||||||
|
num_inference_steps=NUM_INFERENCE_STEPS,
|
||||||
|
guidance_scale=GUIDANCE_SCALE,
|
||||||
|
width=WIDTH,
|
||||||
|
height=HEIGHT,
|
||||||
|
video_length=VIDEO_LENGTH,
|
||||||
|
seed=SEED,
|
||||||
|
control_image_path=str(control_image),
|
||||||
|
output_format=OUTPUT_FORMAT,
|
||||||
|
)
|
||||||
|
|
||||||
|
source = Path(result_path)
|
||||||
|
if OUTPUT_FORMAT == "png_sequence":
|
||||||
|
out_seq_dir = out_dir / f"{used_backend.value}_frames"
|
||||||
|
if out_seq_dir.exists():
|
||||||
|
shutil.rmtree(out_seq_dir)
|
||||||
|
shutil.copytree(source, out_seq_dir)
|
||||||
|
print(f"[Animation] backend={used_backend.value}, saved={out_seq_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
out_path = out_dir / f"{used_backend.value}.gif"
|
||||||
|
out_path.write_bytes(source.read_bytes())
|
||||||
|
print(f"[Animation] backend={used_backend.value}, saved={out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
101
python_server/test_depth.py
Normal file
101
python_server/test_depth.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# 解除大图限制
|
||||||
|
Image.MAX_IMAGE_PIXELS = None
|
||||||
|
|
||||||
|
from model.Depth.depth_loader import build_depth_predictor, UnifiedDepthConfig, DepthBackend
|
||||||
|
|
||||||
|
# ================= 配置区 =================
|
||||||
|
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
|
||||||
|
OUTPUT_DIR = "outputs/test_depth_v4"
|
||||||
|
DEPTH_BACKEND = DepthBackend.DEPTH_ANYTHING_V2
|
||||||
|
|
||||||
|
# 边缘捕捉参数
|
||||||
|
# 增大这个值会让边缘更细,减小会让边缘更粗(捕捉更多微弱信息)
|
||||||
|
EDGE_TOP_PERCENTILE = 96.0 # 选取梯度最强的前 7.0% 的像素
|
||||||
|
# 局部增强的灵敏度,建议在 2.0 - 5.0 之间
|
||||||
|
CLAHE_CLIP_LIMIT = 3.0
|
||||||
|
# 形态学核大小
|
||||||
|
MORPH_SIZE = 5
|
||||||
|
# =========================================
|
||||||
|
|
||||||
|
def _extract_robust_edges(depth_norm: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
通过局部增强和 Sobel 梯度提取闭合边缘
|
||||||
|
"""
|
||||||
|
# 1. 转换为 8 位灰度
|
||||||
|
depth_u8 = (depth_norm * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# 2. 【核心步骤】CLAHE 局部自适应对比度增强
|
||||||
|
# 这会强行放大古画中细微的建筑/人物深度差
|
||||||
|
clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=(16, 16))
|
||||||
|
enhanced_depth = clahe.apply(depth_u8)
|
||||||
|
|
||||||
|
# 3. 高斯模糊:减少数字化噪声
|
||||||
|
blurred = cv2.GaussianBlur(enhanced_depth, (5, 5), 0)
|
||||||
|
|
||||||
|
# 4. Sobel 算子计算梯度强度
|
||||||
|
grad_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
|
||||||
|
grad_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)
|
||||||
|
grad_mag = np.sqrt(grad_x**2 + grad_y**2)
|
||||||
|
|
||||||
|
# 5. 【统计学阈值】不再死守固定数值,而是选 Top X%
|
||||||
|
threshold = np.percentile(grad_mag, EDGE_TOP_PERCENTILE)
|
||||||
|
binary_edges = (grad_mag >= threshold).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
# 6. 形态学闭合:桥接裂缝,让线条连起来
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (MORPH_SIZE, MORPH_SIZE))
|
||||||
|
closed = cv2.morphologyEx(binary_edges, cv2.MORPH_CLOSE, kernel)
|
||||||
|
|
||||||
|
# 7. 再次轻微膨胀,为 SAM 提供更好的引导范围
|
||||||
|
final_mask = cv2.dilate(closed, kernel, iterations=1)
|
||||||
|
|
||||||
|
return final_mask
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
img_path = Path(INPUT_IMAGE)
|
||||||
|
out_dir = base_dir / OUTPUT_DIR
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 初始化深度模型
|
||||||
|
predictor, used_backend = build_depth_predictor(UnifiedDepthConfig(backend=DEPTH_BACKEND))
|
||||||
|
|
||||||
|
# 2. 加载图像
|
||||||
|
print(f"[Loading] 正在处理: {img_path.name}")
|
||||||
|
img_pil = Image.open(img_path).convert("RGB")
|
||||||
|
w_orig, h_orig = img_pil.size
|
||||||
|
|
||||||
|
# 3. 深度预测
|
||||||
|
print(f"[Depth] 正在进行深度估计 (Large Image)...")
|
||||||
|
depth = np.asarray(predictor(img_pil), dtype=np.float32).squeeze()
|
||||||
|
|
||||||
|
# 4. 归一化与保存
|
||||||
|
dmin, dmax = depth.min(), depth.max()
|
||||||
|
depth_norm = (depth - dmin) / (dmax - dmin + 1e-8)
|
||||||
|
depth_u16 = (depth_norm * 65535.0).astype(np.uint16)
|
||||||
|
Image.fromarray(depth_u16).save(out_dir / f"{img_path.stem}.depth.png")
|
||||||
|
|
||||||
|
# 5. 提取强鲁棒性边缘
|
||||||
|
print(f"[Edge] 正在应用 CLAHE + Sobel 增强算法提取边缘...")
|
||||||
|
edge_mask = _extract_robust_edges(depth_norm)
|
||||||
|
|
||||||
|
# 6. 导出
|
||||||
|
mask_path = out_dir / f"{img_path.stem}.edge_mask_robust.png"
|
||||||
|
Image.fromarray(edge_mask).save(mask_path)
|
||||||
|
|
||||||
|
edge_ratio = float((edge_mask > 0).sum()) / float(edge_mask.size)
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"提取完成!")
|
||||||
|
print(f"边缘密度: {edge_ratio:.2%} (目标通常应在 1% ~ 8% 之间)")
|
||||||
|
print(f"如果 Mask 依然太黑,请调低 EDGE_TOP_PERCENTILE (如 90.0)")
|
||||||
|
print(f"如果 Mask 太乱,请调高 EDGE_TOP_PERCENTILE (如 96.0)")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
284
python_server/test_inpaint.py
Normal file
284
python_server/test_inpaint.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from model.Inpaint.inpaint_loader import (
|
||||||
|
build_inpaint_predictor,
|
||||||
|
UnifiedInpaintConfig,
|
||||||
|
InpaintBackend,
|
||||||
|
build_draw_predictor,
|
||||||
|
UnifiedDrawConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 配置区(按需修改)
|
||||||
|
# -----------------------------
|
||||||
|
# 任务模式:
|
||||||
|
# - "inpaint": 补全
|
||||||
|
# - "draw": 绘图(文生图 / 图生图)
|
||||||
|
TASK_MODE = "draw"
|
||||||
|
|
||||||
|
# 指向你的输入图像,例如 image_0.png
|
||||||
|
INPUT_IMAGE = "/home/dwh/code/hfut-bishe/python_server/outputs/test_seg_v2/Up the River During Qingming (detail) - Court painters.objects/Up the River During Qingming (detail) - Court painters.obj_06.png"
|
||||||
|
INPUT_MASK = "" # 为空表示不使用手工 mask
|
||||||
|
OUTPUT_DIR = "outputs/test_inpaint_v2"
|
||||||
|
MASK_RECT = (256, 0, 512, 512) # x1, y1, x2, y2 (如果不使用 AUTO_MASK)
|
||||||
|
USE_AUTO_MASK = True # True: 自动从图像推断补全区域
|
||||||
|
AUTO_MASK_BLACK_THRESHOLD = 6 # 自动mask时,接近黑色像素阈值
|
||||||
|
AUTO_MASK_DILATE_ITER = 4 # 增加膨胀,确保树枝边缘被覆盖
|
||||||
|
FALLBACK_TO_RECT_MASK = False # 自动mask为空时是否回退到矩形mask
|
||||||
|
|
||||||
|
# 透明输出控制
|
||||||
|
PRESERVE_TRANSPARENCY = True # 保持输出透明
|
||||||
|
EXPAND_ALPHA_WITH_MASK = True # 将补全区域 alpha 设为不透明
|
||||||
|
NON_BLACK_ALPHA_THRESHOLD = 6 # 无 alpha 输入时,用近黑判定透明背景
|
||||||
|
|
||||||
|
INPAINT_BACKEND = InpaintBackend.SDXL_INPAINT # 推荐使用 SDXL_INPAINT 获得更好效果
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 关键 Prompt 修改:只生成树
|
||||||
|
# -----------------------------
|
||||||
|
# 细化 PROMPT,强调树的种类、叶子和风格,使其与原图融合
|
||||||
|
PROMPT = (
|
||||||
|
"A highly detailed traditional Chinese ink brush painting. "
|
||||||
|
"Restore and complete the existing trees naturally. "
|
||||||
|
"Extend the complex tree branches and dense, varied green and teal leaves. "
|
||||||
|
"Add more harmonious foliage and intricate bark texture to match the style and flow of the original image_0.png trees. "
|
||||||
|
"Focus solely on vegetation. "
|
||||||
|
"Maintain the light beige background color and texture."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 NEGATIVE_PROMPT 显式排除不想要的内容
|
||||||
|
NEGATIVE_PROMPT = (
|
||||||
|
"buildings, architecture, houses, pavilions, temples, windows, doors, "
|
||||||
|
"people, figures, persons, characters, figures, clothing, faces, hands, "
|
||||||
|
"text, writing, characters, words, letters, signatures, seals, stamps, "
|
||||||
|
"calligraphy, objects, artifacts, boxes, baskets, tools, "
|
||||||
|
"extra branches crossing unnaturally, bad composition, watermark, signature"
|
||||||
|
)
|
||||||
|
|
||||||
|
STRENGTH = 0.8 # 保持较高强度以进行生成
|
||||||
|
GUIDANCE_SCALE = 8.0 # 稍微增加,更严格遵循 prompt
|
||||||
|
NUM_INFERENCE_STEPS = 35 # 稍微增加,提升细节
|
||||||
|
MAX_SIDE = 1024
|
||||||
|
CONTROLNET_SCALE = 1.0
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# 绘图(draw)参数
|
||||||
|
# -----------------------------
|
||||||
|
# DRAW_INPUT_IMAGE 为空时:文生图
|
||||||
|
# DRAW_INPUT_IMAGE 不为空时:图生图(按输入图进行模仿/重绘)
|
||||||
|
DRAW_INPUT_IMAGE = ""
|
||||||
|
DRAW_PROMPT = """
|
||||||
|
Chinese ink wash painting, vast snowy river under a pale sky,
|
||||||
|
a small lonely boat at the horizon where water meets sky,
|
||||||
|
an old fisherman wearing a straw hat sits at the stern, fishing quietly,
|
||||||
|
gentle snowfall, misty atmosphere, distant mountains barely visible,
|
||||||
|
minimalist composition, large empty space, soft brush strokes,
|
||||||
|
calm, cold, and silent mood, poetic and serene
|
||||||
|
"""
|
||||||
|
|
||||||
|
DRAW_NEGATIVE_PROMPT = """
|
||||||
|
blurry, low quality, many people, bright colors, modern elements, crowded, noisy
|
||||||
|
"""
|
||||||
|
DRAW_STRENGTH = 0.55 # 仅图生图使用,越大越偏向重绘
|
||||||
|
DRAW_GUIDANCE_SCALE = 9
|
||||||
|
DRAW_STEPS = 64
|
||||||
|
DRAW_WIDTH = 2560
|
||||||
|
DRAW_HEIGHT = 1440
|
||||||
|
DRAW_MAX_SIDE = 2560
|
||||||
|
|
||||||
|
|
||||||
|
def _dilate(mask: np.ndarray, iterations: int = 1) -> np.ndarray:
|
||||||
|
out = mask.astype(bool)
|
||||||
|
for _ in range(max(0, iterations)):
|
||||||
|
p = np.pad(out, ((1, 1), (1, 1)), mode="constant", constant_values=False)
|
||||||
|
out = (
|
||||||
|
p[:-2, :-2] | p[:-2, 1:-1] | p[:-2, 2:]
|
||||||
|
| p[1:-1, :-2] | p[1:-1, 1:-1] | p[1:-1, 2:]
|
||||||
|
| p[2:, :-2] | p[2:, 1:-1] | p[2:, 2:]
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rect_mask(img_size: tuple[int, int]) -> Image.Image:
|
||||||
|
w, h = img_size
|
||||||
|
x1, y1, x2, y2 = MASK_RECT
|
||||||
|
x1 = max(0, min(x1, w - 1))
|
||||||
|
x2 = max(0, min(x2, w - 1))
|
||||||
|
y1 = max(0, min(y1, h - 1))
|
||||||
|
y2 = max(0, min(y2, h - 1))
|
||||||
|
if x2 < x1:
|
||||||
|
x1, x2 = x2, x1
|
||||||
|
if y2 < y1:
|
||||||
|
y1, y2 = y2, y1
|
||||||
|
|
||||||
|
mask = Image.new("L", (w, h), 0)
|
||||||
|
draw = ImageDraw.Draw(mask)
|
||||||
|
draw.rectangle([x1, y1, x2, y2], fill=255)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_mask_from_image(img_path: Path) -> Image.Image:
|
||||||
|
"""
|
||||||
|
自动推断缺失区域:
|
||||||
|
1) 若输入带 alpha,透明区域作为 mask
|
||||||
|
2) 否则将“接近黑色”区域作为候选缺失区域
|
||||||
|
"""
|
||||||
|
raw = Image.open(img_path)
|
||||||
|
arr = np.asarray(raw)
|
||||||
|
|
||||||
|
if arr.ndim == 3 and arr.shape[2] == 4:
|
||||||
|
alpha = arr[:, :, 3]
|
||||||
|
mask_bool = alpha < 250
|
||||||
|
else:
|
||||||
|
rgb = np.asarray(raw.convert("RGB"), dtype=np.uint8)
|
||||||
|
# 抠图后透明区域常被写成黑色,优先把近黑区域视作缺失
|
||||||
|
dark = np.all(rgb <= AUTO_MASK_BLACK_THRESHOLD, axis=-1)
|
||||||
|
mask_bool = dark
|
||||||
|
|
||||||
|
mask_bool = _dilate(mask_bool, AUTO_MASK_DILATE_ITER)
|
||||||
|
return Image.fromarray((mask_bool.astype(np.uint8) * 255), mode="L")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_alpha_from_input(img_path: Path, img_rgb: Image.Image) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
生成输出 alpha:
|
||||||
|
- 输入若有 alpha,优先沿用
|
||||||
|
- 输入若无 alpha,则把接近黑色区域视作透明背景
|
||||||
|
"""
|
||||||
|
raw = Image.open(img_path)
|
||||||
|
arr = np.asarray(raw)
|
||||||
|
if arr.ndim == 3 and arr.shape[2] == 4:
|
||||||
|
return arr[:, :, 3].astype(np.uint8)
|
||||||
|
|
||||||
|
rgb = np.asarray(img_rgb.convert("RGB"), dtype=np.uint8)
|
||||||
|
non_black = np.any(rgb > NON_BLACK_ALPHA_THRESHOLD, axis=-1)
|
||||||
|
return (non_black.astype(np.uint8) * 255)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_or_make_mask(base_dir: Path, img_path: Path, img_rgb: Image.Image) -> Image.Image:
|
||||||
|
if INPUT_MASK:
|
||||||
|
raw_mask_path = Path(INPUT_MASK)
|
||||||
|
mask_path = raw_mask_path if raw_mask_path.is_absolute() else (base_dir / raw_mask_path)
|
||||||
|
if mask_path.is_file():
|
||||||
|
return Image.open(mask_path).convert("L")
|
||||||
|
|
||||||
|
if USE_AUTO_MASK:
|
||||||
|
auto = _auto_mask_from_image(img_path)
|
||||||
|
auto_arr = np.asarray(auto, dtype=np.uint8)
|
||||||
|
if (auto_arr > 0).any():
|
||||||
|
return auto
|
||||||
|
if not FALLBACK_TO_RECT_MASK:
|
||||||
|
raise ValueError("自动mask为空,请检查输入图像是否存在透明/黑色缺失区域。")
|
||||||
|
|
||||||
|
return _make_rect_mask(img_rgb.size)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(base_dir: Path, p: str) -> Path:
|
||||||
|
raw = Path(p)
|
||||||
|
return raw if raw.is_absolute() else (base_dir / raw)
|
||||||
|
|
||||||
|
|
||||||
|
def run_inpaint_test(base_dir: Path, out_dir: Path) -> None:
|
||||||
|
img_path = _resolve_path(base_dir, INPUT_IMAGE)
|
||||||
|
if not img_path.is_file():
|
||||||
|
raise FileNotFoundError(f"找不到输入图像,请修改 INPUT_IMAGE: {img_path}")
|
||||||
|
|
||||||
|
predictor, used_backend = build_inpaint_predictor(
|
||||||
|
UnifiedInpaintConfig(backend=INPAINT_BACKEND)
|
||||||
|
)
|
||||||
|
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
mask = _load_or_make_mask(base_dir, img_path, img)
|
||||||
|
mask_out = out_dir / f"{img_path.stem}.mask_used.png"
|
||||||
|
mask.save(mask_out)
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
strength=STRENGTH,
|
||||||
|
guidance_scale=GUIDANCE_SCALE,
|
||||||
|
num_inference_steps=NUM_INFERENCE_STEPS,
|
||||||
|
max_side=MAX_SIDE,
|
||||||
|
)
|
||||||
|
if used_backend == InpaintBackend.CONTROLNET:
|
||||||
|
kwargs["controlnet_conditioning_scale"] = CONTROLNET_SCALE
|
||||||
|
|
||||||
|
out = predictor(img, mask, PROMPT, NEGATIVE_PROMPT, **kwargs)
|
||||||
|
out_path = out_dir / f"{img_path.stem}.{used_backend.value}.inpaint.png"
|
||||||
|
if PRESERVE_TRANSPARENCY:
|
||||||
|
alpha = _build_alpha_from_input(img_path, img)
|
||||||
|
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||||
|
if EXPAND_ALPHA_WITH_MASK:
|
||||||
|
alpha = np.maximum(alpha, mask_u8)
|
||||||
|
|
||||||
|
out_rgb = np.asarray(out.convert("RGB"), dtype=np.uint8)
|
||||||
|
out_rgba = np.concatenate([out_rgb, alpha[..., None]], axis=-1)
|
||||||
|
Image.fromarray(out_rgba, mode="RGBA").save(out_path)
|
||||||
|
else:
|
||||||
|
out.save(out_path)
|
||||||
|
|
||||||
|
ratio = float((np.asarray(mask, dtype=np.uint8) > 0).sum()) / float(mask.size[0] * mask.size[1])
|
||||||
|
print(f"[Inpaint] backend={used_backend.value}, saved={out_path}")
|
||||||
|
print(f"[Mask] saved={mask_out}, ratio={ratio:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_draw_test(base_dir: Path, out_dir: Path) -> None:
|
||||||
|
"""
|
||||||
|
绘图测试:
|
||||||
|
- 文生图:DRAW_INPUT_IMAGE=""
|
||||||
|
- 图生图:DRAW_INPUT_IMAGE 指向参考图
|
||||||
|
"""
|
||||||
|
draw_predictor = build_draw_predictor(UnifiedDrawConfig())
|
||||||
|
ref_image: Image.Image | None = None
|
||||||
|
mode = "text2img"
|
||||||
|
|
||||||
|
if DRAW_INPUT_IMAGE:
|
||||||
|
ref_path = _resolve_path(base_dir, DRAW_INPUT_IMAGE)
|
||||||
|
if not ref_path.is_file():
|
||||||
|
raise FileNotFoundError(f"找不到参考图,请修改 DRAW_INPUT_IMAGE: {ref_path}")
|
||||||
|
ref_image = Image.open(ref_path).convert("RGB")
|
||||||
|
mode = "img2img"
|
||||||
|
|
||||||
|
out = draw_predictor(
|
||||||
|
prompt=DRAW_PROMPT,
|
||||||
|
image=ref_image,
|
||||||
|
negative_prompt=DRAW_NEGATIVE_PROMPT,
|
||||||
|
strength=DRAW_STRENGTH,
|
||||||
|
guidance_scale=DRAW_GUIDANCE_SCALE,
|
||||||
|
num_inference_steps=DRAW_STEPS,
|
||||||
|
width=DRAW_WIDTH,
|
||||||
|
height=DRAW_HEIGHT,
|
||||||
|
max_side=DRAW_MAX_SIDE,
|
||||||
|
)
|
||||||
|
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
out_path = out_dir / f"draw_{mode}_{ts}.png"
|
||||||
|
out.save(out_path)
|
||||||
|
print(f"[Draw] mode={mode}, saved={out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
out_dir = base_dir / OUTPUT_DIR
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if TASK_MODE == "inpaint":
|
||||||
|
run_inpaint_test(base_dir, out_dir)
|
||||||
|
return
|
||||||
|
if TASK_MODE == "draw":
|
||||||
|
run_draw_test(base_dir, out_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的 TASK_MODE: {TASK_MODE}(可选: inpaint / draw)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
163
python_server/test_seg.py
Normal file
163
python_server/test_seg.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
from scipy.ndimage import label as nd_label
|
||||||
|
|
||||||
|
from model.Seg.seg_loader import SegBackend, _ensure_sam_on_path, _download_sam_checkpoint_if_needed
|
||||||
|
|
||||||
|
# ================= 配置区 =================
|
||||||
|
INPUT_IMAGE = "/home/dwh/Documents/毕业设计/dwh/数据集/Up the River During Qingming (detail) - Court painters.jpg"
|
||||||
|
INPUT_MASK = "/home/dwh/code/hfut-bishe/python_server/outputs/test_depth/Up the River During Qingming (detail) - Court painters.depth_anything_v2.edge_mask.png"
|
||||||
|
OUTPUT_DIR = "outputs/test_seg_v2"
|
||||||
|
SEG_BACKEND = SegBackend.SAM
|
||||||
|
|
||||||
|
# 目标筛选参数
|
||||||
|
TARGET_MIN_AREA = 1000 # 过滤太小的碎片
|
||||||
|
TARGET_MAX_OBJECTS = 20 # 最多提取多少个物体
|
||||||
|
SAM_MAX_SIDE = 2048 # SAM 推理时的长边限制
|
||||||
|
|
||||||
|
# 视觉效果
|
||||||
|
MASK_ALPHA = 0.4
|
||||||
|
BOUNDARY_COLOR = np.array([0, 255, 0], dtype=np.uint8) # 边界绿色
|
||||||
|
TARGET_FILL_COLOR = np.array([255, 230, 0], dtype=np.uint8)
|
||||||
|
SAVE_OBJECT_PNG = True
|
||||||
|
# =========================================
|
||||||
|
|
||||||
|
def _resize_long_side(arr: np.ndarray, max_side: int, is_mask: bool = False) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
||||||
|
h, w = arr.shape[:2]
|
||||||
|
if max_side <= 0 or max(h, w) <= max_side:
|
||||||
|
return arr, (h, w), (h, w)
|
||||||
|
scale = float(max_side) / float(max(h, w))
|
||||||
|
run_w, run_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
|
||||||
|
resample = Image.NEAREST if is_mask else Image.BICUBIC
|
||||||
|
pil = Image.fromarray(arr)
|
||||||
|
out = pil.resize((run_w, run_h), resample=resample)
|
||||||
|
return np.asarray(out), (h, w), (run_h, run_w)
|
||||||
|
|
||||||
|
def _get_prompts_from_mask(edge_mask: np.ndarray, max_components: int = 20):
|
||||||
|
"""
|
||||||
|
分析边缘 Mask 的连通域,为每个独立的边缘簇提取一个引导点
|
||||||
|
"""
|
||||||
|
# 确保是布尔类型
|
||||||
|
mask_bool = edge_mask > 127
|
||||||
|
# 连通域标记
|
||||||
|
labeled_array, num_features = nd_label(mask_bool)
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
component_info = []
|
||||||
|
for i in range(1, num_features + 1):
|
||||||
|
coords = np.argwhere(labeled_array == i)
|
||||||
|
area = len(coords)
|
||||||
|
if area < 100: continue # 过滤噪声
|
||||||
|
# 取几何中心作为引导点
|
||||||
|
center_y, center_x = np.median(coords, axis=0).astype(int)
|
||||||
|
component_info.append(((center_x, center_y), area))
|
||||||
|
|
||||||
|
# 按面积排序,优先处理大面积线条覆盖的物体
|
||||||
|
component_info.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
for pt, _ in component_info[:max_components]:
|
||||||
|
prompts.append({
|
||||||
|
"point_coords": np.array([pt], dtype=np.float32),
|
||||||
|
"point_labels": np.array([1], dtype=np.int32)
|
||||||
|
})
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
def _save_object_pngs(rgb: np.ndarray, label_map: np.ndarray, targets: list[int], out_dir: Path, stem: str):
|
||||||
|
obj_dir = out_dir / f"{stem}.objects"
|
||||||
|
obj_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for idx, lb in enumerate(targets, start=1):
|
||||||
|
mask = (label_map == lb)
|
||||||
|
ys, xs = np.where(mask)
|
||||||
|
if len(ys) == 0: continue
|
||||||
|
y1, y2, x1, x2 = ys.min(), ys.max(), xs.min(), xs.max()
|
||||||
|
rgb_crop = rgb[y1:y2+1, x1:x2+1]
|
||||||
|
alpha = (mask[y1:y2+1, x1:x2+1].astype(np.uint8) * 255)[..., None]
|
||||||
|
rgba = np.concatenate([rgb_crop, alpha], axis=-1)
|
||||||
|
Image.fromarray(rgba).save(obj_dir / f"{stem}.obj_{idx:02d}.png")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 加载资源
|
||||||
|
img_path, mask_path = Path(INPUT_IMAGE), Path(INPUT_MASK)
|
||||||
|
out_dir = Path(__file__).parent / OUTPUT_DIR
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
rgb_orig = np.asarray(Image.open(img_path).convert("RGB"))
|
||||||
|
mask_orig = np.asarray(Image.open(mask_path).convert("L"))
|
||||||
|
|
||||||
|
# 2. 准备推理尺寸 (缩放以节省显存)
|
||||||
|
rgb_run, orig_hw, run_hw = _resize_long_side(rgb_orig, SAM_MAX_SIDE)
|
||||||
|
mask_run, _, _ = _resize_long_side(mask_orig, SAM_MAX_SIDE, is_mask=True)
|
||||||
|
|
||||||
|
# 3. 初始化 SAM(直接使用 SamPredictor,避免 wrapper 接口不支持 point/bbox prompt)
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_ensure_sam_on_path()
|
||||||
|
sam_root = Path(__file__).resolve().parent / "model" / "Seg" / "segment-anything"
|
||||||
|
ckpt_path = _download_sam_checkpoint_if_needed(sam_root)
|
||||||
|
from segment_anything import sam_model_registry, SamPredictor # type: ignore[import]
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
sam = sam_model_registry["vit_h"](checkpoint=str(ckpt_path)).to(device)
|
||||||
|
predictor = SamPredictor(sam)
|
||||||
|
|
||||||
|
print(f"[SAM] 正在处理图像: {img_path.name},推理尺寸: {run_hw}")
|
||||||
|
|
||||||
|
# 4. 提取引导点
|
||||||
|
prompts = _get_prompts_from_mask(mask_run, max_components=TARGET_MAX_OBJECTS)
|
||||||
|
print(f"[SAM] 从 Mask 提取了 {len(prompts)} 个引导点")
|
||||||
|
|
||||||
|
# 5. 执行推理
|
||||||
|
final_label_map_run = np.zeros(run_hw, dtype=np.int32)
|
||||||
|
predictor.set_image(rgb_run)
|
||||||
|
|
||||||
|
for idx, p in enumerate(prompts, start=1):
|
||||||
|
# multimask_output=True 可以获得更稳定的结果
|
||||||
|
masks, scores, _ = predictor.predict(
|
||||||
|
point_coords=p["point_coords"],
|
||||||
|
point_labels=p["point_labels"],
|
||||||
|
multimask_output=True
|
||||||
|
)
|
||||||
|
# 挑选得分最高的 mask
|
||||||
|
best_mask = masks[np.argmax(scores)]
|
||||||
|
|
||||||
|
# 只有面积足够才保留
|
||||||
|
if np.sum(best_mask) > TARGET_MIN_AREA * (run_hw[0] / orig_hw[0]):
|
||||||
|
# 将新 mask 覆盖到 label_map 上,后续的覆盖前面的
|
||||||
|
final_label_map_run[best_mask > 0] = idx
|
||||||
|
|
||||||
|
# 6. 后处理与映射回原图
|
||||||
|
# 映射回原图尺寸 (Nearest 保证 label ID 不会产生小数)
|
||||||
|
label_map = np.asarray(Image.fromarray(final_label_map_run).resize((orig_hw[1], orig_hw[0]), Image.NEAREST))
|
||||||
|
|
||||||
|
# 7. 导出与可视化
|
||||||
|
unique_labels = [l for l in np.unique(label_map) if l > 0]
|
||||||
|
|
||||||
|
# 绘制可视化图
|
||||||
|
marked_img = rgb_orig.copy()
|
||||||
|
draw = ImageDraw.Draw(Image.fromarray(marked_img)) # 这里只是为了画框方便
|
||||||
|
|
||||||
|
# 混合颜色显示
|
||||||
|
overlay = rgb_orig.astype(np.float32)
|
||||||
|
for lb in unique_labels:
|
||||||
|
m = (label_map == lb)
|
||||||
|
overlay[m] = overlay[m] * (1-MASK_ALPHA) + TARGET_FILL_COLOR * MASK_ALPHA
|
||||||
|
# 简单边缘处理
|
||||||
|
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
cv2.drawContours(marked_img, contours, -1, (0, 255, 0), 2)
|
||||||
|
|
||||||
|
final_vis = Image.fromarray(cv2.addWeighted(marked_img, 0.7, overlay.astype(np.uint8), 0.3, 0))
|
||||||
|
final_vis.save(out_dir / f"{img_path.stem}.sam_guided_result.png")
|
||||||
|
|
||||||
|
if SAVE_OBJECT_PNG:
|
||||||
|
_save_object_pngs(rgb_orig, label_map, unique_labels, out_dir, img_path.stem)
|
||||||
|
|
||||||
|
print(f"[Done] 分割完成。提取了 {len(unique_labels)} 个物体。结果保存在: {out_dir}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user