Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to re-create op with new variadic argument list. #101

Merged
merged 6 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ void createFunctionExample(Module &module, const Twine &name) {
b.create<xd::WriteVarArgOp>(p2, varArgs);
b.create<xd::HandleGetOp>();

auto *replaceable = b.create<xd::WriteVarArgOp>(p2, varArgs);
SmallVector<Value *> varArgs2 = varArgs;
varArgs2.push_back(p2);

replaceable->replaceArgs(varArgs2);
b.create<xd::SetReadOp>(FixedVectorType::get(b.getInt32Ty(), 2));
b.create<xd::SetWriteOp>(y6);

Expand Down
1 change: 1 addition & 0 deletions lib/TableGen/GenDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ void llvm_dialects::genDialectDefs(raw_ostream& out, RecordKeeper& records) {
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
)";

Expand Down
45 changes: 37 additions & 8 deletions lib/TableGen/Operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class AccessorBuilder final {
: m_fmt{fmt}, m_os{out}, m_arg{arg}, m_argTypeString{argTypeString} {}

void emitAccessorDefinitions() const;
void emitVarArgReplacementDefinition() const;

private:
FmtContext &m_fmt;
Expand Down Expand Up @@ -162,10 +163,18 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
const bool isVarArg = arg.type->isVarArgList();
std::string defaultDeclaration = "$0 get$1() $2;";

if (!isVarArg && !arg.type->isImmutable()) {
defaultDeclaration += R"(
void set$1($0 $3);
)";
if (!arg.type->isImmutable()) {
if (!isVarArg) {
defaultDeclaration += R"(
void set$1($0 $3);
)";
} else {
defaultDeclaration += R"(
/// Returns a new op with the same arguments and a new tail argument list.
/// The object on which this is called will be replaced and erased.
$_op *replace$1(::llvm::ArrayRef<Value *>);
)";
}
}

out << tgfmt(defaultDeclaration, &fmt, arg.type->getGetterCppType(),
Expand All @@ -174,8 +183,11 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
}

void AccessorBuilder::emitAccessorDefinitions() const {
// We do not generate a setter for variadic arguments for now.
emitGetterDefinition();

if (m_arg.type->isImmutable())
return;

if (!m_arg.type->isVarArgList())
emitSetterDefinition();
}
Expand Down Expand Up @@ -208,9 +220,6 @@ void AccessorBuilder::emitGetterDefinition() const {
}

void AccessorBuilder::emitSetterDefinition() const {
if (m_arg.type->isImmutable())
return;

std::string toLlvm = m_arg.name;

if (auto *attr = dyn_cast<Attr>(m_arg.type)) {
Expand All @@ -228,6 +237,24 @@ void AccessorBuilder::emitSetterDefinition() const {
&m_fmt);
}

void AccessorBuilder::emitVarArgReplacementDefinition() const {
std::string toLlvm = m_arg.name;

m_os << tgfmt(R"(

$_op *$_op::replace$Name(::llvm::ArrayRef<Value *> $name) {
::llvm::SmallVector<Value *> newArgs;
if ($index > 0)
newArgs.append(arg_begin(), arg_begin() + $index);
newArgs.append($name.begin(), $name.end());
$_op *newOp = ::llvm::cast<$_op>(::llvm::CallInst::Create(getCalledFunction(), newArgs, this->getName(), this->getIterator()));
this->replaceAllUsesWith(newOp);
this->eraseFromParent();
return newOp;
})",
&m_fmt);
}

void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
FmtContext &fmt) const {
unsigned numSuperclassArgs = 0;
Expand All @@ -247,6 +274,8 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
fmt.addSubst("Name", convertToCamelFromSnakeCase(arg.name, true));

builder.emitAccessorDefinitions();
if (!arg.type->isImmutable() && arg.type->isVarArgList())
builder.emitVarArgReplacementDefinition();
}
}

Expand Down
23 changes: 23 additions & 0 deletions test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"

