Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShardedLevels #654

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/src/docs/internals/virtualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
quote
_res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0)), begin
let i = index_instance(i)
(Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(Finch.FinchNotation.Updater()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(Finch.FinchNotation.initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
(Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), (Finch.FinchNotation.updater_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), (Finch.FinchNotation.reader_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), (Finch.FinchNotation.reader_instance)(), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
end
end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))); )
begin
Expand Down Expand Up @@ -91,7 +91,7 @@

```jldoctest example1; filter=r"Finch\.FinchNotation\."
julia> typeof(prgm)
Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}}
Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.UpdaterInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}}

julia> C = Finch.execute(prgm).C
5 Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}:
Expand Down Expand Up @@ -169,7 +169,7 @@
end

julia> typeof(inst)
Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}
Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.UpdaterInstance, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.ReaderInstance, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}

julia> Finch.virtualize(Finch.JuliaContext(), :inst, typeof(inst))
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
Expand Down Expand Up @@ -283,13 +283,13 @@
structure of the program as one would call constructors to build it. For
example,

```jldoctest example2; setup = :(using Finch)

Check failure on line 286 in docs/src/docs/internals/virtualization.md

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in src/docs/internals/virtualization.md:286-311 ```jldoctest example2; setup = :(using Finch) julia> prgm_inst = Finch.@finch_program_instance for i = _ s[] += A[i] end; julia> println(prgm_inst) loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), updater_instance()), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader_instance(), tag_instance(variable_instance(:i), index_instance(i))))) julia> prgm_inst Finch program instance: for i = Dimensionless() tag(s, Scalar{0, Int64})[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)] end julia> prgm = Finch.@finch_program for i = _ s[] += A[i] end; julia> println(prgm) loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), index(i)))) julia> prgm Finch program: for i = virtual(Finch.FinchNotation.Dimensionless) Scalar{0, Int64}(0)[] <<+>>= Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))[i] end ``` Subexpression: println(prgm) Evaluated output: loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), updater()), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader(), index(i)))) Expected output: loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), index(i)))) diff = Warning: Diff output requires color. loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), updater()), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), reader(), index(i))))
julia> prgm_inst = Finch.@finch_program_instance for i = _
s[] += A[i]
end;

julia> println(prgm_inst)
loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), literal_instance(Finch.FinchNotation.Updater())), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal_instance(Finch.FinchNotation.Reader()), tag_instance(variable_instance(:i), index_instance(i)))))
loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), updater_instance()), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader_instance(), tag_instance(variable_instance(:i), index_instance(i)))))

