@@ -14,7 +14,7 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
1414 Flux. Losses. dice_coeff_loss,
1515 Flux. Losses. poisson_loss,
1616 Flux. Losses. hinge_loss, Flux. Losses. squared_hinge_loss,
17- Flux. Losses. binary_focal_loss, Flux. Losses. focal_loss]
17+ Flux. Losses. binary_focal_loss, Flux. Losses. focal_loss, Flux . Losses . siamese_contrastive_loss ]
1818
1919
2020@testset " xlogx & xlogy" begin
210210 @test Flux. focal_loss (ŷ1, y1) ≈ 0.45990566879720157
211211 @test Flux. focal_loss (ŷ, y; γ= 0.0 ) ≈ Flux. crossentropy (ŷ, y)
212212end
213+
214+ @testset " siamese_contrastive_loss" begin
215+ y = [1 0
216+ 0 0
217+ 0 1 ]
218+ ŷ = [0.4 0.2
219+ 0.5 0.5
220+ 0.1 0.3 ]
221+ y1 = [1 0 0 0 1
222+ 0 1 0 1 0
223+ 0 0 1 0 0 ]
224+ ŷ1 = softmax (reshape (- 7 : 7 , 3 , 5 ) .* 1.0f0 )
225+ y2 = [1
226+ 0
227+ 0
228+ 1
229+ 1 ]
230+ ŷ2 = [0.6
231+ 0.4
232+ 0.1
233+ 0.2
234+ 0.7 ]
235+ @test Flux. siamese_contrastive_loss (ŷ, y) ≈ 0.2333333333333333
236+ @test Flux. siamese_contrastive_loss (ŷ, y, margin = 0.5f0 ) ≈ 0.10000000000000002
237+ @test Flux. siamese_contrastive_loss (ŷ, y, margin = 1.5f0 ) ≈ 0.5333333333333333
238+ @test Flux. siamese_contrastive_loss (ŷ1, y1) ≈ 0.32554644f0
239+ @test Flux. siamese_contrastive_loss (ŷ1, y1, margin = 0.5f0 ) ≈ 0.16271012f0
240+ @test Flux. siamese_contrastive_loss (ŷ1, y1, margin = 1.5f0 ) ≈ 0.6532292f0
241+ @test Flux. siamese_contrastive_loss (ŷ, y, margin = 1 ) ≈ Flux. siamese_contrastive_loss (ŷ, y)
242+ @test Flux. siamese_contrastive_loss (y, y) ≈ 0.0
243+ @test Flux. siamese_contrastive_loss (y1, y1) ≈ 0.0
244+ @test Flux. siamese_contrastive_loss (ŷ, y, margin = 0 ) ≈ 0.09166666666666667
245+ @test Flux. siamese_contrastive_loss (ŷ1, y1, margin = 0 ) ≈ 0.13161165f0
246+ @test Flux. siamese_contrastive_loss (ŷ2, y2) ≈ 0.21200000000000005
247+ @test Flux. siamese_contrastive_loss (ŷ2, ŷ2) ≈ 0.18800000000000003
248+ @test_throws DomainError (- 0.5 , " Margin must be non-negative" ) Flux. siamese_contrastive_loss (ŷ1, y1, margin = - 0.5 )
249+ @test_throws DomainError (- 1 , " Margin must be non-negative" ) Flux. siamese_contrastive_loss (ŷ, y, margin = - 1 )
250+ end
0 commit comments