From 8bf84a736379febf5a3e5a62f97f0c9df1a93463 Mon Sep 17 00:00:00 2001 From: ArrogantGao Date: Tue, 6 Aug 2024 10:56:06 +0800 Subject: [PATCH] add interface for the viz tool by OMEinsumContractionOrders v0.9 --- Project.toml | 5 +++-- src/OMEinsum.jl | 4 +++- src/contractionorder.jl | 3 +++ test/contractionorder.jl | 19 +++++++++++++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index bcb7d2b..ac2dac9 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ CUDA = "4, 5" ChainRulesCore = "1" Combinatorics = "1.0" MacroTools = "0.5" -OMEinsumContractionOrders = "0.8, 0.9" +OMEinsumContractionOrders = "0.9" TupleTools = "1.2, 1.3" julia = "1" @@ -39,6 +39,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc" Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -48,4 +49,4 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "CUDA", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"] +test = ["Test", "CUDA", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials", "LuxorGraphPlot"] diff --git a/src/OMEinsum.jl b/src/OMEinsum.jl index e617382..6c40e86 100644 --- a/src/OMEinsum.jl +++ b/src/OMEinsum.jl @@ -26,7 +26,9 @@ export CodeOptimizer, CodeSimplifier, peak_memory, timespace_complexity, timespacereadwrite_complexity, flop, contraction_complexity, # file io writejson, readjson, - label_elimination_order + label_elimination_order, + # visualization + viz_eins, viz_contraction include("Core.jl") include("loop_einsum.jl") diff --git a/src/contractionorder.jl b/src/contractionorder.jl index b9c73dd..759fdaf 100644 --- a/src/contractionorder.jl +++ b/src/contractionorder.jl @@ -38,6 +38,9 @@ OMEinsumContractionOrders.contraction_complexity(code::AbstractEinsum, size_dict OMEinsumContractionOrders.uniformsize(code::AbstractEinsum, size) = Dict([l=>size for l in uniquelabels(code)]) OMEinsumContractionOrders.label_elimination_order(code::AbstractEinsum) = label_elimination_order(rawcode(code)) +OMEinsumContractionOrders.viz_eins(code::AbstractEinsum, args...; kwargs...) = viz_eins(rawcode(code), args...; kwargs...) +OMEinsumContractionOrders.viz_contraction(code::AbstractEinsum, args...; kwargs...) = viz_contraction(rawcode(code), args...; kwargs...) + # save load function writejson(filename::AbstractString, ne::Union{NestedEinsum, SlicedEinsum}) OMEinsumContractionOrders.writejson(filename, rawcode(ne)) diff --git a/test/contractionorder.jl b/test/contractionorder.jl index f52de02..2d07dba 100644 --- a/test/contractionorder.jl +++ b/test/contractionorder.jl @@ -100,4 +100,23 @@ end @test optcode == code2 end end +end + +using LuxorGraphPlot + +@testset "visualization tool" begin + eincode = ein"ab,acd,bcef,e,df->" + nested_ein = optein"ab,acd,bcef,e,df->" + + graph_1 = viz_eins(eincode) + @test graph_1 isa LuxorGraphPlot.Luxor.Drawing + + graph_2 = viz_eins(nested_ein) + @test graph_2 isa LuxorGraphPlot.Luxor.Drawing + + gif = viz_contraction(nested_ein, filename = tempname() * ".gif") + @test gif isa String + + video = viz_contraction(nested_ein) + @test video isa String end \ No newline at end of file