Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a971dcf

Browse files
committed
feat: update to newest versions
1 parent 2bf0090 commit a971dcf

File tree

10 files changed

+71
-56
lines changed

10 files changed

+71
-56
lines changed

Project.toml

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,22 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1212
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
13-
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
13+
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
1414
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1515
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1616
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
1818

1919
[compat]
20-
Aqua = "0.8.7"
2120
ArgCheck = "2.3.0"
2221
ChainRulesCore = "1.24.0"
2322
ConcreteStructs = "0.2.3"
24-
Documenter = "1.4.1"
25-
ExplicitImports = "1.9.0"
2623
FFTW = "1.8.0"
27-
Lux = "0.5.56"
24+
Lux = "0.5.62"
2825
LuxCore = "0.1.15"
29-
LuxDeviceUtils = "0.1.24"
30-
LuxTestUtils = "0.1.15"
26+
LuxLib = "0.3.40"
3127
NNlib = "0.9.17"
32-
Optimisers = "0.3.3"
33-
Pkg = "1.10"
3428
Random = "1.10"
35-
ReTestItems = "1.24.0"
3629
Reexport = "1.2.2"
37-
StableRNGs = "1.0.2"
38-
Test = "1.10"
39-
WeightInitializers = "0.1.7, 1"
40-
Zygote = "0.6.70"
30+
WeightInitializers = "1"
4131
julia = "1.10"
42-
43-
[extras]
44-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
45-
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
46-
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
47-
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
48-
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
49-
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
50-
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
51-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
52-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
53-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
54-
55-
[targets]
56-
test = ["Aqua", "Documenter", "ExplicitImports", "LuxTestUtils", "Optimisers", "Pkg", "ReTestItems", "StableRNGs", "Test", "Zygote"]

src/NeuralOperators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ using ConcreteStructs: @concrete
66
using FFTW: FFTW, irfft, rfft
77
using Lux
88
using LuxCore: LuxCore, AbstractExplicitLayer
9-
using LuxDeviceUtils: get_device, LuxAMDGPUDevice
10-
using NNlib: NNlib, , batched_adjoint
9+
using LuxLib: batched_matmul
10+
using NNlib: NNlib, batched_adjoint
1111
using Random: Random, AbstractRNG
1212
using Reexport: @reexport
1313

src/functional.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ end
1212
x_size = size(x_tr)
1313
x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N])
1414

15-
x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m
16-
x_weighted = permutedims(__batched_mul(weights, x_flat_t), (3, 1, 2)) # m x o x b
15+
x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m
16+
x_weighted = permutedims(batched_matmul(weights, x_flat_t), (3, 1, 2)) # m x o x b
1717

1818
return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...)
1919
end

src/utils.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
# Temporarily capture certain calls like AMDGPU for ComplexFloats
2-
@inline __batched_mul(x, y) = __batched_mul(x, y, get_device((x, y)))
3-
@inline function __batched_mul(
4-
x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}, _)
5-
return x y
6-
end
7-
@inline function __batched_mul(
8-
x::AbstractArray{<:Complex, 3}, y::AbstractArray{<:Complex, 3}, ::LuxAMDGPUDevice)
9-
# FIXME: This is not good for performance but that is okay for now
10-
return stack(*, eachslice(x; dims=3), eachslice(y; dims=3))
11-
end
12-
131
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
142
additional::Nothing) where {T1, T2}
153
# b : p x nb
@@ -25,7 +13,7 @@ end
2513
if size(b, 2) == 1 || size(t, 2) == 1
2614
return sum(b .* t; dims=1) # 1 x N x nb
2715
else
28-
return __batched_mul(batched_adjoint(b), t) # u x N x b
16+
return batched_matmul(batched_adjoint(b), t) # u x N x b
2917
end
3018
end
3119

test/Project.toml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
5+
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
6+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
8+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
9+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
10+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
11+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
12+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
16+
17+
[compat]
18+
Aqua = "0.8.7"
19+
Documenter = "1.5.0"
20+
ExplicitImports = "1.9.0"
21+
Hwloc = "3.2.0"
22+
InteractiveUtils = "<0.0.1, 1"
23+
LuxTestUtils = "1.1.2"
24+
MLDataDevices = "1.0.0"
25+
Optimisers = "0.3.3"
26+
Pkg = "1.10"
27+
Reexport = "1.2.2"
28+
ReTestItems = "1.24.0"
29+
StableRNGs = "1.0.2"
30+
Test = "1.10"
31+
Zygote = "0.6.70"

