219 lines
7.2 KiB
C++
219 lines
7.2 KiB
C++
#include "net/ModelServerClient.h"
|
||
|
||
#include <QEventLoop>
|
||
#include <QJsonArray>
|
||
#include <QJsonDocument>
|
||
#include <QJsonObject>
|
||
#include <QNetworkAccessManager>
|
||
#include <QNetworkReply>
|
||
#include <QNetworkRequest>
|
||
#include <QTimer>
|
||
#include <QUrl>
|
||
|
||
namespace core {
|
||
|
||
ModelServerClient::ModelServerClient(QObject* parent)
|
||
: QObject(parent)
|
||
, m_nam(new QNetworkAccessManager(this)) {
|
||
}
|
||
|
||
void ModelServerClient::setBaseUrl(const QUrl& baseUrl) {
|
||
m_baseUrl = baseUrl;
|
||
}
|
||
|
||
QUrl ModelServerClient::baseUrl() const {
|
||
return m_baseUrl;
|
||
}
|
||
|
||
QNetworkReply* ModelServerClient::computeDepthPng8Async(const QByteArray& imageBytes, QString* outImmediateError) {
|
||
if (outImmediateError) {
|
||
outImmediateError->clear();
|
||
}
|
||
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("后端地址无效。");
|
||
return nullptr;
|
||
}
|
||
if (imageBytes.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("输入图像为空。");
|
||
return nullptr;
|
||
}
|
||
|
||
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/depth")));
|
||
QNetworkRequest req(url);
|
||
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||
|
||
const QByteArray imageB64 = imageBytes.toBase64();
|
||
const QJsonObject payload{
|
||
{QStringLiteral("image_b64"), QString::fromLatin1(imageB64)},
|
||
};
|
||
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
|
||
return m_nam->post(req, body);
|
||
}
|
||
|
||
QNetworkReply* ModelServerClient::segmentSamPromptAsync(
|
||
const QByteArray& cropRgbPngBytes,
|
||
const QByteArray& overlayPngBytes,
|
||
const QJsonArray& pointCoords,
|
||
const QJsonArray& pointLabels,
|
||
const QJsonArray& boxXyxy,
|
||
QString* outImmediateError
|
||
) {
|
||
if (outImmediateError) {
|
||
outImmediateError->clear();
|
||
}
|
||
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("后端地址无效。");
|
||
return nullptr;
|
||
}
|
||
if (cropRgbPngBytes.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("裁剪图像为空。");
|
||
return nullptr;
|
||
}
|
||
|
||
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/segment/sam_prompt")));
|
||
QNetworkRequest req(url);
|
||
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||
|
||
QJsonObject payload;
|
||
payload.insert(QStringLiteral("image_b64"), QString::fromLatin1(cropRgbPngBytes.toBase64()));
|
||
if (!overlayPngBytes.isEmpty()) {
|
||
payload.insert(QStringLiteral("overlay_b64"), QString::fromLatin1(overlayPngBytes.toBase64()));
|
||
}
|
||
payload.insert(QStringLiteral("point_coords"), pointCoords);
|
||
payload.insert(QStringLiteral("point_labels"), pointLabels);
|
||
payload.insert(QStringLiteral("box_xyxy"), boxXyxy);
|
||
|
||
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
|
||
return m_nam->post(req, body);
|
||
}
|
||
|
||
QNetworkReply* ModelServerClient::inpaintAsync(
|
||
const QByteArray& cropRgbPngBytes,
|
||
const QByteArray& maskPngBytes,
|
||
const QString& prompt,
|
||
const QString& negativePrompt,
|
||
double strength,
|
||
int maxSide,
|
||
QString* outImmediateError
|
||
) {
|
||
if (outImmediateError) {
|
||
outImmediateError->clear();
|
||
}
|
||
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("后端地址无效。");
|
||
return nullptr;
|
||
}
|
||
if (cropRgbPngBytes.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("裁剪图像为空。");
|
||
return nullptr;
|
||
}
|
||
if (maskPngBytes.isEmpty()) {
|
||
if (outImmediateError) *outImmediateError = QStringLiteral("Mask 为空。");
|
||
return nullptr;
|
||
}
|
||
|
||
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/inpaint")));
|
||
QNetworkRequest req(url);
|
||
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||
|
||
QJsonObject payload;
|
||
payload.insert(QStringLiteral("image_b64"), QString::fromLatin1(cropRgbPngBytes.toBase64()));
|
||
payload.insert(QStringLiteral("mask_b64"), QString::fromLatin1(maskPngBytes.toBase64()));
|
||
payload.insert(QStringLiteral("prompt"), prompt);
|
||
payload.insert(QStringLiteral("negative_prompt"), negativePrompt);
|
||
payload.insert(QStringLiteral("strength"), strength);
|
||
payload.insert(QStringLiteral("max_side"), maxSide);
|
||
|
||
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
|
||
return m_nam->post(req, body);
|
||
}
|
||
|
||
bool ModelServerClient::computeDepthPng8(
|
||
const QByteArray& imageBytes,
|
||
QByteArray& outPngBytes,
|
||
QString& outError,
|
||
int timeoutMs
|
||
) {
|
||
outPngBytes.clear();
|
||
outError.clear();
|
||
|
||
if (!m_baseUrl.isValid() || m_baseUrl.isEmpty()) {
|
||
outError = QStringLiteral("后端地址无效。");
|
||
return false;
|
||
}
|
||
if (imageBytes.isEmpty()) {
|
||
outError = QStringLiteral("输入图像为空。");
|
||
return false;
|
||
}
|
||
|
||
const QUrl url = m_baseUrl.resolved(QUrl(QStringLiteral("/depth")));
|
||
|
||
QNetworkRequest req(url);
|
||
req.setHeader(QNetworkRequest::ContentTypeHeader, QStringLiteral("application/json"));
|
||
|
||
const QByteArray imageB64 = imageBytes.toBase64();
|
||
const QJsonObject payload{
|
||
{QStringLiteral("image_b64"), QString::fromLatin1(imageB64)},
|
||
};
|
||
const QByteArray body = QJsonDocument(payload).toJson(QJsonDocument::Compact);
|
||
|
||
QNetworkReply* reply = m_nam->post(req, body);
|
||
if (!reply) {
|
||
outError = QStringLiteral("创建网络请求失败。");
|
||
return false;
|
||
}
|
||
|
||
QEventLoop loop;
|
||
QTimer timer;
|
||
timer.setSingleShot(true);
|
||
const int t = (timeoutMs <= 0) ? 30000 : timeoutMs;
|
||
|
||
QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
|
||
QObject::connect(&timer, &QTimer::timeout, &loop, &QEventLoop::quit);
|
||
timer.start(t);
|
||
loop.exec();
|
||
|
||
if (timer.isActive() == false && reply->isFinished() == false) {
|
||
reply->abort();
|
||
reply->deleteLater();
|
||
outError = QStringLiteral("请求超时(%1ms)。").arg(t);
|
||
return false;
|
||
}
|
||
|
||
const int httpStatus = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
|
||
const QByteArray raw = reply->readAll();
|
||
const auto netErr = reply->error();
|
||
const QString netErrStr = reply->errorString();
|
||
reply->deleteLater();
|
||
|
||
if (netErr != QNetworkReply::NoError) {
|
||
outError = QStringLiteral("网络错误:%1").arg(netErrStr);
|
||
return false;
|
||
}
|
||
|
||
if (httpStatus != 200) {
|
||
// FastAPI HTTPException 默认返回 {"detail": "..."}
|
||
QString detail;
|
||
const QJsonDocument jd = QJsonDocument::fromJson(raw);
|
||
if (jd.isObject()) {
|
||
const auto obj = jd.object();
|
||
detail = obj.value(QStringLiteral("detail")).toString();
|
||
}
|
||
outError = detail.isEmpty()
|
||
? QStringLiteral("后端返回HTTP %1。").arg(httpStatus)
|
||
: QStringLiteral("后端错误(HTTP %1):%2").arg(httpStatus).arg(detail);
|
||
return false;
|
||
}
|
||
|
||
if (raw.isEmpty()) {
|
||
outError = QStringLiteral("后端返回空数据。");
|
||
return false;
|
||
}
|
||
|
||
outPngBytes = raw;
|
||
return true;
|
||
}
|
||
|
||
} // namespace core
|
||
|