1- using Test, Random, Flux
1+ using Test, Random, Flux, MAT
22
33@testset " DeepONet" begin
44 @testset " dimensions" begin
@@ -14,4 +14,50 @@ using Test, Random, Flux
1414 # Accept only Int as architecture parameters
1515 @test_throws MethodError DeepONet ((32.5 ,64 ,72 ), (24 ,48 ,72 ), σ, tanh)
1616 @test_throws MethodError DeepONet ((32 ,64 ,72 ), (24.1 ,48 ,72 ))
17- end
17+ end
18+
19+ # Just the first 16 datapoints from the Burgers' equation dataset
20+ a = [0.83541104 , 0.83479851 , 0.83404712 , 0.83315711 , 0.83212979 , 0.83096755 , 0.82967374 , 0.82825263 , 0.82670928 , 0.82504949 , 0.82327962 , 0.82140651 , 0.81943734 , 0.81737952 , 0.8152405 , 0.81302771 ]
21+ sensors = collect (range (0 , 1 , length= 16 ))'
22+
23+ model = DeepONet ((16 , 22 , 30 ), (1 , 16 , 24 , 30 ), σ, tanh; init_branch= Flux. glorot_normal, bias_trunk= false )
24+
25+ model (a,sensors)
26+
27+ # forward pass
28+ @test size (model (a, sensors)) == (1 , 16 )
29+
30+ mgrad = Flux. Zygote. gradient ((x,p)-> sum (model (x,p)),a,sensors)
31+
32+ # gradients
33+ @test ! iszero (Flux. Zygote. gradient ((x,p)-> sum (model (x,p)),a,sensors)[1 ])
34+ @test ! iszero (Flux. Zygote. gradient ((x,p)-> sum (model (x,p)),a,sensors)[2 ])
35+
36+ # training
37+ vars = matread (" burgerset.mat" )
38+
39+ xtrain = vars[" a" ][1 : 280 , :]'
40+ xval = vars[" a" ][end - 19 : end , :]'
41+
42+ ytrain = vars[" u" ][1 : 280 , :]
43+ yval = vars[" u" ][end - 19 : end , :]
44+
45+ grid = collect (range (0 , 1 , length= 1024 ))'
46+ model = DeepONet ((1024 ,1024 ,1024 ),(1 ,1024 ,1024 ),gelu,gelu)
47+
48+ learning_rate = 0.001
49+ opt = ADAM (learning_rate)
50+
51+ parameters = params (model)
52+
53+ loss (xtrain,ytrain,sensor) = Flux. Losses. mse (model (xtrain,sensor),ytrain)
54+
55+ evalcb () = @show (loss (xval,yval,grid))
56+
57+ Flux. @epochs 400 Flux. train! (loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)
58+
59+ ỹ = model (xval, grid)
60+
61+ diffvec = vec (abs .((yval .- ỹ)))
62+ mean_diff = sum (diffvec)/ length (diffvec)
63+ @test mean_diff < 0.4
0 commit comments