@@ -3,6 +3,7 @@ module FlowOverCircle
33using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
44using NeuralOperators, Flux, GeometricFlux, Graphs
55using CUDA, FluxTraining, BSON
6+ using GeometricFlux. GraphSignals: generate_grid
67
78function circle (n, m; Re = 250 ) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
89 # Set physical parameters
@@ -31,16 +32,12 @@ function gen_data(ts::AbstractRange)
3132 return 𝐩s
3233end
3334
34- function get_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
35- ratio:: Float64 = 0.95 , batchsize = 100 , flatten = false )
35+ function get_mno_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
36+ ratio:: Float64 = 0.95 , batchsize = 100 )
3637 data = gen_data (ts)
3738 𝐱, 𝐲 = data[:, :, :, 1 : (end - 1 )], data[:, :, :, 2 : end ]
3839 n = length (ts) - 1
3940
40- if flatten
41- 𝐱, 𝐲 = reshape (𝐱, 1 , :, n), reshape (𝐲, 1 , :, n)
42- end
43-
4441 data_train, data_test = splitobs (shuffleobs ((𝐱, 𝐲)), at = ratio)
4542
4643 loader_train = DataLoader (data_train, batchsize = batchsize, shuffle = true )
@@ -49,7 +46,7 @@ function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
4946 return loader_train, loader_test
5047end
5148
52- function train (; cuda = true , η₀ = 1.0f-3 , λ = 1.0f-4 , epochs = 50 )
49+ function train_mno (; cuda = true , η₀ = 1.0f-3 , λ = 1.0f-4 , epochs = 50 )
5350 if cuda && CUDA. has_cuda ()
5451 device = gpu
5552 CUDA. allowscalar (false )
@@ -61,7 +58,7 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
6158
6259 model = MarkovNeuralOperator (ch = (1 , 64 , 64 , 64 , 64 , 64 , 1 ), modes = (24 , 24 ),
6360 σ = gelu)
64- data = get_dataloader ()
61+ data = get_mno_dataloader ()
6562 optimiser = Flux. Optimiser (WeightDecay (λ), Flux. Adam (η₀))
6663 loss_func = l₂loss
6764
@@ -74,6 +71,32 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
7471 return learner
7572end
7673
74+ function get_gno_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
75+ ratio:: Float64 = 0.95 , batchsize = 8 )
76+ data = gen_data (ts)
77+ 𝐱, 𝐲 = data[:, :, :, 1 : (end - 1 )], data[:, :, :, 2 : end ]
78+ n = length (ts) - 1
79+
80+ # generate graph
81+ graph = Graphs. grid (size (data)[2 : 3 ])
82+
83+ # add grid coordinates
84+ grid = generate_coordinates (𝐱[1 , :, :, 1 ])
85+ grid = repeat (grid, outer = (1 , 1 , 1 , n))
86+ 𝐱 = vcat (𝐱, grid)
87+
88+ # flatten
89+ 𝐱, 𝐲 = reshape (𝐱, size (𝐱, 1 ), :, n), reshape (𝐲, 1 , :, n)
90+
91+ fg = FeaturedGraph (graph, nf = 𝐱, pf = 𝐱)
92+ data_train, data_test = splitobs (shuffleobs ((fg, 𝐲)), at = ratio)
93+
94+ loader_train = DataLoader (data_train, batchsize = batchsize, shuffle = true )
95+ loader_test = DataLoader (data_test, batchsize = batchsize, shuffle = false )
96+
97+ return loader_train, loader_test
98+ end
99+
77100function train_gno (; cuda = true , η₀ = 1.0f-3 , λ = 1.0f-4 , epochs = 50 )
78101 if cuda && CUDA. has_cuda ()
79102 device = gpu
@@ -84,17 +107,19 @@ function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
84107 @info " Training on CPU"
85108 end
86109
87- featured_graph = FeaturedGraph (grid ([96 , 64 ]))
88- model = Chain (Dense (1 , 16 ),
89- WithGraph (featured_graph, GraphKernel (Dense (2 * 16 , 16 , gelu), 16 )),
90- WithGraph (featured_graph, GraphKernel (Dense (2 * 16 , 16 , gelu), 16 )),
91- WithGraph (featured_graph, GraphKernel (Dense (2 * 16 , 16 , gelu), 16 )),
92- WithGraph (featured_graph, GraphKernel (Dense (2 * 16 , 16 , gelu), 16 )),
110+ grid_dim = 2
111+ edge_dim = 2 (grid_dim + 1 )
112+ model = Chain (GraphParallel (node_layer = Dense (grid_dim + 1 , 16 )),
113+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
114+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
115+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
116+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
117+ node_feature,
93118 Dense (16 , 1 ))
94- data = get_dataloader (batchsize = 16 , flatten = true )
119+
95120 optimiser = Flux. Optimiser (WeightDecay (λ), Flux. Adam (η₀))
96121 loss_func = l₂loss
97-
122+ data = get_gno_dataloader ()
98123 learner = Learner (model, data, optimiser, loss_func,
99124 ToDevice (device, device),
100125 Checkpointer (joinpath (@__DIR__ , " ../model/" )))
0 commit comments