julia> prgm_inst
Finch program instance: for i = Dimensionless()
Expand All @@ -316,7 +316,7 @@
[AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)
representations, so you can use the standard `operation`, `arguments`, `istree`, and `children` functions to inspect the structure of the program, as well as the rewriters defined by [RewriteTools.jl](https://github.com/willow-ahrens/RewriteTools.jl)

```jldoctest example2; setup = :(using Finch, AbstractTrees, SyntaxInterface, RewriteTools)

Check failure on line 319 in docs/src/docs/internals/virtualization.md

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in src/docs/internals/virtualization.md:319-331 ```jldoctest example2; setup = :(using Finch, AbstractTrees, SyntaxInterface, RewriteTools) julia> using Finch.FinchNotation; julia> PostOrderDFS(prgm) PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), index(i))))) julia> (@capture prgm loop(~idx, ~ext, ~val)) true julia> idx Finch program: i ``` Subexpression: PostOrderDFS(prgm) Evaluated output: PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), updater()), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), reader(), index(i))))) Expected output: PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), index(i))))) diff = Warning: Diff output requires color. PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), updater()), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), reader(), index(i)))))
julia> using Finch.FinchNotation;


Expand Down
8 changes: 4 additions & 4 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ end
#Finch.is_atomic(ctx, tns::VirtualSparseMatrixCSCColumn) = is_atomic(ctx, tns.mtx)[1]
#Finch.is_concurrent(ctx, tns::VirtualSparseMatrixCSCColumn) = is_concurrent(ctx, tns.mtx)[1]

function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)})
function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode, ::Union{typeof(defaultread), typeof(walk)})
tag = arr.ex
Unfurled(
arr = arr,
Expand All @@ -230,7 +230,7 @@ function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, m
)
end

function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)})
function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseMatrixCSC, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)})
tag = arr.ex
Unfurled(
arr = arr,
Expand Down Expand Up @@ -383,7 +383,7 @@ function Finch.thaw!(ctx::AbstractCompiler, arr::VirtualSparseVector)
return arr
end

function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode::Reader, ::Union{typeof(defaultread), typeof(walk)})
function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode, ::Union{typeof(defaultread), typeof(walk)})
tag = arr.ex
Ti = arr.Ti
my_i = freshen(ctx, tag, :_i)
Expand Down Expand Up @@ -439,7 +439,7 @@ function Finch.unfurl(ctx::AbstractCompiler, arr::VirtualSparseVector, ext, mode
)
end

function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode::Updater, ::Union{typeof(defaultupdate), typeof(extrude)})
function Finch.unfurl(ctx, arr::VirtualSparseVector, ext, mode, ::Union{typeof(defaultupdate), typeof(extrude)})
tag = arr.ex
Tp = arr.Ti
qos = freshen(ctx, tag, :_qos)
Expand Down
2 changes: 1 addition & 1 deletion src/FinchNotation/FinchNotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module FinchNotation
export tag
export call
export cached
export reader, Reader, updater, Updater, access
export reader, updater, access
export define, declare, thaw, freeze
export block
export protocol
Expand Down
12 changes: 10 additions & 2 deletions src/FinchNotation/instances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct LoopInstance{Idx, Ext, Body} <: FinchNodeInstance idx::Idx; ext::Ext; bod
struct SieveInstance{Cond, Body} <: FinchNodeInstance cond::Cond; body::Body end
struct AssignInstance{Lhs, Op, Rhs} <: FinchNodeInstance lhs::Lhs; op::Op; rhs::Rhs end
struct CallInstance{Op, Args<:Tuple} <: FinchNodeInstance op::Op; args::Args end
struct ReaderInstance{} <: FinchNodeInstance end
struct UpdaterInstance{} <: FinchNodeInstance end
struct AccessInstance{Tns, Mode, Idxs} <: FinchNodeInstance tns::Tns; mode::Mode; idxs::Idxs end
struct TagInstance{Var, Bind} <: FinchNodeInstance var::Var; bind::Bind end
struct YieldBindInstance{Args} <: FinchNodeInstance args::Args end
Expand All @@ -35,15 +37,15 @@ Base.getproperty(::VariableInstance{val}, name::Symbol) where {val} = name == :n
@inline sieve_instance(cond, args...) = SieveInstance(cond, sieve_instance(args...))
@inline assign_instance(lhs, op, rhs) = AssignInstance(lhs, op, rhs)
@inline call_instance(op, args...) = CallInstance(op, args)
@inline reader_instance() = ReaderInstance()
@inline updater_instance() = UpdaterInstance()
@inline access_instance(tns, mode, idxs...) = AccessInstance(tns, mode, idxs)
@inline tag_instance(var, bind) = TagInstance(var, bind)
@inline yieldbind_instance(args...) = YieldBindInstance(args)

@inline finch_leaf_instance(arg::Type) = literal_instance(arg)
@inline finch_leaf_instance(arg::Function) = literal_instance(arg)
@inline finch_leaf_instance(arg::FinchNodeInstance) = arg
@inline finch_leaf_instance(arg::Reader) = literal_instance(arg)
@inline finch_leaf_instance(arg::Updater) = literal_instance(arg)
@inline finch_leaf_instance(arg) = arg

SyntaxInterface.istree(node::FinchNodeInstance) = Int(operation(node)) & IS_TREE != 0
Expand All @@ -62,6 +64,8 @@ instance_ctrs = Dict(
sieve => sieve_instance,
assign => assign_instance,
call => call_instance,
reader => reader_instance,
updater => updater_instance,
access => access_instance,
variable => variable_instance,
tag => tag_instance,
Expand All @@ -83,6 +87,8 @@ SyntaxInterface.operation(::LoopInstance) = loop
SyntaxInterface.operation(::SieveInstance) = sieve
SyntaxInterface.operation(::AssignInstance) = assign
SyntaxInterface.operation(::CallInstance) = call
SyntaxInterface.operation(::ReaderInstance) = reader
SyntaxInterface.operation(::UpdaterInstance) = updater
SyntaxInterface.operation(::AccessInstance) = access
SyntaxInterface.operation(::VariableInstance) = variable
SyntaxInterface.operation(::TagInstance) = tag
Expand All @@ -97,6 +103,8 @@ SyntaxInterface.arguments(node::LoopInstance) = [node.idx, node.ext, node.body]
SyntaxInterface.arguments(node::SieveInstance) = [node.cond, node.body]
SyntaxInterface.arguments(node::AssignInstance) = [node.lhs, node.op, node.rhs]
SyntaxInterface.arguments(node::CallInstance) = [node.op, node.args...]
SyntaxInterface.arguments(node::ReaderInstance) = []
SyntaxInterface.arguments(node::UpdaterInstance) = []
SyntaxInterface.arguments(node::AccessInstance) = [node.tns, node.mode, node.idxs...]
SyntaxInterface.arguments(node::TagInstance) = [node.var, node.bind]
SyntaxInterface.arguments(node::YieldBindInstance) = node.args
Expand Down
28 changes: 19 additions & 9 deletions src/FinchNotation/nodes.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
struct Reader end
struct Updater end

const reader = Reader()
const updater = Updater()

const IS_TREE = 1
const IS_STATEFUL = 2
const IS_CONST = 4
Expand All @@ -17,7 +11,9 @@ const ID = 8
virtual = 4ID
tag = 5ID | IS_TREE
call = 6ID | IS_TREE
access = 7ID | IS_TREE
reader = 7ID | IS_TREE
updater = 8ID | IS_TREE
access = 9ID | IS_TREE
cached = 10ID | IS_TREE
assign = 11ID | IS_TREE | IS_STATEFUL
loop = 12ID | IS_TREE | IS_STATEFUL
Expand Down Expand Up @@ -88,6 +84,20 @@ Finch AST expression for the result of calling the function `op` on `args...`.
"""
call

"""
reader()

Finch AST expression representing a read mode for an access operation.
"""
reader

"""
updater()

Finch AST expression representing an update mode for an access operation.
"""
updater

"""
access(tns, mode, idx...)

Expand Down Expand Up @@ -259,6 +269,8 @@ function FinchNode(kind::FinchNodeKind, args::Vector)
elseif (kind === value || kind === literal || kind === index || kind === variable || kind === virtual) && length(args) == 2
return FinchNode(kind, args[1], args[2], FinchNode[])
elseif (kind === cached && length(args) == 2) ||
(kind === reader && length(args) == 0) ||
(kind === updater && length(args) == 0) ||
(kind === access && length(args) >= 2) ||
(kind === tag && length(args) == 2) ||
(kind === call && length(args) >= 1) ||
Expand Down Expand Up @@ -392,8 +404,6 @@ virtual.
finch_leaf(arg) = literal(arg)
finch_leaf(arg::Type) = literal(arg)
finch_leaf(arg::Function) = literal(arg)
finch_leaf(arg::Reader) = literal(arg)
finch_leaf(arg::Updater) = literal(arg)
finch_leaf(arg::FinchNode) = arg

Base.convert(::Type{FinchNode}, x) = finch_leaf(x)
Expand Down
14 changes: 7 additions & 7 deletions src/FinchNotation/syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ const program_nodes = (
call = call,
access = access,
yieldbind = yieldbind,
reader = literal(reader),
updater = literal(updater),
reader = reader,
updater = updater,
variable = variable,
tag = (ex) -> :(finch_leaf($(esc(ex)))),
literal = literal,
Expand All @@ -36,8 +36,8 @@ const instance_nodes = (
call = call_instance,
access = access_instance,
yieldbind = yieldbind_instance,
reader = literal_instance(reader),
updater = literal_instance(updater),
reader = reader_instance,
updater = updater_instance,
variable = variable_instance,
tag = (ex) -> :($tag_instance($(variable_instance(ex)), $finch_leaf_instance($(esc(ex))))),
literal = literal_instance,
Expand Down Expand Up @@ -199,17 +199,17 @@ function (ctx::FinchParserVisitor)(ex::Expr)
return :($(ctx.nodes.yieldbind)($(ctx(arg))))
elseif @capture ex :ref(~tns, ~idxs...)
mode = ctx.nodes.reader
return :($(ctx.nodes.access)($(ctx(tns)), $mode, $(map(ctx, idxs)...)))
return :($(ctx.nodes.access)($(ctx(tns)), $mode(), $(map(ctx, idxs)...)))
elseif (@capture ex (~op)(~lhs, ~rhs)) && haskey(incs, op)
return ctx(:($lhs << $(incs[op]) >>= $rhs))
elseif @capture ex :(=)(:ref(~tns, ~idxs...), ~rhs)
mode = ctx.nodes.updater
lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode, $(map(ctx, idxs)...)))
lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode(), $(map(ctx, idxs)...)))
op = :($(ctx.nodes.literal)($initwrite))
return :($(ctx.nodes.assign)($lhs, $op, $(ctx(rhs))))
elseif @capture ex :>>=(:call(:<<, :ref(~tns, ~idxs...), ~op), ~rhs)
mode = ctx.nodes.updater
lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode, $(map(ctx, idxs)...)))
lhs = :($(ctx.nodes.access)($(ctx(tns)), $mode(), $(map(ctx, idxs)...)))
return :($(ctx.nodes.assign)($lhs, $(ctx(op)), $(ctx(rhs))))
elseif @capture ex :>>=(:call(:<<, ~lhs, ~op), ~rhs)
error("Finch doesn't support incrementing definitions of variables")
Expand Down
2 changes: 2 additions & 0 deletions src/FinchNotation/virtualize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ function Finch.virtualize(ctx, ex, ::Type{FinchNotation.CallInstance{Op, Args}})
end
call(op, args...)
end
Finch.virtualize(ctx, ex, ::Type{FinchNotation.ReaderInstance}) = reader()
Finch.virtualize(ctx, ex, ::Type{FinchNotation.UpdaterInstance}) = updater()
function Finch.virtualize(ctx, ex, ::Type{FinchNotation.AccessInstance{Tns, Mode, Idxs}}) where {Tns, Mode, Idxs}
tns = virtualize(ctx, :($ex.tns), Tns)
idxs = map(enumerate(Idxs.parameters)) do (n, Idx)
Expand Down
Loading
Loading