Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Logit focal loss #2138

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -603,14 +603,56 @@ function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
end

"""
logit_focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ))

Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
which can be used in classification tasks with highly imbalanced classes.
It down-weights well-classified examples and focuses on hard examples.
The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output).

The modulating factor, `γ`, controls the down-weighting strength.
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref).

# Example
```jldoctest
julia> y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
3×5 Matrix{Int64}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0

julia> ŷ = reshape(-7:7, 3, 5) .* 1f0
3×5 Matrix{Float32}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241

julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628
true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example output doesn't match what's written.

More importantly, the example is an opportunity to show exactly how this relates to focal_loss, i.e. where the softmax goes. And perhaps (if you can think of a compact & clear way) the relation to crossentropy (or rather logitcrossentropy?) too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i still need to work through these tests, did not realize about the docstring tests until after already putting tests elsewhere :) Can do !

```

See also: [`Losses.focal_loss`](@ref)

"""
function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ))
Copy link
Member

@mcabbott mcabbott Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some have crept in & need fixing, but there should not be greek-letter keywords. These can be gamma and eps?

Also, as written, γ=1.5 will cause Float32 input to be promoted to Float64. Can you avoid this somehow? Perhaps there should be a line like γ = gamma isa Integer ? gamma : convert(eltype(logpt), gamma). (Integer powers are faster.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice idea, can do!

_check_sizes(ŷ, y)
logpt = logsoftmax(ŷ; dims=dims)
logpt .+= ϵ
loss = agg(sum(@. -y * (1 - exp.(logpt))^γ * logpt; dims=dims))
return loss
josephsdavid marked this conversation as resolved.
Show resolved Hide resolved
end

"""
siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean)

Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
which can be useful for training Siamese Networks. It is given by
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)

agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)

Specify `margin` to set the baseline for distance at which pairs are dissimilar.

# Example
Expand Down
17 changes: 15 additions & 2 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.dice_coeff_loss,
Flux.Losses.poisson_loss,
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss,
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss]
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss, Flux.Losses.logit_focal_loss]


@testset "xlogx & xlogy" begin
Expand Down Expand Up @@ -210,7 +210,20 @@ end
@test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157
@test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y)
end


@testset "logit_focal_loss" begin
rng = Random.seed!(Random.default_rng(), 5)
y = rand(rng, Float32, 6, 40, 2)
yhat = rand(rng, Float32, 6, 40, 2)

@test logit_focal_loss(yhat, y; γ=0) ≈
Flux.Losses.logitcrossentropy(yhat, y)


@test logit_focal_loss(yhat, y; γ=2) ==
Flux.Losses.focal_loss(Flux.softmax(yhat; dims=1), y; γ=2)
end

@testset "siamese_contrastive_loss" begin
y = [1 0
0 0
Expand Down