Skip to content

Commit

Permalink
Add support for immutable strings.
Browse files Browse the repository at this point in the history
We want a way to store strings in a global variable by passing in a
StringRef to an op, but we don't want to generate a setter for it, since
we currently don't inject the builder into the setter. So, add an
immutable string type based on the existing isImmutable option for
attributes.
  • Loading branch information
Thomas Symalla committed May 21, 2024
1 parent 45c128d commit da01bca
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 19 deletions.
10 changes: 10 additions & 0 deletions example/ExampleDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,13 @@ def ImmutableOp : Op<ExampleDialect, "immutable.op", [WillReturn]> {
Make an argument immutable
}];
}

def StringAttrOp : Op<ExampleDialect, "string.attr.op", [WillReturn]> {
let results = (outs);
let arguments = (ins ImmutableStringAttr:$val);

let summary = "demonstrate an argument that takes in a StringRef";
let description = [{
The argument should not have a setter method
}];
}
2 changes: 2 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ void createFunctionExample(Module &module, const Twine &name) {
moreVarArgs.push_back(b.getInt32(4));
b.create<xd::InstNameConflictVarargsOp>(moreVarArgs, "four.varargs");

b.create<xd::StringAttrOp>("Hello world!");

b.CreateRetVoid();
}

Expand Down
15 changes: 15 additions & 0 deletions include/llvm-dialects/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,21 @@ def : AttrLlvmType<AttrI16, I16>;
def : AttrLlvmType<AttrI32, I32>;
def : AttrLlvmType<AttrI64, I64>;

def ImmutableStringAttr : Attr<"::llvm::StringRef"> {
let toLlvmValue = [{ $_builder.CreateGlobalString($0) }];
let fromLlvmValue = [{ ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>($0)->getInitializer())->getAsString() }];
let isImmutable = "true";
}

def ImmutableStringType : BuiltinType {
let arguments = (args type:$self);
let evaluate = "::llvm::PointerType::get($_context, 0)";
let check = "$self->isPointerTy()";
let capture = [];
}

def : AttrLlvmType<ImmutableStringAttr, ImmutableStringType>;

// ============================================================================
/// More general attributes
// ============================================================================
Expand Down
118 changes: 100 additions & 18 deletions test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ namespace xd {
state.setError();
});

builder.add<StringAttrOp>([](::llvm_dialects::VerifierState &state, StringAttrOp &op) {
if (!op.verifier(state.out()))
state.setError();
});

builder.add<WriteOp>([](::llvm_dialects::VerifierState &state, WriteOp &op) {
if (!op.verifier(state.out()))
state.setError();
Expand All @@ -154,21 +159,21 @@ namespace xd {
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
Expand Down Expand Up @@ -329,7 +334,7 @@ return true;


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), {
lhs->getType(),
rhs->getType(),
Expand Down Expand Up @@ -451,7 +456,7 @@ uint32_t const extra = getExtra();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {lhs->getType()});
Expand Down Expand Up @@ -546,7 +551,7 @@ rhs


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {::llvm::cast<XdVectorType>(vector->getType())->getElementType()});
Expand Down Expand Up @@ -650,7 +655,7 @@ index


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -820,7 +825,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), {
}, false);

Expand Down Expand Up @@ -882,7 +887,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -980,7 +985,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -1147,7 +1152,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {vector->getType()});
Expand Down Expand Up @@ -1613,7 +1618,7 @@ instName


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1676,7 +1681,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1750,7 +1755,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1842,7 +1847,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1934,7 +1939,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -2017,6 +2022,75 @@ initial



const ::llvm::StringLiteral StringAttrOp::s_name{"xd.string.attr.op"};

StringAttrOp* StringAttrOp::create(llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName) {
::llvm::LLVMContext& context = b.getContext();
(void)context;
::llvm::Module& module = *b.GetInsertBlock()->getModule();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(4);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), {
::llvm::PointerType::get(context, 0),
}, false);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
::llvm::SmallString<32> newName;
for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) ||
::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) {
// If a function with the same name but a different types already exists,
// we get a bitcast of a function or a function with the wrong type.
// Try new names until we get one with the correct type.
newName = "";
::llvm::raw_svector_ostream newNameStream(newName);
newNameStream << s_name << "_" << i;
fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs);
}
assert(::llvm::isa<::llvm::Function>(fn.getCallee()));
assert(fn.getFunctionType() == fnType);
assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType());


::llvm::SmallVector<::llvm::Value*, 1> args = {
b.CreateGlobalString(val)
};

return ::llvm::cast<StringAttrOp>(b.CreateCall(fn, args, instName));
}


bool StringAttrOp::verifier(::llvm::raw_ostream &errs) {
::llvm::LLVMContext &context = getModule()->getContext();
(void)context;

using ::llvm_dialects::printable;

if (arg_size() != 1) {
errs << " wrong number of arguments: " << arg_size()
<< ", expected 1\n";
return false;
}

if (getArgOperand(0)->getType() != ::llvm::PointerType::get(context, 0)) {
errs << " argument 0 (val) has type: "
<< *getArgOperand(0)->getType() << '\n';
errs << " expected: " << *::llvm::PointerType::get(context, 0) << '\n';
return false;
}
::llvm::StringRef const val = getVal();
(void)val;
return true;
}


::llvm::StringRef StringAttrOp::getVal() {
return ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>(getArgOperand(0))->getInitializer())->getAsString() ;
}



const ::llvm::StringLiteral WriteOp::s_name{"xd.write"};

WriteOp* WriteOp::create(llvm_dialects::Builder& b, ::llvm::Value * data, const llvm::Twine &instName) {
Expand All @@ -2026,7 +2100,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2089,7 +2163,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2303,6 +2377,14 @@ data
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::StringAttrOp>() {
static const ::llvm_dialects::OpDescription desc{false, "xd.string.attr.op"};
return desc;
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::WriteOp>() {
Expand Down
20 changes: 20 additions & 0 deletions test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,26 @@ bool verifier(::llvm::raw_ostream &errs);
::llvm::Value * getResult();


};

class StringAttrOp : public ::llvm::CallInst {
static const ::llvm::StringLiteral s_name; //{"xd.string.attr.op"};

public:
static bool classof(const ::llvm::CallInst* i) {
return ::llvm_dialects::detail::isSimpleOperation(i, s_name);
}
static bool classof(const ::llvm::Value* v) {
return ::llvm::isa<::llvm::CallInst>(v) &&
classof(::llvm::cast<::llvm::CallInst>(v));
}
static StringAttrOp* create(::llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName = "");

bool verifier(::llvm::raw_ostream &errs);

::llvm::StringRef getVal();


};

class WriteOp : public ::llvm::CallInst {
Expand Down
6 changes: 5 additions & 1 deletion test/example/test-builder.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals
; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy.
; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s

;.
; CHECK: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
;.
; CHECK-LABEL: @example(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.read__i32()
Expand Down Expand Up @@ -42,5 +45,6 @@
; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3)
; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4)
; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0:[0-9]+]])
; CHECK-NEXT: ret void
;

0 comments on commit da01bca

Please sign in to comment.