Skip to content

Commit

Permalink
HookManager: Add support for recursion w/ new lock-free(-ish) method
Browse files Browse the repository at this point in the history
  • Loading branch information
praydog committed Mar 30, 2024
1 parent 10428d0 commit 1c45846
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 14 deletions.
93 changes: 80 additions & 13 deletions src/HookManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,33 @@ HookManager::HookedFn::~HookedFn() {
}

HookManager::PreHookResult HookManager::HookedFn::on_pre_hook() {
std::shared_lock _{this->access_mux};
//std::shared_lock _{this->access_mux};

auto any_skipped = false;

auto storage = get_storage(this);

if (storage->pre_depth == 0) {
// afaik, shared locks are not reentrant, so only lock it
// if we're not already in a pre-hook.
this->access_mux.lock_shared();
} else if (!storage->pre_warned_recursion) {
const auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
const auto declaring_type = fn_def->get_declaring_type();
const auto decltype_name = declaring_type != nullptr ? declaring_type->get_full_name() : "unknownclass";
spdlog::warn("[HookManager] (Pre) Recursive hook detected for '{}.{}' (thread ID: {:x})", decltype_name, fn_def->get_name(), tid);
storage->pre_warned_recursion = true;
}

if (storage->overall_depth > 0 && !storage->overall_warned_recursion) {
const auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
const auto declaring_type = fn_def->get_declaring_type();
const auto decltype_name = declaring_type != nullptr ? declaring_type->get_full_name() : "unknownclass";
spdlog::warn("[HookManager] (Overall) '{}.{}' appears to be calling itself in some way (thread ID: {:x})", decltype_name, fn_def->get_name(), tid);
storage->overall_warned_recursion = true;
}

++storage->pre_depth;
const auto ret_addr_pre = storage->ret_addr_pre;

for (const auto& cb : cbs) {
Expand All @@ -70,15 +92,38 @@ HookManager::PreHookResult HookManager::HookedFn::on_pre_hook() {
any_skipped = true;
}
}
}
}

++storage->overall_depth;
--storage->pre_depth;

if (storage->pre_depth == 0) {
this->access_mux.unlock_shared();
}

return any_skipped ? PreHookResult::SKIP_ORIGINAL : PreHookResult::CALL_ORIGINAL;
}

void HookManager::HookedFn::on_post_hook() {
std::shared_lock _{this->access_mux};
//std::shared_lock _{this->access_mux};

auto storage = get_storage(this);

if (storage->post_depth == 0) {
// afaik, shared locks are not reentrant, so only lock it
// if we're not already in a post-hook.
this->access_mux.lock_shared();
} else if (!storage->post_warned_recursion) {
const auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
const auto declaring_type = fn_def->get_declaring_type();
const auto decltype_name = declaring_type != nullptr ? declaring_type->get_full_name() : "unknownclass";
spdlog::warn("[HookManager] (Post) Recursive hook detected for '{}.{}' (thread ID: {:x})", decltype_name, fn_def->get_name(), tid);
storage->post_warned_recursion = true;
}

++storage->post_depth;
--storage->overall_depth;

auto& ret_val = storage->ret_val;
auto& ret_addr = storage->ret_addr;

Expand All @@ -87,6 +132,12 @@ void HookManager::HookedFn::on_post_hook() {
cb.post_fn(ret_val, ret_ty, ret_addr);
}
}

--storage->post_depth;

if (storage->post_depth == 0) {
this->access_mux.unlock_shared();
}
}

