@@ -71,29 +71,6 @@ function train_mno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
7171 return learner
7272end
7373
74- function batch_featured_graph (data, graph, batchsize)
75- tot_len = size (data)[end ]
76- bch_data = FeaturedGraph[]
77- for i in 1 : batchsize: tot_len
78- bch_rng = (i + batchsize >= tot_len) ? (i: tot_len) : (i: (i + batchsize - 1 ))
79- fg = FeaturedGraph (graph, nf = data[:, :, bch_rng], pf = data[:, :, bch_rng])
80- push! (bch_data, fg)
81- end
82-
83- return bch_data
84- end
85-
86- function batch_data (data, batchsize)
87- tot_len = size (data)[end ]
88- bch_data = Array{Float32, 3 }[]
89- for i in 1 : batchsize: tot_len
90- bch_rng = (i + batchsize >= tot_len) ? (i: tot_len) : (i: (i + batchsize - 1 ))
91- push! (bch_data, data[:, :, bch_rng])
92- end
93-
94- return bch_data
95- end
96-
9774function get_gno_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
9875 ratio:: Float64 = 0.95 , batchsize = 8 )
9976 data = gen_data (ts)
@@ -111,17 +88,11 @@ function get_gno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
11188 # flatten
11289 𝐱, 𝐲 = reshape (𝐱, size (𝐱, 1 ), :, n), reshape (𝐲, 1 , :, n)
11390
114- data_train, data_test = splitobs (shuffleobs ((𝐱, 𝐲)), at = ratio)
115-
116- batched_train_X = batch_featured_graph (data_train[1 ], graph, batchsize)
117- batched_test_X = batch_featured_graph (data_test[1 ], graph, batchsize)
118- batched_train_y = batch_data (data_train[2 ], batchsize)
119- batched_test_y = batch_data (data_test[2 ], batchsize)
91+ fg = FeaturedGraph (graph, nf = 𝐱, pf = 𝐱)
92+ data_train, data_test = splitobs (shuffleobs ((fg, 𝐲)), at = ratio)
12093
121- loader_train = DataLoader ((batched_train_X, batched_train_y), batchsize = - 1 ,
122- shuffle = true )
123- loader_test = DataLoader ((batched_test_X, batched_test_y), batchsize = - 1 ,
124- shuffle = false )
94+ loader_train = DataLoader (data_train, batchsize = batchsize, shuffle = true )
95+ loader_test = DataLoader (data_test, batchsize = batchsize, shuffle = false )
12596
12697 return loader_train, loader_test
12798end
0 commit comments