diff --git a/docs/src/api.md b/docs/src/api.md index 1017af94..f3b7fdaf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -33,7 +33,7 @@ Optimisers.OptimiserChain ```@docs Optimisers.setup Optimisers.update -Optimisers.update! +Optimisers.update!! Optimisers.adjust(::Any, ::Real) ``` diff --git a/docs/src/index.md b/docs/src/index.md index 65b441bb..4acec851 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -60,7 +60,7 @@ Notice that a completely new instance of the model is returned. Internally, this is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the tree formed by the model and update the parameters using the gradients. -There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, +There is also [`Optimisers.update!!`](@ref) which similarly returns a new model and new state, but is free to mutate arrays within the old one for efficiency. The method of `apply!` for each rule is likewise free to mutate arrays within its state; they are defensively copied when this rule is used with `update`. diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 8e8cb19f..4e29c4cd 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -16,6 +16,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain + +@deprecate update! update!! false ### ### one-array functions ### @@ -70,7 +72,7 @@ init Initialises the given optimiser for every trainable parameter within the model. Returns a tree of the relevant states, which must be passed to [`update`](@ref) -or [`update!`](@ref). +or [`update!!`](@ref). # Example ```jldoctest @@ -113,7 +115,7 @@ Uses the optimiser and the gradient to change the trainable parameters in the mo Returns the improved model, and the optimiser states needed for the next update. The initial tree of states comes from [`setup`](@ref). -See also [`update!`](@ref), which will be faster for models of ordinary `Array`s or `CuArray`s. +See also [`update!!`](@ref), which will be faster for models of ordinary `Array`s or `CuArray`s. # Example ```jldoctest @@ -131,7 +133,7 @@ julia> Optimisers.update(t, m, g) update """ - Optimisers.update!(tree, model, gradient) -> (tree, model) + Optimisers.update!!(tree, model, gradient) -> (tree, model) Uses the optimiser and the gradient to change the trainable parameters in the model. Returns the improved model, and the optimiser states needed for the next update. @@ -154,12 +156,12 @@ julia> t = Optimisers.setup(Momentum(1/30, 0.9), m); julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] (x = Float32[10.0, 14.0], y = Float32[10.0, 14.0]) -julia> t2, m2 = Optimisers.update!(t, m, g); +julia> t2, m2 = Optimisers.update!!(t, m, g); -julia> m2 # after update or update!, this is the new model +julia> m2 # this is the model with new parameters (x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333]) -julia> m2.x === m.x # update! has re-used this array, for efficiency +julia> m2.x === m.x # update!! has re-used this array, for efficiency true julia> m # original should be discarded, may be mutated but no guarantee @@ -169,6 +171,6 @@ julia> t == t2 # original state is in fact guaranteed to be mutated true ``` """ -update! +update!! end # module diff --git a/src/interface.jl b/src/interface.jl index 79d03396..5f331622 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -59,7 +59,9 @@ function update(tree, model, grad, higher...) update!(t′, x′, grad, higher...) end -function update!(tree, model, grad, higher...) +update!!(tree, model, grad, higher...) = old_update!(tree, model, grad, higher...) + +function old_update!(tree, model, grad, higher...) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: grads = IdDict{Leaf, Any}()