Skip to content

Commit

Permalink
feat: auto-training mode and strict checks
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 29, 2024
1 parent 56e40d8 commit fb000d0
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 39 deletions.
4 changes: 4 additions & 0 deletions ext/LuxLibReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x)
Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x)
Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T)

Utils.within_gradient(::TrackedReal) = True()
Utils.within_gradient(::TrackedArray) = True()
Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True()

# Traits extensions
Traits.is_tracked(::Type{<:TrackedReal}) = True()

Expand Down
4 changes: 4 additions & 0 deletions ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ Utils.remove_tracking(x::TrackedArray) = Tracker.data(x)
Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x)
Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T)

Utils.within_gradient(::TrackedReal) = True()
Utils.within_gradient(::TrackedArray) = True()
Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True()

# Traits extensions
Traits.is_tracked(::Type{<:TrackedReal}) = True()

Expand Down
3 changes: 0 additions & 3 deletions ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ end

function CRC.rrule(
::typeof(Impl.batchnorm_cudnn), γ, β, x, rμ, rσ², m, ϵ, training::StaticBool)
# TODO: Transition this to an error in the future
unsafe_known(training) ||
@warn "`training=Val(false)` but gradient was called." maxlog=1
y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, rμ, rσ², m, ϵ, training)
𝒫x, 𝒫γ, 𝒫β = CRC.ProjectTo(x), CRC.ProjectTo(γ), CRC.ProjectTo(β)
∇batchnorm_cudnn = @closure Δ -> begin
Expand Down
4 changes: 3 additions & 1 deletion src/api/API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ using Static: Static, StaticBool, static

using ..LuxLib: Optional
using ..Impl: Impl, select_fastest_activation
using ..Utils: default_epsilon, expand_batchdim, remove_tracking
using ..Utils: default_epsilon, expand_batchdim, remove_tracking, static_training_mode

const CRC = ChainRulesCore

const TrainingType = Union{Val{true}, Val{false}, StaticBool, Nothing}

