|
1 | 1 | import Adapt: adapt, adapt_storage |
2 | 2 | using LinearAlgebra: Cholesky |
3 | 3 | using Zygote: IdSet |
4 | | -import Functors: Functors, @functor, functor, fmap, isleaf |
5 | 4 | using SparseArrays: AbstractSparseArray |
6 | 5 |
|
7 | 6 | trainable(m) = functor(m)[1] |
@@ -38,23 +37,31 @@ Possible values include: |
38 | 37 | """ |
39 | 38 | trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) |
40 | 39 |
|
41 | | -# push!(::Params, x) automatically discards already seen arrays |
42 | | -params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) |
| 40 | +# # push!(::Params, x) automatically discards already seen arrays |
| 41 | +# params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) |
43 | 42 |
|
44 | | -function params!(p::Params, x, seen = IdSet()) |
45 | | - x in seen && return |
46 | | - push!(seen, x) |
47 | | - for child in trainable(x) |
48 | | - params!(p, child, seen) |
49 | | - end |
50 | | -end |
| 43 | +# function params!(p::Params, x, seen = IdSet()) |
| 44 | +# x in seen && return |
| 45 | +# push!(seen, x) |
| 46 | +# for child in trainable(x) |
| 47 | +# params!(p, child, seen) |
| 48 | +# end |
| 49 | +# end |
| 50 | + |
| 51 | +# function params(m...) |
| 52 | +# ps = Params() |
| 53 | +# params!(ps, m) |
| 54 | +# return ps |
| 55 | +# end |
51 | 56 |
|
52 | 57 | function params(m...) |
53 | 58 | ps = Params() |
54 | | - params!(ps, m) |
| 59 | + collect_params!(ps, m) |
55 | 60 | return ps |
56 | 61 | end |
57 | 62 |
|
| 63 | + |
| 64 | + |
58 | 65 | function loadparams!(m, xs) |
59 | 66 | for (p, x) in zip(params(m), xs) |
60 | 67 | size(p) == size(x) || |
|
0 commit comments