@@ -3,7 +3,7 @@ module FlowOverCircle
33using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
44using NeuralOperators, Flux, GeometricFlux, Graphs
55using CUDA, FluxTraining, BSON
6- using GeometricFlux. GraphSignals: generate_coordinates
6+ using GeometricFlux. GraphSignals: generate_grid
77
88function circle (n, m; Re = 250 ) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
99 # Set physical parameters
@@ -32,29 +32,21 @@ function gen_data(ts::AbstractRange)
3232 return 𝐩s
3333end
3434
35- function get_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
36- ratio:: Float64 = 0.95 , batchsize = 100 , flatten = false )
35+ function get_mno_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
36+ ratio:: Float64 = 0.95 , batchsize = 100 )
3737 data = gen_data (ts)
3838 𝐱, 𝐲 = data[:, :, :, 1 : (end - 1 )], data[:, :, :, 2 : end ]
3939 n = length (ts) - 1
40- grid = generate_coordinates (𝐱[1 , :, :, 1 ])
41- grid = repeat (grid, outer = (1 , 1 , 1 , n))
42- x_with_grid = vcat (𝐱, grid)
43-
44- if flatten
45- x_with_grid = reshape (x_with_grid, size (x_with_grid, 1 ), :, n)
46- 𝐲 = reshape (𝐲, 1 , :, n)
47- end
4840
49- data_train, data_test = splitobs (shuffleobs ((x_with_grid , 𝐲)), at = ratio)
41+ data_train, data_test = splitobs (shuffleobs ((𝐱 , 𝐲)), at = ratio)
5042
5143 loader_train = DataLoader (data_train, batchsize = batchsize, shuffle = true )
5244 loader_test = DataLoader (data_test, batchsize = batchsize, shuffle = false )
5345
5446 return loader_train, loader_test
5547end
5648
57- function train (; cuda = true , η₀ = 1.0f-3 , λ = 1.0f-4 , epochs = 50 )
49+ function train_mno (; cuda = true , η₀ = 1.0f-3 , λ = 1.0f-4 , epochs = 50 )
5850 if cuda && CUDA. has_cuda ()
5951 device = gpu
6052 CUDA. allowscalar (false )
@@ -66,7 +58,7 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
6658
6759 model = MarkovNeuralOperator (ch = (1 , 64 , 64 , 64 , 64 , 64 , 1 ), modes = (24 , 24 ),
6860 σ = gelu)
69- data = get_dataloader ()
61+ data = get_mno_dataloader ()
7062 optimiser = Flux. Optimiser (WeightDecay (λ), Flux. Adam (η₀))
7163 loss_func = l₂loss
7264
@@ -79,6 +71,61 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
7971 return learner
8072end
8173
74+ function batch_featured_graph (data, graph, batchsize)
75+ tot_len = size (data)[end ]
76+ bch_data = FeaturedGraph[]
77+ for i in 1 : batchsize: tot_len
78+ bch_rng = (i + batchsize >= tot_len) ? (i: tot_len) : (i: (i + batchsize - 1 ))
79+ fg = FeaturedGraph (graph, nf = data[:, :, bch_rng], pf = data[:, :, bch_rng])
80+ push! (bch_data, fg)
81+ end
82+
83+ return bch_data
84+ end
85+
86+ function batch_data (data, batchsize)
87+ tot_len = size (data)[end ]
88+ bch_data = Array{Float32, 3 }[]
89+ for i in 1 : batchsize: tot_len
90+ bch_rng = (i + batchsize >= tot_len) ? (i: tot_len) : (i: (i + batchsize - 1 ))
91+ push! (bch_data, data[:, :, bch_rng])
92+ end
93+
94+ return bch_data
95+ end
96+
97+ function get_gno_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
98+ ratio:: Float64 = 0.95 , batchsize = 8 )
99+ data = gen_data (ts)
100+ 𝐱, 𝐲 = data[:, :, :, 1 : (end - 1 )], data[:, :, :, 2 : end ]
101+ n = length (ts) - 1
102+
103+ # generate graph
104+ graph = Graphs. grid (size (data)[2 : 3 ])
105+
106+ # add grid coordinates
107+ grid = generate_coordinates (𝐱[1 , :, :, 1 ])
108+ grid = repeat (grid, outer = (1 , 1 , 1 , n))
109+ 𝐱 = vcat (𝐱, grid)
110+
111+ # flatten
112+ 𝐱, 𝐲 = reshape (𝐱, size (𝐱, 1 ), :, n), reshape (𝐲, 1 , :, n)
113+
114+ data_train, data_test = splitobs (shuffleobs ((𝐱, 𝐲)), at = ratio)
115+
116+ batched_train_X = batch_featured_graph (data_train[1 ], graph, batchsize)
117+ batched_test_X = batch_featured_graph (data_test[1 ], graph, batchsize)
118+ batched_train_y = batch_data (data_train[2 ], batchsize)
119+ batched_test_y = batch_data (data_test[2 ], batchsize)
120+
121+ loader_train = DataLoader ((batched_train_X, batched_train_y), batchsize = - 1 ,
122+ shuffle = true )
123+ loader_test = DataLoader ((batched_test_X, batched_test_y), batchsize = - 1 ,
124+ shuffle = false )
125+
126+ return loader_train, loader_test
127+ end
128+
82129function train_gno (; cuda = true , η₀ = 1.0f-3 , λ = 1.0f-4 , epochs = 50 )
83130 if cuda && CUDA. has_cuda ()
84131 device = gpu
@@ -91,23 +138,17 @@ function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
91138
92139 grid_dim = 2
93140 edge_dim = 2 (grid_dim + 1 )
94- featured_graph = FeaturedGraph (grid ([96 , 64 ]))
95- model = Chain (Flux. SkipConnection (Dense (grid_dim + 1 , 16 ), vcat),
96- # size(x) = (19, 6144, 8)
97- WithGraph (featured_graph,
98- GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 )),
99- WithGraph (featured_graph,
100- GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 )),
101- WithGraph (featured_graph,
102- GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 )),
103- WithGraph (featured_graph,
104- GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 )),
105- x -> x[1 : end - 3 , :, :],
141+ model = Chain (GraphParallel (node_layer = Dense (grid_dim + 1 , 16 )),
142+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
143+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
144+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
145+ GraphKernel (Dense (edge_dim, abs2 (16 ), gelu), 16 ),
146+ node_feature,
106147 Dense (16 , 1 ))
107148
108149 optimiser = Flux. Optimiser (WeightDecay (λ), Flux. Adam (η₀))
109150 loss_func = l₂loss
110- data = get_dataloader (batchsize = 8 , flatten = true )
151+ data = get_gno_dataloader ( )
111152 learner = Learner (model, data, optimiser, loss_func,
112153 ToDevice (device, device),
113154 Checkpointer (joinpath (@__DIR__ , " ../model/" )))
0 commit comments