# The names are aliased so we define constants for them
for op in (:batched_matmul, :batchnorm, :bias_activation, :bias_activation!!,
:dropout, :alpha_dropout, :groupnorm, :instancenorm, :layernorm,
Expand Down
14 changes: 8 additions & 6 deletions src/api/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@doc doc"""
batchnorm(x, scale, bias, running_mean, running_var, training::Union{Val, StaticBool},
batchnorm(x, scale, bias, running_mean, running_var, training,
σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7))
Batch Normalization. For details see [1].
Expand All @@ -15,7 +15,9 @@ accordingly.
- `bias`: Bias factor (``\beta``) (can be `nothing`)
- `running_mean`: Running mean (can be `nothing`)
- `running_var`: Running variance (can be `nothing`)
- `training`: Set to `Val(true)` if running in training mode
- `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to
`nothing` to automatically determine if the function is being called within an autodiff
context
- `σ`: Activation function (default: `identity`)
- `momentum`: Momentum for updating running mean and variance (default: `0.1f0`)
- `epsilon`: Value added to the denominator for numerical stability
Expand All @@ -34,11 +36,11 @@ mean and variance.
"""
function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector},
β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector},
rσ²::Optional{<:AbstractVector}, training::Union{Val, StaticBool},
act::F=identity, momentum::Real=0.1f0,
epsilon::Real=default_epsilon(x)) where {F, T, N}
rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity,
momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N}
σ = select_fastest_activation(act, x, γ, β, rμ, rσ²)
y, rμ, rσ² = batchnorm_impl(
x, γ, β, rμ, rσ², static(training), σ, momentum, epsilon)
x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²),
σ, momentum, epsilon)
return y, (; running_mean=remove_tracking(rμ), running_var=remove_tracking(rσ²))
end
45 changes: 23 additions & 22 deletions src/api/dropout.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, invp, dims)
dropout(rng::AbstractRNG, x, mask, p, training::Union{Val, StaticBool},
update_mask::Union{Val, StaticBool}, invp, dims)
dropout(rng::AbstractRNG, x, p, training, invp, dims)
dropout(rng::AbstractRNG, x, mask, p, training, update_mask::Union{Val, StaticBool},
invp, dims)
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Expand All @@ -11,10 +11,11 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
- `x`: Input Array
- `mask`: Dropout Mask. If not used then it is constructed automatically
- `p`: Probability of an element to be dropped out
- `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along
`dims`. Else, `x` is returned
- `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask`
provided is directly used
- `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to
`nothing` to automatically determine if the function is being called within an autodiff
context
- `update_mask`: If `Val(true)` or `True()` then the mask is generated and used. Else, the
`mask` provided is directly used
- `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`.
## Returns
Expand All @@ -28,20 +29,20 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
"""
function dropout(rng::AbstractRNG, x::AbstractArray, p::T,
training::Union{Val, StaticBool}, invp::T, dims) where {T}
return dropout_impl(rng, x, p, static(training), invp, dims)
function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::TrainingType, invp::T,
dims) where {T}
return dropout_impl(rng, x, p, static_training_mode(training, x), invp, dims)
end

function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
p::T, training::Union{Val, StaticBool},
update_mask::Union{Val, StaticBool}, invp::T, dims) where {T}
return dropout_impl(rng, x, mask, p, static(training), static(update_mask), invp, dims)
p::T, training::TrainingType, update_mask::TrainingType, invp::T, dims) where {T}
return dropout_impl(rng, x, mask, p, static_training_mode(training, x),
static(update_mask), invp, dims)
end

"""
alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool})
alpha_dropout(rng::AbstractRNG, x, p, training::Union{Val, StaticBool}, α, A, B)
alpha_dropout(rng::AbstractRNG, x, p, training)
alpha_dropout(rng::AbstractRNG, x, p, training, α, A, B)
Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the
input. For details see [1]. Use the second call signature to avoid recomputing the constants
Expand All @@ -52,8 +53,9 @@ for a fixed dropout probability.
- `rng`: Random number generator
- `x`: Input Array
- `p`: Probability of an element to be dropped out
- `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else,
`x` is returned
- `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to
`nothing` to automatically determine if the function is being called within an autodiff
context`
- `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α`
- `A`: Scaling factor for the mean
- `B`: Scaling factor for the variance
Expand All @@ -68,12 +70,11 @@ for a fixed dropout probability.
[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural
information processing systems 30 (2017).
"""
function alpha_dropout(
rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool})
return alpha_dropout_impl(rng, x, p, static(training))
function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training::TrainingType)
return alpha_dropout_impl(rng, x, p, static_training_mode(training, x))
end

function alpha_dropout(
rng::AbstractRNG, x::AbstractArray, p, training::Union{Val, StaticBool}, α, A, B)
return alpha_dropout_impl(rng, x, p, static(training), α, A, B)
rng::AbstractRNG, x::AbstractArray, p, training::TrainingType, α, A, B)
return alpha_dropout_impl(rng, x, p, static_training_mode(training, x), α, A, B)
end
14 changes: 8 additions & 6 deletions src/api/instancenorm.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@doc doc"""
instancenorm(x, scale, bias, training::Union{Val, StaticBool}, σ = identity,
instancenorm(x, scale, bias, training, σ = identity,
epsilon = eps(eltype(x)) ^ (5 // 7))
Instance Normalization. For details see [1].
Expand All @@ -16,7 +16,9 @@ accordingly.
- `σ`: Activation function (default: `identity`)
- `epsilon`: Value added to the denominator for numerical stability
(default: `eps(eltype(x)) ^ (5 / 7)`)
- `training`: Set to `Val(true)` if running in training mode
- `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to
`nothing` to automatically determine if the function is being called within an autodiff
context
## Returns
Expand All @@ -29,13 +31,13 @@ mean and variance.
missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
"""
function instancenorm(x::AbstractArray, scale::Optional{<:AbstractVector},
bias::Optional{<:AbstractVector}, training::Union{Val, StaticBool}=Val(false),
bias::Optional{<:AbstractVector}, training::TrainingType,
σ::F=identity, epsilon::Real=default_epsilon(x)) where {F}
assert_valid_instancenorm_arguments(x)

σ′ = select_fastest_activation(σ, x, scale, bias)
y, xμ, xσ² = instancenorm_impl(
x, nothing, nothing, scale, bias, static(training), nothing, epsilon, σ′)
y, xμ, xσ² = instancenorm_impl(x, nothing, nothing, scale, bias,
static_training_mode(training, x, scale, bias), nothing, epsilon,
select_fastest_activation(σ, x, scale, bias))

return y, (; running_mean=xμ, running_var=xσ²)
end
Expand Down
37 changes: 36 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using KernelAbstractions: KernelAbstractions
using LinearAlgebra: LinearAlgebra, BLAS
using MLDataDevices: get_device_type, CPUDevice
using NNlib: NNlib
using Static: Static, False, True
using Static: Static, StaticBool, False, True, static
using StaticArraysCore: SVector, SMatrix

using ..LuxLib: Optional, ∂∅
Expand Down Expand Up @@ -231,4 +231,39 @@ end
return
end

within_gradient_vararg(args...) = unrolled_any(within_gradient, args)

within_gradient(_) = False()
within_gradient(::ForwardDiff.Dual) = True()
within_gradient(::AbstractArray{<:ForwardDiff.Dual}) = True()

CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅)

static_training_mode(::Nothing, args...) = within_gradient_vararg(args...)

function static_training_mode(
training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...)
return static_training_mode_check(
training, static(training), within_gradient_vararg(args...))
end

static_training_mode_check(_, ::True, ::True) = True()
static_training_mode_check(_, ::False, ::False) = False()

function static_training_mode_check(training, ::True, ::False)
safe_warning(
"`training` is set to `$(training)` but is not being used within an autodiff call \
(gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` \
model, set it to inference (test) mode using `LuxCore.testmode`.", -1)
return True()
end

function static_training_mode_check(training, ::False, ::True)
safe_warning(
"`training` is set to `$(training)` but is being used within an autodiff call \
(gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` \
model, set it to training mode using `LuxCore.trainmode`.", -1)
return False()
end

end

0 comments on commit fb000d0

Please sign in to comment.