11@testsetup module SharedTestSetup
22import Reexport: @reexport
33
4- @reexport using Lux, Zygote, Optimisers, Random, StableRNGs
5- using LuxTestUtils : @jet , @test_gradients
4+ @reexport using Lux, Zygote, Optimisers, Random, StableRNGs, LuxTestUtils
5+ using MLDataDevices
66
77const BACKEND_GROUP = lowercase (get (ENV , " BACKEND_GROUP" , " All" ))
88
1717cpu_testing () = BACKEND_GROUP == " all" || BACKEND_GROUP == " cpu"
1818function cuda_testing ()
1919 return (BACKEND_GROUP == " all" || BACKEND_GROUP == " cuda" ) &&
20- LuxDeviceUtils . functional (LuxCUDADevice )
20+ MLDataDevices . functional (CUDADevice )
2121end
2222function amdgpu_testing ()
2323 return (BACKEND_GROUP == " all" || BACKEND_GROUP == " amdgpu" ) &&
24- LuxDeviceUtils . functional (LuxAMDGPUDevice )
24+ MLDataDevices . functional (AMDGPUDevice )
2525end
2626
2727const MODES = begin
2828 modes = []
29- cpu_testing () && push! (modes, (" CPU" , Array, LuxCPUDevice (), false ))
30- cuda_testing () && push! (modes, (" CUDA" , CuArray, LuxCUDADevice (), true ))
31- amdgpu_testing () && push! (modes, (" AMDGPU" , ROCArray, LuxAMDGPUDevice (), true ))
29+ cpu_testing () && push! (modes, (" CPU" , Array, CPUDevice (), false ))
30+ cuda_testing () && push! (modes, (" CUDA" , CuArray, CUDADevice (), true ))
31+ amdgpu_testing () && push! (modes, (" AMDGPU" , ROCArray, AMDGPUDevice (), true ))
3232 modes
3333end
3434
@@ -47,7 +47,7 @@ function train!(loss, backend, model, ps, st, data; epochs=10)
4747 return l2, l1
4848end
4949
50- export @jet , @test_gradients , check_approx
50+ export check_approx
5151export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, train!
5252
5353end
0 commit comments