1+ """
2+ Flux.Losses
3+
4+ This sub-module contains many loss functions, all of which accept two arguments,
5+ with the model output as the fist argument: `loss(model(x), y)`.
6+ It also contains a few related utilities, such as `label_smoothing`.
7+ The complete list of exports is:
8+
9+ label_smoothing,
10+ mse, mae, msle,
11+ crossentropy,
12+ logitcrossentropy,
13+ binarycrossentropy,
14+ logitbinarycrossentropy,
15+ kldivergence,
16+ huber_loss,
17+ tversky_loss,
18+ dice_coeff_loss,
19+ poisson_loss,
20+ hinge_loss,
21+ squared_hinge_loss,
22+ binary_focal_loss,
23+ focal_loss,
24+ siamese_contrastive_loss
25+ """
126module Losses
227
328using Statistics
@@ -9,8 +34,8 @@ using CUDA
934using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
1035import Base. Broadcast: broadcasted
1136
12- export mse, mae, msle ,
13- label_smoothing ,
37+ export label_smoothing ,
38+ mse, mae, msle ,
1439 crossentropy, logitcrossentropy,
1540 binarycrossentropy, logitbinarycrossentropy,
1641 kldivergence,
@@ -19,9 +44,33 @@ export mse, mae, msle,
1944 dice_coeff_loss,
2045 poisson_loss,
2146 hinge_loss, squared_hinge_loss,
22- binary_focal_loss, focal_loss, siamese_contrastive_loss
47+ binary_focal_loss, focal_loss,
48+ siamese_contrastive_loss
2349
2450include (" utils.jl" )
2551include (" functions.jl" )
2652
53+ for loss in Symbol .([
54+ mse, mae, msle,
55+ crossentropy, logitcrossentropy,
56+ binarycrossentropy, logitbinarycrossentropy,
57+ kldivergence,
58+ huber_loss,
59+ tversky_loss,
60+ dice_coeff_loss,
61+ poisson_loss,
62+ hinge_loss, squared_hinge_loss,
63+ binary_focal_loss, focal_loss,
64+ siamese_contrastive_loss,
65+ ])
66+ @eval begin
67+ """
68+ $($ loss) (model, x, y)
69+
70+ This method calculates `ŷ = model(x)`. Accepts the same keyword arguments.
71+ """
72+ $ loss (f, x:: AbstractArray , y:: AbstractArray ; kw... ) = $ loss (f (x), y; kw... )
73+ end
74+ end
75+
2776end # module
0 commit comments