From 1145e57bd28a43acaa334d953c8f2971ea6bef98 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Thu, 9 Jan 2025 12:59:16 -0500 Subject: [PATCH] another hack --- ext/SparseArraysExt.jl | 4 +-- src/execute.jl | 2 +- src/interface/abstract_arrays.jl | 8 +++--- src/interface/lazy.jl | 30 ++++++++++----------- src/interface/morgue.jl | 6 ++--- src/interface/traits.jl | 6 ++--- src/scheduler/LogicExecutor.jl | 2 +- src/scheduler/optimize.jl | 12 ++++----- src/tensors/levels/atomic_element_levels.jl | 2 +- src/tensors/levels/element_levels.jl | 10 +++---- src/tensors/levels/pattern_levels.jl | 4 +-- src/util/shims.jl | 2 +- src/util/staging.jl | 14 +++------- 13 files changed, 47 insertions(+), 55 deletions(-) diff --git a/ext/SparseArraysExt.jl b/ext/SparseArraysExt.jl index aae66c084..5059704c5 100644 --- a/ext/SparseArraysExt.jl +++ b/ext/SparseArraysExt.jl @@ -268,7 +268,7 @@ function Finch.unfurl(ctx, tns::VirtualSparseMatrixCSCColumn, ext, mode, ::Union end $dirty = false end, - body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), fgensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode), + body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), gensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode), epilogue = quote if $dirty $(arr.idx)[$qos] = $(ctx(idx)) @@ -463,7 +463,7 @@ function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode::Updater, ::Union end $dirty = false end, - body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), fgensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode), + body = (ctx) -> Finch.instantiate(ctx, Finch.VirtualSparseScalar(nothing, arr.Tv, zero(arr.Tv), gensym(), :($(arr.val)[$(ctx(qos))]), dirty), mode), epilogue = quote if $dirty $(arr.idx)[$qos] = $(ctx(idx)) diff --git a/src/execute.jl b/src/execute.jl index 09c2f2a58..c17bba248 100644 --- a/src/execute.jl +++ b/src/execute.jl @@ -228,7 +228,7 @@ See also: [`@finch`](@ref) """ function finch_kernel(fname, args, prgm; algebra = DefaultAlgebra(), mode = :safe, ctx = FinchCompiler(algebra=algebra, mode=mode)) maybe_typeof(x) = x isa Type ? x : typeof(x) - unreachable = fgensym(:unreachable) + unreachable = gensym(:unreachable) code = contain(ctx) do ctx_2 foreach(args) do (key, val) set_binding!(ctx_2, variable(key), finch_leaf(virtualize(ctx_2.code, key, maybe_typeof(val), key))) diff --git a/src/interface/abstract_arrays.jl b/src/interface/abstract_arrays.jl index 8f2e317ad..5e8b40d95 100644 --- a/src/interface/abstract_arrays.jl +++ b/src/interface/abstract_arrays.jl @@ -53,11 +53,11 @@ function unfurl(ctx, tns::VirtualAbstractArraySlice, ext, mode, proto) preamble = quote $val = $(arr.ex)[$(map(ctx, idx_2)...)] end, - body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, fgensym(), val), mode) + body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, gensym(), val), mode) ) else Thunk( - body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, fgensym(), :($(arr.ex)[$(map(ctx, idx_2)...)])), mode) + body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, gensym(), :($(arr.ex)[$(map(ctx, idx_2)...)])), mode) ) end else @@ -81,11 +81,11 @@ function instantiate(ctx::AbstractCompiler, arr::VirtualAbstractArray, mode) preamble = quote $val = $(arr.ex)[] end, - body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, fgensym(), val), mode) + body = (ctx) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here =#, gensym(), val), mode) ) else Thunk( - body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, fgensym(), :($(arr.ex)[])), mode) + body = (ctx,) -> instantiate(ctx, VirtualScalar(nothing, arr.eltype, nothing#=We don't know what init is, but it won't be used here=#, gensym(), :($(arr.ex)[])), mode) ) end else diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 11d8a2eaa..97af5f223 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -38,8 +38,8 @@ function expanddims(arr::LazyTensor{T}, dims) where {T} @assert allunique(dims) @assert issubset(dims,1:ndims(arr) + length(dims)) antidims = setdiff(1:ndims(arr) + length(dims), dims) - idxs_1 = [field(fgensym(:i)) for _ in 1:ndims(arr)] - idxs_2 = [field(fgensym(:i)) for _ in 1:ndims(arr) + length(dims)] + idxs_1 = [field(gensym(:i)) for _ in 1:ndims(arr)] + idxs_2 = [field(gensym(:i)) for _ in 1:ndims(arr) + length(dims)] idxs_2[antidims] .= idxs_1 data_2 = reorder(relabel(arr.data, idxs_1...), idxs_2...) extrude_2 = [false for _ in 1:ndims(arr) + length(dims)] @@ -48,7 +48,7 @@ function expanddims(arr::LazyTensor{T}, dims) where {T} end function identify(data) - lhs = alias(fgensym(:A)) + lhs = alias(gensym(:A)) subquery(lhs, data) end @@ -56,8 +56,8 @@ LazyTensor(data::Number) = LazyTensor{typeof(data), 0}(immediate(data), (), data LazyTensor{T}(data::Number) where {T} = LazyTensor{T, 0}(immediate(data), (), data) LazyTensor(arr::Base.AbstractArrayOrBroadcasted) = LazyTensor{eltype(arr)}(arr) function LazyTensor{T}(arr::Base.AbstractArrayOrBroadcasted) where {T} - name = alias(fgensym(:A)) - idxs = [field(fgensym(:i)) for _ in 1:ndims(arr)] + name = alias(gensym(:A)) + idxs = [field(gensym(:i)) for _ in 1:ndims(arr)] extrude = ntuple(n -> size(arr, n) == 1, ndims(arr)) tns = subquery(name, table(immediate(arr), idxs...)) LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, fill_value(arr)) @@ -65,8 +65,8 @@ end LazyTensor(arr::AbstractTensor) = LazyTensor{eltype(arr)}(arr) LazyTensor(swizzle_arr::SwizzleArray{dims, <:Tensor}) where {dims} = permutedims(LazyTensor(swizzle_arr.body), dims) function LazyTensor{T}(arr::AbstractTensor) where {T} - name = alias(fgensym(:A)) - idxs = [field(fgensym(:i)) for _ in 1:ndims(arr)] + name = alias(gensym(:A)) + idxs = [field(gensym(:i)) for _ in 1:ndims(arr)] extrude = ntuple(n -> size(arr)[n] == 1, ndims(arr)) tns = subquery(name, table(immediate(arr), idxs...)) LazyTensor{eltype(arr), ndims(arr)}(tns, extrude, fill_value(arr)) @@ -90,7 +90,7 @@ end function Base.map(f, src::LazyTensor, args...) largs = map(LazyTensor, (src, args...)) extrude = largs[something(findfirst(arg -> length(arg.extrude) > 0, largs), 1)].extrude - idxs = [field(fgensym(:i)) for _ in src.extrude] + idxs = [field(gensym(:i)) for _ in src.extrude] ldatas = map(largs) do larg if larg.extrude == extrude return relabel(larg.data, idxs...) @@ -135,7 +135,7 @@ end function Base.reduce(op, arg::LazyTensor{T, N}; dims=:, init = initial_value(op, T)) where {T, N} dims = dims == Colon() ? (1:N) : collect(dims) extrude = ((arg.extrude[n] for n in 1:N if !(n in dims))...,) - fields = [field(fgensym(:i)) for _ in 1:N] + fields = [field(gensym(:i)) for _ in 1:N] S = fixpoint_type(op, init, eltype(arg)) data = aggregate(immediate(op), immediate(init), relabel(arg.data, fields), fields[dims]...) LazyTensor{S}(identify(data), extrude, init) @@ -160,8 +160,8 @@ function tensordot(A::LazyTensor{T1, N1}, B::LazyTensor{T2, N2}, idxs; mult_op=* extrude = ((A.extrude[n] for n in 1:N1 if !(n in A_idxs))..., (B.extrude[n] for n in 1:N2 if !(n in B_idxs))...,) - A_fields = [field(fgensym(:i)) for _ in 1:N1] - B_fields = [field(fgensym(:i)) for _ in 1:N2] + A_fields = [field(gensym(:i)) for _ in 1:N1] + B_fields = [field(gensym(:i)) for _ in 1:N2] reduce_fields = [] for i in eachindex(A_idxs) B_fields[B_idxs[i]] = A_fields[A_idxs[i]] @@ -197,7 +197,7 @@ function broadcast_to_query(bc::Broadcast.Broadcasted, idxs) end function broadcast_to_query(tns::LazyTensor{T, N}, idxs) where {T, N} - idxs_2 = [tns.extrude[i] ? field(fgensym(:idx)) : idxs[i] for i in 1:N] + idxs_2 = [tns.extrude[i] ? field(gensym(:idx)) : idxs[i] for i in 1:N] data_2 = relabel(tns.data, idxs_2...) reorder(data_2, idxs[findall(!, tns.extrude)]...) end @@ -232,7 +232,7 @@ Base.copyto!(out, bc::Broadcasted{LazyStyle{N}}) where {N} = copyto!(out, copy(b function Base.copy(bc::Broadcasted{LazyStyle{N}}) where {N} bc_lgc = broadcast_to_logic(bc) - idxs = [field(fgensym(:i)) for _ in 1:N] + idxs = [field(gensym(:i)) for _ in 1:N] data = reorder(broadcast_to_query(bc_lgc, idxs), idxs) extrude = ntuple(n -> broadcast_to_extrude(bc_lgc, n), N) def = broadcast_to_default(bc_lgc) @@ -252,7 +252,7 @@ function Base.permutedims(arg::LazyTensor{T, N}, perm) where {T, N} length(perm) == N || throw(ArgumentError("permutedims given wrong number of dimensions")) isperm(perm) || throw(ArgumentError("permutedims given invalid permutation")) perm = collect(perm) - idxs = [field(fgensym(:i)) for _ in 1:N] + idxs = [field(gensym(:i)) for _ in 1:N] return LazyTensor{T, N}(reorder(relabel(arg.data, idxs...), idxs[perm]...), arg.extrude[perm], arg.fill_value) end Base.permutedims(arr::SwizzleArray, perm) = swizzle(arr, perm...) @@ -536,7 +536,7 @@ compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kw compute(args::Tuple; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) function compute_parse(ctx, args::Tuple) args = collect(args) - vars = map(arg -> alias(fgensym(:A)), args) + vars = map(arg -> alias(gensym(:A)), args) bodies = map((arg, var) -> query(var, arg.data), args, vars) prgm = plan(bodies, produces(vars)) diff --git a/src/interface/morgue.jl b/src/interface/morgue.jl index 20b0df180..0d56e7b50 100644 --- a/src/interface/morgue.jl +++ b/src/interface/morgue.jl @@ -27,13 +27,13 @@ end #= root = Rewrite(Fixpoint(Postwalk(Chain([ (@rule plan(~a1..., query(~b, relabel(~c, ~i...)), ~a2...) => begin - d = alias(fgensym(:A)) + d = alias(gensym(:A)) bindings[d] = c rw = Rewrite(Postwalk(@rule b => relabel(d, i...))) plan(a1..., query(d, c), map(rw, a2)...) end), (@rule plan(~a1..., query(~b, reorder(~c, ~i...)), ~a2...) => begin - d = alias(fgensym(:A)) + d = alias(gensym(:A)) bindings[d] = c rw = Rewrite(Postwalk(@rule b => reorder(d, i...))) plan(a1..., query(d, c), map(rw, a2)...) @@ -54,7 +54,7 @@ end function push_reorders(root, bindings) Rewrite(Fixpoint(Postwalk(Chain([ (@rule plan(~a1..., query(~b, reorder(~c, ~i...)), ~a2...) => begin - d = alias(fgensym(:A)) + d = alias(gensym(:A)) bindings[d] = c rw = Rewrite(Postwalk(@rule b => reorder(d, i...))) plan(a1..., query(d, c), map(rw, a2)...) diff --git a/src/interface/traits.jl b/src/interface/traits.jl index 65b299126..1927bf101 100644 --- a/src/interface/traits.jl +++ b/src/interface/traits.jl @@ -214,7 +214,7 @@ function map_rep_def(::MapRepHollowStyle, f, args) end for (n, arg) in enumerate(args) if arg isa HollowData - args_2 = map(arg -> value(fgensym(), eltype(arg)), collect(args)) + args_2 = map(arg -> value(gensym(), eltype(arg)), collect(args)) args_2[n] = literal(fill_value(arg)) if finch_leaf(simplify(FinchCompiler(), call(f, args_2...))) == literal(fill_value(lvl)) return HollowData(lvl) @@ -231,7 +231,7 @@ function map_rep_def(::MapRepSparseStyle, f, args) end for (n, arg) in enumerate(args) if arg isa SparseData - args_2 = map(arg -> value(fgensym(), eltype(arg)), collect(args)) + args_2 = map(arg -> value(gensym(), eltype(arg)), collect(args)) args_2[n] = literal(fill_value(arg)) if finch_leaf(simplify(FinchCompiler(), call(f, args_2...))) == literal(fill_value(lvl)) return SparseData(lvl) @@ -248,7 +248,7 @@ function map_rep_def(::MapRepRepeatStyle, f, args) end for (n, arg) in enumerate(args) if arg isa RepeatData - args_2 = map(arg -> value(fgensym(), eltype(arg)), collect(args)) + args_2 = map(arg -> value(gensym(), eltype(arg)), collect(args)) args_2[n] = literal(fill_value(arg)) if finch_leaf(simplify(FinchCompiler(), call(f, args_2...))) == literal(fill_value(lvl)) return RepeatData(lvl) diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index b8e4f6b88..a10b972ae 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -44,7 +44,7 @@ function logic_executor_code(ctx, prgm) ctx(prgm) end code = pretty(code) - fname = fgensym(:compute) + fname = gensym(Symbol(:compute, hash(get_structure(prgm)))) #The fact that we need this hash is worrisome return :(function $fname(prgm) $code end) |> striplines diff --git a/src/scheduler/optimize.jl b/src/scheduler/optimize.jl index 38413fefa..0c169d4b3 100644 --- a/src/scheduler/optimize.jl +++ b/src/scheduler/optimize.jl @@ -5,21 +5,21 @@ flatten_plans = Rewrite(Postwalk(Fixpoint(Chain([ isolate_aggregates = Rewrite(Postwalk( @rule aggregate(~op, ~init, ~arg, ~idxs...) => begin - name = alias(fgensym(:A)) + name = alias(gensym(:A)) subquery(name, aggregate(~op, ~init, ~arg, ~idxs...)) end )) isolate_reformats = Rewrite(Postwalk( @rule reformat(~tns, ~arg) => begin - name = alias(fgensym(:A)) + name = alias(gensym(:A)) subquery(name, reformat(tns, arg)) end )) isolate_tables = Rewrite(Postwalk( @rule table(~tns, ~idxs...) => begin - name = alias(fgensym(:A)) + name = alias(gensym(:A)) subquery(name, table(tns, idxs...)) end )) @@ -193,7 +193,7 @@ function materialize_squeeze_expand_productions(root) preamble = [] args_2 = map(args) do arg if (@capture arg reorder(relabel(~tns::isalias, ~idxs_1...), ~idxs_2...)) && Set(idxs_1) != Set(idxs_2) - tns_2 = alias(fgensym(:A)) + tns_2 = alias(gensym(:A)) idxs_3 = withsubsequence(intersect(idxs_1, idxs_2), idxs_2) push!(preamble, query(tns_2, reorder(relabel(tns, idxs_1), idxs_3))) if idxs_3 == idxs_2 @@ -541,7 +541,7 @@ function set_loop_order(node, perms = Dict(), reps = Dict()) reps[lhs] = SuitableRep(reps)(rhs_2) query(lhs, reformat(tns, rhs_2)) elseif @capture node query(~lhs, reformat(~tns, ~rhs)) - arg = alias(fgensym(:A)) + arg = alias(gensym(:A)) set_loop_order(plan( query(A, rhs), query(lhs, reformat(tns, A)) @@ -593,7 +593,7 @@ function optimize(prgm) prgm = isolate_tables(prgm) prgm = lift_subqueries(prgm) - #I shouldn't use fgensym but I do, so this cleans up the names + #I shouldn't use gensym but I do, so this cleans up the names prgm = pretty_labels(prgm) #These steps fuse copy, permutation, and mapjoin statements diff --git a/src/tensors/levels/atomic_element_levels.jl b/src/tensors/levels/atomic_element_levels.jl index 365b038f0..6393cbebe 100644 --- a/src/tensors/levels/atomic_element_levels.jl +++ b/src/tensors/levels/atomic_element_levels.jl @@ -162,7 +162,7 @@ function instantiate(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode: preamble = quote $val = $(lvl.val)[$(ctx(pos))] end, - body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), val) + body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val) ) end diff --git a/src/tensors/levels/element_levels.jl b/src/tensors/levels/element_levels.jl index 47e9ce756..58f1f086b 100644 --- a/src/tensors/levels/element_levels.jl +++ b/src/tensors/levels/element_levels.jl @@ -170,26 +170,26 @@ function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Reade preamble = quote $val = $(lvl.val)[$(ctx(pos))] end, - body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), val) + body = (ctx) -> VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), val) ) end function instantiate(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Updater) (lvl, pos) = (fbr.lvl, fbr.pos) - VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))])) + VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))])) end function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode::Updater) (lvl, pos) = (fbr.lvl, fbr.pos) - VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty) + VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty) end function lower_assign(ctx, fbr::VirtualHollowSubFiber{VirtualElementLevel}, mode::Updater, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) - lower_assign(ctx, VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty), mode, op, rhs) + lower_assign(ctx, VirtualSparseScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))]), fbr.dirty), mode, op, rhs) end function lower_assign(ctx, fbr::VirtualSubFiber{VirtualElementLevel}, mode::Updater, op, rhs) (lvl, pos) = (fbr.lvl, fbr.pos) - lower_assign(ctx, VirtualScalar(nothing, lvl.Tv, lvl.Vf, fgensym(), :($(lvl.val)[$(ctx(pos))])), mode, op, rhs) + lower_assign(ctx, VirtualScalar(nothing, lvl.Tv, lvl.Vf, gensym(), :($(lvl.val)[$(ctx(pos))])), mode, op, rhs) end \ No newline at end of file diff --git a/src/tensors/levels/pattern_levels.jl b/src/tensors/levels/pattern_levels.jl index 90e339d05..b6fdab12a 100644 --- a/src/tensors/levels/pattern_levels.jl +++ b/src/tensors/levels/pattern_levels.jl @@ -128,9 +128,9 @@ instantiate(ctx, ::VirtualSubFiber{VirtualPatternLevel}, mode::Reader) = FillLea function instantiate(ctx, fbr::VirtualSubFiber{VirtualPatternLevel}, mode::Updater) val = freshen(ctx, :null) push_preamble!(ctx, :($val = false)) - VirtualScalar(nothing, Bool, false, fgensym(), val) + VirtualScalar(nothing, Bool, false, gensym(), val) end function instantiate(ctx, fbr::VirtualHollowSubFiber{VirtualPatternLevel}, mode::Updater) - VirtualScalar(nothing, Bool, false, fgensym(), fbr.dirty) + VirtualScalar(nothing, Bool, false, gensym(), fbr.dirty) end \ No newline at end of file diff --git a/src/util/shims.jl b/src/util/shims.jl index 8b34159ce..fecce7119 100644 --- a/src/util/shims.jl +++ b/src/util/shims.jl @@ -64,7 +64,7 @@ ensuring that the variables in `ex` are not mutated by the arguments. """ macro barrier(args_ex...) (args, ex) = args_ex[1:end-1], args_ex[end] - f = fgensym() + f = gensym() esc(quote $f = Finch.@closure ($(args...),) -> $ex $f() diff --git a/src/util/staging.jl b/src/util/staging.jl index ca50887c5..532832168 100644 --- a/src/util/staging.jl +++ b/src/util/staging.jl @@ -21,9 +21,9 @@ This macro does not support type parameters, varargs, or keyword arguments. macro staged(def) (@capture def :function(:call(~name, ~args...), ~body)) || throw(ArgumentError("unrecognized function definition in @staged")) - name_generator = fgensym(Symbol(name, :_generator)) - name_invokelatest = fgensym(Symbol(name, :_invokelatest)) - name_eval_invokelatest = fgensym(Symbol(name, :_eval_invokelatest)) + name_generator = gensym(Symbol(name, :_generator)) + name_invokelatest = gensym(Symbol(name, :_invokelatest)) + name_eval_invokelatest = gensym(Symbol(name, :_eval_invokelatest)) def = quote function $name_generator($(args...)) @@ -84,11 +84,3 @@ function refresh() @eval $def end end - -""" - fgensym([tag]) - -Generate a new fgensym symbol with the given name, for use in Finch. -""" -fgensym(tag) = eval(Finch, :(gensym($(QuoteNode(tag))))) -fgensym() = eval(Finch, :(gensym())) \ No newline at end of file