This repository was archived by the owner on Sep 28, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 5 files changed +27
-7
lines changed
example/FlowOverCircle/src Expand file tree Collapse file tree 5 files changed +27
-7
lines changed Original file line number Diff line number Diff line change @@ -31,7 +31,7 @@ function train()
3131 Dense (64 , 1 ),
3232 ) |> device
3333
34- loss (𝐱, 𝐲) = sum (abs2, 𝐲 .- m (𝐱)) / size (𝐱)[ end ]
34+ loss (𝐱, 𝐲) = l₂loss ( m (𝐱), 𝐲)
3535
3636 opt = Flux. Optimiser (WeightDecay (1f-4 ), Flux. ADAM (1f-3 ))
3737
Original file line number Diff line number Diff line change @@ -14,6 +14,7 @@ module NeuralOperators
1414
1515 include (" Transform/Transform.jl" )
1616 include (" operator_kernel.jl" )
17+ include (" loss.jl" )
1718 include (" model.jl" )
1819 include (" DeepONet.jl" )
1920 include (" subnets.jl" )
Original file line number Diff line number Diff line change 1+ export l₂loss
2+
3+ function l₂loss (𝐲̂, 𝐲; agg= mean, grid_normalize= true )
4+ feature_dims = 2 : (ndims (𝐲)- 1 )
5+ loss = agg (.√ (sum (abs2, 𝐲̂- 𝐲, dims= feature_dims)))
6+
7+ return grid_normalize ? loss/ prod (feature_dims) : loss
8+ end
Original file line number Diff line number Diff line change 1+ @testset " loss" begin
2+ 𝐲 = rand (1 , 3 , 3 , 5 )
3+ 𝐲̂ = rand (1 , 3 , 3 , 5 )
4+
5+ feature_dims = 2 : 3
6+ loss = mean (.√ (sum (abs2, 𝐲̂- 𝐲, dims= feature_dims)))
7+
8+ @test l₂loss (𝐲̂, 𝐲) ≈ loss/ prod (feature_dims)
9+ end
Original file line number Diff line number Diff line change @@ -4,19 +4,21 @@ using Flux
44using GeometricFlux
55using Graphs
66using Zygote
7+ using Statistics
78using Test
89
910CUDA. allowscalar (false )
1011
1112cuda_tests = [
12- " cuda" ,
13+ " cuda.jl " ,
1314]
1415
1516tests = [
16- " Transform/Transform" ,
17- " operator_kernel" ,
18- " model" ,
19- " deeponet" ,
17+ " Transform/Transform.jl" ,
18+ " operator_kernel.jl" ,
19+ " loss.jl" ,
20+ " model.jl" ,
21+ " deeponet.jl" ,
2022]
2123
2224if CUDA. functional ()
2729
2830@testset " NeuralOperators.jl" begin
2931 for t in tests
30- include (" $(t) .jl " )
32+ include (t )
3133 end
3234end
3335
You can’t perform that action at this time.
0 commit comments