Skip to content

Commit

Permalink
refactored tcp server and client
Browse files Browse the repository at this point in the history
  • Loading branch information
fszontagh committed Jan 10, 2025
1 parent d6af564 commit 7defe9e
Show file tree
Hide file tree
Showing 23 changed files with 514 additions and 454 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ vcpkg_installed/
/external/stable-diffusion
/ui/ver.hpp
locale/**/*.mo
CMakeUserPresets.json
CMakeUserPresets.json
CMakePresets.json
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ if (NOT WIN32)
set(APPDEPENDS wx::base wx::core wx::xrc wx::aui wx::richtext)

else()
set( wxWidgets::wxWidgets)
set(APPDEPENDS wxWidgets::wxWidgets)
if (MSVC)
set_target_properties(${PROJECT_BINARY_NAME} PROPERTIES
COMPILE_FLAGS "/DwxUSE_RC_MANIFEST"
Expand Down
2 changes: 1 addition & 1 deletion cmake/stable_diffusion.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ if(NOT SD_HIPBLAS)
GIT_TAG ${SD_GIT_TAG}
#WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/sdcpp_${variant_name}
BINARY_DIR ${CMAKE_BINARY_DIR}/sdcpp_${variant_name}
CMAKE_ARGS "-DGGML_NATIVE=ON"
CMAKE_ARGS "-DGGML_NATIVE=OFF"
"-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}"
"-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}"
"-DCMAKE_CXX_FLAGS=${DISABLE_WARNINGS_FLAGS}"
Expand Down
6 changes: 3 additions & 3 deletions graphics/window.fbp
Original file line number Diff line number Diff line change
Expand Up @@ -17946,7 +17946,7 @@
<property name="ellipsize"></property>
<property name="flags">wxDATAVIEW_COL_RESIZABLE</property>
<property name="label">Host</property>
<property name="mode">wxDATAVIEW_CELL_INERT</property>
<property name="mode">wxDATAVIEW_CELL_EDITABLE</property>
<property name="name">m_dataViewListColumn36</property>
<property name="permission">protected</property>
<property name="type">Text</property>
Expand All @@ -17957,7 +17957,7 @@
<property name="ellipsize"></property>
<property name="flags">wxDATAVIEW_COL_RESIZABLE</property>
<property name="label">Port</property>
<property name="mode">wxDATAVIEW_CELL_INERT</property>
<property name="mode">wxDATAVIEW_CELL_EDITABLE</property>
<property name="name">m_dataViewListColumn37</property>
<property name="permission">protected</property>
<property name="type">Text</property>
Expand All @@ -17968,7 +17968,7 @@
<property name="ellipsize">wxELLIPSIZE_MIDDLE</property>
<property name="flags">wxDATAVIEW_COL_RESIZABLE</property>
<property name="label">Auth key</property>
<property name="mode">wxDATAVIEW_CELL_INERT</property>
<property name="mode">wxDATAVIEW_CELL_EDITABLE</property>
<property name="name">m_dataViewListColumn38</property>
<property name="permission">protected</property>
<property name="type">Text</property>
Expand Down
2 changes: 1 addition & 1 deletion server/platforms/Docker/Dockerfile.in
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ RUN mkdir /app/models/esrgan

RUN if [ "$BACKEND" = "cuda" ]; then \
echo "CUDA support enabled"; \
apt update; apt install -y --no-install-recommends libcudart12 libcublas12; \
apt update; apt install -y --no-install-recommends libcudart12 libcublas12 libnvidia-compute-565-server; \
fi


Expand Down
5 changes: 1 addition & 4 deletions server/src/SocketApp.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
#ifndef _SERVER_SOCKETAPP_H
#define _SERVER_SOCKETAPP_H
#include <sstream>
#include <map>

#include <wx/log.h>
#include <wx/time.h>
#include <wx/timer.h>

#include "ver.hpp"

#include "libs/json.hpp"

#include "network/packets.h"
Expand Down
15 changes: 0 additions & 15 deletions src/libs/MessageHandler.cpp

This file was deleted.

22 changes: 0 additions & 22 deletions src/libs/MessageHandler.h

This file was deleted.

61 changes: 32 additions & 29 deletions src/libs/TcpClient.cpp
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
#include "TcpClient.h"
namespace sd_gui_utils::networks {
TcpClient::TcpClient(sd_gui_utils::sdServer* server, wxEvtHandler* evt)
: m_client(*this), server(server), evt(evt) {
sockets::SocketRet ret = m_client.connectTo(this->server->host.c_str(), this->server->port);
this->server->connected = ret.m_success;
if (ret.m_success) {
this->server->disconnect_reason = "";
std::cout << "Connected to " << this->server->host << ":" << this->server->port << "\n";
this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_CONNECTED, this->server);
} else {
this->server->disconnect_reason = ret.m_msg;
this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server);
this->m_client.finish();
}
}

