From 9040c6309a6b921da8c13a28eb6703ca05b4cc27 Mon Sep 17 00:00:00 2001 From: Mikael Slevinsky Date: Tue, 7 Jul 2020 10:10:16 -0500 Subject: [PATCH 1/3] fix some floating-point logic --- Project.toml | 2 +- src/dual.jl | 16 +++++++++------- test/automatic_differentiation_test.jl | 20 ++++++++++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 26519af..744dd22 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DualNumbers" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.2" +version = "0.7" [deps] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" diff --git a/src/dual.jl b/src/dual.jl index c82e8c0..4e43937 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -42,9 +42,11 @@ dual(z::Dual) = z const realpart = value const dualpart = epsilon -Base.isnan(z::Dual) = isnan(value(z)) -Base.isinf(z::Dual) = isinf(value(z)) -Base.isfinite(z::Dual) = isfinite(value(z)) +Base.isfinite(z::Dual) = isfinite(value(z)) & isfinite(epsilon(z)) +Base.isnan(z::Dual) = isnan(value(z)) | isnan(epsilon(z)) +Base.isinf(z::Dual) = isinf(value(z)) | isinf(epsilon(z)) +Base.iszero(z::Dual) = iszero(value(z)) & iszero(epsilon(z)) +Base.isone(z::Dual) = isone(value(z)) & iszero(epsilon(z)) isdual(x::Dual) = true isdual(x::Number) = false Base.eps(z::Dual) = eps(value(z)) @@ -163,11 +165,11 @@ end Base.convert(::Type{Dual}, z::Dual) = z Base.convert(::Type{Dual}, x::Number) = Dual(x) -Base.:(==)(z::Dual, w::Dual) = value(z) == value(w) -Base.:(==)(z::Dual, x::Number) = value(z) == x -Base.:(==)(x::Number, z::Dual) = value(z) == x +Base.:(==)(z::Dual, w::Dual) = (value(z) == value(w)) & (epsilon(z) == epsilon(w)) +Base.:(==)(z::Dual, x::Number) = iszero(epsilon(z)) && value(z) == x +Base.:(==)(x::Number, z::Dual) = z == x -Base.isequal(z::Dual, w::Dual) = isequal(value(z),value(w)) && isequal(epsilon(z), epsilon(w)) +Base.isequal(z::Dual, w::Dual) = isequal(value(z), value(w)) && isequal(epsilon(z), epsilon(w)) Base.isequal(z::Dual, x::Number) = isequal(value(z), x) && isequal(epsilon(z), zero(x)) Base.isequal(x::Number, z::Dual) = isequal(z, x) diff --git a/test/automatic_differentiation_test.jl b/test/automatic_differentiation_test.jl index b6126be..c629c82 100644 --- a/test/automatic_differentiation_test.jl +++ b/test/automatic_differentiation_test.jl @@ -63,6 +63,26 @@ x = Dual(1.0,1.0) @test convert(Dual{Float64}, Inf) == convert(Float64, Inf) @test isnan(convert(Dual{Float64}, NaN)) +w = 1.0 +x = Dual(1.0, 0.0) +y = Dual(1.0, 1.0) +z = Dual(1.0, NaN) +@test w !== x +@test w == x +@test isequal(w, x) +@test x !== y +@test x != y +@test isfinite(x) +@test !isfinite(z) +@test isnan(z) +@test !isinf(z) +@test z != z +@test isequal(z, z) +x = Dual(0.0, 0.0) +@test x !== -x +@test x == -x +@test !isequal(x, -x) + @test convert(Dual{Float64},Dual(1,2)) == Dual(1.0,2.0) @test convert(Float64, Dual(10.0,0.0)) == 10.0 @test convert(Dual{Int}, Dual(10.0,0.0)) == Dual(10,0) From 922e86aa31434b09511cd6957eea0d40667d3d3f Mon Sep 17 00:00:00 2001 From: Mikael Slevinsky Date: Tue, 7 Jul 2020 11:47:38 -0500 Subject: [PATCH 2/3] fix up / of dual numbers to not use squaring --- src/dual.jl | 6 +++--- test/automatic_differentiation_test.jl | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 4e43937..632bb65 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -242,8 +242,8 @@ Base.:*(z::Dual, w::Dual) = Dual(value(z)*value(w), epsilon(z)*value(w)+value(z) Base.:*(x::Number, z::Dual) = Dual(x*value(z), x*epsilon(z)) Base.:*(z::Dual, x::Number) = Dual(x*value(z), x*epsilon(z)) -Base.:/(z::Dual, w::Dual) = Dual(value(z)/value(w), (epsilon(z)*value(w)-value(z)*epsilon(w))/(value(w)*value(w))) -Base.:/(z::Number, w::Dual) = Dual(z/value(w), -z*epsilon(w)/value(w)^2) +Base.:/(z::Dual, w::Dual) = Dual(value(z)/value(w), (epsilon(z)-value(z)/value(w)*epsilon(w))/value(w)) +Base.:/(z::Number, w::Dual) = Dual(z/value(w), -z*epsilon(w)/value(w)/value(w)) Base.:/(z::Dual, x::Number) = Dual(value(z)/x, epsilon(z)/x) for f in [:(Base.:^), :(NaNMath.pow)] @@ -270,7 +270,7 @@ Base.:^(z::Dual, n::Number) = Dual(value(z)^n, epsilon(z)*n*value(z)^(n-1)) NaNMath.pow(z::Dual, n::Number) = Dual(NaNMath.pow(value(z),n), epsilon(z)*n*NaNMath.pow(value(z),n-1)) NaNMath.pow(z::Number, w::Dual) = Dual(NaNMath.pow(z,value(w)), epsilon(w)*NaNMath.pow(z,value(w))*log(z)) -Base.inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)^2) +Base.inv(z::Dual) = dual(inv(value(z)),-epsilon(z)/value(z)/value(z)) # force use of NaNMath functions in derivative calculations function to_nanmath(x::Expr) diff --git a/test/automatic_differentiation_test.jl b/test/automatic_differentiation_test.jl index c629c82..4926c8a 100644 --- a/test/automatic_differentiation_test.jl +++ b/test/automatic_differentiation_test.jl @@ -97,6 +97,9 @@ x = Dual(1.2,1.0) @test trunc(Int, x) === 1 @test round(Int, x) === 1 +x = 2dual(sqrt(floatmax(Float64)), sqrt(floatmax(Float64))) +@test x/x == 1.0 + # test Dual{Complex} z = Dual(1.0+1.0im,1.0) From 942dab4796a5d1882760aa9c68b7c0ecba4c7eba Mon Sep 17 00:00:00 2001 From: Mikael Slevinsky Date: Tue, 7 Jul 2020 11:50:08 -0500 Subject: [PATCH 3/3] inv test --- test/automatic_differentiation_test.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/automatic_differentiation_test.jl b/test/automatic_differentiation_test.jl index 4926c8a..3f29438 100644 --- a/test/automatic_differentiation_test.jl +++ b/test/automatic_differentiation_test.jl @@ -99,6 +99,7 @@ x = Dual(1.2,1.0) x = 2dual(sqrt(floatmax(Float64)), sqrt(floatmax(Float64))) @test x/x == 1.0 +@test inv(x)*x == 1.0 # test Dual{Complex}