void HookManager::create_jitted_facilitator(std::unique_ptr<HookManager::HookedFn>& hook, sdk::REMethodDefinition* fn, std::function<uintptr_t ()> hook_initialization, std::function<void ()> hook_create) {
Expand Down Expand Up @@ -116,6 +167,8 @@ void HookManager::create_jitted_facilitator(std::unique_ptr<HookManager::HookedF
auto on_post_hook_label = a.newLabel();
auto orig_label = a.newLabel();
auto get_storage_label = a.newLabel();
auto push_rbx_label = a.newLabel();
auto pop_rbx_label = a.newLabel();
auto lock_label = a.newLabel();
auto unlock_label = a.newLabel();

Expand All @@ -141,6 +194,7 @@ void HookManager::create_jitted_facilitator(std::unique_ptr<HookManager::HookedF
a.mov(rcx, ptr(hook_label));
a.call(ptr(get_storage_label));


// restore stack
a.mov(rsp, rbx);

Expand Down Expand Up @@ -237,15 +291,21 @@ void HookManager::create_jitted_facilitator(std::unique_ptr<HookManager::HookedF
save_arg(args_offset, is_float);
}

// Call on_pre_hook.
a.mov(rcx, ptr(hook_label));

// fix stack
a.push(r10); // push storage
a.mov(rbx, rsp);
a.sub(rsp, STACK_STORAGE_AMOUNT);
a.and_(rsp, -16);

// Use this moment to push RBX to our pseudo-stack.
// because the pre-hook may call this function recursively, clobbering RBX.
a.mov(rcx, r10); // storage ptr.
a.mov(rdx, ptr(r10, rbx_offset)); // original rbx.
a.call(ptr(push_rbx_label));

// Call on_pre_hook.
a.mov(rcx, ptr(hook_label));
a.call(ptr(on_pre_hook_label));

// Save the return value so we can see if we need to call the original later.
Expand Down Expand Up @@ -364,23 +424,26 @@ void HookManager::create_jitted_facilitator(std::unique_ptr<HookManager::HookedF
a.mov(ptr(r10, ret_val_offset), rax);
}

a.push(r10);

// Call on_post_hook.
a.push(r12); // R12 being used as a cross-call register for storage of the storage ptr.
a.mov(r12, r10);

a.mov(rbx, rsp);
a.sub(rsp, STACK_STORAGE_AMOUNT);
a.and_(rsp, -16);

a.mov(rcx, ptr(hook_label));
a.call(ptr(on_post_hook_label));

// Now use this moment to pop RBX from our pseudo-stack.
a.mov(rcx, r12); // storage ptr.
a.call(ptr(pop_rbx_label));

a.mov(rsp, rbx);

a.pop(r10);
a.mov(rsp, rbx); // Restore stack ptr.
a.mov(rbx, rax); // restore original RBX, the return value from pop_rbx.

// Restore return value and RBX.
//a.mov(rcx, ptr(ret_val_label));
a.mov(rbx, ptr(r10, rbx_offset));
a.mov(r10, r12); // storage ptr.
a.pop(r12);

if (is_ret_ty_float) {
a.movq(xmm0, ptr(r10, ret_val_offset));
Expand All @@ -401,6 +464,10 @@ void HookManager::create_jitted_facilitator(std::unique_ptr<HookManager::HookedF
a.dq((uint64_t)&HookedFn::on_post_hook_static);
a.bind(get_storage_label);
a.dq((uint64_t)&HookedFn::get_storage);
a.bind(push_rbx_label);
a.dq((uint64_t)&HookedFn::push_rbx);
a.bind(pop_rbx_label);
a.dq((uint64_t)&HookedFn::pop_rbx);
a.bind(lock_label);
a.dq((uint64_t)&HookedFn::lock_static);
a.bind(unlock_label);
Expand Down
23 changes: 22 additions & 1 deletion src/HookManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>
#include <memory>
#include <mutex>
#include <stack>

#include <asmjit/asmjit.h>

Expand Down Expand Up @@ -62,14 +63,24 @@ class HookManager {
bool is_virtual{false};
HookedVTable* vtable{nullptr};

// Per-thread storage for hooked function.
struct HookStorage {
size_t* args{};
uintptr_t This{};
uintptr_t ret_addr_pre{};
uintptr_t ret_addr{};
uintptr_t ret_val{};
uintptr_t rbx{};

uintptr_t rbx; // temp storage for rbx.
std::stack<uintptr_t> rbx_stack{}; // full storage for rbx.
std::vector<size_t> args_impl{};

uint32_t pre_depth{0};
uint32_t overall_depth{0};
uint32_t post_depth{0};
bool pre_warned_recursion{false}; // for logging recursion.
bool overall_warned_recursion{false}; // for logging recursion.
bool post_warned_recursion{false}; // for logging recursion.
};

// Thread->storage
Expand All @@ -82,6 +93,16 @@ class HookManager {
PreHookResult on_pre_hook();
void on_post_hook();

__declspec(noinline) static void push_rbx(HookStorage* storage, uintptr_t rbx) {
storage->rbx_stack.push(rbx);
}

__declspec(noinline) static uintptr_t pop_rbx(HookStorage* storage) {
auto rbx = storage->rbx_stack.top();
storage->rbx_stack.pop();
return rbx;
}

__declspec(noinline) static HookStorage* get_storage(HookedFn* fn) {
auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
{
Expand Down

0 comments on commit 1c45846

Please sign in to comment.