@@ -13,59 +13,66 @@ struct DataLoader{D,R<:AbstractRNG}
1313end
1414
1515"""
16- DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
16+ Flux. DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
1717
18- An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
18+ An object that iterates over mini-batches of `data`,
19+ each mini-batch containing `batchsize` observations
1920(except possibly the last one).
2021
2122Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
22- The last dimension in each tensor is considered to be the observation dimension.
23+ The last dimension in each tensor is the observation dimension, i.e. the one
24+ divided into mini-batches.
2325
24- If `shuffle=true`, shuffles the observations each time iterations are re-started.
25- If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
26+ If `shuffle=true`, it shuffles the observations each time iterations are re-started.
27+ If `partial=false` and the number of observations is not divisible by the batchsize,
28+ then the last mini-batch is dropped.
2629
2730The original data is preserved in the `data` field of the DataLoader.
2831
29- Usage example:
32+ # Examples
33+ ```jldoctest
34+ julia> Xtrain = rand(10, 100);
3035
31- Xtrain = rand(10, 100)
32- train_loader = DataLoader(Xtrain, batchsize=2)
33- # iterate over 50 mini-batches of size 2
34- for x in train_loader
35- @assert size(x) == (10, 2)
36- ...
37- end
36+ julia> array_loader = Flux.DataLoader(Xtrain, batchsize=2);
3837
39- train_loader.data # original dataset
38+ julia> for x in array_loader
39+ @assert size(x) == (10, 2)
40+ # do something with x, 50 times
41+ end
4042
41- # similar, but yielding tuples
42- train_loader = DataLoader((Xtrain,), batchsize=2)
43- for (x,) in train_loader
44- @assert size(x) == (10, 2)
45- ...
46- end
43+ julia> array_loader.data === Xtrain
44+ true
4745
48- Xtrain = rand(10, 100)
49- Ytrain = rand(100)
50- train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
51- for epoch in 1:100
52- for (x, y) in train_loader
53- @assert size(x) == (10, 2)
54- @assert size(y) == (2,)
55- ...
56- end
57- end
46+ julia> tuple_loader = Flux.DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples
5847
59- # train for 10 epochs
60- using IterTools: ncycle
61- Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
48+ julia> for x in tuple_loader
49+ @assert x isa Tuple{Matrix}
50+ @assert size(x[1]) == (10, 2)
51+ end
6252
63- # can use NamedTuple to name tensors
64- train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
65- for datum in train_loader
66- @assert size(datum.images) == (10, 2)
67- @assert size(datum.labels) == (2,)
68- end
53+ julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples
54+
55+ julia> train_loader = Flux.DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);
56+
57+ julia> for epoch in 1:100
58+ for (x, y) in train_loader # access via tuple destructuring
59+ @assert size(x) == (10, 5)
60+ @assert size(y) == (5,)
61+ # loss += f(x, y) # etc, runs 100 * 20 times
62+ end
63+ end
64+
65+ julia> first(train_loader).label isa Vector{Char} # access via property name
66+ true
67+
68+ julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
69+ false
70+
71+ julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
72+ 10×30 Matrix{Int8}
73+ 10×30 Matrix{Int8}
74+ 10×4 Matrix{Int8}
75+ ```
6976"""
7077function DataLoader (data; batchsize= 1 , shuffle= false , partial= true , rng= GLOBAL_RNG)
7178 batchsize > 0 || throw (ArgumentError (" Need positive batchsize" ))
@@ -100,8 +107,10 @@ _nobs(data::AbstractArray) = size(data)[end]
100107function _nobs (data:: Union{Tuple, NamedTuple} )
101108 length (data) > 0 || throw (ArgumentError (" Need at least one data input" ))
102109 n = _nobs (data[1 ])
103- if ! all (x -> _nobs (x) == n, Base. tail (data))
104- throw (DimensionMismatch (" All data should contain same number of observations" ))
110+ for i in keys (data)
111+ ni = _nobs (data[i])
112+ n == ni || throw (DimensionMismatch (" All data inputs should have the same number of observations, i.e. size in the last dimension. " *
113+ " But data[$(repr (first (keys (data)))) ] ($(summary (data[1 ])) ) has $n , while data[$(repr (i)) ] ($(summary (data[i])) ) has $ni ." ))
105114 end
106115 return n
107116end
0 commit comments