Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a bug with separate compilation. #2264

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx/using/extending/cudaq_ir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Let's see the output of :code:`nvq++` in verbose mode. Consider a simple code li
$ nvq++ simple.cpp -v --save-temps

cudaq-quake --emit-llvm-file simple.cpp -o simple.qke
cudaq-opt --pass-pipeline=builtin.module(canonicalize,lambda-lifting,canonicalize,apply-op-specialization,kernel-execution,inline{default-pipeline=func.func(indirect-to-direct-calls)},func.func(quake-add-metadata),device-code-loader{use-quake=1},expand-measurements,func.func(lower-to-cfg),canonicalize,cse) simple.qke -o simple.qke.LpsXpu
cudaq-opt --pass-pipeline=builtin.module(canonicalize,lambda-lifting,canonicalize,apply-op-specialization,kernel-execution,indirect-to-direct-calls,inline,func.func(quake-add-metadata),device-code-loader{use-quake=1},expand-measurements,func.func(lower-to-cfg),canonicalize,cse) simple.qke -o simple.qke.LpsXpu
cudaq-translate --convert-to=qir simple.qke.LpsXpu -o simple.ll.p3De4L
fixup-linkage.pl simple.qke simple.ll
llc --relocation-model=pic --filetype=obj -O2 simple.ll.p3De4L -o simple.qke.o
Expand Down
4 changes: 4 additions & 0 deletions include/cudaq/Optimizer/Builder/Factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ mlir::Value createCast(mlir::OpBuilder &builder, mlir::Location loc,
std::vector<std::complex<double>>
readGlobalConstantArray(cudaq::cc::GlobalOp &global);

std::pair<mlir::func::FuncOp, /*alreadyDefined=*/bool>
getOrAddFunc(mlir::Location loc, mlir::StringRef funcName,
mlir::FunctionType funcTy, mlir::ModuleOp module);

} // namespace factory

std::size_t getDataSize(llvm::DataLayout &dataLayout, mlir::Type ty);
Expand Down
3 changes: 1 addition & 2 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ def ConstPropComplex : Pass<"const-prop-complex", "mlir::ModuleOp"> {
}];
}

