diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 2a34fa7..26facab 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -3,12 +3,18 @@ set(SERVER_BINARY_NAME "${PROJECT_BINARY_NAME}_server") set(SOURCES src/main.cpp src/TerminalApp.cpp src/SocketApp.cpp) if (MSVC) - list(APPEND SOURCES ${CMAKE_SOURCE_DIR}/minimal.rc) + set(BNAME ${SERVER_BINARY_NAME}) + set(COMPONENT_NAME "${PROJECT_DISPLAY_NAME} Server") + configure_file(../platform/msvc/app.rc.in app.rc) + configure_file(../platform/msvc/minimal.rc.in minimal.rc) + list(APPEND SOURCES minimal.rc) + list(APPEND SOURCES ${CMAKE_CURRENT_BINARY_DIR}/app.rc) + list(APPEND SOURCES ${CMAKE_CURRENT_BINARY_DIR}/minimal.rc) endif() add_executable(${SERVER_BINARY_NAME} ${SOURCES} ${CMAKE_SOURCE_DIR}/src/libs/SharedLibrary.cpp ${CMAKE_SOURCE_DIR}/src/libs/SharedMemoryManager.cpp) add_dependencies(${SERVER_BINARY_NAME} sockets_cpp) -#target_precompile_headers(${SERVER_BINARY_NAME} PRIVATE src/pch.h) +target_precompile_headers(${SERVER_BINARY_NAME} PRIVATE src/pch.h) if (MSVC) target_link_options(${SERVER_BINARY_NAME} PRIVATE /STACK:${STACK_SIZE}) @@ -17,6 +23,8 @@ else() set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,stack-size=${STACK_SIZE}") endif() + + target_include_directories(${SERVER_BINARY_NAME} PRIVATE ${CMAKE_SOURCE_DIR}/server/src ${CMAKE_SOURCE_DIR}/src ${sockets_cpp_SOURCE} ${wxWidgets_SOURCE_DIR}/include ${exiv2_INCLUDE_DIR} ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/server/src/ServerConfig.h b/server/src/ServerConfig.h index 590a380..c6ad621 100644 --- a/server/src/ServerConfig.h +++ b/server/src/ServerConfig.h @@ -1,11 +1,6 @@ #ifndef __SERVER_CONFIG_H #define __SERVER_CONFIG_H -#include -#include - -#include "libs/json.hpp" - enum class backend_type { AVX, AVX2, diff --git a/server/src/SocketApp.cpp b/server/src/SocketApp.cpp index 6c876d1..5083869 100644 --- a/server/src/SocketApp.cpp +++ b/server/src/SocketApp.cpp @@ -1,5 +1,4 @@ #include "SocketApp.h" -#include "TerminalApp.h" SocketApp::SocketApp(const char* listenAddr, uint16_t port, TerminalApp* parent) : m_socketOpt({sockets::TX_BUFFER_SIZE, sockets::RX_BUFFER_SIZE, listenAddr}), m_server(*this, &m_socketOpt), parent(parent) { sockets::SocketRet ret = m_server.start(port); @@ -80,8 +79,8 @@ void SocketApp::onClientConnect(const sockets::ClientHandle& client) { this->m_clientInfo[client] = {ipAddr, port, client, wxGetLocalTime()}; } sd_gui_utils::networks::Packet auth_required_packet; - auth_required_packet.type = sd_gui_utils::networks::PacketType::REQUEST; - auth_required_packet.param = sd_gui_utils::networks::PacketParam::AUTH; + auth_required_packet.type = sd_gui_utils::networks::Packet::Type::REQUEST_TYPE; + auth_required_packet.param = sd_gui_utils::networks::Packet::Param::PARAM_AUTH; this->sendMsg(client, auth_required_packet); } } @@ -107,7 +106,7 @@ void SocketApp::OnTimer() { } else { if (it->second.last_keepalive == 0 || (wxGetLocalTime() - it->second.last_keepalive) > this->parent->configData->tcp_keepalive) { it->second.last_keepalive = wxGetLocalTime(); - this->sendMsg(it->second.idx, sd_gui_utils::networks::Packet(sd_gui_utils::networks::PacketType::REQUEST, sd_gui_utils::networks::PacketParam::KEEPALIVE)); + this->sendMsg(it->second.idx, sd_gui_utils::networks::Packet(sd_gui_utils::networks::Packet::Type::REQUEST_TYPE, sd_gui_utils::networks::Packet::Param::PARAM_KEEPALIVE)); } ++it; } @@ -118,19 +117,19 @@ void SocketApp::parseMsg(sd_gui_utils::networks::Packet& packet) { // parse from cbor to struct Packet try { if (packet.version != SD_GUI_VERSION) { - auto errorPacket = sd_gui_utils::networks::Packet(sd_gui_utils::networks::PacketType::RESPONSE, sd_gui_utils::networks::PacketParam::ERROR); + auto errorPacket = sd_gui_utils::networks::Packet(sd_gui_utils::networks::Packet::Type::RESPONSE_TYPE, sd_gui_utils::networks::Packet::Param::PARAM_ERROR); errorPacket.SetData("Version Mismatch, server: " + std::string(SD_GUI_VERSION) + ", client: " + packet.version); this->sendMsg(packet.source_idx, errorPacket); this->parent->sendLogEvent(wxString::Format("Version Mismatch, server: %s, client: %s", SD_GUI_VERSION, packet.version), wxLOG_Error); this->DisconnectClient(packet.source_idx); return; } - if (packet.type == sd_gui_utils::networks::PacketType::REQUEST) { + if (packet.type == sd_gui_utils::networks::Packet::Type::REQUEST_TYPE) { this->parent->ProcessReceivedSocketPackages(packet); } // handle the response to the auth request - if (packet.type == sd_gui_utils::networks::PacketType::RESPONSE) { - if (packet.param == sd_gui_utils::networks::PacketParam::AUTH) { + if (packet.type == sd_gui_utils::networks::Packet::Type::RESPONSE_TYPE) { + if (packet.param == sd_gui_utils::networks::Packet::Param::PARAM_AUTH) { { std::lock_guard guard(m_mutex); std::string packetData = packet.GetData(); @@ -140,7 +139,7 @@ void SocketApp::parseMsg(sd_gui_utils::networks::Packet& packet) { this->m_clientInfo[packet.source_idx].apikey = packetData; this->parent->sendLogEvent("Client authenticated: " + std::to_string(packet.source_idx), wxLOG_Info); } else { - auto errorPacket = sd_gui_utils::networks::Packet(sd_gui_utils::networks::PacketType::RESPONSE, sd_gui_utils::networks::PacketParam::ERROR); + auto errorPacket = sd_gui_utils::networks::Packet(sd_gui_utils::networks::Packet::Type::RESPONSE_TYPE, sd_gui_utils::networks::Packet::Param::PARAM_ERROR); errorPacket.SetData("Authentication Failed"); this->sendMsg(packet.source_idx, errorPacket); this->parent->sendLogEvent("Authentication Failed, ip: " + this->m_clientInfo[packet.source_idx].host + " port: " + std::to_string(this->m_clientInfo[packet.source_idx].port) + " key: " + this->m_clientInfo[packet.source_idx].apikey + "", wxLOG_Error); diff --git a/server/src/SocketApp.h b/server/src/SocketApp.h index f30bbc4..e1964f6 100644 --- a/server/src/SocketApp.h +++ b/server/src/SocketApp.h @@ -1,14 +1,5 @@ #ifndef _SERVER_SOCKETAPP_H #define _SERVER_SOCKETAPP_H -#include - -#include -#include -#include -#include "libs/json.hpp" - -#include "network/packets.h" -#include "sockets-cpp/TcpServer.h" diff --git a/server/src/TerminalApp.cpp b/server/src/TerminalApp.cpp index 1a07102..c424ad9 100644 --- a/server/src/TerminalApp.cpp +++ b/server/src/TerminalApp.cpp @@ -355,7 +355,7 @@ bool TerminalApp::ProcessEventHandler(std::string message) { } try { nlohmann::json msg = nlohmann::json::parse(message); - sd_gui_utils::networks::Packet packet(sd_gui_utils::networks::PacketType::REQUEST, sd_gui_utils::networks::PacketParam::ERROR); + sd_gui_utils::networks::Packet packet(sd_gui_utils::networks::Packet::Type::REQUEST_TYPE, sd_gui_utils::networks::Packet::Param::PARAM_ERROR); packet.SetData(message); this->socket->sendMsg(0, packet); @@ -371,8 +371,8 @@ void TerminalApp::ProcessReceivedSocketPackages(const sd_gui_utils::networks::Pa this->sendLogEvent("Invalid source index", wxLOG_Error); return; } - if (packet.param == sd_gui_utils::networks::PacketParam::MODEL_LIST) { - auto response = sd_gui_utils::networks::Packet(sd_gui_utils::networks::PacketType::RESPONSE, sd_gui_utils::networks::PacketParam::MODEL_LIST); + if (packet.param == sd_gui_utils::networks::Packet::Param::PARAM_MODEL_LIST) { + auto response = sd_gui_utils::networks::Packet(sd_gui_utils::networks::Packet::Type::RESPONSE_TYPE, sd_gui_utils::networks::Packet::Param::PARAM_MODEL_LIST); std::vector list; for (auto model : this->modelFiles) { // change the model's path diff --git a/server/src/TerminalApp.h b/server/src/TerminalApp.h index bfbb3c7..e814e79 100644 --- a/server/src/TerminalApp.h +++ b/server/src/TerminalApp.h @@ -1,27 +1,5 @@ #ifndef _SERVER_TERMINALAPP_H #define _SERVER_TERMINALAPP_H - -#include "libs/SharedLibrary.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include "libs/json.hpp" -#include "network/RemoteModelInfo.h" - -#include "ServerConfig.h" -#include "SocketApp.h" -#include "helpers/sslUtils.hpp" -#include "libs/ExternalProcess.h" - -#include "EventQueue.h" - wxDECLARE_APP(TerminalApp); class TerminalApp : public wxAppConsole { diff --git a/server/src/pch.h b/server/src/pch.h new file mode 100644 index 0000000..2e0ab54 --- /dev/null +++ b/server/src/pch.h @@ -0,0 +1,40 @@ +#include +#include +#include +#include +#include + +#include "libs/SharedLibrary.h" + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sockets-cpp/TcpServer.h" +#include "EventQueue.h" +#include "libs/SharedLibrary.h" +#include "libs/SharedMemoryManager.h" +#include "ver.hpp" +#include "extprocess/config.hpp" +#include "libs/json.hpp" +#include "helpers/sslUtils.hpp" +#include "network/packets.h" +#include "network/RemoteModelInfo.h" + +#include "ServerConfig.h" +#include "SocketApp.h" +#include "TerminalApp.h" +#include "EventQueue.h" + +#include "ServerConfig.h" + + +#include "libs/json.hpp" + diff --git a/src/libs/ExternalProcess.cpp b/src/libs/ExternalProcess.cpp deleted file mode 100644 index 8fa9ec4..0000000 --- a/src/libs/ExternalProcess.cpp +++ /dev/null @@ -1,304 +0,0 @@ -#include "ExternalProcess.h" -#include -#include -#include -#include -#include -#include -#include "wx/event.h" - -#if defined(_WIN32) || defined(_WIN64) -#include -#else -#include -#include -#include -#include -#include -#endif - -std::unordered_map ExternalProcess::instances; -std::mutex ExternalProcess::instanceMutex; - -ExternalProcess::ExternalProcess(const std::string& command, const std::string& arguments, bool autoRestart) - : command(command), arguments(arguments), autoRestart(autoRestart), running(false), restartRequested(false) { - this->command = std::filesystem::canonical(this->command).string(); - this->sharedMemory = std::make_shared(SHARED_MEMORY_PATH, SHARED_MEMORY_SIZE, true); -} - -ExternalProcess::~ExternalProcess() { - stop(); -} - -void ExternalProcess::start(onMessageCallBack callback) { - if (running) { - return; - } - this->onMessage = callback; -#if defined(_WIN32) || defined(_WIN64) - SECURITY_ATTRIBUTES sa; - sa.nLength = sizeof(SECURITY_ATTRIBUTES); - sa.bInheritHandle = TRUE; - sa.lpSecurityDescriptor = NULL; - // Create pipes for stdout and stderr - if (this->onStdOut) { - HANDLE stdoutRead, stdoutWrite - CreatePipe(&stdoutRead, &stdoutWrite, &sa, 0); - SetHandleInformation(stdoutRead, HANDLE_FLAG_INHERIT, 0); - } - - if (this->onStdErr) { - HANDLE stderrRead, stderrWrite; - CreatePipe(&stderrRead, &stderrWrite, &sa, 0); - SetHandleInformation(stderrRead, HANDLE_FLAG_INHERIT, 0); - } - - STARTUPINFOA si; - ZeroMemory(&si, sizeof(si)); - si.cb = sizeof(si); - si.dwFlags |= STARTF_USESTDHANDLES; - - if (this->onStdOut) { - si.hStdOutput = stdoutWrite; - } - if (this->onStdErr) { - si.hStdError = stderrWrite; - } - - ZeroMemory(&processInfo, sizeof(processInfo)); - - if (CreateProcessA(NULL, const_cast(command.c_str()), NULL, NULL, FALSE, 0, NULL, NULL, &si, &processInfo)) { - running = true; - // Close unused ends of the pipes - - if (this->onStdOut) { - CloseHandle(stdoutWrite); - stdoutThread = std::thread(&ExternalProcess::readPipeOutput, this, stdoutRead, this->onStdOut); - } - if (this->onStdErr) { - CloseHandle(stderrWrite); - stderrThread = std::thread(&ExternalProcess::readPipeOutput, this, stderrRead, this->onStdErr); - } - processWatchThread = std::thread([this]() { - WaitForSingleObject(processInfo.hProcess, INFINITE); - DWORD exitCode; - GetExitCodeProcess(processInfo.hProcess, &exitCode); - - if (onExit) { - autoRestart = onExit(false); - } - - if (autoRestart) { - restartRequested = true; // Újraindítás kérés beállítása - } else { - running = false; - } - - CloseHandle(processInfo.hProcess); - CloseHandle(processInfo.hThread); - }); - - if (onStarted) { - onStarted(); - } - outputThread = std::thread(&ExternalProcess::processOutput, this, this->onMessage); - } else { - throw std::runtime_error("Failed to start process: " + std::to_string(GetLastError())); - } -#else - signal(SIGCHLD, sigchldHandler); - int stdoutPipe[2], stderrPipe[2]; - if (pipe(stdoutPipe) == -1 || pipe(stderrPipe) == -1) { - throw std::runtime_error("Failed to create pipes for stdout or stderr."); - } - - pid_t pid = fork(); - if (pid == -1) { - throw std::runtime_error("Failed to fork: " + std::string(strerror(errno))); - } else if (pid == 0) { // Child process - close(stdoutPipe[0]); - close(stderrPipe[0]); - - if (this->onStdOut) { - dup2(stdoutPipe[1], STDOUT_FILENO); - } - - if (this->onStdErr) { - dup2(stderrPipe[1], STDERR_FILENO); - } - - execl(command.c_str(), command.c_str(), arguments.c_str(), (char*)0); - _exit(1); // Exit if execl fails - } else { // Parent process - processId = pid; - running = true; - - if (this->onStdOut) { - close(stdoutPipe[1]); - stdoutThread = std::thread(&ExternalProcess::readPipeOutput, this, stdoutPipe[0], this->onStdOut); - } - if (this->onStdErr) { - close(stderrPipe[1]); - stderrThread = std::thread(&ExternalProcess::readPipeOutput, this, stderrPipe[0], this->onStdErr); - } - - if (onStarted) { - onStarted(); - } - outputThread = std::thread(&ExternalProcess::processOutput, this, this->onMessage); - std::lock_guard lock(instanceMutex); - instances[processId] = this; - } -#endif -} - -void ExternalProcess::stop() { - if (!running) - return; - -#if defined(_WIN32) || defined(_WIN64) - TerminateProcess(processInfo.hProcess, 0); - CloseHandle(processInfo.hProcess); - CloseHandle(processInfo.hThread); -#else - kill(processId, SIGTERM); - waitpid(processId, nullptr, 0); // Wait for child process to terminate - { - std::lock_guard lock(instanceMutex); - instances.erase(processId); - } -#endif - - running = false; - if (stdoutThread.joinable()) { - stdoutThread.join(); - } - - if (stderrThread.joinable()) { - stderrThread.join(); - } - if (outputThread.joinable()) { - outputThread.join(); - } - if (onExit) { - onExit(true); - } -} - -void ExternalProcess::readPipeOutput(int pipeFd, onStdCallBack callback) { - char buffer[EPROCESS_STD_BUFFER]; - while (running) { -#if defined(_WIN32) || defined(_WIN64) - DWORD bytesRead; - if (ReadFile((HANDLE)pipeFd, buffer, sizeof(buffer) - 1, &bytesRead, NULL) && bytesRead > 0) { - buffer[bytesRead] = '\0'; - if (callback) - callback(std::string(buffer)); - } -#else - ssize_t bytesRead = read(pipeFd, buffer, sizeof(buffer) - 1); - if (bytesRead > 0) { - buffer[bytesRead] = '\0'; - if (callback) - callback(std::string(buffer)); - } -#endif - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } -} - -void ExternalProcess::send(std::string message) { - sharedMemory->write(message.c_str(), message.size()); -} - -bool ExternalProcess::isRunning() const { - return running; -} - -void ExternalProcess::processOutput(onMessageCallBack callback) { - while (running) { -#if defined(_WIN32) || defined(_WIN64) - DWORD exitCode; - GetExitCodeProcess(processInfo.hProcess, &exitCode); - if (exitCode != STILL_ACTIVE) { - running = false; - if (onExit) { - autoRestart = onExit(false); - } - if (autoRestart) { - restartRequested = true; // Set flag to request restart - } else { - running = false; - } - return; - } -#else - int status; - if (waitpid(processId, &status, WNOHANG) > 0) { - if (WIFEXITED(status) || WIFSIGNALED(status)) { - running = false; - if (onExit) { - autoRestart = onExit(false); - } - if (autoRestart) { - restartRequested = true; // Set flag to request restart - } else { - running = false; - } - return; - } - } -#endif - - char buffer[SHARED_MEMORY_SIZE]; - if (sharedMemory->read(buffer, sizeof(buffer)) && callback) { - if (std::strlen(buffer) > 0) { - std::string message = std::string(buffer, SHARED_MEMORY_SIZE); - if (callback(message) == true) { - sharedMemory->clear(); - } - } - } - - std::this_thread::sleep_for(std::chrono::milliseconds(EPROCESS_SLEEP_TIME)); - } -} - -/** - * @brief Restart the process if the flag is set. - * - * This function restarts the process if the restart flag is set. - * The flag is set when the process exits and the onExit callback - * returns true. - * - * @return true if the process was restarted, false otherwise. - */ -bool ExternalProcess::restartIfNeeded() { - if (restartRequested) { - restartRequested = false; - restartProcess(this->onMessage); - return true; - } - return false; -} - -void ExternalProcess::restartProcess(onMessageCallBack callback) { - stop(); - start(callback); -} - -void ExternalProcess::setOnExitCallback(onExitCallBack callback) { - onExit = callback; -} - -void ExternalProcess::setOnStartedCallback(onStartedCallBack callback) { - onStarted = callback; -} - -void ExternalProcess::setOnStdErrCallback(onStdCallBack callback) { - onStdErr = callback; -} - -void ExternalProcess::setOnStdOutCallback(onStdCallBack callback) { - onStdOut = callback; -} diff --git a/src/libs/ExternalProcess.h b/src/libs/ExternalProcess.h deleted file mode 100644 index 62eab8d..0000000 --- a/src/libs/ExternalProcess.h +++ /dev/null @@ -1,94 +0,0 @@ -#ifndef EXTERNAL_PROCESS_H -#define EXTERNAL_PROCESS_H - -#include -#include -#include -#include -#include -#include -#include "SharedMemoryManager.h" - -#if defined(_WIN32) || defined(_WIN64) -#include -#else -#include -#endif -#include "extprocess/config.hpp" - -typedef std::function onMessageCallBack; -typedef std::function onExitCallBack; -typedef std::function onStartedCallBack; -typedef std::function onStdCallBack; - -class ExternalProcess { -public: - // Constructor to initialize with command, arguments, and auto-restart option - ExternalProcess(const std::string& command, const std::string& arguments, bool autoRestart = false); - ~ExternalProcess(); - - void start(onMessageCallBack callback = nullptr); - void setOnExitCallback(onExitCallBack callback); - void setOnStartedCallback(onStartedCallBack callback); - void setOnStdOutCallback(onStdCallBack callback); - void setOnStdErrCallback(onStdCallBack callback); - void stop(); - void send(std::string data); - void clear(); - std::string receive(); - bool isRunning() const; - inline std::string getArguments() const { return arguments; } - bool restartIfNeeded(); - -private: - std::atomic restartRequested; - std::string command; - std::string arguments; - bool autoRestart = false; - std::atomic running; - std::thread outputThread; - std::thread stdoutThread; - std::thread stderrThread; - void processOutput(onMessageCallBack callback); - void restartProcess(onMessageCallBack callback = nullptr); - void readPipeOutput(int pipeFd, onStdCallBack callback); - - onExitCallBack onExit = nullptr; - onStartedCallBack onStarted = nullptr; - onMessageCallBack onMessage = nullptr; - onStdCallBack onStdOut = nullptr; - onStdCallBack onStdErr = nullptr; - - std::shared_ptr sharedMemory; - - static std::unordered_map instances; - static std::mutex instanceMutex; - -#if defined(_WIN32) || defined(_WIN64) - PROCESS_INFORMATION processInfo; - HANDLE sharedMemoryHandle; -#else - pid_t processId; - int shmFd; -#endif - - inline static void sigchldHandler(int /*signum*/) { - int status; - pid_t pid; - - while ((pid = waitpid(-1, &status, WNOHANG)) > 0) { - std::lock_guard lock(instanceMutex); - auto it = instances.find(pid); - - if (it != instances.end()) { - ExternalProcess* instance = it->second; - if (instance && instance->onExit) { - instance->onExit(false); - } - instances.erase(it); - } - } - } -}; - -#endif // EXTERNAL_PROCESS_H diff --git a/src/libs/TcpClient.cpp b/src/libs/TcpClient.cpp index 80307f8..0cd7162 100644 --- a/src/libs/TcpClient.cpp +++ b/src/libs/TcpClient.cpp @@ -1,109 +1,111 @@ #include "TcpClient.h" -namespace sd_gui_utils::networks { - void TcpClient::sendMsg(const char* data, size_t len) { - auto ret = m_client.sendMsg(data, len); - if (!ret.m_success) { - std::cout << "Send Error: " << ret.m_msg << "\n"; - if (this->onErrorClb != nullptr) { - this->onErrorClb("Send Error: " + ret.m_msg); +namespace sd_gui_utils { + inline namespace networks { + void TcpClient::sendMsg(const char* data, size_t len) { + auto ret = m_client.sendMsg(data, len); + if (!ret.m_success) { + std::cout << "Send Error: " << ret.m_msg << "\n"; + if (this->onErrorClb != nullptr) { + this->onErrorClb("Send Error: " + ret.m_msg); + } } } - } - void TcpClient::onReceiveData(const char* data, size_t size) { - if (this->expected_size == 0 && size >= sizeof(size_t)) { - memcpy(&this->expected_size, data, sizeof(this->expected_size)); + void TcpClient::onReceiveData(const char* data, size_t size) { + if (this->expected_size == 0 && size >= sizeof(size_t)) { + memcpy(&this->expected_size, data, sizeof(this->expected_size)); - size -= sizeof(this->expected_size); - data += sizeof(this->expected_size); - if (size > 0) { + size -= sizeof(this->expected_size); + data += sizeof(this->expected_size); + if (size > 0) { + this->buffer.insert(this->buffer.end(), data, data + size); + } + std::cout << "Expected size: " << this->expected_size << "\n"; + } else if (this->expected_size > 0) { this->buffer.insert(this->buffer.end(), data, data + size); } - std::cout << "Expected size: " << this->expected_size << "\n"; - } else if (this->expected_size > 0) { - this->buffer.insert(this->buffer.end(), data, data + size); - } - if (this->buffer.size() == this->expected_size && this->expected_size > 0) { - auto packet = sd_gui_utils::networks::Packet::DeSerialize(this->buffer.data(), this->buffer.size()); + if (this->buffer.size() == this->expected_size && this->expected_size > 0) { + auto packet = sd_gui_utils::networks::Packet::DeSerialize(this->buffer.data(), this->buffer.size()); - this->HandlePackets(packet); - this->buffer.clear(); - this->expected_size = 0; + this->HandlePackets(packet); + this->buffer.clear(); + this->expected_size = 0; + } } - } - void TcpClient::HandlePackets(Packet& packet) { - if (packet.type == sd_gui_utils::networks::PacketType::RESPONSE) { - if (packet.param == sd_gui_utils::networks::PacketParam::ERROR) { - if (this->onErrorClb != nullptr) { - this->disconnect_reason = packet.GetData(); - this->onErrorClb(this->disconnect_reason); + void TcpClient::HandlePackets(Packet& packet) { + if (packet.type == sd_gui_utils::networks::Packet::Type::RESPONSE_TYPE) { + if (packet.param == sd_gui_utils::networks::Packet::Param::PARAM_ERROR) { + if (this->onErrorClb != nullptr) { + this->disconnect_reason = packet.GetData(); + this->onErrorClb(this->disconnect_reason); + } + // this->server->disconnect_reason = packet.GetData(); + // this->server->needToRun = false; + // this->server->enabled = false; + // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_ERROR, this->server); } - // this->server->disconnect_reason = packet.GetData(); - // this->server->needToRun = false; - // this->server->enabled = false; - // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_ERROR, this->server); - } - if (packet.param == sd_gui_utils::networks::PacketParam::MODEL_LIST) { - size_t packet_id = this->receivedPackets.size(); - this->receivedPackets[packet_id] = packet; - // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_MODEL_LIST_UPDATE, server, packet_id); - if (this->onMessageClb != nullptr) { - this->onMessageClb(packet_id); + if (packet.param == sd_gui_utils::networks::Packet::Param::PARAM_MODEL_LIST) { + size_t packet_id = this->receivedPackets.size(); + this->receivedPackets[packet_id] = packet; + // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_MODEL_LIST_UPDATE, server, packet_id); + if (this->onMessageClb != nullptr) { + this->onMessageClb(packet_id); + } } } - } - if (packet.type == sd_gui_utils::networks::PacketType::REQUEST) { - if (packet.param == sd_gui_utils::networks::PacketParam::AUTH) { - // get the auth token from secret store - wxSecretStore store = wxSecretStore::GetDefault(); - wxString serviceName = wxString::Format(wxT("%s/%s_%d"), PROJECT_NAME, wxString::FromUTF8Unchecked(this->host), this->port); - wxString username; - wxSecretValue authkey; + if (packet.type == sd_gui_utils::networks::Packet::Type::REQUEST_TYPE) { + if (packet.param == sd_gui_utils::networks::Packet::Param::PARAM_AUTH) { + // get the auth token from secret store + wxSecretStore store = wxSecretStore::GetDefault(); + wxString serviceName = wxString::Format(wxT("%s/%s_%d"), PROJECT_NAME, wxString::FromUTF8Unchecked(this->host), this->port); + wxString username; + wxSecretValue authkey; - if (store.IsOk() && store.Load(serviceName, username, authkey)) { - auto responsePacket = sd_gui_utils::networks::Packet(); - responsePacket.type = sd_gui_utils::networks::PacketType::RESPONSE; - responsePacket.param = sd_gui_utils::networks::PacketParam::AUTH; - responsePacket.SetData(authkey.GetAsString().utf8_string()); + if (store.IsOk() && store.Load(serviceName, username, authkey)) { + auto responsePacket = sd_gui_utils::networks::Packet(); + responsePacket.type = sd_gui_utils::networks::Packet::Type::RESPONSE_TYPE; + responsePacket.param = sd_gui_utils::networks::Packet::Param::PARAM_AUTH; + responsePacket.SetData(authkey.GetAsString().utf8_string()); - this->sendMsg(responsePacket); - // send model list request - auto modelListPacket = sd_gui_utils::networks::Packet(); - modelListPacket.type = sd_gui_utils::networks::PacketType::REQUEST; - modelListPacket.param = sd_gui_utils::networks::PacketParam::MODEL_LIST; - this->sendMsg(modelListPacket); - } else { - this->disconnect_reason = _("Authentication Failed"); - // this->server->needToRun = false; - // this->server->enabled = false; - // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server); - if (this->onErrorClb != nullptr) { - this->disconnect_reason = packet.GetData(); - this->onErrorClb(this->disconnect_reason); + this->sendMsg(responsePacket); + // send model list request + auto modelListPacket = sd_gui_utils::networks::Packet(); + modelListPacket.type = sd_gui_utils::networks::Packet::Type::REQUEST_TYPE; + modelListPacket.param = sd_gui_utils::networks::Packet::Param::PARAM_MODEL_LIST; + this->sendMsg(modelListPacket); + } else { + this->disconnect_reason = _("Authentication Failed"); + // this->server->needToRun = false; + // this->server->enabled = false; + // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server); + if (this->onErrorClb != nullptr) { + this->disconnect_reason = packet.GetData(); + this->onErrorClb(this->disconnect_reason); + } } } } } - } - void TcpClient::onDisconnect(const sockets::SocketRet& ret) { - this->connected.store(false); - if (ret.m_msg.length() > 0) { - this->disconnect_reason = ret.m_msg; - } - // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server); - if (this->onDisconnectClb != nullptr) { - this->onDisconnectClb(ret.m_msg); + void TcpClient::onDisconnect(const sockets::SocketRet& ret) { + this->connected.store(false); + if (ret.m_msg.length() > 0) { + this->disconnect_reason = ret.m_msg; + } + // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server); + if (this->onDisconnectClb != nullptr) { + this->onDisconnectClb(ret.m_msg); + } + this->stop(); } - this->stop(); - } - void TcpClient::sendMsg(const sd_gui_utils::networks::Packet& packet) { - auto raw_packet = sd_gui_utils::networks::Packet::Serialize(packet); - size_t packet_size = raw_packet.second; - this->sendMsg(reinterpret_cast(&packet_size), sizeof(packet_size)); - this->sendMsg(raw_packet.first, raw_packet.second); - delete[] raw_packet.first; + void TcpClient::sendMsg(const sd_gui_utils::networks::Packet& packet) { + auto raw_packet = sd_gui_utils::networks::Packet::Serialize(packet); + size_t packet_size = raw_packet.second; + this->sendMsg(reinterpret_cast(&packet_size), sizeof(packet_size)); + this->sendMsg(raw_packet.first, raw_packet.second); + delete[] raw_packet.first; + } } -} // namespace sd_gui_utils::networks \ No newline at end of file +} \ No newline at end of file diff --git a/src/libs/TcpClient.h b/src/libs/TcpClient.h index 52b1f10..70f1419 100644 --- a/src/libs/TcpClient.h +++ b/src/libs/TcpClient.h @@ -3,73 +3,75 @@ #include "../network/packets.h" -namespace sd_gui_utils::networks { - using TcpClientOnMessage = std::function; - using TcpClientOnConnect = std::function; - using TcpClientOntDisconnect = std::function; - using TcpClientOnError = std::function; +namespace sd_gui_utils { + inline namespace networks { + using TcpClientOnMessage = std::function; + using TcpClientOnConnect = std::function; + using TcpClientOntDisconnect = std::function; + using TcpClientOnError = std::function; - class TcpClient { - sockets::TcpClient m_client; - void sendMsg(const sd_gui_utils::networks::Packet& packet); - std::unordered_map receivedPackets; - std::vector buffer; - size_t expected_size = 0; - void HandlePackets(sd_gui_utils::networks::Packet& packet); - std::atomic connected{false}; - std::string host, disconnect_reason; - int port; + class TcpClient { + sockets::TcpClient m_client; + void sendMsg(const sd_gui_utils::networks::Packet& packet); + std::unordered_map receivedPackets; + std::vector buffer; + size_t expected_size = 0; + void HandlePackets(sd_gui_utils::networks::Packet& packet); + std::atomic connected{false}; + std::string host, disconnect_reason; + int port; - public: - TcpClientOnMessage onMessageClb = nullptr; - TcpClientOnConnect onConnectClb = nullptr; - TcpClientOntDisconnect onDisconnectClb = nullptr; - TcpClientOnError onErrorClb = nullptr; - TcpClient() - : m_client(*this) {}; + public: + TcpClientOnMessage onMessageClb = nullptr; + TcpClientOnConnect onConnectClb = nullptr; + TcpClientOntDisconnect onDisconnectClb = nullptr; + TcpClientOnError onErrorClb = nullptr; + TcpClient() + : m_client(*this) {}; - TcpClient(TcpClient&& other) = delete; - TcpClient& operator=(TcpClient&& other) = delete; - TcpClient(const TcpClient&) = delete; - TcpClient& operator=(const TcpClient&) = delete; - ~TcpClient() = default; - void Connect(const std::string& host, int port) { - this->host = host; - this->port = port; - sockets::SocketRet ret = m_client.connectTo(this->host.c_str(), this->port); - this->connected.store(ret.m_success); - if (ret.m_success) { - if (this->onConnectClb != nullptr) { - this->onConnectClb(); + TcpClient(TcpClient&& other) = delete; + TcpClient& operator=(TcpClient&& other) = delete; + TcpClient(const TcpClient&) = delete; + TcpClient& operator=(const TcpClient&) = delete; + ~TcpClient() = default; + void Connect(const std::string& host, int port) { + this->host = host; + this->port = port; + sockets::SocketRet ret = m_client.connectTo(this->host.c_str(), this->port); + this->connected.store(ret.m_success); + if (ret.m_success) { + if (this->onConnectClb != nullptr) { + this->onConnectClb(); + } + // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_CONNECTED, this->server); + } else { + // this->server->disconnect_reason = ret.m_msg; + // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server); + // this->m_client.finish(); + this->disconnect_reason = ret.m_msg; + if (this->onDisconnectClb != nullptr) { + this->onDisconnectClb(this->disconnect_reason); + } + this->stop(); } - // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_CONNECTED, this->server); - } else { - // this->server->disconnect_reason = ret.m_msg; - // this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server); - // this->m_client.finish(); - this->disconnect_reason = ret.m_msg; - if (this->onDisconnectClb != nullptr) { - this->onDisconnectClb(this->disconnect_reason); - } - this->stop(); } - } - void onReceiveData(const char* data, size_t size); - void onDisconnect(const sockets::SocketRet& ret); - void sendMsg(const char* data, size_t len); - bool IsConnected() { return this->connected.load(); } - std::string GetDisconnectReason() { return this->disconnect_reason; } - inline sd_gui_utils::networks::Packet getPacket(size_t Id) { - if (receivedPackets.find(Id) != receivedPackets.end()) { - return receivedPackets[Id]; + void onReceiveData(const char* data, size_t size); + void onDisconnect(const sockets::SocketRet& ret); + void sendMsg(const char* data, size_t len); + bool IsConnected() { return this->connected.load(); } + std::string GetDisconnectReason() { return this->disconnect_reason; } + inline sd_gui_utils::networks::Packet getPacket(size_t Id) { + if (receivedPackets.find(Id) != receivedPackets.end()) { + return receivedPackets[Id]; + } + return sd_gui_utils::networks::Packet(); } - return sd_gui_utils::networks::Packet(); - } - inline void stop() { - this->m_client.finish(); - this->connected.store(false); - } - }; + inline void stop() { + this->m_client.finish(); + this->connected.store(false); + } + }; + } } #endif // _LIBS_TCPCLIENT_H_ \ No newline at end of file diff --git a/src/network/RemoteModelInfo.h b/src/network/RemoteModelInfo.h index af1e2b5..c8f8ae6 100644 --- a/src/network/RemoteModelInfo.h +++ b/src/network/RemoteModelInfo.h @@ -5,36 +5,111 @@ #include "helpers/DirTypes.h" #endif -namespace sd_gui_utils::networks { +namespace sd_gui_utils { + inline namespace networks { - class RemoteModelInfo { - public: - int server_id = -1; // deleted server id - std::string name; - std::string path; - std::string root_path; - std::string sha256; - size_t size; - std::string size_f; - size_t hash_progress_size; - size_t hash_fullsize; - sd_gui_utils::DirTypes model_type = sd_gui_utils::DirTypes::UNKNOWN; + class RemoteModelInfo { + public: + int server_id = -1; // deleted server id + std::string name; + std::string path; + std::string root_path; + std::string sha256; + size_t size; + std::string size_f; + size_t hash_progress_size; + size_t hash_fullsize; + sd_gui_utils::DirTypes model_type = sd_gui_utils::DirTypes::UNKNOWN; - RemoteModelInfo() = default; - RemoteModelInfo(const wxFileName &path, sd_gui_utils::DirTypes type, const wxString &root_path) { - this->model_type = type; - this->name = path.GetFullName(); - this->size = path.GetSize().GetValue(); - this->size_f = path.GetHumanReadableSize().utf8_string(); - this->hash_fullsize = 0; - this->hash_progress_size = 0; - this->path = path.GetAbsolutePath().utf8_string(); - this->root_path = root_path.utf8_string(); - } - ~RemoteModelInfo() {} - NLOHMANN_DEFINE_TYPE_INTRUSIVE(RemoteModelInfo, server_id, name, path, root_path, sha256, size, size_f, hash_progress_size, hash_fullsize, model_type) - }; + RemoteModelInfo() = default; + RemoteModelInfo(const wxFileName& path, sd_gui_utils::DirTypes type, const wxString& root_path) { + this->model_type = type; + this->name = path.GetFullName(); + this->size = path.GetSize().GetValue(); + this->size_f = path.GetHumanReadableSize().utf8_string(); + this->hash_fullsize = 0; + this->hash_progress_size = 0; + this->path = path.GetAbsolutePath().utf8_string(); + this->root_path = root_path.utf8_string(); + } + ~RemoteModelInfo() {} + friend void to_json(nlohmann ::json& nlohmann_json_j, const RemoteModelInfo& nlohmann_json_t) { + nlohmann_json_j["server_id"] = nlohmann_json_t.server_id; + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["path"] = nlohmann_json_t.path; + nlohmann_json_j["root_path"] = nlohmann_json_t.root_path; + nlohmann_json_j["sha256"] = nlohmann_json_t.sha256; + nlohmann_json_j["size"] = nlohmann_json_t.size; + nlohmann_json_j["size_f"] = nlohmann_json_t.size_f; + nlohmann_json_j["hash_progress_size"] = nlohmann_json_t.hash_progress_size; + nlohmann_json_j["hash_fullsize"] = nlohmann_json_t.hash_fullsize; + nlohmann_json_j["model_type"] = nlohmann_json_t.model_type; + } + friend void from_json(const nlohmann ::json& nlohmann_json_j, RemoteModelInfo& nlohmann_json_t) { + { + auto iter = nlohmann_json_j.find("server_id"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.server_id); + } + { + auto iter = nlohmann_json_j.find("name"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.name); + } + { + auto iter = nlohmann_json_j.find("path"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.path); + } + { + auto iter = nlohmann_json_j.find("root_path"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.root_path); + } + { + auto iter = nlohmann_json_j.find("sha256"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.sha256); + } + { + auto iter = nlohmann_json_j.find("size"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.size); + } + { + auto iter = nlohmann_json_j.find("size_f"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.size_f); + } + { + auto iter = nlohmann_json_j.find("hash_progress_size"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.hash_progress_size); + } + { + auto iter = nlohmann_json_j.find("hash_fullsize"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.hash_fullsize); + } + { + auto iter = nlohmann_json_j.find("model_type"); + if (iter != nlohmann_json_j.end()) + if (!iter->is_null()) + iter->get_to(nlohmann_json_t.model_type); + } + } + }; - typedef std::vector RemoteModelList; -} // namespace sd_gui_utils::networks + typedef std::vector RemoteModelList; + } +} #endif // _NETWORK_REMOTE_MODEL_INFO_H_ \ No newline at end of file diff --git a/src/network/RemoteQueueJobItem.h b/src/network/RemoteQueueJobItem.h index 7700d75..ce95e65 100644 --- a/src/network/RemoteQueueJobItem.h +++ b/src/network/RemoteQueueJobItem.h @@ -1,32 +1,34 @@ #ifndef _NETWORK_REMOTE_QUEUE_JOB_ITEM_H_ #define _NETWORK_REMOTE_QUEUE_JOB_ITEM_H_ -namespace sd_gui_utils::networks { - struct RemoteQueueItem { - int id = 0, created_at = 0, updated_at = 0, finished_at = 0, started_at = 0; - SDParams params = SDParams(); - QueueStatus status = QueueStatus::PENDING; - QueueEvents event = QueueEvents::ITEM_ADDED; - QueueItemStats stats = QueueItemStats(); - int step = 0, steps = 0; - size_t hash_fullsize = 0, hash_progress_size = 0; - float time = 0; - std::string model = ""; - SDMode mode = SDMode::TXT2IMG; - std::string status_message = ""; - uint32_t upscale_factor = 4; - std::string sha256 = ""; - std::string app_version = SD_GUI_VERSION; - std::string git_version = GIT_HASH; - std::string original_prompt = ""; - std::string original_negative_prompt = ""; - bool keep_checkpoint_in_memory = false; - bool keep_upscaler_in_memory = false; - bool need_sha256 = false; - std::string generated_sha256 = ""; - int update_index = -1; - std::string server = ""; - }; -} +namespace sd_gui_utils { + inline namespace networks { + struct RemoteQueueItem { + int id = 0, created_at = 0, updated_at = 0, finished_at = 0, started_at = 0; + SDParams params = SDParams(); + QueueStatus status = QueueStatus::PENDING; + QueueEvents event = QueueEvents::ITEM_ADDED; + QueueItemStats stats = QueueItemStats(); + int step = 0, steps = 0; + size_t hash_fullsize = 0, hash_progress_size = 0; + float time = 0; + std::string model = ""; + SDMode mode = SDMode::TXT2IMG; + std::string status_message = ""; + uint32_t upscale_factor = 4; + std::string sha256 = ""; + std::string app_version = SD_GUI_VERSION; + std::string git_version = GIT_HASH; + std::string original_prompt = ""; + std::string original_negative_prompt = ""; + bool keep_checkpoint_in_memory = false; + bool keep_upscaler_in_memory = false; + bool need_sha256 = false; + std::string generated_sha256 = ""; + int update_index = -1; + std::string server = ""; + }; + } // namespace networks +} // namespace sd_gui_utils #endif // _NETWORK_REMOTE_QUEUE_JOB_ITEM_H_ \ No newline at end of file diff --git a/src/network/packets.h b/src/network/packets.h index 6a620ac..5fa570d 100644 --- a/src/network/packets.h +++ b/src/network/packets.h @@ -1,111 +1,95 @@ -#ifndef _LIBS_NETWORK_PACKETS_H_ -#define _LIBS_NETWORK_PACKETS_H_ +#ifndef _SDGUI_LIBS_NETWORK_PACKETS_H_ +#define _SDGUI_LIBS_NETWORK_PACKETS_H_ -#include "ver.hpp" +namespace sd_gui_utils { + inline namespace networks { -namespace sd_gui_utils::networks { + class Packet { + private: + std::vector data; + size_t data_size = 0; - enum class PacketType { - REQUEST = 0, - RESPONSE = 1 - }; + public: + enum class Type : int { + REQUEST_TYPE, + RESPONSE_TYPE + }; - enum class PacketParam { - ERROR = 0, - AUTH = 1, - MODEL_LIST = 2, - KEEPALIVE = 3 - }; + enum class Param : int { + PARAM_ERROR, + PARAM_AUTH, + PARAM_MODEL_LIST, + PARAM_KEEPALIVE + }; - struct Packet { - private: - std::vector data; - size_t data_size = 0; + Packet(Type type, Param param) + : type(type), param(param) {} + Packet() + : type(sd_gui_utils::networks::Packet::Type::REQUEST_TYPE), param(sd_gui_utils::networks::Packet::Param::PARAM_ERROR) {} + sd_gui_utils::networks::Packet::Type type; + sd_gui_utils::networks::Packet::Param param; + // std::string version = std::string(SD_GUI_VERSION); + std::string version = ""; + int source_idx = -1; + int target_idx = -1; + std::string server_id = ""; - public: - Packet(sd_gui_utils::networks::PacketType type, sd_gui_utils::networks::PacketParam param) - : type(type), param(param) {} - Packet() - : type(PacketType::REQUEST), param(PacketParam::ERROR) {} - sd_gui_utils::networks::PacketType type; - sd_gui_utils::networks::PacketParam param; - std::string version = std::string(SD_GUI_VERSION); - int source_idx = -1; - int target_idx = -1; - std::string server_id = ""; - - inline size_t GetSize() { return this->data_size; } - template - void SetData(const T& j) { - try { - nlohmann::json json_obj = j; - std::string json_str = json_obj.dump(); - this->data = std::vector(json_str.begin(), json_str.end()); - this->data_size = this->data.size(); - } catch (const std::exception& e) { - throw std::runtime_error("Failed to serialize object in SetData: " + std::string(e.what())); + inline size_t GetSize() { return this->data_size; } + template + inline void SetData(const T& j) { + try { + nlohmann::json json_obj = j; + std::string json_str = json_obj.dump(); + this->data = std::vector(json_str.begin(), json_str.end()); + this->data_size = this->data.size(); + } catch (const std::exception& e) { + throw std::runtime_error("Failed to serialize object in SetData: " + std::string(e.what())); + } } - } - template - T GetData() { - try { - nlohmann::json json_obj = nlohmann::json::parse(std::string(this->data.begin(), this->data.end())); - return json_obj.get(); - } catch (const std::exception& e) { - throw std::runtime_error("Failed to deserialize GetData data: " + std::string(e.what())); + template + inline T GetData() { + try { + nlohmann::json json_obj = nlohmann::json::parse(std::string(this->data.begin(), this->data.end())); + return json_obj.get(); + } catch (const std::exception& e) { + throw std::runtime_error("Failed to deserialize GetData data: " + std::string(e.what())); + } } - } - size_t GetDataSize() { return this->data_size; } + size_t GetDataSize() { return this->data_size; } - static std::pair Serialize(const Packet& packet) { - try { - nlohmann::json json_data = packet; - auto binary_data = nlohmann::json::to_ubjson(json_data); - char* d = new char[binary_data.size() + 1]; - std::copy(binary_data.begin(), binary_data.end(), d); - return std::make_pair(d, binary_data.size()); - } catch (const std::exception& e) { - throw std::runtime_error("Failed to serialize Packet to MessagePack: " + std::string(e.what())); + static std::pair Serialize(const Packet& packet) { + try { + nlohmann::json json_data = packet; + auto binary_data = nlohmann::json::to_ubjson(json_data); + char* d = new char[binary_data.size() + 1]; + std::copy(binary_data.begin(), binary_data.end(), d); + return std::make_pair(d, binary_data.size()); + } catch (const std::exception& e) { + throw std::runtime_error("Failed to serialize Packet to MessagePack: " + std::string(e.what())); + } } - } - static Packet DeSerialize(const char* data, size_t size) { - try { - std::vector binary_data(data, data + size); - auto json_data = nlohmann::json::from_ubjson(binary_data); + static Packet DeSerialize(const char* data, size_t size) { + try { + std::vector binary_data(data, data + size); + auto json_data = nlohmann::json::from_ubjson(binary_data); - if (json_data.is_null()) { - throw std::runtime_error("Received empty MessagePack data."); - } + if (json_data.is_null()) { + throw std::runtime_error("Received empty MessagePack data."); + } - return json_data.get(); - } catch (const std::exception& e) { - std::string received_data(data, data + size); - throw std::runtime_error( - "Failed to convert Packet: " + std::string(e.what()) + "\nRaw data size: " + std::to_string(size)); + return json_data.get(); + } catch (const std::exception& e) { + std::string received_data(data, data + size); + throw std::runtime_error( + "Failed to convert Packet: " + std::string(e.what()) + "\nRaw data size: " + std::to_string(size)); + } } - } - - friend void to_json(nlohmann ::json& nlohmann_json_j, const Packet& nlohmann_json_t) { - nlohmann_json_j["type"] = (int)nlohmann_json_t.type; - nlohmann_json_j["param"] = (int)nlohmann_json_t.param; - nlohmann_json_j["version"] = nlohmann_json_t.version; - nlohmann_json_j["data_size"] = nlohmann_json_t.data_size; - nlohmann_json_j["data"] = nlohmann_json_t.data; - nlohmann_json_j["server_id"] = nlohmann_json_t.server_id; - } - friend void from_json(const nlohmann ::json& nlohmann_json_j, Packet& nlohmann_json_t) { - const Packet nlohmann_json_default_obj{}; - nlohmann_json_t.type = (sd_gui_utils::networks::PacketType)nlohmann_json_j.value("type", sd_gui_utils::networks::PacketType::REQUEST); - nlohmann_json_t.param = (sd_gui_utils::networks::PacketParam)nlohmann_json_j.value("param", sd_gui_utils::networks::PacketParam::ERROR); - nlohmann_json_t.version = nlohmann_json_j.value("version", nlohmann_json_default_obj.version); - nlohmann_json_t.data_size = nlohmann_json_j.value("data_size", nlohmann_json_default_obj.data_size); - nlohmann_json_t.data = nlohmann_json_j.value("data", nlohmann_json_default_obj.data); - nlohmann_json_t.server_id = nlohmann_json_j.value("server_id", nlohmann_json_default_obj.server_id); - } - }; -}; // namespace sd_gui_utils::networks + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Packet, type, param, version, data_size, data, server_id) + }; // struct Packet + } // namespace networks +} // namespace sd_gui_utils -#endif // _LIBS_NETWORK_PACKETS_H_ \ No newline at end of file +#endif // _SDGUI_LIBS_NETWORK_PACKETS_H_ \ No newline at end of file diff --git a/src/network/sdServer.h b/src/network/sdServer.h index 9976fe5..da752cd 100644 --- a/src/network/sdServer.h +++ b/src/network/sdServer.h @@ -15,7 +15,7 @@ namespace sd_gui_utils { int internal_id = -1; std::string disconnect_reason; std::thread thread; - std::unique_ptr client = nullptr; + std::shared_ptr client = nullptr; wxEvtHandler* evt = nullptr; bool IsOk() const { return !host.empty() && port > 0; } @@ -94,7 +94,7 @@ namespace sd_gui_utils { void StartServer() { this->needToRun.store(true); if (this->client == nullptr) { - this->client = std::make_unique(); + this->client = std::make_shared(); } this->thread = std::thread([this]() { while (this->needToRun.load()) { @@ -158,7 +158,7 @@ namespace sd_gui_utils { } sdServer(const std::string& host, int port, wxEvtHandler* evt) : host(host), port(port), evt(evt) { - this->client = std::make_unique(); + this->client = std::make_shared(); this->client->onConnectClb = [this]() { this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_CONNECTED, this); }; @@ -167,7 +167,7 @@ namespace sd_gui_utils { if (msg.server_id.empty() == false && this->server_id.empty()) { this->SetId(msg.server_id); } - if (msg.param == sd_gui_utils::networks::PacketParam::MODEL_LIST) { + if (msg.param == sd_gui_utils::networks::Packet::Param::PARAM_MODEL_LIST) { this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_MODEL_LIST_UPDATE, this); } }; diff --git a/src/ui/ModelInfo.cpp b/src/ui/ModelInfo.cpp index ccb7983..365c8e0 100644 --- a/src/ui/ModelInfo.cpp +++ b/src/ui/ModelInfo.cpp @@ -323,25 +323,12 @@ wxString ModelInfo::Manager::GetMetaPath(const wxString& model_path, bool remote return wxEmptyString; } - wxString path_name = model_path; - - if (remote) { - path_name = path_name.SubString(65, path_name.Length()); - } + wxString path_name = remote ? model_path.SubString(65, model_path.Length()) : model_path; wxFileName path(path_name); - - if (!path.Exists() && !remote) { - return wxEmptyString; - } - path.SetExt("json"); wxFileName meta_path(this->MetaStorePath, path.GetFullName()); - if (!wxFileName::DirExists(meta_path.GetPath())) { - return wxEmptyString; - } if (remote) { - // add _remote suffix meta_path.SetName(meta_path.GetName() + "_remote"); } @@ -350,7 +337,7 @@ wxString ModelInfo::Manager::GetMetaPath(const wxString& model_path, bool remote wxString ModelInfo::Manager::GetFolderName(const wxString& model_path, const sd_gui_utils::DirTypes& type, wxString root_path, const sd_gui_utils::sdServer* server) { wxString path = model_path; - if (server == nullptr) { + if (server != nullptr) { path = path.SubString(65, path.Length()); } auto folderGroupName = wxFileName(path).GetPath();