Skip to content

Commit

Permalink
updat the spirv backend to the new vector type ctors
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugobros3 committed Jul 5, 2024
1 parent 7ddd364 commit 4f0756a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
37 changes: 23 additions & 14 deletions src/thorin/be/spirv/spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ case PrimType_##T: \
default: THORIN_UNREACHABLE;
}
}
inline const PrimType* get_primtype(World& world, PrimTypeKind kind, int bitwidth, int length) {
inline const PrimType* get_scalar_primtype(World& world, PrimTypeKind kind, int bitwidth) {
#define GET_PRIMTYPE_WITH_KIND(kind) \
switch (bitwidth) { \
case 8: return world.type_p##kind##8 (length); \
case 16: return world.type_p##kind##16(length); \
case 32: return world.type_p##kind##32(length); \
case 64: return world.type_p##kind##64(length); \
case 8: return world.type_p##kind##8 (); \
case 16: return world.type_p##kind##16(); \
case 32: return world.type_p##kind##32(); \
case 64: return world.type_p##kind##64(); \
}

#define GET_PRIMTYPE_WITH_KIND_F(kind) \
switch (bitwidth) { \
case 8: world.ELOG("8-bit floats do not exist"); \
case 16: return world.type_p##kind##16(length); \
case 32: return world.type_p##kind##32(length); \
case 64: return world.type_p##kind##64(length); \
case 16: return world.type_p##kind##16(); \
case 32: return world.type_p##kind##32(); \
case 64: return world.type_p##kind##64(); \
}

switch (kind) {
Expand All @@ -75,6 +75,11 @@ switch (bitwidth) { \
#undef GET_PRIMTYPE_WITH_KIND_F
}


inline const Type* get_primtype(World& world, PrimTypeKind kind, int bitwidth, int length) {
return world.vector_or_scalar_type(get_scalar_primtype(world, kind, bitwidth), length);
}

BasicBlockBuilder::BasicBlockBuilder(FnBuilder& fn_builder)
: builder::SpvBasicBlockBuilder(fn_builder.file_builder), fn_builder(fn_builder), file_builder(fn_builder.file_builder) {
label = file_builder.generate_fresh_id();
Expand Down Expand Up @@ -625,13 +630,15 @@ SpvId CodeGen::emit_bb(BasicBlockBuilder* bb, const Def* def) {
} else if (auto slot = def->isa<Slot>()) {
emit_unsafe(slot->frame());
auto type = slot->type();
auto id = bb->fn_builder.variable(convert(world().ptr_type(type->pointee(), 1, AddrSpace::Function)).id, spv::StorageClass::StorageClassFunction);
auto id = bb->fn_builder.variable(convert(world().ptr_type(type->pointee(), AddrSpace::Function)).id, spv::StorageClass::StorageClassFunction);
id = bb->convert(spv::Op::OpBitcast, convert(type).id, id);
return id;
} else if (auto enter = def->isa<Enter>()) {
return emit_unsafe(enter->mem());
} else if (auto lea = def->isa<LEA>()) {
switch (lea->type()->addr_space()) {
auto [ptr_t, len] = deconstruct_vector_type<PtrType>(lea->type());
assert(len == 1 && "TODO: spirv only supports scalar LEA");
switch (ptr_t->addr_space()) {
case AddrSpace::Global:
case AddrSpace::Shared:
break;
Expand Down Expand Up @@ -729,12 +736,14 @@ SpvId CodeGen::emit_bb(BasicBlockBuilder* bb, const Def* def) {
return bb->convert(spv::OpBitcast, convert(bitcast->type()).id, emit(bitcast->from()));
} else if (auto cast = def->isa<Cast>()) {
// NB: all ops used here are scalar/vector agnostic
auto src_prim = src_type->isa<PrimType>();
auto dst_prim = dst_type->isa<PrimType>();
if (!src_prim || !dst_prim || src_prim->length() != dst_prim->length())
auto [src_scalar_type, src_len] = deconstruct_vector_type(src_type);
auto [dst_scalar_type, dst_len] = deconstruct_vector_type(dst_type);
auto src_prim = src_scalar_type->isa<PrimType>();
auto dst_prim = dst_scalar_type->isa<PrimType>();
if (!src_prim || !dst_prim || src_len != dst_len)
world().ELOG("Illegal cast: % to %, casts are only supported between primitives with identical vector length", src_type->to_string(), dst_type->to_string());

auto length = src_prim->length();
auto length = src_len;

auto src_kind = classify_primtype(src_prim);
auto dst_kind = classify_primtype(dst_prim);
Expand Down
10 changes: 5 additions & 5 deletions src/thorin/be/spirv/spirv_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ConvertedType CodeGen::convert(const Type* type) {
switch (type->tag()) {
#define THORIN_Q_TYPE(T, M) \
case PrimType_##T: \
type = world().prim_type(PrimType_p##M, type->as<VectorType>()->length()); \
type = world().prim_type(PrimType_p##M); \
break;
#include "thorin/tables/primtypetable.h"
#undef THORIN_Q_TYPE
Expand All @@ -54,16 +54,16 @@ ConvertedType CodeGen::convert(const Type* type) {
if (target_info_.dialect == Target::OpenCL) {
switch (type->tag()) {
case Node_PrimType_ps8:
type = world().prim_type(PrimType_pu8, type->as<VectorType>()->length()); \
type = world().prim_type(PrimType_pu8); \
break;
case Node_PrimType_ps16:
type = world().prim_type(PrimType_pu16, type->as<VectorType>()->length()); \
type = world().prim_type(PrimType_pu16); \
break;
case Node_PrimType_ps32:
type = world().prim_type(PrimType_pu32, type->as<VectorType>()->length()); \
type = world().prim_type(PrimType_pu32); \
break;
case Node_PrimType_ps64:
type = world().prim_type(PrimType_pu64, type->as<VectorType>()->length()); \
type = world().prim_type(PrimType_pu64); \
break;
default: break;
}
Expand Down

0 comments on commit 4f0756a

Please sign in to comment.