@@ -36,37 +36,37 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
3636 data_train, data_validate = splitobs (shuffleobs ((𝐱= data[:, :, :, 1 : end - 1 ], 𝐲= data[:, :, :, 2 : end ])), at= ratio)
3737
3838 data = gen_data (ts, resolution= 2 )
39- _, data_test = splitobs (shuffleobs ((𝐱= data[:, :, :, 1 : end - 1 ], 𝐲= data[:, :, :, 2 : end ])), at= ratio)
39+ _, data_super_res = splitobs (shuffleobs ((𝐱= data[:, :, :, 1 : end - 1 ], 𝐲= data[:, :, :, 2 : end ])), at= ratio)
4040
4141 loader_train = DataLoader (data_train, batchsize= batchsize, shuffle= true )
4242 loader_validate = DataLoader (data_validate, batchsize= batchsize, shuffle= false )
43- loader_test = DataLoader (data_test , batchsize= batchsize, shuffle= false )
43+ loader_super_res = DataLoader (data_super_res , batchsize= batchsize, shuffle= false )
4444
45- return (training= loader_train, validation= loader_validate, testing = loader_test )
45+ return (training= loader_train, validation= loader_validate, super_res = loader_super_res )
4646end
4747
48- struct TestPhase <: FluxTraining.AbstractValidationPhase end
48+ struct SuperResPhase <: FluxTraining.AbstractValidationPhase end
4949
50- FluxTraining. phasedataiter (:: TestPhase ) = :testing
50+ FluxTraining. phasedataiter (:: SuperResPhase ) = :super_res
5151
52- function FluxTraining. step! (learner, phase:: TestPhase , batch)
52+ function FluxTraining. step! (learner, phase:: SuperResPhase , batch)
5353 xs, ys = batch
5454 FluxTraining. runstep (learner, phase, (xs= xs, ys= ys)) do _, state
5555 state. ŷs = learner. model (state. xs)
5656 state. loss = learner. lossfn (state. ŷs, state. ys)
5757 end
5858end
5959
60- function fit! (learner, nepochs:: Int , (trainiter, validiter, testiter ))
60+ function fit! (learner, nepochs:: Int , (loader_train, loader_validate, loader_super_res ))
6161 for i in 1 : nepochs
62- epoch! (learner, TrainingPhase (), trainiter )
63- epoch! (learner, ValidationPhase (), validiter )
64- epoch! (learner, TestPhase (), testiter )
62+ epoch! (learner, TrainingPhase (), loader_train )
63+ epoch! (learner, ValidationPhase (), loader_validate )
64+ epoch! (learner, SuperResPhase (), loader_super_res )
6565 end
6666end
6767
6868function fit! (learner, nepochs:: Int )
69- fit! (learner, nepochs, (learner. data. training, learner. data. validation, learner. data. testing ))
69+ fit! (learner, nepochs, (learner. data. training, learner. data. validation, learner. data. super_res ))
7070end
7171
7272function train (; cuda= true , η₀= 1f-3 , λ= 1f-4 , epochs= 50 )
0 commit comments