Skip to content

Commit 2e7596d

Browse files
authored
Addition of Siamese Contrastive Loss function ( Updated ) (#1892)
1 parent a67d184 commit 2e7596d

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

docs/src/models/losses.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ Flux.Losses.dice_coeff_loss
4141
Flux.Losses.tversky_loss
4242
Flux.Losses.binary_focal_loss
4343
Flux.Losses.focal_loss
44+
Flux.losses.siamese_contrastive_loss
4445
```

src/losses/Losses.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export mse, mae, msle,
2020
poisson_loss,
2121
hinge_loss, squared_hinge_loss,
2222
ctc_loss,
23-
binary_focal_loss, focal_loss
23+
binary_focal_loss, focal_loss, siamese_contrastive_loss
2424

2525
include("utils.jl")
2626
include("functions.jl")

src/losses/functions.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,23 @@ function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
527527
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
528528
end
529529

530+
"""
531+
siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean)
532+
533+
Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
534+
which can be useful for training Siamese Networks. It is given by
535+
536+
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
537+
538+
Specify `margin` to set the baseline for distance at which pairs are dissimilar.
539+
540+
"""
541+
function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
542+
_check_sizes(ŷ, y)
543+
margin < 0 && throw(DomainError(margin, "Margin must be non-negative"))
544+
return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
545+
end
546+
530547
```@meta
531548
DocTestFilters = nothing
532549
```

test/losses.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
1414
Flux.Losses.dice_coeff_loss,
1515
Flux.Losses.poisson_loss,
1616
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss,
17-
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss]
17+
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss]
1818

1919

2020
@testset "xlogx & xlogy" begin
@@ -210,3 +210,41 @@ end
210210
@test Flux.focal_loss(ŷ1, y1) 0.45990566879720157
211211
@test Flux.focal_loss(ŷ, y; γ=0.0) Flux.crossentropy(ŷ, y)
212212
end
213+
214+
@testset "siamese_contrastive_loss" begin
215+
y = [1 0
216+
0 0
217+
0 1]
218+
ŷ = [0.4 0.2
219+
0.5 0.5
220+
0.1 0.3]
221+
y1 = [1 0 0 0 1
222+
0 1 0 1 0
223+
0 0 1 0 0]
224+
ŷ1 = softmax(reshape(-7:7, 3, 5) .* 1.0f0)
225+
y2 = [1
226+
0
227+
0
228+
1
229+
1]
230+
ŷ2 = [0.6
231+
0.4
232+
0.1
233+
0.2
234+
0.7]
235+
@test Flux.siamese_contrastive_loss(ŷ, y) 0.2333333333333333
236+
@test Flux.siamese_contrastive_loss(ŷ, y, margin = 0.5f0) 0.10000000000000002
237+
@test Flux.siamese_contrastive_loss(ŷ, y, margin = 1.5f0) 0.5333333333333333
238+
@test Flux.siamese_contrastive_loss(ŷ1, y1) 0.32554644f0
239+
@test Flux.siamese_contrastive_loss(ŷ1, y1, margin = 0.5f0) 0.16271012f0
240+
@test Flux.siamese_contrastive_loss(ŷ1, y1, margin = 1.5f0) 0.6532292f0
241+
@test Flux.siamese_contrastive_loss(ŷ, y, margin = 1) Flux.siamese_contrastive_loss(ŷ, y)
242+
@test Flux.siamese_contrastive_loss(y, y) 0.0
243+
@test Flux.siamese_contrastive_loss(y1, y1) 0.0
244+
@test Flux.siamese_contrastive_loss(ŷ, y, margin = 0) 0.09166666666666667
245+
@test Flux.siamese_contrastive_loss(ŷ1, y1, margin = 0) 0.13161165f0
246+
@test Flux.siamese_contrastive_loss(ŷ2, y2) 0.21200000000000005
247+
@test Flux.siamese_contrastive_loss(ŷ2, ŷ2) 0.18800000000000003
248+
@test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5)
249+
@test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1)
250+
end

0 commit comments

Comments
 (0)