Skip to content

Commit 4d25c23

Browse files
committed
some updates for eachgroup
1 parent 85632b2 commit 4d25c23

File tree

3 files changed

+91
-7
lines changed

3 files changed

+91
-7
lines changed

docs/src/man/grouping.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,57 @@ Similar to `groupby!/groupby` functions, `gatherby` can be passed to functions w
336336
As mentioned before, the result of `gatherby` is stable, i.e. the observations order within each group will be the order of their appearance in the original data set. However, when this stability is not needed and there are many groups in the data set, passing `stable = false` improves the performance by sacrificing the stability.
337337

338338
The `gatherby` function has two extra keyword arguments, `isgathered` and `eachrow`, which by default are set to `false`. When the `isgathered` argument is set to `true`, InMemoryDatasets assumes that the observations are currently gathered by some rules and it only finds the starts and ends of each group and marks the data set as gathered. So users can manually group observations by setting this keyword argument. When the `eachrow` argument is set to `true`, InMemoryDatasets does the gathering and then mark each row of the input data set as an individual group. This option is handy for transposing data sets.
339+
340+
## Iterate `eachgroup`
341+
342+
User can use `eachgroup` to iterate each group of a grouped data set. Each element of `eachgroup` is a `SubDataset`.
343+
344+
345+
### Examples
346+
347+
```jldoctest
348+
julia> ds = Dataset(rand(1:10, 10, 3), :auto)
349+
10×3 Dataset
350+
Row │ x1 x2 x3
351+
│ identity identity identity
352+
│ Int64? Int64? Int64?
353+
─────┼──────────────────────────────
354+
1 │ 7 8 10
355+
2 │ 4 1 5
356+
3 │ 7 2 5
357+
4 │ 4 7 4
358+
5 │ 5 9 6
359+
6 │ 9 5 3
360+
7 │ 9 8 2
361+
8 │ 7 9 6
362+
9 │ 2 3 8
363+
10 │ 1 6 2
364+
365+
julia> i_gds = eachgroup(groupby(ds, 1));
366+
367+
julia> map(nrow, i_gds)
368+
6-element Vector{Int64}:
369+
1
370+
1
371+
2
372+
1
373+
3
374+
2
375+
376+
julia> i_gds[1]
377+
1×3 SubDataset
378+
Row │ x1 x2 x3
379+
│ identity identity identity
380+
│ Int64? Int64? Int64?
381+
─────┼──────────────────────────────
382+
1 │ 1 6 2
383+
384+
julia> i_gds[end]
385+
2×3 SubDataset
386+
Row │ x1 x2 x3
387+
│ identity identity identity
388+
│ Int64? Int64? Int64?
389+
─────┼──────────────────────────────
390+
1 │ 9 5 3
391+
2 │ 9 8 2
392+
```

src/abstractdataset/iteration.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,28 +152,34 @@ end
152152

153153

154154
Base.IndexStyle(::Type{<:GroupedDataset}) = Base.IndexLinear()
155-
Base.size(itr::GroupedDataset{Dataset}) = (index(itr.ds).ngroups[], )
156-
Base.size(itr::GroupedDataset{<:Union{GroupBy, GatherBy}}) = (itr.ds.lastvalid, )
157-
Base.length(itr::GroupedDataset{Dataset}) = index(itr.ds).ngroups[]
158-
Base.length(itr::GroupedDataset{<:Union{GroupBy, GatherBy}}) = itr.ds.lastvalid
155+
Base.size(itr::GroupedDataset{Dataset})::Tuple{Int64} = (index(itr.ds).ngroups[], )
156+
Base.size(itr::GroupedDataset{<:Union{GroupBy, GatherBy}})::Tuple{Int64} = (itr.ds.lastvalid, )
157+
Base.length(itr::GroupedDataset{Dataset})::Int64 = index(itr.ds).ngroups[]
158+
Base.length(itr::GroupedDataset{<:Union{GroupBy, GatherBy}})::Int64 = itr.ds.lastvalid
159159
Base.iterate(itr::GroupedDataset, i::Integer=1) =
160160
i <= length(itr) ? (itr[i], i + 1) : nothing
161-
function Base.getindex(itr::GroupedDataset{Dataset}, i::Int)
161+
function Base.getindex(itr::GroupedDataset{Dataset}, i::Integer)
162162
i > size(itr)[1] && throw(BoundsError(itr, i))
163163
st = index(itr.ds).starts
164164
i == size(itr)[1] ? hi = nrow(itr.ds) : hi = st[i+1]-1
165165
lo = st[i]
166166
view(itr.ds, lo:hi, :)
167167
end
168-
function Base.getindex(itr::GroupedDataset{<:Union{GroupBy, GatherBy}}, i::Int)
168+
function Base.getindex(itr::GroupedDataset{<:Union{GroupBy, GatherBy}}, i::Integer)
169169
i > size(itr)[1] && throw(BoundsError(itr, i))
170170
st = _group_starts(itr.ds)
171171
prm = _get_perms(itr.ds)
172172
i == size(itr)[1] ? hi = nrow(parent(itr.ds)) : hi = st[i+1]-1
173173
lo = st[i]
174174
view(parent(itr.ds), view(prm, lo:hi), :)
175175
end
176-
176+
Base.firstindex(::GroupedDataset) = 1
177+
Base.lastindex(itr::GroupedDataset) = length(itr)
178+
Base.eltype(::GroupedDataset) = SubDataset
179+
Base.keys(itr::GroupedDataset) = LinearIndices(itr)
180+
Base.pairs(itr::GroupedDataset) = Base.Iterators.Pairs(itr, keys(itr))
181+
Base.axes(itr::GroupedDataset) = (Base.OneTo(length(itr)), )
182+
Base.LinearIndices(itr::GroupedDataset) = LinearIndices(axes(itr))
177183
# Iteration by columns
178184

179185
const DATASETCOLUMNS_DOCSTR = """

test/grouping.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,27 @@ end
590590
@test combine(gatherby(sds, 1), r"x"=>var) == combine(groupby(sds, 1), r"x"=>var)
591591

592592
end
593+
594+
@testset "eachgroup iterator" begin
595+
596+
ds = Dataset(x = [2,1,1,2], y = [1.0, -10.0, 2.0, 0.5])
597+
i_gds_1 = eachgroup(groupby(ds, 1))
598+
i_gds_2 = eachgroup(gatherby(ds, 1))
599+
@test length(i_gds_1) == length(i_gds_2) == 2
600+
@test i_gds_1[1] == Dataset(x = [1,1], y = [-10.0, 2.0])
601+
@test i_gds_1[1] == i_gds_2[2]
602+
@test i_gds_1[end] == Dataset(x = [2,2], y = [1.0, 0.5])
603+
@test i_gds_2[begin] == i_gds_1[end]
604+
605+
sds = view(ds, [2,3,4], [2,1])
606+
i_gds = eachgroup(groupby(sds, 2))
607+
@test length(i_gds) == 2
608+
@test i_gds[Int32(1)] == Dataset(y = [-10.0, 2.0], x = [1,1])
609+
@test i_gds[Int8(2)] == i_gds[Int16(2)] == i_gds[2] == Dataset(y = [0.5], x = [2])
610+
@test_throws BoundsError i_gds[3]
611+
@test_throws BoundsError i_gds[0]
612+
613+
end
614+
615+
616+

0 commit comments

Comments
 (0)