#include "llvm/Support/ModRef.h"
Expand Down Expand Up @@ -1502,6 +1503,17 @@ instName
value_op_iterator(arg_begin() + 0),
value_op_iterator(arg_end()));
}

InstNameConflictVarargsOp *InstNameConflictVarargsOp::replaceInstName_0(::llvm::ArrayRef<Value *> instName_0) {
::llvm::SmallVector<Value *> newArgs;
if (0 > 0)
newArgs.append(arg_begin(), arg_begin() + 0);
newArgs.append(instName_0.begin(), instName_0.end());
InstNameConflictVarargsOp *newOp = ::llvm::cast<InstNameConflictVarargsOp>(::llvm::CallInst::Create(getCalledFunction(), newArgs, this->getName(), this->getIterator()));
this->replaceAllUsesWith(newOp);
this->eraseFromParent();
return newOp;
}
::llvm::Value *InstNameConflictVarargsOp::getResult() {return this;}


Expand Down Expand Up @@ -2233,6 +2245,17 @@ data
value_op_iterator(arg_end()));
}

WriteVarArgOp *WriteVarArgOp::replaceArgs(::llvm::ArrayRef<Value *> args) {
::llvm::SmallVector<Value *> newArgs;
if (1 > 0)
newArgs.append(arg_begin(), arg_begin() + 1);
newArgs.append(args.begin(), args.end());
WriteVarArgOp *newOp = ::llvm::cast<WriteVarArgOp>(::llvm::CallInst::Create(getCalledFunction(), newArgs, this->getName(), this->getIterator()));
this->replaceAllUsesWith(newOp);
this->eraseFromParent();
return newOp;
}


} // namespace xd

Expand Down
100 changes: 54 additions & 46 deletions test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ uint32_t getNumElements() const;
classof(::llvm::cast<::llvm::CallInst>(v));
}
::llvm::Value * getPtr() const;
void setPtr(::llvm::Value * ptr);
::llvm::Value * getCount() const;
void setCount(::llvm::Value * count);
::llvm::Value * getInitial() const;
void setInitial(::llvm::Value * initial);

void setPtr(::llvm::Value * ptr);
::llvm::Value * getCount() const;
void setCount(::llvm::Value * count);
::llvm::Value * getInitial() const;
void setInitial(::llvm::Value * initial);
};

class Add32Op : public ::llvm::CallInst {
Expand All @@ -123,12 +123,12 @@ uint32_t getNumElements() const;
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getLhs() const;
void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);
uint32_t getExtra() const;
void setExtra(uint32_t extra);

void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);
uint32_t getExtra() const;
void setExtra(uint32_t extra);
::llvm::Value * getResult();


Expand All @@ -150,10 +150,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getLhs() const;
void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);

void setLhs(::llvm::Value * lhs);
::llvm::Value * getRhs() const;
void setRhs(::llvm::Value * rhs);
::llvm::Value * getResult();


Expand All @@ -175,10 +175,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getVector() const;
void setVector(::llvm::Value * vector);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);

void setVector(::llvm::Value * vector);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);
::llvm::Value * getResult();


Expand All @@ -200,8 +200,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getSource() const;
void setSource(::llvm::Value * source);

void setSource(::llvm::Value * source);
::llvm::Value * getResult();


Expand Down Expand Up @@ -244,8 +244,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getSource() const;
void setSource(::llvm::Value * source);

void setSource(::llvm::Value * source);
::llvm::Value * getResult();


Expand All @@ -267,8 +267,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getSource() const;
void setSource(::llvm::Value * source);

void setSource(::llvm::Value * source);
::llvm::Value * getResult();


Expand Down Expand Up @@ -310,12 +310,12 @@ bool getVal() const;
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getVector() const;
void setVector(::llvm::Value * vector);
::llvm::Value * getValue() const;
void setValue(::llvm::Value * value);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);

void setVector(::llvm::Value * vector);
::llvm::Value * getValue() const;
void setValue(::llvm::Value * value);
::llvm::Value * getIndex() const;
void setIndex(::llvm::Value * index);
::llvm::Value * getResult();


