Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex + dualcache #47

Open
moble opened this issue Nov 1, 2022 · 2 comments
Open

Complex + dualcache #47

moble opened this issue Nov 1, 2022 · 2 comments

Comments

@moble
Copy link

moble commented Nov 1, 2022

Nowadays ForwardDiff seems to work nicely when using Complex (internally, at least), so I was trying to use dualcache with a complex array. Unfortunately, it seems like the more interesting methods of get_tmp base the returned eltype on the type of u (which is always some kind of Real) rather than on the type underlying the DiffCache. I wonder if it might be possible to generalize that.

It is possible to work around this by constructing the cache as zd = dualcache(reinterpret(reshape, T, z)) and then unpacking it inside the function as

    ztmp = get_tmp(zd, t)
    z = reinterpret(reshape, Complex{eltype(z)}, ztmp)

Obviously, this is ugly, as well as hard to remember to do. I've fallen into this trap a few times, getting weird error messages far from the source of the problem. Automatic handling of Complex would really make things nicer for me. :)

@thomvet
Copy link
Contributor

thomvet commented Nov 3, 2022

Sorry, I am not usually dealing with models where complex numbers appear (unless I am doing something wrong :)); can you provide a MWE that illustrates the issue and what you are trying to accomplish?

@moble
Copy link
Author

moble commented Nov 3, 2022

Sorry. Here's an MWE. The uncommented code is what I expected to be able to do; the commented code is what I have to do to get it to work.

using PreallocationTools
using ForwardDiff

z = zeros(ComplexF64, 20)
zd = dualcache(z)
#zd = dualcache(reinterpret(reshape, real(eltype(z)), z))

function sum_cis(θ)
    z = get_tmp(zd, θ)
    #ztmp = get_tmp(zd, θ)
    #z = reinterpret(reshape, Complex{eltype(ztmp)}, ztmp)
    for i  eachindex(z)
        z[i] = cis(i*θ)
    end
    abs(sum(z))
end

ForwardDiff.derivative(sum_cis, 1.1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants