Skip to content

Commit

Permalink
Eliminate a bunch of branching, pass a temporary context instead
Browse files Browse the repository at this point in the history
  • Loading branch information
ZehMatt committed Jan 21, 2025
1 parent 585ec39 commit 090abd8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 67 deletions.
25 changes: 24 additions & 1 deletion zasm/src/zasm/src/encoder/encoder.context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ namespace zasm
Section::Attribs attribs{};
};

enum class EncoderFlags : std::uint32_t
{
none = 0,
temporary = 1U << 0,
};
ZASM_ENABLE_ENUM_OPERATORS(EncoderFlags);

struct EncoderContext
{
public:
EncoderFlags flags{};
detail::ProgramState* program{};
bool needsExtraPass{};
std::size_t nodeIndex{};
Expand All @@ -45,7 +53,6 @@ namespace zasm
std::int64_t va{};
std::int32_t offset{};
std::int32_t instrSize{};


struct LabelLink
{
Expand Down Expand Up @@ -105,6 +112,11 @@ namespace zasm
{
assert(id != Label::Id::Invalid);

if ((flags & EncoderFlags::temporary) != EncoderFlags::none)
{
return std::nullopt;
}

const auto& entry = getOrCreateLabelLink(id);
if (entry.boundVA == -1)
{
Expand All @@ -113,5 +125,16 @@ namespace zasm

return entry.boundVA;
}

std::uint32_t getNodeSize(std::size_t nodeIndex) const
{
if ((flags & EncoderFlags::temporary) != EncoderFlags::none)
{
return 0;
}

assert(nodeIndex < nodes.size());
return nodes[nodeIndex].length;
}
};
} // namespace zasm
121 changes: 55 additions & 66 deletions zasm/src/zasm/src/encoder/encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ namespace zasm

struct EncoderState
{
EncoderContext* ctx{};
EncoderContext& ctx;
ZydisEncoderRequest req{};
std::size_t operandIndex{};
RelocationType relocKind{};
RelocationData relocData{};
Label::Id relocLabel{ Label::Id::Invalid };

EncoderState(EncoderContext& ctx_) noexcept
: ctx(ctx_)
{
}
};

// NOTE: This value has to be at least larger than 0xFFFF to be used with imm32/rel32 displacement.
Expand Down Expand Up @@ -97,8 +102,14 @@ namespace zasm
return encoderVariantData[mnemonic]; // NOLINT
}