Expand All @@ -337,10 +337,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getInstName() const;
void setInstName(::llvm::Value * instName);
::llvm::Value * getInstName_0() const;
void setInstName_0(::llvm::Value * instName_0);

void setInstName(::llvm::Value * instName);
::llvm::Value * getInstName_0() const;
void setInstName_0(::llvm::Value * instName_0);
::llvm::Value * getResult();


Expand All @@ -362,8 +362,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getInstName() const;
void setInstName(::llvm::Value * instName);

void setInstName(::llvm::Value * instName);
::llvm::Value * getResult();


Expand All @@ -385,6 +385,10 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::iterator_range<::llvm::User::value_op_iterator> getInstName_0() ;
/// Returns a new op with the same arguments and a new tail argument list.
/// The object on which this is called will be replaced and erased.
InstNameConflictVarargsOp *replaceInstName_0(::llvm::ArrayRef<Value *>);

::llvm::Value * getResult();


Expand Down Expand Up @@ -448,8 +452,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getData() const;
void setData(::llvm::Value * data);

void setData(::llvm::Value * data);


};
Expand All @@ -470,8 +474,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Type * getSizeofType() const;
void setSizeofType(::llvm::Type * sizeof_type);

void setSizeofType(::llvm::Type * sizeof_type);
::llvm::Value * getResult();


Expand Down Expand Up @@ -576,8 +580,8 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getData() const;
void setData(::llvm::Value * data);

void setData(::llvm::Value * data);


};
Expand All @@ -598,8 +602,12 @@ bool verifier(::llvm::raw_ostream &errs);
bool verifier(::llvm::raw_ostream &errs);

::llvm::Value * getData() const;
void setData(::llvm::Value * data);
::llvm::iterator_range<::llvm::User::value_op_iterator> getArgs() ;
void setData(::llvm::Value * data);
::llvm::iterator_range<::llvm::User::value_op_iterator> getArgs() ;
/// Returns a new op with the same arguments and a new tail argument list.
/// The object on which this is called will be replaced and erased.
WriteVarArgOp *replaceArgs(::llvm::ArrayRef<Value *>);



};
Expand Down
14 changes: 11 additions & 3 deletions test/example/test-builder.test
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --check-globals --include-generated-funcs
; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy.
; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s

;.
; CHECK: @[[GLOB0:.*]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
; CHECK: @str = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
;.
; CHECK-LABEL: @example(
; CHECK-NEXT: entry:
Expand All @@ -28,6 +28,7 @@
; CHECK-NEXT: call void (...) @xd.write(i8 [[P2]])
; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[TMP14:%.*]] = call target("xd.handle") @xd.handle.get()
; CHECK-NEXT: call void (...) @xd.write.vararg(i8 [[P2]], ptr [[P1]], i8 [[P2]], i8 [[P2]])
; CHECK-NEXT: [[TMP15:%.*]] = call <2 x i32> @xd.set.read__v2i32()
; CHECK-NEXT: call void (...) @xd.set.write(target("xd.vector", i32, 1, 2) [[TMP13]])
; CHECK-NEXT: [[TMP16:%.*]] = call [[TMP0]] @xd.read__s_s()
Expand All @@ -45,6 +46,13 @@
; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3)
; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4)
; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0]])
; CHECK-NEXT: call void @xd.string.attr.op(ptr @str)
; CHECK-NEXT: ret void
;
;.
; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind memory(inaccessiblemem: readwrite) }
; CHECK: attributes #[[ATTR1:[0-9]+]] = { nounwind willreturn memory(none) }
; CHECK: attributes #[[ATTR2:[0-9]+]] = { nounwind willreturn memory(inaccessiblemem: write) }
; CHECK: attributes #[[ATTR3:[0-9]+]] = { nounwind willreturn memory(read) }
; CHECK: attributes #[[ATTR4:[0-9]+]] = { willreturn }
;.
Loading