def ConvertToDirectCalls :
Pass<"indirect-to-direct-calls", "mlir::func::FuncOp"> {
def ConvertToDirectCalls : Pass<"indirect-to-direct-calls", "mlir::ModuleOp"> {
let summary = "Convert calls to direct calls to Quake routines.";
let description = [{
Rewrite the calls in the IR so that they point to the generated code and not
Expand Down
13 changes: 8 additions & 5 deletions lib/Frontend/nvqpp/ASTBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ class QPUCodeFinder : public clang::RecursiveASTVisitor<QPUCodeFinder> {
return result;
}

bool VisitFunctionDecl(clang::FunctionDecl *func) {
bool VisitFunctionDecl(clang::FunctionDecl *x) {
if (ignoreTemplate)
return true;
func = func->getDefinition();
auto *func = x->getDefinition();

if (func) {
bool runChecks = false;
Expand All @@ -248,6 +248,9 @@ class QPUCodeFinder : public clang::RecursiveASTVisitor<QPUCodeFinder> {
processQpu(cudaq::details::getTagNameOfFunctionDecl(func, mangler),
func);
}
} else if (cudaq::ASTBridgeAction::ASTBridgeConsumer::isQuantum(x)) {
// Add declarations to support separate compilation.
processQpu(cudaq::details::getTagNameOfFunctionDecl(x, mangler), x);
}
return true;
}
Expand All @@ -268,9 +271,9 @@ class QPUCodeFinder : public clang::RecursiveASTVisitor<QPUCodeFinder> {
if (ignoreTemplate)
return true;
if (const auto *cxxMethodDecl = lambda->getCallOperator())
if (const auto *f = cxxMethodDecl->getAsFunction()->getDefinition();
f && cudaq::ASTBridgeAction::ASTBridgeConsumer::isQuantum(f))
processQpu(cudaq::details::getTagNameOfFunctionDecl(f, mangler), f);
if (const auto *f = cxxMethodDecl->getAsFunction()->getDefinition())
if (cudaq::ASTBridgeAction::ASTBridgeConsumer::isQuantum(f))
processQpu(cudaq::details::getTagNameOfFunctionDecl(f, mangler), f);
return true;
}

Expand Down
20 changes: 2 additions & 18 deletions lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,10 @@ void QuakeBridgeVisitor::createEntryBlock(func::FuncOp func,
addArgumentSymbols(entryBlock, x->parameters());
}

std::pair<func::FuncOp, /*alreadyDefined=*/bool>
std::pair<func::FuncOp, bool>
QuakeBridgeVisitor::getOrAddFunc(Location loc, StringRef funcName,
FunctionType funcTy) {
auto func = module.lookupSymbol<func::FuncOp>(funcName);
if (func) {
if (!func.empty()) {
// Already lowered function func, skip it.
return {func, /*defined=*/true};
}
// Function was declared but not defined.
return {func, /*defined=*/false};
}
// Function not found, so add it to the module.
OpBuilder build(module.getBodyRegion());
OpBuilder::InsertionGuard guard(build);
build.setInsertionPointToEnd(module.getBody());
SmallVector<NamedAttribute> attrs;
func = build.create<func::FuncOp>(loc, funcName, funcTy, attrs);
func.setPrivate();
return {func, /*defined=*/false};
return cudaq::opt::factory::getOrAddFunc(loc, funcName, funcTy, module);
}

bool QuakeBridgeVisitor::interceptRecordDecl(clang::RecordDecl *x) {
Expand Down
22 changes: 22 additions & 0 deletions lib/Optimizer/Builder/Factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,4 +614,26 @@ factory::readGlobalConstantArray(cudaq::cc::GlobalOp &global) {
return result;
}

std::pair<mlir::func::FuncOp, bool>
factory::getOrAddFunc(mlir::Location loc, mlir::StringRef funcName,
mlir::FunctionType funcTy, mlir::ModuleOp module) {
auto func = module.lookupSymbol<func::FuncOp>(funcName);
if (func) {
if (!func.empty()) {
// Already lowered function func, skip it.
return {func, /*defined=*/true};
}
// Function was declared but not defined.
return {func, /*defined=*/false};
}
// Function not found, so add it to the module.
OpBuilder build(module.getBodyRegion());
OpBuilder::InsertionGuard guard(build);
build.setInsertionPointToEnd(module.getBody());
SmallVector<NamedAttribute> attrs;
func = build.create<func::FuncOp>(loc, funcName, funcTy, attrs);
func.setPrivate();
return {func, /*defined=*/false};
}

} // namespace cudaq::opt
48 changes: 28 additions & 20 deletions lib/Optimizer/Transforms/AggressiveEarlyInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ namespace cudaq::opt {

using namespace mlir;

static bool isIndirectFunc(llvm::StringRef funcName,
llvm::StringMap<llvm::StringRef> indirectMap) {
static bool isIndirectFunc(StringRef funcName,
llvm::StringMap<StringRef> indirectMap) {
return indirectMap.find(funcName) != indirectMap.end();
}

// Return the inverted mangled name map.
static std::optional<llvm::StringMap<llvm::StringRef>>
static std::optional<llvm::StringMap<StringRef>>
getConversionMap(ModuleOp module) {
llvm::StringMap<llvm::StringRef> result;
llvm::StringMap<StringRef> result;
if (auto mangledNameMap = module->getAttrOfType<DictionaryAttr>(
cudaq::runtime::mangledNameMap)) {
for (auto namedAttr : mangledNameMap) {
Expand All @@ -53,25 +53,34 @@ namespace {
/// dialect calls and callables as well.]
class RewriteCall : public OpRewritePattern<func::CallOp> {
public:
RewriteCall(MLIRContext *ctx, llvm::StringMap<llvm::StringRef> &indirectMap)
: OpRewritePattern(ctx), indirectMap(indirectMap) {}
RewriteCall(MLIRContext *ctx, llvm::StringMap<StringRef> &indirectMap,
ModuleOp m)
: OpRewritePattern(ctx), indirectMap(indirectMap), module(m) {}

LogicalResult matchAndRewrite(func::CallOp op,
LogicalResult matchAndRewrite(func::CallOp call,
PatternRewriter &rewriter) const override {
if (!isIndirectFunc(op.getCallee(), indirectMap))
if (!isIndirectFunc(call.getCallee(), indirectMap))
return failure();

rewriter.startRootUpdate(op);
auto callee = op.getCallee();
llvm::StringRef directName = indirectMap[callee];
op.setCalleeAttr(SymbolRefAttr::get(op.getContext(), directName));
auto callee = call.getCallee();
StringRef directName = indirectMap[callee];
auto *ctx = rewriter.getContext();
auto loc = call.getLoc();
auto funcTy = call.getCalleeType();
auto [fn, defn] =
cudaq::opt::factory::getOrAddFunc(loc, directName, funcTy, module);
if (!defn)
fn.setPrivate();
rewriter.startRootUpdate(call);
call.setCalleeAttr(SymbolRefAttr::get(ctx, directName));
rewriter.finalizeRootUpdate(call);
LLVM_DEBUG(llvm::dbgs() << "Rewriting " << directName << '\n');
rewriter.finalizeRootUpdate(op);
return success();
}

private:
llvm::StringMap<llvm::StringRef> &indirectMap;
llvm::StringMap<StringRef> &indirectMap;
ModuleOp module;
};

/// Translate indirect calls to direct calls.
Expand All @@ -81,14 +90,13 @@ class ConvertToDirectCalls
using ConvertToDirectCallsBase::ConvertToDirectCallsBase;

void runOnOperation() override {
auto op = getOperation();
ModuleOp module = getOperation();
auto *ctx = &getContext();
auto module = op->template getParentOfType<ModuleOp>();
if (auto indirectMapOpt = getConversionMap(module)) {
LLVM_DEBUG(llvm::dbgs() << "Processing: " << op << '\n');
LLVM_DEBUG(llvm::dbgs() << "Processing: " << module << '\n');
RewritePatternSet patterns(ctx);
patterns.insert<RewriteCall>(ctx, *indirectMapOpt);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
patterns.insert<RewriteCall>(ctx, *indirectMapOpt, module);
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
signalPassFailure();
}
}
Expand Down Expand Up @@ -136,7 +144,7 @@ static void defaultInlinerOptPipeline(OpPassManager &pm) {
/// graph.
void cudaq::opt::addAggressiveEarlyInlining(OpPassManager &pm) {
llvm::StringMap<OpPassManager> opPipelines;
pm.addNestedPass<func::FuncOp>(cudaq::opt::createConvertToDirectCalls());
pm.addPass(cudaq::opt::createConvertToDirectCalls());
pm.addPass(createInlinerPass(opPipelines, defaultInlinerOptPipeline));
pm.addNestedPass<func::FuncOp>(cudaq::opt::createCheckKernelCalls());
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Optimizer/Transforms/GenDeviceCodeLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class GenerateDeviceCodeLoaderPass
continue;
if (!funcOp.getName().startswith(cudaq::runtime::cudaqGenPrefixName))
continue;
if (funcOp->hasAttr(cudaq::generatorAnnotation))
if (funcOp->hasAttr(cudaq::generatorAnnotation) || funcOp.empty())
continue;
auto className =
funcOp.getName().drop_front(cudaq::runtime::cudaqGenPrefixLength);
Expand Down
50 changes: 50 additions & 0 deletions targettests/SeparateCompilation/pure_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

// clang-format off
// RUN: if [ command -v split-file ]; then \
// RUN: split-file %s %t && \
// RUN: nvq++ %cpp_std --enable-mlir -c %t/pd_lib.cpp -o %t/pd_lib.o && \
// RUN: nvq++ %cpp_std --enable-mlir -c %t/pd_main.cpp -o %t/pd_main.o && \
// RUN: nvq++ %cpp_std --enable-mlir %t/pd_lib.o %t/pd_main.o -o %t/pd.a.out && \
// RUN: %t/pd.a.out | FileCheck %s ; else \
// RUN: echo "skipping" ; fi
// clang-format on

//--- pd_lib.h

#pragma once

#include "cudaq.h"

// NB: The __qpu__ here on this declaration cannot be omitted!
__qpu__ void callMe(cudaq::qvector<> &q, int i);

//--- pd_lib.cpp

#include "pd_lib.h"

void send_bat_signal() { std::cout << "na na na na na ... BATMAN!\n"; }

__qpu__ void callMe(cudaq::qvector<> &q, int i) {
ry(2.2, q[0]);
send_bat_signal();
}

//--- pd_main.cpp

#include "pd_lib.h"

__qpu__ void entry() {
cudaq::qvector q(2);
callMe(q, 5);
}

int main() { entry(); }

// CHECK: na ... BATMAN!
2 changes: 1 addition & 1 deletion tools/nvqpp/nvq++.in
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ if ${ENABLE_AGGRESSIVE_EARLY_INLINE}; then
if ${DO_LINK}; then
OPT_PASSES=$(add_pass_to_pipeline "${OPT_PASSES}" "aggressive-early-inlining")
else
OPT_PASSES=$(add_pass_to_pipeline "${OPT_PASSES}" "func.func(indirect-to-direct-calls),inline")
OPT_PASSES=$(add_pass_to_pipeline "${OPT_PASSES}" "indirect-to-direct-calls,inline")
fi
fi
if ${ENABLE_DEVICE_CODE_LOADERS}; then
Expand Down
Loading