Skip to content

Commit

Permalink
limit token if user is suspended
Browse files Browse the repository at this point in the history
  • Loading branch information
dr3mro committed Nov 27, 2024
1 parent 3099038 commit 1200c71
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/controllers/base/controller/controller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Controller
void Logout(T &entity, CALLBACK_ &&callback)
requires(std::is_base_of_v<Client, T>)
{
TokenManager::LoggedClientInfo loggedClientInfo;
SessionManager::LoggedClientInfo loggedClientInfo;

try
{
Expand Down
17 changes: 14 additions & 3 deletions src/controllers/clientcontroller/clientcontroller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,24 @@ void ClientController<T>::Login(CALLBACK_&& callback, std::string_view data)

if (success)
{
TokenManager::LoggedClientInfo loggedClientInfo;
SessionManager::LoggedClientInfo loggedClientInfo;

loggedClientInfo.clientId = client_id;
loggedClientInfo.userName = credentials.username;
loggedClientInfo.group = client.getGroupName();
sessionManager->setNowLoginTime(client_id.value(), loggedClientInfo.group.value());
loggedClientInfo.llodt = sessionManager->getLastLogoutTime(loggedClientInfo.clientId.value(), loggedClientInfo.group.value()).value();

std::optional<std::string> last_logout = sessionManager->setNowLoginTime(client_id.value(), loggedClientInfo.group.value());
if (last_logout.has_value())
{
loggedClientInfo.llodt = last_logout.value();
LOG_TRACE << loggedClientInfo.llodt.value();
}
else
{
Message::ErrorMessage("Failed to set login time and get last logout time");
callback(api::v2::Http::Status::UNAUTHORIZED, "Failed to set login time");
return;
}

jsoncons::json token_object;
token_object["token"] = tokenManager->GenerateToken(loggedClientInfo);
Expand Down
7 changes: 4 additions & 3 deletions src/filters/auth.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace api
}

// Extract the token
TokenManager::LoggedClientInfo clientInfo;
SessionManager::LoggedClientInfo clientInfo;
clientInfo.token = auth_header.substr(7); // Efficient token extraction

// Get TokenManager object
Expand All @@ -38,9 +38,10 @@ namespace api
// Validate token in a single step
if (!tokenManager->ValidateToken(clientInfo))
{
auto resp = drogon::HttpResponse::newHttpResponse();
auto message = clientInfo.is_active ? "Token is invalid or expired" : "User is suspended";
auto resp = drogon::HttpResponse::newHttpResponse();
resp->setStatusCode(drogon::HttpStatusCode::k401Unauthorized);
resp->setBody(JsonHelper::stringify(JsonHelper::jsonify("Access Denied.")));
resp->setBody(JsonHelper::stringify(JsonHelper::jsonify(message)));
fcb(resp);
return;
}
Expand Down
42 changes: 37 additions & 5 deletions src/utils/sessionmanager/sessionmanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
#include <ctime>
#include <string>

#include "controllers/clientcontroller/clientcontroller.hpp"
#include "fmt/core.h"
#include "utils/message/message.hpp"

void SessionManager::setNowLoginTime(uint64_t id, const std::string &group)
std::optional<std::string> SessionManager::setNowLoginTime(uint64_t id, const std::string &group)
{
try
{
std::string login_time = current_time_to_utc_string();

std::string query = fmt::format(
"INSERT INTO {}_sessions (id, last_login,last_logout) VALUES ({}, '{}', '{}') "
"ON CONFLICT (id) DO UPDATE SET last_login = EXCLUDED.last_login;",
"ON CONFLICT (id) DO UPDATE SET last_login = EXCLUDED.last_login RETURNING last_logout;",
group, id, login_time, login_time);

auto result = databaseController->executeQuery(query);
Expand All @@ -26,12 +27,18 @@ void SessionManager::setNowLoginTime(uint64_t id, const std::string &group)
Message::ErrorMessage("Error updating login time.");
Message::CriticalMessage(result.value().to_string());
}
else if (!result->empty())
{
return result.value().at("last_logout").as_string();
}
return std::nullopt;
}
catch (const std::exception &e)
{
Message::ErrorMessage("Error updating login time.");
Message::CriticalMessage(e.what());
}
return std::nullopt;
}

void SessionManager::setNowLogoutTime(uint64_t id, const std::string &group)
Expand Down Expand Up @@ -64,10 +71,35 @@ std::optional<std::string> SessionManager::getLastLoginTime(uint64_t id, const s
std::string query = fmt::format("SELECT last_login FROM {}_sessions WHERE id = {};", group, id);
return databaseController->doReadQuery(query);
}
std::optional<std::string> SessionManager::getLastLogoutTime(uint64_t id, const std::string &group)

bool SessionManager::getLastLogoutTimeIfActive(LoggedClientInfo &loggedClientInfo)
{
std::string query = fmt::format("SELECT last_logout FROM {}_sessions WHERE id = {};", group, id);
return databaseController->doReadQuery(query);
try
{
std::string query = fmt::format(
"WITH client_data AS (SELECT ses.id, ses.last_logout, client.active FROM {}_sessions ses\
LEFT JOIN {} client ON ses.id = client.id) \
SELECT id,\
last_logout, active FROM client_data WHERE id = {};",
loggedClientInfo.group.value(), loggedClientInfo.group.value(), loggedClientInfo.clientId.value());

auto results = databaseController->executeReadQuery(query);
if (results.has_value())
{
if (!results.value().empty())
{
loggedClientInfo.is_active = results.value().at("active").as_bool();
loggedClientInfo.llodt = results.value().at("last_logout").as_string();
return true;
}
}
return false;
}
catch (std::exception &e)
{
CRITICALMESSAGE
}
return false;
}

std::string SessionManager::current_time_to_utc_string()
Expand Down
13 changes: 11 additions & 2 deletions src/utils/sessionmanager/sessionmanager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@
class SessionManager
{
public:
using LoggedClientInfo = struct LoggedClientInfo
{
std::optional<std::string> token;
std::optional<std::string> userName;
std::optional<std::string> group;
std::optional<uint64_t> clientId;
std::optional<std::string> llodt; // used to invalidate tokens on logout
bool is_active;
};
SessionManager() : databaseController(Store::getObject<DatabaseController>()) {}

virtual ~SessionManager() = default;
void setNowLoginTime(uint64_t id, const std::string &group);
std::optional<std::string> setNowLoginTime(uint64_t id, const std::string &group);
void setNowLogoutTime(uint64_t id, const std::string &group);
std::optional<std::string> getLastLoginTime(uint64_t id, const std::string &group);
std::optional<std::string> getLastLogoutTime(uint64_t id, const std::string &group);
bool getLastLogoutTimeIfActive(LoggedClientInfo &loggedClientInfo);

private:
std::shared_ptr<DatabaseController> databaseController;
Expand Down
22 changes: 15 additions & 7 deletions src/utils/tokenmanager/tokenmanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using jwt::error::token_verification_exception;

std::optional<std::string> TokenManager::GenerateToken(const LoggedClientInfo &loggedinClientInfo) const
std::optional<std::string> TokenManager::GenerateToken(const SessionManager::LoggedClientInfo &loggedinClientInfo) const
{
try
{
Expand All @@ -31,7 +31,7 @@ std::optional<std::string> TokenManager::GenerateToken(const LoggedClientInfo &l
return std::nullopt;
}

bool TokenManager::ValidateToken(LoggedClientInfo &loggedinClientInfo) const
bool TokenManager::ValidateToken(SessionManager::LoggedClientInfo &loggedinClientInfo) const
{
try
{
Expand All @@ -48,6 +48,10 @@ bool TokenManager::ValidateToken(LoggedClientInfo &loggedinClientInfo) const
// Update user info from token
fillUserInfo(loggedinClientInfo, token);

if (!loggedinClientInfo.is_active)
{
return false;
}
// Validate token claims
auto verifier = createTokenVerifier(loggedinClientInfo);
verifier.verify(token);
Expand All @@ -67,7 +71,8 @@ bool TokenManager::ValidateToken(LoggedClientInfo &loggedinClientInfo) const
return false;
}

jwt::verifier<jwt::default_clock, jwt::traits::kazuho_picojson> TokenManager::createTokenVerifier(const LoggedClientInfo &loggedinClientInfo) const
jwt::verifier<jwt::default_clock, jwt::traits::kazuho_picojson> TokenManager::createTokenVerifier(
const SessionManager::LoggedClientInfo &loggedinClientInfo) const
{
return jwt::verify<jwt::traits::kazuho_picojson>()
.allow_algorithm(jwt::algorithm::hs256{tokenManagerParameters_.secret.data()})
Expand All @@ -79,7 +84,8 @@ jwt::verifier<jwt::default_clock, jwt::traits::kazuho_picojson> TokenManager::cr
.with_claim("llodt", jwt::basic_claim<jwt::traits::kazuho_picojson>(loggedinClientInfo.llodt.value()));
}

void TokenManager::fillUserInfo(LoggedClientInfo &loggedinClientInfo, const jwt::decoded_jwt<jwt::traits::kazuho_picojson> &token) const
void TokenManager::fillUserInfo(SessionManager::LoggedClientInfo &loggedinClientInfo,
const jwt::decoded_jwt<jwt::traits::kazuho_picojson> &token) const
{
if (!loggedinClientInfo.group)
{
Expand All @@ -88,6 +94,7 @@ void TokenManager::fillUserInfo(LoggedClientInfo &loggedinClientInfo, const jwt:
else if (loggedinClientInfo.group != token.get_payload_claim("group").as_string())
{
throw std::runtime_error("Group mismatch in token");
return;
}

loggedinClientInfo.clientId = std::stoull(token.get_id());
Expand All @@ -96,12 +103,13 @@ void TokenManager::fillUserInfo(LoggedClientInfo &loggedinClientInfo, const jwt:
if (!loggedinClientInfo.group || !loggedinClientInfo.clientId || !loggedinClientInfo.userName)
{
throw std::runtime_error("Missing required user information in token");
return;
}
// Get last logout time
loggedinClientInfo.llodt = sessionManager->getLastLogoutTime(loggedinClientInfo.clientId.value(), loggedinClientInfo.group.value());

sessionManager->getLastLogoutTimeIfActive(loggedinClientInfo);
}

bool TokenManager::validateUserInDatabase(const LoggedClientInfo &loggedinClientInfo) const
bool TokenManager::validateUserInDatabase(const SessionManager::LoggedClientInfo &loggedinClientInfo) const
{
return databaseController->findIfUserID(loggedinClientInfo.userName.value(), loggedinClientInfo.group.value()) ==
loggedinClientInfo.clientId.value();
Expand Down
20 changes: 6 additions & 14 deletions src/utils/tokenmanager/tokenmanager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,20 @@
class TokenManager
{
public:
using LoggedClientInfo = struct LoggedClientInfo
{
std::optional<std::string> token;
std::optional<std::string> userName;
std::optional<std::string> group;
std::optional<uint64_t> clientId;
std::optional<std::string> llodt; // used to invalidate tokens on logout
};

TokenManager() : databaseController(Store::getObject<DatabaseController>()), sessionManager(Store::getObject<SessionManager>()) {}
virtual ~TokenManager() = default;

std::optional<std::string> GenerateToken(const LoggedClientInfo &loggedinClientInfo) const;
bool ValidateToken(LoggedClientInfo &loggedinClientInfo) const;
std::optional<std::string> GenerateToken(const SessionManager::LoggedClientInfo &loggedinClientInfo) const;
bool ValidateToken(SessionManager::LoggedClientInfo &loggedinClientInfo) const;

private:
std::shared_ptr<Configurator> configurator_ = Store::getObject<Configurator>();
const Configurator::TokenManagerParameters &tokenManagerParameters_ = configurator_->get<Configurator::TokenManagerParameters>();

std::shared_ptr<DatabaseController> databaseController;
std::shared_ptr<SessionManager> sessionManager;
jwt::verifier<jwt::default_clock, jwt::traits::kazuho_picojson> createTokenVerifier(const LoggedClientInfo &loggedinClientInfo) const;
void fillUserInfo(LoggedClientInfo &loggedinClientInfo, const jwt::decoded_jwt<jwt::traits::kazuho_picojson> &token) const;
bool validateUserInDatabase(const LoggedClientInfo &loggedinClientInfo) const;
jwt::verifier<jwt::default_clock, jwt::traits::kazuho_picojson> createTokenVerifier(
const SessionManager::LoggedClientInfo &loggedinClientInfo) const;
void fillUserInfo(SessionManager::LoggedClientInfo &loggedinClientInfo, const jwt::decoded_jwt<jwt::traits::kazuho_picojson> &token) const;
bool validateUserInDatabase(const SessionManager::LoggedClientInfo &loggedinClientInfo) const;
};

0 comments on commit 1200c71

Please sign in to comment.