Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autodiff call runs fine with CPU backend but hangs with CUDABackend #2293

Open
SouthEndMusic opened this issue Feb 2, 2025 · 0 comments
Open

Comments

@SouthEndMusic
Copy link

Follow-up of #2290 (probably unrelated but same context)

MWE:

using SplineGrids
using Enzyme
using .EnzymeRules
using CUDA
using ConstructionBase

function SplineGrids.evaluate!(
    spline_grid::SplineGrid,
    control_points::SplineGrids.AbstractControlPointArray;
    kwargs...
)::Nothing
    evaluate!(
        spline_grid;
        control_points,
        kwargs...
    )
    return nothing
end

function Enzyme.make_zero(spline_grid::SplineGrid)
    setproperties(
        spline_grid;
        eval=make_zero(spline_grid.eval),
        control_points=make_zero(spline_grid.control_points)
    )
end

function Enzyme.make_zero!(spline_grid::SplineGrid)::Nothing
    make_zero!(spline_grid.eval)
    make_zero!(spline_grid.control_points)
    return nothing
end

function EnzymeRules.augmented_primal(
    config::RevConfigWidth{1},
    ::Const{typeof(evaluate!)},
    ::Type{Const{Nothing}},
    spline_grid::MixedDuplicated{<:SplineGrid},
    control_points::Duplicated{<:SplineGrids.AbstractControlPointArray};
    kwargs...
)
    evaluate!(spline_grid.val, control_points.val; kwargs...)
    primal = if needs_primal(config)
        spline_grid.val
    else
        nothing
    end
    shadow = if needs_shadow(config)
        spline_grid.dval
    else
        nothing
    end
    EnzymeRules.AugmentedReturn(primal, shadow, kwargs)
end

function EnzymeRules.reverse(
    ::RevConfigWidth{1},
    ::Const{typeof(evaluate!)},
    ::Type{Const{Nothing}},
    kwargs,
    spline_grid::MixedDuplicated{<:SplineGrid},
    control_points::Duplicated{<:SplineGrids.AbstractControlPointArray}
)
    evaluate_adjoint!(spline_grid.dval[]; control_points=control_points.dval, kwargs...)
    make_zero!(spline_grid.dval[])
    (nothing, nothing)
end

n_control_points = (10, 10)
degree = (2, 2)
n_sample_points = (50, 50)
Nout = 2
backend = CUDA.CUDABackend()

spline_dimensions = SplineDimension.(n_control_points, degree, n_sample_points; backend)
spline_grid = SplineGrid(spline_dimensions, Nout)

function loss(control_points_flat, spline_grid)
    evaluate!(
        spline_grid,
        reshape(control_points_flat, size(spline_grid.control_points))
    )
    return sum(spline_grid.eval)
end

control_points_flat = CUDA.rand(length(spline_grid.control_points))

dcontrol_points_flat = Duplicated(control_points_flat, make_zero(control_points_flat))
dspline_grid = Duplicated(spline_grid, make_zero(spline_grid))

autodiff(
    Reverse,
    loss,
    Active,
    dcontrol_points_flat,
    dspline_grid
)

env:

  [052768ef] CUDA v5.6.1
  [187b0558] ConstructionBase v1.5.8
  [7da242da] Enzyme v0.13.29
  [59c446ea] SplineGrids v0.1.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant