Skip to content

Commit

Permalink
Ensure global variables have a safe name. (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Sep 6, 2023
1 parent c92ec6d commit d11d9ee
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 114 deletions.
1 change: 1 addition & 0 deletions src/GPUCompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Core: MethodInstance, CodeInstance, CodeInfo
const use_newpm = LLVM.has_newpm()

include("utils.jl")
include("mangling.jl")

# compiler interface and implementations
include("interface.jl")
Expand Down
121 changes: 7 additions & 114 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,13 @@ function irgen(@nospecialize(job::CompilerJob))
Base.depwarn("GPUCompiler.process_module! is deprecated; implement GPUCompiler.finish_module! instead", :process_module)
end

# sanitize function names
# FIXME: Julia should do this, but sometimes fails (see maleadt/LLVM.jl#201)
for f in functions(mod)
isdeclaration(f) && continue
llvmfn = LLVM.name(f)
llvmfn′ = safe_name(llvmfn)
if llvmfn != llvmfn′
@assert !haskey(functions(mod), llvmfn′) "Cannot rename $llvmfn to $llvmfn′, already exists"
LLVM.name!(f, llvmfn′)
# sanitize global values (Julia doesn't when using the external codegen policy)
for val in [collect(globals(mod)); collect(functions(mod))]
isdeclaration(val) && continue
old_name = LLVM.name(val)
new_name = safe_name(old_name)
if old_name != new_name
LLVM.name!(val, new_name)
end
end

Expand Down Expand Up @@ -123,111 +121,6 @@ function irgen(@nospecialize(job::CompilerJob))
end


## name mangling

# we generate function names that look like C++ functions, because many NVIDIA tools
# support them, e.g., grouping different instantiations of the same kernel together.

function mangle_param(t, substitutions=String[])
t == Nothing && return "v"

if isa(t, DataType) && t <: Ptr
tn = mangle_param(eltype(t), substitutions)
"P$tn"
elseif isa(t, DataType)
tn = safe_name(t)

# handle substitutions
sub = findfirst(isequal(tn), substitutions)
if sub === nothing
str = "$(length(tn))$tn"
push!(substitutions, tn)
elseif sub == 1
str = "S_"
else
str = "S$(sub-2)_"
end

# encode typevars as template parameters
if !isempty(t.parameters)
str *= "I"
for t in t.parameters
str *= mangle_param(t, substitutions)
end
str *= "E"
end

str
elseif isa(t, Union)
tn = "Union"

# handle substitutions
sub = findfirst(isequal(tn), substitutions)
if sub === nothing
str = "$(length(tn))$tn"
push!(substitutions, tn)
elseif sub == 1
str = "S_"
else
str = "S$(sub-2)_"
end

# encode union types as template parameters
if !isempty(Base.uniontypes(t))
str *= "I"
for t in Base.uniontypes(t)
str *= mangle_param(t, substitutions)
end
str *= "E"
end

str
elseif isa(t, Integer)
t > 0 ? "Li$(t)E" : "Lin$(abs(t))E"
else
tn = safe_name(t) # TODO: actually does support digits...
if startswith(tn, r"\d")
# C++ classes cannot start with a digit, so mangling doesn't support it
tn = "_$(tn)"
end
"$(length(tn))$tn"
end
end

function mangle_sig(sig)
ft, tt... = sig.parameters

# mangle the function name
fn = safe_name(ft)
str = "_Z$(length(fn))$fn"

# mangle each parameter
substitutions = String[]
for t in tt
str *= mangle_param(t, substitutions)
end

return str
end

# make names safe for ptxas
safe_name(fn::String) = replace(fn, r"[^A-Za-z0-9]"=>"_")
safe_name(t::DataType) = safe_name(String(nameof(t)))
function safe_name(t::Type{<:Function})
# like Base.nameof, but for function types
mt = t.name.mt
fn = if mt === Symbol.name.mt
# uses shared method table, so name is not unique to this function type
nameof(t)
else
mt.name
end
safe_name(string(fn))
end
safe_name(::Type{Union{}}) = "Bottom"
safe_name(x) = safe_name(repr(x))


## exception handling

# this pass lowers `jl_throw` and friends to GPU-compatible exceptions.
Expand Down
112 changes: 112 additions & 0 deletions src/mangling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# name mangling

# safe name generation

# LLVM doesn't like names with special characters, so we need to sanitize them.
# note that we are stricter than LLVM, because of `ptxas`.

safe_name(fn::String) = replace(fn, r"[^A-Za-z0-9]"=>"_")

safe_name(t::DataType) = safe_name(String(nameof(t)))
function safe_name(t::Type{<:Function})
# like Base.nameof, but for function types
mt = t.name.mt
fn = if mt === Symbol.name.mt
# uses shared method table, so name is not unique to this function type
nameof(t)
else
mt.name
end
safe_name(string(fn))
end
safe_name(::Type{Union{}}) = "Bottom"

safe_name(x) = safe_name(repr(x))


# C++ mangling

# we generate function names that look like C++ functions, because many tools, like NVIDIA's
# profilers, support them (grouping different instantiations of the same kernel together).

function mangle_param(t, substitutions=String[])
t == Nothing && return "v"

if isa(t, DataType) && t <: Ptr
tn = mangle_param(eltype(t), substitutions)
"P$tn"
elseif isa(t, DataType)
tn = safe_name(t)

# handle substitutions
sub = findfirst(isequal(tn), substitutions)
if sub === nothing
str = "$(length(tn))$tn"
push!(substitutions, tn)
elseif sub == 1
str = "S_"
else
str = "S$(sub-2)_"
end

# encode typevars as template parameters
if !isempty(t.parameters)
str *= "I"
for t in t.parameters
str *= mangle_param(t, substitutions)
end
str *= "E"
end

str
elseif isa(t, Union)
tn = "Union"

# handle substitutions
sub = findfirst(isequal(tn), substitutions)
if sub === nothing
str = "$(length(tn))$tn"
push!(substitutions, tn)
elseif sub == 1
str = "S_"
else
str = "S$(sub-2)_"
end

# encode union types as template parameters
if !isempty(Base.uniontypes(t))
str *= "I"
for t in Base.uniontypes(t)
str *= mangle_param(t, substitutions)
end
str *= "E"
end

str
elseif isa(t, Integer)
t > 0 ? "Li$(t)E" : "Lin$(abs(t))E"
else
tn = safe_name(t) # TODO: actually does support digits...
if startswith(tn, r"\d")
# C++ classes cannot start with a digit, so mangling doesn't support it
tn = "_$(tn)"
end
"$(length(tn))$tn"
end
end

function mangle_sig(sig)
ft, tt... = sig.parameters

# mangle the function name
fn = safe_name(ft)
str = "_Z$(length(fn))$fn"

# mangle each parameter
substitutions = String[]
for t in tt
str *= mangle_param(t, substitutions)
end

return str
end

0 comments on commit d11d9ee

Please sign in to comment.