From 4abb65e65d9ae2ae159b1e37f30f4921b9448b16 Mon Sep 17 00:00:00 2001 From: Duncan Starkenburg <45837337+dstarkenburg@users.noreply.github.com> Date: Sun, 8 Dec 2024 16:06:51 -0500 Subject: [PATCH] Adding support for Acos (#110) * Adding support for Acos * Float64 not supported, changed to Float32 --- src/load.jl | 3 +++ src/ops.jl | 1 + src/save.jl | 5 +++++ test/saveload.jl | 6 ++++++ 4 files changed, 15 insertions(+) diff --git a/src/load.jl b/src/load.jl index c0b401e..d4a5cfa 100644 --- a/src/load.jl +++ b/src/load.jl @@ -59,6 +59,9 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Abs}, args::VarVec, attrs::At return push_call!(tape, _abs, args[1]) end +function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acos}, args::VarVec, attrs::AttrDict) + return push_call!(tape, _acos, args[1]) +end function load_node!(tape::Tape, nd::NodeProto, backend::Symbol) args = [tape.c.name2var[name] for name in nd.input] diff --git a/src/ops.jl b/src/ops.jl index 7ee8620..2cbcef9 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -50,6 +50,7 @@ sub(xs...) = .-(xs...) _sin(x) = sin.(x) _cos(x) = cos.(x) _abs(x) = abs.(x) +_acos(x) = acos.(x) mul(xs...) = .*(xs...) relu(x) = NNlib.relu.(x) leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a) diff --git a/src/save.jl b/src/save.jl index a52342d..f956fa5 100644 --- a/src/save.jl +++ b/src/save.jl @@ -126,6 +126,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_abs)}, op::Umlaut.C push!(g.node, nd) end +function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acos)}, op::Umlaut.Call) + nd = NodeProto("Acos", op) + push!(g.node, nd) +end + function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call) nd = NodeProto( input=[onnx_name(v) for v in reverse(op.args)], diff --git a/test/saveload.jl b/test/saveload.jl index 82422c3..55407e6 100644 --- a/test/saveload.jl +++ b/test/saveload.jl @@ -36,6 +36,12 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name ort_test(ONNX._abs, A) end + @testset "Acos" begin + # ONNXRunTime has no implementation for Acos(x::Float64), using Float32 + A = rand(Float32, 3, 4) + ort_test(ONNX._acos, A) + end + @testset "Gemm" begin A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3)) ort_test(ONNX.onnx_gemm, A, B')