From 66b92e6c1da58008267dcd5ccaebb86bb688cbd0 Mon Sep 17 00:00:00 2001 From: Peter Goodman Date: Thu, 7 Dec 2023 16:09:08 -0500 Subject: [PATCH] Enables VAST codegen to be used directly on a TranslatioUnitDecl, independent of an ASTConsumer. Thrust of the changes focuses on codegening function bodies only when the recursive decl visitor actually visits those bodies. --- include/vast/CodeGen/CodeGenDeclVisitor.hpp | 63 +++++++++---------- include/vast/CodeGen/CodeGenStmtVisitor.hpp | 36 +++++------ include/vast/CodeGen/CodeGenTypeVisitor.hpp | 2 +- .../vast/Dialect/HighLevel/HighLevelTypes.td | 10 ++- 4 files changed, 54 insertions(+), 57 deletions(-) diff --git a/include/vast/CodeGen/CodeGenDeclVisitor.hpp b/include/vast/CodeGen/CodeGenDeclVisitor.hpp index 431b366bf3..5202f2488c 100644 --- a/include/vast/CodeGen/CodeGenDeclVisitor.hpp +++ b/include/vast/CodeGen/CodeGenDeclVisitor.hpp @@ -332,18 +332,9 @@ namespace vast::cg { operation VisitFunctionLikeDecl(const Decl *decl) { auto gdecl = get_gdecl(decl); auto mangled = context().get_mangled_name(gdecl); - - if (auto fn = context().lookup_function(mangled, false /* emit no error */)) { - return fn; - } - - auto guard = insertion_guard(); auto is_definition = decl->isThisDeclarationADefinition(); - - // emit definition instead of declaration - if (!is_definition && decl->getDefinition()) { - return visit(decl->getDefinition()); - } + auto fn = context().lookup_function(mangled, false /* emit no error */); + auto guard = insertion_guard(); auto is_terminator = [] (auto &op) { return op.template hasTrait< mlir::OpTrait::IsTerminator >() || @@ -354,13 +345,13 @@ namespace vast::cg { // In MLIR the entry block of the function must have the same // argument list as the function itself. // FIXME: driver solves this already - auto params = llvm::zip(decl->getDefinition()->parameters(), entry->getArguments()); + auto params = llvm::zip(decl->parameters(), entry->getArguments()); for (const auto &[arg, earg] : params) { context().declare(arg, mlir_value(earg)); } }; - auto emit_function_terminator = [&] (auto fn) { + auto emit_function_terminator = [&] () { auto loc = fn.getLoc(); if (decl->getReturnType()->isVoidType()) { auto void_val = constant(loc); @@ -377,7 +368,7 @@ namespace vast::cg { } }; - auto emit_function_body = [&] (auto fn) { + auto emit_function_body = [&] () { auto entry = fn.addEntryBlock(); set_insertion_point_to_start(entry); @@ -430,32 +421,38 @@ namespace vast::cg { || !last_op || !is_terminator(*last_op)) { - emit_function_terminator(fn); + emit_function_terminator(); } }; llvm::ScopedHashTableScope scope(context().vars); - auto linkage = core::get_function_linkage(gdecl); - - auto fn = context().declare(mangled, [&] () { - auto loc = meta_location(decl); - auto type = visit_function_type(decl->getFunctionType(), decl->isVariadic()); - // make function header, that will be later filled with function body - // or returned as declaration in the case of external function - return make< hl::FuncOp >(loc, mangled.name, type, linkage); - }); + auto def = decl->getDefinition(); + auto linkage = core::get_function_linkage(def ? get_gdecl(def) : gdecl); + + if (!fn) { + fn = context().declare(mangled, [&] () { + auto loc = meta_location(decl); + auto type = visit_function_type(decl->getFunctionType(), decl->isVariadic()); + + // make function header, that will be later filled with function body + // or returned as declaration in the case of external function + auto ret = make< hl::FuncOp >(loc, mangled.name, type, linkage); + + // MLIR requires declrations to have private visibility + ret.setVisibility(mlir::SymbolTable::Visibility::Private); + + return ret; + }); + } if (!is_definition) { - // MLIR requires declrations to have private visibility - fn.setVisibility(mlir::SymbolTable::Visibility::Private); return fn; } fn.setVisibility(core::get_visibility_from_linkage(linkage)); - if (fn.empty()) { - emit_function_body(fn); + emit_function_body(); } return fn; @@ -564,12 +561,10 @@ namespace vast::cg { // operation VisitLinkageSpecDecl(const clang::LinkageSpecDecl *decl) operation VisitTranslationUnitDecl(const clang::TranslationUnitDecl *tu) { - auto loc = meta_location(tu); - return derived().template make_scoped< TranslationUnitScope >(loc, [&] { - for (const auto &decl : tu->decls()) { - visit(decl); - } - }); + for (const auto &decl : tu->decls()) { + visit(decl); + } + return {}; } // operation VisitTypedefNameDecl(const clang::TypedefNameDecl *decl) diff --git a/include/vast/CodeGen/CodeGenStmtVisitor.hpp b/include/vast/CodeGen/CodeGenStmtVisitor.hpp index 4e5f29b671..9f75ab149a 100644 --- a/include/vast/CodeGen/CodeGenStmtVisitor.hpp +++ b/include/vast/CodeGen/CodeGenStmtVisitor.hpp @@ -663,16 +663,12 @@ namespace vast::cg { return visit_as_lvalue_type(expr->getType()); } - const clang::VarDecl * getDeclForVarRef(const clang::DeclRefExpr *expr) { - return clang::cast< clang::VarDecl >(expr->getDecl()->getUnderlyingDecl()); - } - hl::VarDeclOp getDefiningOpOfGlobalVar(const clang::VarDecl *decl) { return context().vars.lookup(decl).template getDefiningOp< hl::VarDeclOp >(); } - operation VisitEnumDeclRefExpr(const clang::DeclRefExpr *expr) { - auto decl = clang::cast< clang::EnumConstantDecl >(expr->getDecl()->getUnderlyingDecl()); + operation VisitEnumDeclRefExpr(const clang::DeclRefExpr *expr, const clang::Decl *underlying_decl) { + auto decl = clang::cast< clang::EnumConstantDecl >( underlying_decl )->getFirstDecl(); if (auto val = context().enumconsts.lookup(decl)) { auto rty = visit(expr->getType()); return make< hl::EnumRefOp >(meta_location(expr), rty, val.getName()); @@ -689,8 +685,8 @@ namespace vast::cg { return make< hl::DeclRefOp >(meta_location(expr), rty, var); } - operation VisitVarDeclRefExpr(const clang::DeclRefExpr *expr) { - auto decl = getDeclForVarRef(expr); + operation VisitVarDeclRefExpr(const clang::DeclRefExpr *expr, const clang::Decl *underlying_decl) { + auto decl = clang::cast< clang::VarDecl >( underlying_decl )->getFirstDecl(); if (auto var = context().vars.lookup(decl)) { return VisitVarDeclRefExprImpl(expr, var); } @@ -701,8 +697,8 @@ namespace vast::cg { return nullptr; } - operation VisitFileVarDeclRefExpr(const clang::DeclRefExpr *expr) { - auto decl = getDeclForVarRef(expr); + operation VisitFileVarDeclRefExpr(const clang::DeclRefExpr *expr, const clang::Decl *underlying_decl) { + auto decl = clang::cast< clang::VarDecl >( underlying_decl )->getFirstDecl(); if (!context().vars.lookup(decl)) { // Ref: https://github.com/trailofbits/vast/issues/384 // github issue to avoid emitting error if declaration is missing @@ -716,8 +712,8 @@ namespace vast::cg { return make< hl::GlobalRefOp >(meta_location(expr), rty, name); } - operation VisitFunctionDeclRefExpr(const clang::DeclRefExpr *expr) { - auto decl = clang::cast< clang::FunctionDecl >( expr->getDecl()->getUnderlyingDecl() ); + operation VisitFunctionDeclRefExpr(const clang::DeclRefExpr *expr, const clang::Decl *underlying_decl) { + auto decl = clang::cast< clang::FunctionDecl >( underlying_decl )->getFirstDecl(); auto mangled = context().get_mangled_name(decl); auto fn = context().lookup_function(mangled, false); if (!fn) { @@ -734,17 +730,17 @@ namespace vast::cg { auto underlying = expr->getDecl()->getUnderlyingDecl(); if (clang::isa< clang::EnumConstantDecl >(underlying)) { - return VisitEnumDeclRefExpr(expr); + return VisitEnumDeclRefExpr(expr, underlying); } if (auto decl = clang::dyn_cast< clang::VarDecl >(underlying)) { if (decl->isFileVarDecl()) - return VisitFileVarDeclRefExpr(expr); - return VisitVarDeclRefExpr(expr); + return VisitFileVarDeclRefExpr(expr, underlying); + return VisitVarDeclRefExpr(expr, underlying); } if (clang::isa< clang::FunctionDecl >(underlying)) { - return VisitFunctionDeclRefExpr(expr); + return VisitFunctionDeclRefExpr(expr, underlying); } VAST_UNIMPLEMENTED_MSG("unknown underlying declaration to be referenced"); @@ -969,8 +965,8 @@ namespace vast::cg { return args; } - operation VisitDirectCall(const clang::CallExpr *expr) { - auto callee = VisitDirectCallee(expr->getDirectCallee()); + operation VisitDirectCall(const clang::CallExpr *expr, const clang::Decl *decl) { + auto callee = VisitDirectCallee(clang::cast< clang::FunctionDecl >( decl )); auto args = VisitArguments(expr); return make< hl::CallOp >(meta_location(expr), callee, args); } @@ -989,8 +985,8 @@ namespace vast::cg { } operation VisitCallExpr(const clang::CallExpr *expr) { - if (expr->getDirectCallee()) { - return VisitDirectCall(expr); + if (auto callee = expr->getDirectCallee()) { + return VisitDirectCall(expr, callee->getFirstDecl()); } return VisitIndirectCall(expr); diff --git a/include/vast/CodeGen/CodeGenTypeVisitor.hpp b/include/vast/CodeGen/CodeGenTypeVisitor.hpp index 4a4a5beaf7..e651617752 100644 --- a/include/vast/CodeGen/CodeGenTypeVisitor.hpp +++ b/include/vast/CodeGen/CodeGenTypeVisitor.hpp @@ -117,7 +117,7 @@ namespace vast::cg { auto with_qualifiers(const clang::EnumType *ty, qualifiers quals) -> mlir_type { auto name = make_name_attr( context().decl_name(ty->getDecl()) ); - return with_cv_qualifiers( type_builder< hl::RecordType >().bind(name), quals ).freeze(); + return with_cv_qualifiers( type_builder< hl::EnumType >().bind(name), quals ).freeze(); } auto with_qualifiers(const clang::TypedefType *ty, qualifiers quals) -> mlir_type { diff --git a/include/vast/Dialect/HighLevel/HighLevelTypes.td b/include/vast/Dialect/HighLevel/HighLevelTypes.td index 12c46eebcf..651dc36b63 100644 --- a/include/vast/Dialect/HighLevel/HighLevelTypes.td +++ b/include/vast/Dialect/HighLevel/HighLevelTypes.td @@ -170,6 +170,11 @@ class IsMaybeWrappedTypedef< string arg > : PredOpTrait< CPred< "::vast::hl::strip_elaborated($" # arg # ").hasTrait< mlir::TypeTrait::TypedefTrait >()" > >; +class IsMaybeWrappedEnum< string arg > : PredOpTrait< + "type is an enum, maybe wrapped in the elaborated type", + CPred< "::vast::hl::strip_elaborated($" # arg # ").isa< ::vast::hl::EnumType >()" > +>; + class ContainsTypedef< list< string > tl > : PredOpTrait< "contains typedef type", Or< !foreach(type, tl, IsMaybeWrappedTypedef< type >.predicate) > >; @@ -264,12 +269,13 @@ def LongLongType : IntegerType< "LongLong", "longlong", [LongLongTypeTrait] >; def Int128Type : IntegerType< "Int128", "int128", [Int128TypeTrait] >; def HLIntegerType : AnyTypeOf<[ - CharType, ShortType, IntType, LongType, LongLongType, Int128Type, TypedefType + CharType, ShortType, IntType, LongType, LongLongType, Int128Type, TypedefType, EnumType ]>; def IntegerLikeType : TypeConstraint< Or< [HLIntegerType.predicate, AnyInteger.predicate, - IsMaybeWrappedTypedef< "_self" >.predicate] >, + IsMaybeWrappedTypedef< "_self" >.predicate, + IsMaybeWrappedEnum< "_self" >.predicate] >, "integer like type" >;