Skip to content

Commit

Permalink
[flang][cuda] Support data transfer from descriptor to a pointer (#11…
Browse files Browse the repository at this point in the history
…5023)

Data transfer from a variable with a descriptor to a pointer. We create
a descriptor for the pointer so we can use the flang runtime to perform
the transfer. The Assign function handles all corner cases. We add a new
entry points `CUFDataTransferDescDescNoRealloc` to avoid reallocation
since the variable on the LHS is not an allocatable.
  • Loading branch information
clementval authored Nov 5, 2024
1 parent 17d9565 commit db69d69
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 49 deletions.
4 changes: 4 additions & 0 deletions flang/include/flang/Runtime/CUDA/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a descriptor to a descriptor.
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a descriptor to a global descriptor.
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
Expand Down
51 changes: 14 additions & 37 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,50 +581,27 @@ struct CUFDataTransferOpConversion
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
// Type used to compute the width.
mlir::Type computeType = dstTy;
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
computeType = srcTy;
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
}
int width = computeWidth(loc, computeType, kindMap);
// Transfer from a descriptor.

mlir::Value nbElement;
mlir::Type idxTy = rewriter.getIndexType();
if (!op.getShape()) {
nbElement = rewriter.create<mlir::arith::ConstantOp>(
loc, idxTy,
rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
} else {
auto shapeOp =
mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
nbElement =
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
auto operand =
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
nbElement =
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
}
}
mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Type boxTy = fir::BoxType::get(dstTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value box =
builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
/*slice=*/nullptr, lenParams,
/*tdesc=*/nullptr);
mlir::Value memBox = builder.createTemporary(loc, box.getType());
builder.create<fir::StoreOp>(loc, box, memBox);

mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
mlir::Value bytes =
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
CUFDataTransferDescDescNoRealloc)>(loc, builder);

mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
llvm::SmallVector<mlir::Value> args{
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
Expand Down
18 changes: 18 additions & 0 deletions flang/runtime/CUDA/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}

void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
MemmoveFct memmoveFct;
Terminator terminator{sourceFile, sourceLine};
if (mode == kHostToDevice) {
memmoveFct = &MemmoveHostToDevice;
} else if (mode == kDeviceToHost) {
memmoveFct = &MemmoveDeviceToHost;
} else if (mode == kDeviceToDevice) {
memmoveFct = &MemmoveDeviceToDevice;
} else {
terminator.Crash("host to host copy not supported");
}
Fortran::runtime::Assign(
*dstDesc, *srcDesc, terminator, NoAssignFlags, memmoveFct);
}

void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
Expand Down
22 changes: 10 additions & 12 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func.func @_QPsub4() {
return
}
// CHECK-LABEL: func.func @_QPsub4()
// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
Expand All @@ -81,13 +82,11 @@ func.func @_QPsub4() {
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
%0 = fir.dummy_scope : !fir.dscope
Expand Down Expand Up @@ -115,6 +114,7 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
}

// CHECK-LABEL: func.func @_QPsub5
// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
// CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
Expand All @@ -124,13 +124,11 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPsub6() {
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>
Expand Down

0 comments on commit db69d69

Please sign in to comment.