|
13 | 13 |
|
14 | 14 | loss(x, y) = Flux.mse(m(x), y) |
15 | 15 | data = [(𝐱, rand(Float32, 128, 1024, 5))] |
16 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 16 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
17 | 17 | end |
18 | 18 |
|
19 | 19 | @testset "permuted 1D OperatorConv" begin |
|
32 | 32 |
|
33 | 33 | loss(x, y) = Flux.mse(m(x), y) |
34 | 34 | data = [(𝐱, rand(Float32, 1024, 128, 5))] |
35 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 35 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
36 | 36 | end |
37 | 37 |
|
38 | 38 | @testset "1D OperatorKernel" begin |
|
49 | 49 |
|
50 | 50 | loss(x, y) = Flux.mse(m(x), y) |
51 | 51 | data = [(𝐱, rand(Float32, 128, 1024, 5))] |
52 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 52 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
53 | 53 | end |
54 | 54 |
|
55 | 55 | @testset "permuted 1D OperatorKernel" begin |
|
67 | 67 |
|
68 | 68 | loss(x, y) = Flux.mse(m(x), y) |
69 | 69 | data = [(𝐱, rand(Float32, 1024, 128, 5))] |
70 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 70 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
71 | 71 | end |
72 | 72 |
|
73 | 73 | @testset "2D OperatorConv" begin |
|
83 | 83 |
|
84 | 84 | loss(x, y) = Flux.mse(m(x), y) |
85 | 85 | data = [(𝐱, rand(Float32, 64, 22, 22, 5))] |
86 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 86 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
87 | 87 | end |
88 | 88 |
|
89 | 89 | @testset "permuted 2D OperatorConv" begin |
|
100 | 100 |
|
101 | 101 | loss(x, y) = Flux.mse(m(x), y) |
102 | 102 | data = [(𝐱, rand(Float32, 22, 22, 64, 5))] |
103 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 103 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
104 | 104 | end |
105 | 105 |
|
106 | 106 | @testset "2D OperatorKernel" begin |
|
115 | 115 |
|
116 | 116 | loss(x, y) = Flux.mse(m(x), y) |
117 | 117 | data = [(𝐱, rand(Float32, 64, 22, 22, 5))] |
118 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 118 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
119 | 119 | end |
120 | 120 |
|
121 | 121 | @testset "permuted 2D OperatorKernel" begin |
|
131 | 131 |
|
132 | 132 | loss(x, y) = Flux.mse(m(x), y) |
133 | 133 | data = [(𝐱, rand(Float32, 22, 22, 64, 5))] |
134 | | - Flux.train!(loss, Flux.params(m), data, Flux.ADAM()) |
| 134 | + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) |
135 | 135 | end |
136 | 136 |
|
137 | 137 | @testset "SpectralConv" begin |
|
0 commit comments