From 7972aad7ef62a9839eb9ea13bc6c28ddab70bd79 Mon Sep 17 00:00:00 2001 From: EvilDunk Date: Fri, 6 Dec 2024 12:56:13 -0500 Subject: [PATCH 1/3] Added support for Cos --- src/load.jl | 4 ++++ src/ops.jl | 1 + src/save.jl | 5 +++++ test/saveload.jl | 5 +++++ 4 files changed, 15 insertions(+) diff --git a/src/load.jl b/src/load.jl index 3985972..71ed41b 100644 --- a/src/load.jl +++ b/src/load.jl @@ -51,6 +51,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Sin}, args::VarVec, attrs::At return push_call!(tape, _sin, args[1]) end +function load_node!(tape::Tape, ::OpConfig{:ONNX, :Cos}, args::VarVec, attrs::AttrDict) + return push_call!(tape, _cos, args[1]) +end + function load_node!(tape::Tape, nd::NodeProto, backend::Symbol) args = [tape.c.name2var[name] for name in nd.input] attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute)) diff --git a/src/ops.jl b/src/ops.jl index 67403e9..a45b1fb 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -48,6 +48,7 @@ end add(xs...) = .+(xs...) sub(xs...) = .-(xs...) _sin(x) = sin.(x) +_cos(x) = cos.(x) mul(xs...) = .*(xs...) relu(x) = NNlib.relu.(x) leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a) diff --git a/src/save.jl b/src/save.jl index 5205da4..9228738 100644 --- a/src/save.jl +++ b/src/save.jl @@ -116,6 +116,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_sin)}, op::Umlaut.C push!(g.node, nd) end +function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_cos)}, op::Umlaut.Call) + nd = NodeProto("Cos", op) + push!(g.node, nd) +end + function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call) nd = NodeProto( input=[onnx_name(v) for v in reverse(op.args)], diff --git a/test/saveload.jl b/test/saveload.jl index 3d53276..cf4697f 100644 --- a/test/saveload.jl +++ b/test/saveload.jl @@ -25,6 +25,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name ort_test(ONNX._sin, A) end + @testset "Cos" begin + A = rand(3, 4) + ort_test(ONNX._cos, A) + end + @testset "Gemm" begin A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3)) ort_test(ONNX.onnx_gemm, A, B') From aaace22d911f0aa1eab31fab52010333927ae5c1 Mon Sep 17 00:00:00 2001 From: EvilDunk Date: Fri, 6 Dec 2024 13:04:14 -0500 Subject: [PATCH 2/3] Changes to Cos implementation --- src/load.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/load.jl b/src/load.jl index 71ed41b..27bd53a 100644 --- a/src/load.jl +++ b/src/load.jl @@ -51,7 +51,7 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Sin}, args::VarVec, attrs::At return push_call!(tape, _sin, args[1]) end -function load_node!(tape::Tape, ::OpConfig{:ONNX, :Cos}, args::VarVec, attrs::AttrDict) +function load_node!(tape::Tape, @opconfig_kw{:ONNX, :Cos}, args::Union{VarVec, Int}, attrs::AttrDict) return push_call!(tape, _cos, args[1]) end From f0a1aac3229ff030b74730a81b1ab7e9193ccdb1 Mon Sep 17 00:00:00 2001 From: EvilDunk Date: Fri, 6 Dec 2024 15:25:39 -0500 Subject: [PATCH 3/3] Changes --- src/load.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/load.jl b/src/load.jl index 27bd53a..4d67f26 100644 --- a/src/load.jl +++ b/src/load.jl @@ -51,7 +51,7 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Sin}, args::VarVec, attrs::At return push_call!(tape, _sin, args[1]) end -function load_node!(tape::Tape, @opconfig_kw{:ONNX, :Cos}, args::Union{VarVec, Int}, attrs::AttrDict) +function load_node!(tape::Tape, @opconfig_kw{:ONNX, :Cos}, args::Union{VarVec, Umlaut.Variable}, attrs::AttrDict) return push_call!(tape, _cos, args[1]) end