void TcpClient::sendMsg(const char* data, size_t len) {
auto ret = m_client.sendMsg(data, len);
if (!ret.m_success) {
std::cout << "Send Error: " << ret.m_msg << "\n";
if (this->onErrorClb != nullptr) {
this->onErrorClb("Send Error: " + ret.m_msg);
}
}
}

Expand All @@ -44,25 +32,32 @@ namespace sd_gui_utils::networks {
this->expected_size = 0;
}
}
void TcpClient::HandlePackets(sd_gui_utils::networks::Packet& packet) {
void TcpClient::HandlePackets(Packet& packet) {
if (packet.type == sd_gui_utils::networks::PacketType::RESPONSE) {
if (packet.param == sd_gui_utils::networks::PacketParam::ERROR) {
this->server->disconnect_reason = packet.GetData<std::string>();
this->server->needToRun = false;
this->server->enabled = false;
this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_ERROR, this->server);
if (this->onErrorClb != nullptr) {
this->disconnect_reason = packet.GetData<std::string>();
this->onErrorClb(this->disconnect_reason);
}
// this->server->disconnect_reason = packet.GetData<std::string>();
// this->server->needToRun = false;
// this->server->enabled = false;
// this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_ERROR, this->server);
}
if (packet.param == sd_gui_utils::networks::PacketParam::MODEL_LIST) {
size_t packet_id = this->receivedPackets.size();
this->receivedPackets[packet_id] = packet;
this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_MODEL_LIST_UPDATE, server, packet_id);
// this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_MODEL_LIST_UPDATE, server, packet_id);
if (this->onMessageClb != nullptr) {
this->onMessageClb(packet_id);
}
}
}
if (packet.type == sd_gui_utils::networks::PacketType::REQUEST) {
if (packet.param == sd_gui_utils::networks::PacketParam::AUTH) {
// get the auth token from secret store
wxSecretStore store = wxSecretStore::GetDefault();
wxString serviceName = wxString::Format(wxT("%s/%s_%d"), PROJECT_NAME, wxString::FromUTF8Unchecked(server->host), server->port);
wxString serviceName = wxString::Format(wxT("%s/%s_%d"), PROJECT_NAME, wxString::FromUTF8Unchecked(this->host), this->port);
wxString username;
wxSecretValue authkey;

Expand All @@ -79,21 +74,29 @@ namespace sd_gui_utils::networks {
modelListPacket.param = sd_gui_utils::networks::PacketParam::MODEL_LIST;
this->sendMsg(modelListPacket);
} else {
this->server->disconnect_reason = "Authentication Failed";
this->server->needToRun = false;
this->server->enabled = false;
this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server);
this->disconnect_reason = _("Authentication Failed");
// this->server->needToRun = false;
// this->server->enabled = false;
// this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server);
if (this->onErrorClb != nullptr) {
this->disconnect_reason = packet.GetData<std::string>();
this->onErrorClb(this->disconnect_reason);
}
}
}
}
}

void TcpClient::onDisconnect(const sockets::SocketRet& ret) {
this->server->connected = false;
this->connected.store(false);
if (ret.m_msg.length() > 0) {
this->server->disconnect_reason = ret.m_msg;
this->disconnect_reason = ret.m_msg;
}
// this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server);
if (this->onDisconnectClb != nullptr) {
this->onDisconnectClb();
}
this->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server);
this->stop();
}

