Skip to content

Commit

Permalink
Add support for immutable attributes.
Browse files Browse the repository at this point in the history
Sometimes we don't want or can't update operands with a auto-generated
setter method, so this change adds support to make attributes immutable.
  • Loading branch information
Thomas Symalla authored and tsymalla-AMD committed May 23, 2024
1 parent 28eac73 commit 7377ed4
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 19 deletions.
16 changes: 16 additions & 0 deletions example/ExampleDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def VectorKindLittleEndian : CppConstant<"xd::VectorKind::LittleEndian">;
def VectorKindBigEndian : CppConstant<"xd::VectorKind::BigEndian">;
def VectorKindMiddleEndian : CppConstant<"xd::VectorKind::MiddleEndian">;

def ImmutableAttrI1 : IntegerAttr<"bool"> {
let isImmutable = true;
}

def : AttrLlvmType<ImmutableAttrI1, I1>;

def isReasonableVectorKind : TgPredicate<
(args AttrVectorKind:$kind),
(eq $kind, (or VectorKindLittleEndian, VectorKindBigEndian))>;
Expand Down Expand Up @@ -301,3 +307,13 @@ def InstNameConflictVarargsOp : Op<ExampleDialect, "inst.name.conflict.varargs",
Like InstNameConflictOp but with varargs
}];
}

def ImmutableOp : Op<ExampleDialect, "immutable.op", [WillReturn]> {
let results = (outs);
let arguments = (ins ImmutableAttrI1:$val);

let summary = "demonstrate how an argument will not get a setter method";
let description = [{
Make an argument immutable
}];
}
3 changes: 3 additions & 0 deletions include/llvm-dialects/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class Attr<string cppType_> : MetaType {
// A check statement that is issued before using the C++ value in builders.
// $0 is the C++ value.
string check = "";

// Overriding prevents generating a setter method. Attributes are mutable by default.
bit isImmutable = false;
}

class IntegerAttr<string cppType_> : Attr<cppType_> {
Expand Down
3 changes: 3 additions & 0 deletions include/llvm-dialects/TableGen/Constraints.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class MetaType {
bool isTypeArg() const { return m_kind == Kind::Type; }
bool isValueArg() const { return m_kind == Kind::Value; }
bool isVarArgList() const { return m_kind == Kind::VarArgList; }
bool isImmutable() const;

protected:
MetaType(Kind kind) : m_kind(kind) {}
Expand Down Expand Up @@ -231,6 +232,7 @@ class Attr : public MetaType {
llvm::StringRef getToUnsigned() const { return m_toUnsigned; }
llvm::StringRef getFromUnsigned() const { return m_fromUnsigned; }
llvm::StringRef getCheck() const { return m_check; }
bool getIsImmutable() const { return m_isImmutable; }

// Set the LLVMType once -- used during initialization to break a circular
// dependency in how IntegerType is defined.
Expand All @@ -249,6 +251,7 @@ class Attr : public MetaType {
std::string m_toUnsigned;
std::string m_fromUnsigned;
std::string m_check;
bool m_isImmutable;
};

} // namespace llvm_dialects
8 changes: 8 additions & 0 deletions lib/TableGen/Constraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ StringRef MetaType::getBuilderCppType() const {
return getCppType();
}

bool MetaType::isImmutable() const {
if (auto *attr = dyn_cast<Attr>(this))
return attr->getIsImmutable();

return false;
}

/// Return the C++ expression @p value transformed to be suitable for printing
/// using LLVM's raw_ostream.
std::string MetaType::printable(const MetaType *type, llvm::StringRef value) {
Expand Down Expand Up @@ -394,6 +401,7 @@ std::unique_ptr<Attr> Attr::parse(raw_ostream &errs,
attr->m_toUnsigned = record->getValueAsString("toUnsigned");
attr->m_fromUnsigned = record->getValueAsString("fromUnsigned");
attr->m_check = record->getValueAsString("check");
attr->m_isImmutable = record->getValueAsBit("isImmutable");

return attr;
}
Expand Down
5 changes: 4 additions & 1 deletion lib/TableGen/Operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void OperationBase::emitArgumentAccessorDeclarations(llvm::raw_ostream &out,
FmtContext &fmt) const {
for (const auto &arg : m_arguments) {
std::string defaultDeclaration = "$0 get$1();";
if (!arg.type->isVarArgList()) {
if (!arg.type->isVarArgList() && !arg.type->isImmutable()) {
defaultDeclaration += R"(
void set$1($0 $2);
)";
Expand Down Expand Up @@ -205,6 +205,9 @@ void AccessorBuilder::emitGetterDefinition() const {
}

void AccessorBuilder::emitSetterDefinition() const {
if (m_arg.type->isImmutable())
return;

std::string toLlvm = m_arg.name;

if (auto *attr = dyn_cast<Attr>(m_arg.type)) {
Expand Down
118 changes: 100 additions & 18 deletions test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ namespace xd {
state.setError();
});

builder.add<ImmutableOp>([](::llvm_dialects::VerifierState &state, ImmutableOp &op) {
if (!op.verifier(state.out()))
state.setError();
});

builder.add<InsertElementOp>([](::llvm_dialects::VerifierState &state, InsertElementOp &op) {
if (!op.verifier(state.out()))
state.setError();
Expand Down Expand Up @@ -149,21 +154,21 @@ namespace xd {
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
Expand Down Expand Up @@ -324,7 +329,7 @@ return true;


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), {
lhs->getType(),
rhs->getType(),
Expand Down Expand Up @@ -446,7 +451,7 @@ uint32_t const extra = getExtra();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {lhs->getType()});
Expand Down Expand Up @@ -541,7 +546,7 @@ rhs


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {::llvm::cast<XdVectorType>(vector->getType())->getElementType()});
Expand Down Expand Up @@ -645,7 +650,7 @@ index


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -815,7 +820,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), {
}, false);

Expand Down Expand Up @@ -877,7 +882,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -975,7 +980,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -1064,6 +1069,75 @@ source



const ::llvm::StringLiteral ImmutableOp::s_name{"xd.immutable.op"};

ImmutableOp* ImmutableOp::create(llvm_dialects::Builder& b, bool val, const llvm::Twine &instName) {
::llvm::LLVMContext& context = b.getContext();
(void)context;
::llvm::Module& module = *b.GetInsertBlock()->getModule();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(4);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), {
::llvm::IntegerType::get(context, 1),
}, false);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
::llvm::SmallString<32> newName;
for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) ||
::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) {
// If a function with the same name but a different types already exists,
// we get a bitcast of a function or a function with the wrong type.
// Try new names until we get one with the correct type.
newName = "";
::llvm::raw_svector_ostream newNameStream(newName);
newNameStream << s_name << "_" << i;
fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs);
}
assert(::llvm::isa<::llvm::Function>(fn.getCallee()));
assert(fn.getFunctionType() == fnType);
assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType());


::llvm::SmallVector<::llvm::Value*, 1> args = {
::llvm::ConstantInt::get(::llvm::IntegerType::get(context, 1), val)
};

return ::llvm::cast<ImmutableOp>(b.CreateCall(fn, args, instName));
}


