Skip to content

Commit

Permalink
[CIR][OpenMP] Taskwait, Taskyield and Barrier implementation (#555)
Browse files Browse the repository at this point in the history
This PR is the final fix for issue #499.
  • Loading branch information
eZWALT authored May 1, 2024
1 parent 09c86ed commit c9b4fdc
Show file tree
Hide file tree
Showing 14 changed files with 243 additions and 3 deletions.
5 changes: 5 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"

namespace clang {
class Expr;
Expand Down Expand Up @@ -993,6 +994,10 @@ class CIRGenFunction : public CIRGenTypeCache {

// OpenMP gen functions:
mlir::LogicalResult buildOMPParallelDirective(const OMPParallelDirective &S);
mlir::LogicalResult buildOMPTaskwaitDirective(const OMPTaskwaitDirective &S);
mlir::LogicalResult
buildOMPTaskyieldDirective(const OMPTaskyieldDirective &S);
mlir::LogicalResult buildOMPBarrierDirective(const OMPBarrierDirective &S);

LValue buildOpaqueValueLValue(const OpaqueValueExpr *e);

Expand Down
52 changes: 52 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,55 @@ bool CIRGenOpenMPRuntime::emitTargetGlobal(clang::GlobalDecl &GD) {
assert(!UnimplementedFeature::openMPRuntime());
return false;
}

void CIRGenOpenMPRuntime::emitTaskWaitCall(CIRGenBuilderTy &builder,
CIRGenFunction &CGF,
mlir::Location Loc,
const OMPTaskDataTy &Data) {

if (!CGF.HaveInsertPoint())
return;

if (CGF.CGM.getLangOpts().OpenMPIRBuilder && Data.Dependences.empty()) {
// TODO: Need to support taskwait with dependences in the OpenMPIRBuilder.
// TODO(cir): This could change in the near future when OpenMP 5.0 gets
// supported by MLIR
builder.create<mlir::omp::TaskwaitOp>(Loc);
} else {
llvm_unreachable("NYI");
}
assert(!UnimplementedFeature::openMPRegionInfo());
}

void CIRGenOpenMPRuntime::emitBarrierCall(CIRGenBuilderTy &builder,
CIRGenFunction &CGF,
mlir::Location Loc) {

assert(!UnimplementedFeature::openMPRegionInfo());

if (CGF.CGM.getLangOpts().OpenMPIRBuilder) {
builder.create<mlir::omp::BarrierOp>(Loc);
return;
}

if (!CGF.HaveInsertPoint())
return;

llvm_unreachable("NYI");
}

void CIRGenOpenMPRuntime::emitTaskyieldCall(CIRGenBuilderTy &builder,
CIRGenFunction &CGF,
mlir::Location Loc) {

if (!CGF.HaveInsertPoint())
return;

if (CGF.CGM.getLangOpts().OpenMPIRBuilder) {
builder.create<mlir::omp::TaskyieldOp>(Loc);
} else {
llvm_unreachable("NYI");
}

assert(!UnimplementedFeature::openMPRegionInfo());
}
36 changes: 36 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenOpenMPRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,21 @@
#ifndef LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENOPENMPRUNTIME_H
#define LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENOPENMPRUNTIME_H

#include "CIRGenBuilder.h"
#include "CIRGenValue.h"

#include "clang/AST/Redeclarable.h"
#include "clang/Basic/OpenMPKinds.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"

#include "llvm/Support/ErrorHandling.h"

#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"

#include "UnimplementedFeatureGuarding.h"

namespace clang {
class Decl;
class Expr;
Expand All @@ -27,6 +39,20 @@ namespace cir {
class CIRGenModule;
class CIRGenFunction;

struct OMPTaskDataTy final {
struct DependData {
clang::OpenMPDependClauseKind DepKind = clang::OMPC_DEPEND_unknown;
const clang::Expr *IteratorExpr = nullptr;
llvm::SmallVector<const clang::Expr *, 4> DepExprs;
explicit DependData() = default;
DependData(clang::OpenMPDependClauseKind DepKind,
const clang::Expr *IteratorExpr)
: DepKind(DepKind), IteratorExpr(IteratorExpr) {}
};
llvm::SmallVector<DependData, 4> Dependences;
bool HasNowaitClause = false;
};

class CIRGenOpenMPRuntime {
public:
explicit CIRGenOpenMPRuntime(CIRGenModule &CGM);
Expand Down Expand Up @@ -69,6 +95,16 @@ class CIRGenOpenMPRuntime {
/// \param GD Global to scan.
virtual bool emitTargetGlobal(clang::GlobalDecl &D);

/// Emit code for 'taskwait' directive
virtual void emitTaskWaitCall(CIRGenBuilderTy &builder, CIRGenFunction &CGF,
mlir::Location Loc, const OMPTaskDataTy &Data);

virtual void emitBarrierCall(CIRGenBuilderTy &builder, CIRGenFunction &CGF,
mlir::Location Loc);

virtual void emitTaskyieldCall(CIRGenBuilderTy &builder, CIRGenFunction &CGF,
mlir::Location Loc);

protected:
CIRGenModule &CGM;
};
Expand Down
9 changes: 6 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ mlir::LogicalResult CIRGenFunction::buildStmt(const Stmt *S,
// OMP directives:
case Stmt::OMPParallelDirectiveClass:
return buildOMPParallelDirective(cast<OMPParallelDirective>(*S));
case Stmt::OMPTaskwaitDirectiveClass:
return buildOMPTaskwaitDirective(cast<OMPTaskwaitDirective>(*S));
case Stmt::OMPTaskyieldDirectiveClass:
return buildOMPTaskyieldDirective(cast<OMPTaskyieldDirective>(*S));
case Stmt::OMPBarrierDirectiveClass:
return buildOMPBarrierDirective(cast<OMPBarrierDirective>(*S));
// Unsupported AST nodes:
case Stmt::CapturedStmtClass:
case Stmt::ObjCAtTryStmtClass:
Expand All @@ -205,9 +211,6 @@ mlir::LogicalResult CIRGenFunction::buildStmt(const Stmt *S,
case Stmt::OMPParallelMasterDirectiveClass:
case Stmt::OMPParallelSectionsDirectiveClass:
case Stmt::OMPTaskDirectiveClass:
case Stmt::OMPTaskyieldDirectiveClass:
case Stmt::OMPBarrierDirectiveClass:
case Stmt::OMPTaskwaitDirectiveClass:
case Stmt::OMPTaskgroupDirectiveClass:
case Stmt::OMPFlushDirectiveClass:
case Stmt::OMPDepobjDirectiveClass:
Expand Down
76 changes: 76 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,63 @@
// This contains code to emit OpenMP Stmt nodes as MLIR code.
//
//===----------------------------------------------------------------------===//
#include "clang/AST/ASTFwd.h"
#include "clang/AST/StmtIterator.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/Basic/OpenMPKinds.h"

#include "CIRGenFunction.h"
#include "CIRGenOpenMPRuntime.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"

using namespace cir;
using namespace clang;
using namespace mlir::omp;

static void buildDependences(const OMPExecutableDirective &S,
OMPTaskDataTy &Data) {

// First look for 'omp_all_memory' and add this first.
bool OmpAllMemory = false;
if (llvm::any_of(
S.getClausesOfKind<OMPDependClause>(), [](const OMPDependClause *C) {
return C->getDependencyKind() == OMPC_DEPEND_outallmemory ||
C->getDependencyKind() == OMPC_DEPEND_inoutallmemory;
})) {
OmpAllMemory = true;
// Since both OMPC_DEPEND_outallmemory and OMPC_DEPEND_inoutallmemory are
// equivalent to the runtime, always use OMPC_DEPEND_outallmemory to
// simplify.
OMPTaskDataTy::DependData &DD =
Data.Dependences.emplace_back(OMPC_DEPEND_outallmemory,
/*IteratorExpr=*/nullptr);
// Add a nullptr Expr to simplify the codegen in emitDependData.
DD.DepExprs.push_back(nullptr);
}
// Add remaining dependences skipping any 'out' or 'inout' if they are
// overridden by 'omp_all_memory'.
for (const auto *C : S.getClausesOfKind<OMPDependClause>()) {
OpenMPDependClauseKind Kind = C->getDependencyKind();
if (Kind == OMPC_DEPEND_outallmemory || Kind == OMPC_DEPEND_inoutallmemory)
continue;
if (OmpAllMemory && (Kind == OMPC_DEPEND_out || Kind == OMPC_DEPEND_inout))
continue;
OMPTaskDataTy::DependData &DD =
Data.Dependences.emplace_back(C->getDependencyKind(), C->getModifier());
DD.DepExprs.append(C->varlist_begin(), C->varlist_end());
}
}

mlir::LogicalResult
CIRGenFunction::buildOMPParallelDirective(const OMPParallelDirective &S) {
mlir::LogicalResult res = mlir::success();
Expand All @@ -43,3 +91,31 @@ CIRGenFunction::buildOMPParallelDirective(const OMPParallelDirective &S) {
builder.create<TerminatorOp>(getLoc(S.getSourceRange().getEnd()));
return res;
}

mlir::LogicalResult
CIRGenFunction::buildOMPTaskwaitDirective(const OMPTaskwaitDirective &S) {
mlir::LogicalResult res = mlir::success();
OMPTaskDataTy Data;
buildDependences(S, Data);
Data.HasNowaitClause = S.hasClausesOfKind<OMPNowaitClause>();
CGM.getOpenMPRuntime().emitTaskWaitCall(builder, *this,
getLoc(S.getSourceRange()), Data);
return res;
}
mlir::LogicalResult
CIRGenFunction::buildOMPTaskyieldDirective(const OMPTaskyieldDirective &S) {
mlir::LogicalResult res = mlir::success();
// Creation of an omp.taskyield operation
CGM.getOpenMPRuntime().emitTaskyieldCall(builder, *this,
getLoc(S.getSourceRange()));
return res;
}

mlir::LogicalResult
CIRGenFunction::buildOMPBarrierDirective(const OMPBarrierDirective &S) {
mlir::LogicalResult res = mlir::success();
// Creation of an omp.barrier operation
CGM.getOpenMPRuntime().emitBarrierCall(builder, *this,
getLoc(S.getSourceRange()));
return res;
}
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ struct UnimplementedFeature {
static bool CUDA() { return false; }
static bool openMP() { return false; }
static bool openMPRuntime() { return false; }
static bool openMPRegionInfo() { return false; }
static bool openMPTarget() { return false; }
static bool isVarArg() { return false; }
static bool setNonGC() { return false; }
Expand Down
8 changes: 8 additions & 0 deletions clang/test/CIR/CodeGen/OpenMP/barrier.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fopenmp-enable-irbuilder -fopenmp -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

// CHECK: cir.func
void omp_barrier_1(){
// CHECK: omp.barrier
#pragma omp barrier
}
File renamed without changes.
8 changes: 8 additions & 0 deletions clang/test/CIR/CodeGen/OpenMP/taskwait.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fopenmp-enable-irbuilder -fopenmp -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

// CHECK: cir.func
void omp_taskwait_1(){
// CHECK: omp.taskwait
#pragma omp taskwait
}
8 changes: 8 additions & 0 deletions clang/test/CIR/CodeGen/OpenMP/taskyield.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fopenmp-enable-irbuilder -fopenmp -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s

// CHECK: cir.func
void omp_taskyield_1(){
// CHECK: omp.taskyield
#pragma omp taskyield
}
15 changes: 15 additions & 0 deletions clang/test/CIR/Lowering/OpenMP/barrier.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s


module {
cir.func @omp_barrier_1() {
omp.barrier
cir.return
}
}

// CHECK: define void @omp_barrier_1()
// CHECK: call i32 @__kmpc_global_thread_num(ptr {{.*}})
// CHECK: call void @__kmpc_barrier(ptr {{.*}}, i32 {{.*}})
// CHECK: ret void
File renamed without changes.
14 changes: 14 additions & 0 deletions clang/test/CIR/Lowering/OpenMP/taskwait.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s


module {
cir.func @omp_taskwait_1() {
omp.taskwait
cir.return
}
}

// CHECK: define void @omp_taskwait_1()
// CHECK: call i32 @__kmpc_global_thread_num(ptr {{.*}})
// CHECK: call i32 @__kmpc_omp_taskwait(ptr {{.*}}, i32 {{.*}})
// CHECK: ret void
14 changes: 14 additions & 0 deletions clang/test/CIR/Lowering/OpenMP/taskyield.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s


module {
cir.func @omp_taskyield_1() {
omp.taskyield
cir.return
}
}

// CHECK: define void @omp_taskyield_1()
// CHECK: call i32 @__kmpc_global_thread_num(ptr {{.*}})
// CHECK: call i32 @__kmpc_omp_taskyield(ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
// CHECK: ret void

0 comments on commit c9b4fdc

Please sign in to comment.