void TcpClient::sendMsg(const sd_gui_utils::networks::Packet& packet) {
Expand Down
75 changes: 47 additions & 28 deletions src/libs/TcpClient.h
Original file line number Diff line number Diff line change
@@ -1,55 +1,74 @@
#ifndef _LIBS_TCPCLIENT_H_
#define _LIBS_TCPCLIENT_H_
namespace sd_gui_utils {
enum class ThreadEvents;
}
#include "../network/sdServer.h"

#include "../network/packets.h"

namespace sd_gui_utils::networks {
using TcpClientOnMessage = std::function<void(int)>;
using TcpClientOnConnect = std::function<void()>;
using TcpClientOntDisconnect = std::function<void()>;
using TcpClientOnError = std::function<void(const std::string&)>;

class TcpClient {
sd_gui_utils::sdServer* server = nullptr;
wxEvtHandler* evt = nullptr;
sockets::TcpClient<TcpClient> m_client;
std::atomic_int conn_counter = 0;
void sendMsg(const sd_gui_utils::networks::Packet& packet);
std::unordered_map<size_t, sd_gui_utils::networks::Packet> receivedPackets;
std::vector<char> buffer;
size_t expected_size = 0;

template <typename T>
void SendThreadEvent(sd_gui_utils::ThreadEvents eventType, const T& payload, std::string text = "") {
wxThreadEvent* e = new wxThreadEvent();
e->SetString(wxString::Format("%d:%s", (int)eventType, text));
e->SetPayload(payload);
wxQueueEvent(this->evt, e);
};

template <typename T>
void SendThreadEvent(sd_gui_utils::ThreadEvents eventType, const T& payload, size_t packet_id, std::string text = "") {
wxThreadEvent* e = new wxThreadEvent();
e->SetString(wxString::Format("%d:%s", (int)eventType, text));
e->SetInt(packet_id);
e->SetPayload(payload);
wxQueueEvent(this->evt, e);
};

void HandlePackets(sd_gui_utils::networks::Packet& packet);
std::atomic<bool> connected{false};
std::string host, disconnect_reason;
int port;

public:
TcpClient(sd_gui_utils::sdServer* server, wxEvtHandler* evt);
virtual ~TcpClient() = default;
TcpClientOnMessage onMessageClb = nullptr;
TcpClientOnConnect onConnectClb = nullptr;
TcpClientOntDisconnect onDisconnectClb = nullptr;
TcpClientOnError onErrorClb = nullptr;
TcpClient()
: m_client(*this) {};

TcpClient(TcpClient&& other) = delete;
TcpClient& operator=(TcpClient&& other) = delete;
TcpClient(const TcpClient&) = delete;
TcpClient& operator=(const TcpClient&) = delete;
~TcpClient() = default;
void Connect(const std::string& host, int port) {
this->host = host;
this->port = port;
sockets::SocketRet ret = m_client.connectTo(this->host.c_str(), this->port);
this->connected.store(ret.m_success);
if (ret.m_success) {
if (this->onConnectClb != nullptr) {
this->onConnectClb();
}
// this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_CONNECTED, this->server);
} else {
// this->server->disconnect_reason = ret.m_msg;
// this->server->SendThreadEvent(sd_gui_utils::ThreadEvents::SERVER_DISCONNECTED, this->server);
// this->m_client.finish();
this->disconnect_reason = ret.m_msg;
if (this->onDisconnectClb != nullptr) {
this->onDisconnectClb();
}
this->stop();
}
}
void onReceiveData(const char* data, size_t size);
void onDisconnect(const sockets::SocketRet& ret);
void sendMsg(const char* data, size_t len);
bool IsConnected() { return this->connected.load(); }
std::string GetDisconnectReason() { return this->disconnect_reason; }
inline sd_gui_utils::networks::Packet getPacket(size_t Id) {
if (receivedPackets.find(Id) != receivedPackets.end()) {
return receivedPackets[Id];
}
return sd_gui_utils::networks::Packet();
}

inline void stop() {
this->server->needToRun = false;
this->m_client.finish();
this->connected.store(false);
}
};
}
Expand Down
Loading

0 comments on commit 7defe9e

Please sign in to comment.