Skip to content

Commit

Permalink
[Snippets][CPU] Fixed aux GPR allocation for Loop emitters
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Aug 21, 2024
1 parent 54f58b8 commit 5c4ccea
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@ using namespace dnnl::impl::cpu::x64;

namespace ov {
namespace intel_cpu {
namespace {
bool init_aux_gpr(Reg64& aux_gpr, const std::vector<size_t> &pool_gpr_idxs, const std::vector<size_t>& used_gprs) {
if (pool_gpr_idxs.empty()) {
for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) {
size_t _idx = Operand::R15 - gpr_idx; // we allocate from the end
if (std::find(used_gprs.cbegin(), used_gprs.cend(), _idx) != used_gprs.cend()) continue;
aux_gpr = Reg64(static_cast<int>(_idx));
return true;
}
} else {
aux_gpr = Reg64(static_cast<int>(pool_gpr_idxs[0]));
return false;
}
OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR");
}
} // namespace

/* ================== jit_loop_begin_emitter ====================== */

Expand All @@ -30,12 +46,6 @@ jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generat
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
}

size_t jit_loop_begin_emitter::aux_gprs_count() const {
// We should have aux GPR to store Loop arguments from `runtime_args`
// where we will take all needed information about the current loop: work amount
return is_work_amount_dynamic ? 1 : 0;
}

void jit_loop_begin_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.empty(), "Invalid inputs size: expected 0 got " + std::to_string(in.size()));
// Note: the only expected output is work amount register (communicated to jit_loop_end_emitter)
Expand All @@ -48,22 +58,25 @@ void jit_loop_begin_emitter::validate_arguments(const std::vector<size_t> &in, c
void jit_loop_begin_emitter::emit_code(const std::vector<size_t> &in, const std::vector<size_t> &out,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
validate_arguments(in, out);
jit_emitter::emit_code(in, out, pool_vec_idxs, pool_gpr_idxs);
}

void jit_loop_begin_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
// If the loop evaulate once, we can skip loop begin code emission
// If work_amount is dynamic, we should get runtime `work_amount` - it might be `zero` and we should skip loop evaluation
if (evaluate_once && !is_work_amount_dynamic)
return;

Reg64 reg_work_amount = Reg64(static_cast<int>(out.back()));
if (is_work_amount_dynamic) {
Reg64 reg_runtime_params = abi_param1; // defined by jit_kernel_emitter
Reg64 reg_loop_args_ptr = Reg64(static_cast<int>(aux_gpr_idxs[0]));
Reg64 reg_loop_args_ptr;
const auto preserved = init_aux_gpr(reg_loop_args_ptr, pool_gpr_idxs, out); // loop_begin doesn't have input registers
if (preserved)
h->push(reg_loop_args_ptr);

const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t);
h->mov(reg_loop_args_ptr, h->ptr[reg_runtime_params + GET_OFF(loop_args)]);
h->mov(reg_loop_args_ptr, h->ptr[abi_param1 + GET_OFF(loop_args)]); // defined by jit_kernel_emitter
h->mov(reg_work_amount, h->ptr[reg_loop_args_ptr + id_offset + GET_OFF_LOOP_ARGS(m_work_amount)]);

if (preserved)
h->pop(reg_loop_args_ptr);
} else {
h->mov(reg_work_amount, work_amount);
}
Expand Down Expand Up @@ -138,26 +151,22 @@ void jit_loop_end_emitter::validate_arguments(const std::vector<size_t> &in, con
void jit_loop_end_emitter::emit_code(const std::vector<size_t> &in, const std::vector<size_t> &out,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
validate_arguments(in, out);
jit_emitter::emit_code(in, out, pool_vec_idxs, pool_gpr_idxs);
}

size_t jit_loop_end_emitter::aux_gprs_count() const {
// We should have aux GPR to store Loop arguments from `runtime_args`
// where we will take all needed information about the current loop: data pointer shifts
return are_ptr_shifts_dynamic ? 1 : 0;
}

void jit_loop_end_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
std::vector<size_t> data_ptr_reg_idxs;
// the last input is actually a work_amount reg
data_ptr_reg_idxs.reserve(num_inputs + num_outputs);
std::copy(in.begin(), in.end() - 1, std::back_inserter(data_ptr_reg_idxs));

const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t);
Reg64 reg_increments = are_ptr_shifts_dynamic ? Reg64(static_cast<int>(aux_gpr_idxs[0])) : Reg64();

auto apply_increments = [&](bool use_runtime_args, size_t field_offset, const std::vector<int64_t>& increments, size_t scale) {
Reg64 reg_increments;
bool preserved = false;
if (use_runtime_args) {
preserved = init_aux_gpr(reg_increments, pool_gpr_idxs, in); // loop_end doesn't have output registers
if (preserved)
h->push(reg_increments);

Reg64 reg_runtime_params = abi_param1; /* defined by jit_kernel_emitter */
h->mov(reg_increments, h->ptr[reg_runtime_params + GET_OFF(loop_args)]);
h->mov(reg_increments, h->ptr[reg_increments + id_offset + field_offset]);
Expand All @@ -173,6 +182,9 @@ void jit_loop_end_emitter::emit_impl(const std::vector<size_t>& in, const std::v
}
}
}

if (preserved)
h->pop(reg_increments);
};

if (!evaluate_once) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ class jit_loop_begin_emitter: public jit_emitter {

protected:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

size_t aux_gprs_count() const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override {}

std::shared_ptr<Xbyak::Label> loop_begin_label = nullptr;
std::shared_ptr<const Xbyak::Label> loop_end_label = nullptr;
Expand Down Expand Up @@ -59,9 +57,7 @@ class jit_loop_end_emitter: public jit_emitter {

protected:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

size_t aux_gprs_count() const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override {}

static ov::snippets::lowered::ExpressionPtr get_loop_begin_expr(const ov::snippets::lowered::ExpressionPtr& expr);

Expand Down

0 comments on commit 5c4ccea

Please sign in to comment.