bool ImmutableOp::verifier(::llvm::raw_ostream &errs) {
::llvm::LLVMContext &context = getModule()->getContext();
(void)context;

using ::llvm_dialects::printable;

if (arg_size() != 1) {
errs << " wrong number of arguments: " << arg_size()
<< ", expected 1\n";
return false;
}

if (getArgOperand(0)->getType() != ::llvm::IntegerType::get(context, 1)) {
errs << " argument 0 (val) has type: "
<< *getArgOperand(0)->getType() << '\n';
errs << " expected: " << *::llvm::IntegerType::get(context, 1) << '\n';
return false;
}
bool const val = getVal();
(void)val;
return true;
}


bool ImmutableOp::getVal() {
return ::llvm::cast<::llvm::ConstantInt>(getArgOperand(0))->getZExtValue() ;
}



const ::llvm::StringLiteral InsertElementOp::s_name{"xd.insertelement"};

InsertElementOp* InsertElementOp::create(llvm_dialects::Builder& b, ::llvm::Value * vector, ::llvm::Value * value, ::llvm::Value * index, const llvm::Twine &instName) {
Expand All @@ -1073,7 +1147,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {vector->getType()});
Expand Down Expand Up @@ -1539,7 +1613,7 @@ instName


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1602,7 +1676,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1676,7 +1750,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1768,7 +1842,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1860,7 +1934,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1952,7 +2026,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2015,7 +2089,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2133,6 +2207,14 @@ data
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::ImmutableOp>() {
static const ::llvm_dialects::OpDescription desc{false, "xd.immutable.op"};
return desc;
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::InsertElementOp>() {
Expand Down
20 changes: 20 additions & 0 deletions test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,26 @@ bool verifier(::llvm::raw_ostream &errs);
::llvm::Value * getResult();


};

class ImmutableOp : public ::llvm::CallInst {
static const ::llvm::StringLiteral s_name; //{"xd.immutable.op"};

public:
static bool classof(const ::llvm::CallInst* i) {
return ::llvm_dialects::detail::isSimpleOperation(i, s_name);
}
static bool classof(const ::llvm::Value* v) {
return ::llvm::isa<::llvm::CallInst>(v) &&
classof(::llvm::cast<::llvm::CallInst>(v));
}
static ImmutableOp* create(::llvm_dialects::Builder& b, bool val, const llvm::Twine &instName = "");

bool verifier(::llvm::raw_ostream &errs);

bool getVal();


};

class InsertElementOp : public ::llvm::CallInst {
Expand Down

0 comments on commit 7377ed4

Please sign in to comment.