From 4f0756ad578b8e5ebaa70db6739dc9f983cda118 Mon Sep 17 00:00:00 2001 From: Hugo Devillers Date: Fri, 5 Jul 2024 11:35:45 +0200 Subject: [PATCH] updat the spirv backend to the new vector type ctors --- src/thorin/be/spirv/spirv.cpp | 37 ++++++++++++++++++----------- src/thorin/be/spirv/spirv_types.cpp | 10 ++++---- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/thorin/be/spirv/spirv.cpp b/src/thorin/be/spirv/spirv.cpp index 1e4e66532..a14af8452 100644 --- a/src/thorin/be/spirv/spirv.cpp +++ b/src/thorin/be/spirv/spirv.cpp @@ -47,21 +47,21 @@ case PrimType_##T: \ default: THORIN_UNREACHABLE; } } -inline const PrimType* get_primtype(World& world, PrimTypeKind kind, int bitwidth, int length) { +inline const PrimType* get_scalar_primtype(World& world, PrimTypeKind kind, int bitwidth) { #define GET_PRIMTYPE_WITH_KIND(kind) \ switch (bitwidth) { \ - case 8: return world.type_p##kind##8 (length); \ - case 16: return world.type_p##kind##16(length); \ - case 32: return world.type_p##kind##32(length); \ - case 64: return world.type_p##kind##64(length); \ + case 8: return world.type_p##kind##8 (); \ + case 16: return world.type_p##kind##16(); \ + case 32: return world.type_p##kind##32(); \ + case 64: return world.type_p##kind##64(); \ } #define GET_PRIMTYPE_WITH_KIND_F(kind) \ switch (bitwidth) { \ case 8: world.ELOG("8-bit floats do not exist"); \ - case 16: return world.type_p##kind##16(length); \ - case 32: return world.type_p##kind##32(length); \ - case 64: return world.type_p##kind##64(length); \ + case 16: return world.type_p##kind##16(); \ + case 32: return world.type_p##kind##32(); \ + case 64: return world.type_p##kind##64(); \ } switch (kind) { @@ -75,6 +75,11 @@ switch (bitwidth) { \ #undef GET_PRIMTYPE_WITH_KIND_F } + +inline const Type* get_primtype(World& world, PrimTypeKind kind, int bitwidth, int length) { + return world.vector_or_scalar_type(get_scalar_primtype(world, kind, bitwidth), length); +} + BasicBlockBuilder::BasicBlockBuilder(FnBuilder& fn_builder) : builder::SpvBasicBlockBuilder(fn_builder.file_builder), fn_builder(fn_builder), file_builder(fn_builder.file_builder) { label = file_builder.generate_fresh_id(); @@ -625,13 +630,15 @@ SpvId CodeGen::emit_bb(BasicBlockBuilder* bb, const Def* def) { } else if (auto slot = def->isa()) { emit_unsafe(slot->frame()); auto type = slot->type(); - auto id = bb->fn_builder.variable(convert(world().ptr_type(type->pointee(), 1, AddrSpace::Function)).id, spv::StorageClass::StorageClassFunction); + auto id = bb->fn_builder.variable(convert(world().ptr_type(type->pointee(), AddrSpace::Function)).id, spv::StorageClass::StorageClassFunction); id = bb->convert(spv::Op::OpBitcast, convert(type).id, id); return id; } else if (auto enter = def->isa()) { return emit_unsafe(enter->mem()); } else if (auto lea = def->isa()) { - switch (lea->type()->addr_space()) { + auto [ptr_t, len] = deconstruct_vector_type(lea->type()); + assert(len == 1 && "TODO: spirv only supports scalar LEA"); + switch (ptr_t->addr_space()) { case AddrSpace::Global: case AddrSpace::Shared: break; @@ -729,12 +736,14 @@ SpvId CodeGen::emit_bb(BasicBlockBuilder* bb, const Def* def) { return bb->convert(spv::OpBitcast, convert(bitcast->type()).id, emit(bitcast->from())); } else if (auto cast = def->isa()) { // NB: all ops used here are scalar/vector agnostic - auto src_prim = src_type->isa(); - auto dst_prim = dst_type->isa(); - if (!src_prim || !dst_prim || src_prim->length() != dst_prim->length()) + auto [src_scalar_type, src_len] = deconstruct_vector_type(src_type); + auto [dst_scalar_type, dst_len] = deconstruct_vector_type(dst_type); + auto src_prim = src_scalar_type->isa(); + auto dst_prim = dst_scalar_type->isa(); + if (!src_prim || !dst_prim || src_len != dst_len) world().ELOG("Illegal cast: % to %, casts are only supported between primitives with identical vector length", src_type->to_string(), dst_type->to_string()); - auto length = src_prim->length(); + auto length = src_len; auto src_kind = classify_primtype(src_prim); auto dst_kind = classify_primtype(dst_prim); diff --git a/src/thorin/be/spirv/spirv_types.cpp b/src/thorin/be/spirv/spirv_types.cpp index 90a718697..63f32ab8e 100644 --- a/src/thorin/be/spirv/spirv_types.cpp +++ b/src/thorin/be/spirv/spirv_types.cpp @@ -43,7 +43,7 @@ ConvertedType CodeGen::convert(const Type* type) { switch (type->tag()) { #define THORIN_Q_TYPE(T, M) \ case PrimType_##T: \ - type = world().prim_type(PrimType_p##M, type->as()->length()); \ + type = world().prim_type(PrimType_p##M); \ break; #include "thorin/tables/primtypetable.h" #undef THORIN_Q_TYPE @@ -54,16 +54,16 @@ ConvertedType CodeGen::convert(const Type* type) { if (target_info_.dialect == Target::OpenCL) { switch (type->tag()) { case Node_PrimType_ps8: - type = world().prim_type(PrimType_pu8, type->as()->length()); \ + type = world().prim_type(PrimType_pu8); \ break; case Node_PrimType_ps16: - type = world().prim_type(PrimType_pu16, type->as()->length()); \ + type = world().prim_type(PrimType_pu16); \ break; case Node_PrimType_ps32: - type = world().prim_type(PrimType_pu32, type->as()->length()); \ + type = world().prim_type(PrimType_pu32); \ break; case Node_PrimType_ps64: - type = world().prim_type(PrimType_pu64, type->as()->length()); \ + type = world().prim_type(PrimType_pu64); \ break; default: break; }