test/deeponet_tests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848

4949
pred = first(deeponet((u, y), ps, st))
5050
@test setup.out_size == size(pred)
51+
52+
__f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st)))
53+
test_gradients(__f, u, y, ps; atol=1f-3, rtol=1f-3)
5154
end
5255

5356
@testset "Embedding layer mismatch" begin
@@ -59,6 +62,9 @@
5962

6063
ps, st = Lux.setup(rng, deeponet) |> dev
6164
@test_throws ArgumentError deeponet((u, y), ps, st)
65+
66+
__f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st)))
67+
test_gradients(__f, u, y, ps; atol=1f-3, rtol=1f-3)
6268
end
6369
end
6470
end

test/fno_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
l2, l1 = train!(fno, ps, st, data; epochs=10)
2828
l2 < l1
2929
end broken=broken
30+
31+
__f = (x, ps) -> sum(abs2, first(fno(x, ps, st)))
32+
test_gradients(__f, x, ps; atol=1f-3, rtol=1f-3)
3033
end
3134
end
3235
end

test/layers_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
l2, l1 = train!(m, ps, st, data; epochs=10)
3737
l2 < l1
3838
end broken=broken
39+
40+
__f = (x, ps) -> sum(abs2, first(m(x, ps, st)))
41+
test_gradients(__f, x, ps; atol=1f-3, rtol=1f-3)
3942
end
4043
end
4144
end

test/runtests.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using ReTestItems, Pkg, ReTestItems, Test
1+
using ReTestItems, Pkg, Test
2+
using InteractiveUtils, Hwloc
3+
using NeuralOperators
24

35
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))
46

@@ -14,6 +16,13 @@ if !isempty(EXTRA_PKGS)
1416
Pkg.instantiate()
1517
end
1618

19+
const RETESTITEMS_NWORKERS = parse(
20+
Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16))))
21+
const RETESTITEMS_NWORKER_THREADS = parse(Int,
22+
get(ENV, "RETESTITEMS_NWORKER_THREADS",
23+
string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1))))
24+
1725
@testset "NeuralOperators.jl Tests" begin
18-
ReTestItems.runtests(@__DIR__)
26+
ReTestItems.runtests(NeuralOperators; nworkers=RETESTITEMS_NWORKERS,
27+
nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600)
1928
end

test/shared_testsetup.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
@testsetup module SharedTestSetup
22
import Reexport: @reexport
33

4-
@reexport using Lux, Zygote, Optimisers, Random, StableRNGs
5-
using LuxTestUtils: @jet, @test_gradients
4+
@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, LuxTestUtils
5+
using MLDataDevices
66

77
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))
88

@@ -17,18 +17,18 @@ end
1717
cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
1818
function cuda_testing()
1919
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
20-
LuxDeviceUtils.functional(LuxCUDADevice)
20+
MLDataDevices.functional(CUDADevice)
2121
end
2222
function amdgpu_testing()
2323
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") &&
24-
LuxDeviceUtils.functional(LuxAMDGPUDevice)
24+
MLDataDevices.functional(AMDGPUDevice)
2525
end
2626

2727
const MODES = begin
2828
modes = []
29-
cpu_testing() && push!(modes, ("CPU", Array, LuxCPUDevice(), false))
30-
cuda_testing() && push!(modes, ("CUDA", CuArray, LuxCUDADevice(), true))
31-
amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true))
29+
cpu_testing() && push!(modes, ("CPU", Array, CPUDevice(), false))
30+
cuda_testing() && push!(modes, ("CUDA", CuArray, CUDADevice(), true))
31+
amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, AMDGPUDevice(), true))
3232
modes
3333
end
3434

@@ -47,7 +47,7 @@ function train!(loss, backend, model, ps, st, data; epochs=10)
4747
return l2, l1
4848
end
4949

50-
export @jet, @test_gradients, check_approx
50+
export check_approx
5151
export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, train!
5252

5353
end

0 commit comments

Comments
 (0)