static bool isLabelExternal(detail::ProgramState* state, Label::Id labelId)
static bool isLabelExternal(EncoderContext& ctx, Label::Id labelId)
{
if ((ctx.flags & EncoderFlags::temporary) != EncoderFlags::none)
{
return false;
}

const auto state = ctx.program;
const auto idx = static_cast<std::size_t>(labelId);
if (idx >= state->labels.size())
{
Expand Down Expand Up @@ -137,7 +148,8 @@ namespace zasm
return res;
}

static Error buildOperand_(ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state, const Reg& src) noexcept
static Error buildOperand_(
EncoderContext&, ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state, const Reg& src) noexcept
{
dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_REGISTER;
dst.reg.value = static_cast<ZydisRegister>(src.getId());
Expand All @@ -147,8 +159,6 @@ namespace zasm

static int64_t getTemporaryRel(EncoderState& state, const EncodeVariantsInfo& encodeInfo) noexcept
{
auto* ctx = state.ctx;

std::int64_t tempRel = 0;

if (encodeInfo.canEncodeRel32())
Expand All @@ -163,22 +173,19 @@ namespace zasm
return tempRel;
}

static Error buildOperand_(ZydisEncoderOperand& dst, EncoderState& state, const Label& src)
static Error buildOperand_(EncoderContext& ctx, ZydisEncoderOperand& dst, EncoderState& state, const Label& src)
{
auto* ctx = state.ctx;

const auto curVA = ctx != nullptr ? ctx->va : 0;
const auto& encodeInfo = getEncodeVariantInfo(state.req.mnemonic);

// Initially a temporary placeholder.
std::int64_t immValue = curVA + getTemporaryRel(state, encodeInfo);
std::int64_t immValue = ctx.va + getTemporaryRel(state, encodeInfo);

if (ctx != nullptr && !isLabelExternal(ctx->program, src.getId()))
if (!isLabelExternal(ctx, src.getId()))
{
auto labelVA = ctx->getLabelAddress(src.getId());
auto labelVA = ctx.getLabelAddress(src.getId());
if (!labelVA.has_value())
{
ctx->needsExtraPass = true;
ctx.needsExtraPass = true;
}
else
{
Expand All @@ -188,8 +195,7 @@ namespace zasm

if (encodeInfo.isControlFlow)
{
const auto instrSize = ctx != nullptr ? ctx->instrSize : 0;
const auto rel = immValue - (curVA + instrSize);
const auto rel = immValue - (ctx.va + ctx.instrSize);

if (!encodeInfo.canEncodeRel32())
{
Expand Down Expand Up @@ -228,10 +234,8 @@ namespace zasm
return ErrorCode::None;
}

static Error buildOperand_(ZydisEncoderOperand& dst, EncoderState& state, const Imm& src)
static Error buildOperand_(EncoderContext&, ZydisEncoderOperand& dst, EncoderState& state, const Imm& src)
{
auto* ctx = state.ctx;

auto desiredBranchType = ZydisBranchType::ZYDIS_BRANCH_TYPE_NONE;

dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_IMMEDIATE;
Expand All @@ -240,10 +244,8 @@ namespace zasm
return ErrorCode::None;
}

static Error buildOperand_(ZydisEncoderOperand& dst, EncoderState& state, const Mem& src)
static Error buildOperand_(EncoderContext& ctx, ZydisEncoderOperand& dst, EncoderState& state, const Mem& src)
{
auto* ctx = state.ctx;

dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_MEMORY;
dst.mem.base = static_cast<ZydisRegister>(src.getBase().getId());
dst.mem.index = static_cast<ZydisRegister>(src.getIndex().getId());
Expand All @@ -258,38 +260,29 @@ namespace zasm

std::int64_t displacement = src.getDisplacement();

const auto address = ctx != nullptr ? ctx->va : 0;

bool usingLabel = false;
bool externalLabel = false;
bool isDisplacementValid = true;

if (const auto labelId = src.getLabelId(); labelId != Label::Id::Invalid)
{
if (ctx != nullptr)
{
externalLabel = isLabelExternal(ctx->program, labelId);
externalLabel = isLabelExternal(ctx, labelId);

auto labelVA = ctx->getLabelAddress(labelId);
if (labelVA.has_value())
{
displacement += *labelVA;
}
else
{
displacement += address + kTemporaryRel32Value;
isDisplacementValid = false;
if (!externalLabel)
{
ctx->needsExtraPass = true;
}
}
auto labelVA = ctx.getLabelAddress(labelId);
if (labelVA.has_value())
{
displacement += *labelVA;
}
else
{
displacement = kTemporaryRel32Value;
displacement += ctx.va + kTemporaryRel32Value;
isDisplacementValid = false;
if (!externalLabel)
{
ctx.needsExtraPass = true;
}
}

usingLabel = true;
}

Expand All @@ -307,8 +300,7 @@ namespace zasm
{
if (isDisplacementValid)
{
const auto instrSize = ctx != nullptr ? ctx->instrSize : 0;
const auto rel = displacement - (address + instrSize);
const auto rel = displacement - (ctx.va + ctx.instrSize);
if (std::abs(rel) > std::numeric_limits<std::int32_t>::max())
{
char msg[128];
Expand Down Expand Up @@ -343,15 +335,16 @@ namespace zasm
}

static Error buildOperand_(
ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state, [[maybe_unused]] const Operand::None& src) noexcept
EncoderContext&, ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state,
[[maybe_unused]] const Operand::None& src) noexcept
{
dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_UNUSED;
return ErrorCode::None;
}

static Error buildOperand(ZydisEncoderOperand& dst, EncoderState& state, const Operand& src)
static Error buildOperand(EncoderContext& ctx, ZydisEncoderOperand& dst, EncoderState& state, const Operand& src)
{
return src.visit([&dst, &state](auto&& src2) { return buildOperand_(dst, state, src2); });
return src.visit([&](auto&& src2) { return buildOperand_(ctx, dst, state, src2); });
}

static void fixupIs4Operands(ZydisEncoderRequest& req) noexcept
Expand Down Expand Up @@ -432,7 +425,7 @@ namespace zasm
}

static Error encode_(
EncoderResult& res, EncoderContext* ctx, MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic,
EncoderResult& res, EncoderContext& ctx, MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic,
size_t numOps, const Operand* operands)
{
if (!validateMachineMode(mode))
Expand All @@ -442,8 +435,7 @@ namespace zasm

res.buffer.length = 0;

EncoderState state{};
state.ctx = ctx;
EncoderState state{ ctx };

ZydisEncoderRequest& req = state.req;
if (mode == MachineMode::AMD64)
Expand Down Expand Up @@ -481,7 +473,7 @@ namespace zasm
{
auto& dstOp = req.operands[state.operandIndex]; // NOLINT
const auto& srcOp = operands[state.operandIndex]; // NOLINT
if (auto opStatus = buildOperand(dstOp, state, srcOp); opStatus != ErrorCode::None)
if (auto opStatus = buildOperand(ctx, dstOp, state, srcOp); opStatus != ErrorCode::None)
{
return opStatus;
}
Expand All @@ -491,8 +483,7 @@ namespace zasm
fixupIs4Operands(req);

std::size_t bufLen = res.buffer.data.size();
const auto curAddress = ctx != nullptr ? ctx->va : 0;
switch (auto status = ZydisEncoderEncodeInstructionAbsolute(&req, res.buffer.data.data(), &bufLen, curAddress); status)
switch (auto status = ZydisEncoderEncodeInstructionAbsolute(&req, res.buffer.data.data(), &bufLen, ctx.va); status)
{
case ZYAN_STATUS_SUCCESS:
break;
Expand All @@ -509,26 +500,14 @@ namespace zasm
return ErrorCode::None;
}

Expected<EncoderResult, Error> encode(
MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic, std::size_t numOps,
const Operand* operands)
{
EncoderResult res;
if (auto err = encode_(res, nullptr, mode, attribs, mnemonic, numOps, operands); err != ErrorCode::None)
{
return makeUnexpected(err);
}
return res;
}

static Expected<EncoderResult, Error> encodeWithContext(
EncoderContext& ctx, MachineMode mode, Instruction::Attribs prefixes, Instruction::Mnemonic mnemonic,
std::size_t numOps, const Operand* operands)
{
EncoderResult res;
ctx.instrSize = ctx.nodes[ctx.nodeIndex].length;
ctx.instrSize = ctx.getNodeSize(ctx.nodeIndex);

if (const auto encodeError = encode_(res, &ctx, mode, prefixes, mnemonic, numOps, operands);
if (const auto encodeError = encode_(res, ctx, mode, prefixes, mnemonic, numOps, operands);
encodeError != ErrorCode::None)
{
return makeUnexpected(encodeError);
Expand All @@ -537,6 +516,16 @@ namespace zasm
return res;
}

Expected<EncoderResult, Error> encode(
MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic, std::size_t numOps,
const Operand* operands)
{
EncoderContext tempCtx{};
tempCtx.flags |= EncoderFlags::temporary;

return encodeWithContext(tempCtx, mode, attribs, mnemonic, numOps, operands);
}

Expected<EncoderResult, Error> encode(EncoderContext& ctx, MachineMode mode, const Instruction& instr)
{
const auto& ops = instr.getOperands();
Expand Down

0 comments on commit 090abd8

Please sign in to comment.