Skip to content

Commit 57503bc

Browse files
committed
improve repeat/! performance and allow passing view
1 parent 156a943 commit 57503bc

File tree

1 file changed

+168
-94
lines changed

1 file changed

+168
-94
lines changed

src/dataset/other.jl

Lines changed: 168 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -104,69 +104,79 @@ disallowmissing!(ds::Dataset, cols::Colon=:; error::Bool=false) =
104104
disallowmissing!(ds, axes(ds, 2), error=error)
105105

106106
"""
107-
repeat!(ds::Dataset; inner::Integer = 1, outer::Integer = 1)
107+
repeat!(ds::Dataset; inner::Integer = 1, outer::Integer = 1, freq = nothing)
108+
repeat!(ds::Dataset, count) = repeat!(ds, outer = count)
109+
110+
Update a data set `ds` in-place by repeating its rows.
111+
112+
* `inner` specifies how many times each row is repeated,
113+
* `outer` specifies how many times the full set of rows is repeated, and
114+
* `freq` allow user to pass a vector of integers or a column name or index which indicate how many times the corresponding row should be repeated.
115+
When `freq` is passed, the values must be positive integer values or zero (zero means the corresponding row should be dropped).
108116
109-
Update a data set `ds` in-place by repeating its rows. `inner` specifies how many
110-
times each row is repeated, and `outer` specifies how many times the full set
111-
of rows is repeated. Columns of `ds` are freshly allocated.
112117
113118
# Example
114119
```jldoctest
115120
julia> ds = Dataset(a = 1:2, b = 3:4)
116121
2×2 Dataset
117-
Row │ a b
118-
│ Int64 Int64
119-
─────┼──────────────
120-
1 │ 1 3
121-
2 │ 2 4
122+
Row │ a b
123+
│ identity identity
124+
│ Int64? Int64?
125+
─────┼────────────────────
126+
1 │ 1 3
127+
2 │ 2 4
122128
123129
julia> repeat!(ds, inner = 2, outer = 3);
124130
125131
julia> ds
126132
12×2 Dataset
127-
Row │ a b
128-
│ Int64 Int64
129-
─────┼──────────────
130-
1 │ 1 3
131-
2 │ 1 3
132-
3 │ 2 4
133-
4 │ 2 4
134-
5 │ 1 3
135-
6 │ 1 3
136-
7 │ 2 4
137-
8 │ 2 4
138-
9 │ 1 3
139-
10 │ 1 3
140-
11 │ 2 4
141-
12 │ 2 4
133+
Row │ a b
134+
│ identity identity
135+
│ Int64? Int64?
136+
─────┼────────────────────
137+
1 │ 1 3
138+
2 │ 1 3
139+
3 │ 2 4
140+
4 │ 2 4
141+
5 │ 1 3
142+
6 │ 1 3
143+
7 │ 2 4
144+
8 │ 2 4
145+
9 │ 1 3
146+
10 │ 1 3
147+
11 │ 2 4
148+
12 │ 2 4
149+
150+
julia> ds = Dataset(a = 1:2, b = 3:4)
151+
2×2 Dataset
152+
Row │ a b
153+
│ identity identity
154+
│ Int64? Int64?
155+
─────┼────────────────────
156+
1 │ 1 3
157+
2 │ 2 4
158+
159+
julia> repeat!(ds, freq = :a)
160+
3×2 Dataset
161+
Row │ a b
162+
│ identity identity
163+
│ Int64? Int64?
164+
─────┼────────────────────
165+
1 │ 1 3
166+
2 │ 2 4
167+
3 │ 2 4
142168
```
143169
"""
144170
function repeat!(ds::Dataset; inner::Integer = 1, outer::Integer = 1, freq::Union{AbstractVector, DatasetColumn, SubDatasetColumn, ColumnIndex, Nothing} = nothing)
145-
171+
T = nrow(ds) < typemax(Int8) ? Int8 : nrow(ds) < typemax(Int32) ? Int32 : Int64
146172
# Modify Dataset
147173
if freq === nothing
148174
inner <= 0 && throw(ArgumentError("inner keyword argument must be greater than zero"))
149175
outer <= 0 && throw(ArgumentError("outer keyword argument must be greater than zero"))
150-
if outer == 1
151-
for j in 1:ncol(ds)
152-
_columns(ds)[j] = repeat(_columns(ds)[j], inner = Int(inner), outer = 1)
153-
end
154-
_reset_grouping_info!(ds)
155-
# ngroups = index(ds).ngroups[]
156-
# diffs = diff(index(ds).starts[1:ngroups]) .* inner
157-
# @show diffs
158-
# cumsum!(diffs, diffs)
159-
# @show diffs
160-
# for j in 2:ngroups
161-
# index(ds).starts[j] = diffs[j-1]
162-
# end
163-
# @show index(ds).starts
164-
elseif outer > 1
165-
for j in 1:ncol(ds)
166-
_columns(ds)[j] = repeat(_columns(ds)[j], inner = Int(inner), outer = Int(outer))
167-
end
168-
_reset_grouping_info!(ds)
169-
end
176+
r_indx = repeat(T(1):T(nrow(ds)), inner = inner, outer = outer)
177+
_permute_ds_after_sort!(ds, r_indx, check = false)
178+
_reset_grouping_info!(ds)
179+
170180
_modified(_attributes(ds))
171181
ds
172182
else
@@ -177,24 +187,14 @@ function repeat!(ds::Dataset; inner::Integer = 1, outer::Integer = 1, freq::Unio
177187
elseif freq isa AbstractVector
178188
lengths = freq
179189
end
180-
if !(eltype(lengths) <: Union{Missing, Integer}) || any(ismissing, lengths) || any(x->isless(x, 1), lengths)
181-
throw(ArgumentError("The column selected for repeating must be an Intger column with all values greater than zero and no missing value"))
190+
if !(eltype(lengths) <: Union{Missing, Integer}) || any(ismissing, lengths) || any(x->isless(x, 0), lengths)
191+
throw(ArgumentError("The column selected for repeating must be an integer column with all values greater than or equal to zero and with no missing values"))
182192
end
183193
if length(lengths) != nrow(ds)
184-
throw(ArgumentError("The length of repeating weights must be the same as the number of row of the passed data set"))
185-
end
186-
lengths = copy(lengths)
187-
total_new = sum(lengths)
188-
for j in 1:ncol(ds)
189-
if DataAPI.refpool(_columns(ds)[j]) !== nothing
190-
_res = allocatecol(_columns(ds)[j], total_new, addmissing = false)
191-
_columns(ds)[j].refs = repeat_lengths_v2!(_res.refs, DataAPI.refarray(_columns(ds)[j]), lengths)
192-
else
193-
_res = allocatecol(_columns(ds)[j], total_new)
194-
_columns(ds)[j] = repeat_lengths_v2!(_res, _columns(ds)[j], lengths)
195-
end
196-
194+
throw(ArgumentError("The length of frequencies must match the number of rows in passed data set"))
197195
end
196+
r_index = _create_index_for_repeat(lengths, Val(T))
197+
_permute_ds_after_sort!(ds, r_index, check = false)
198198
_reset_grouping_info!(ds)
199199
_modified(_attributes(ds))
200200
ds
@@ -203,62 +203,109 @@ end
203203
function _fill_index_for_repeat!(res, w)
204204
counter = 1
205205
for i in 1:length(w)
206-
207-
l = w[i]
208-
fill!(view(res, counter:(counter + l - 1)), i)
209-
counter += l
206+
for j in 1:w[i]
207+
res[counter] = i
208+
counter += 1
209+
end
210210
end
211211
end
212212
# use this to create index for new data set
213213
# and then use getindex with the result of this function for repeating
214214
# This should be better for repeating large data set/ since getindex is threaded
215-
function _create_index_for_repeat(w)
216-
res = Vector{Int}(undef, sum(w))
215+
function _create_index_for_repeat(w, ::Val{T}) where T
216+
res = Vector{T}(undef, sum(w))
217217
_fill_index_for_repeat!(res, w)
218218
res
219219
end
220220

221+
function repeat!(ds::Dataset, count::Integer)
222+
223+
# Modify Dataset
224+
count <= 0 && throw(ArgumentError("count must be greater than zero"))
225+
repeat!(ds, inner = 1, outer = count)
226+
ds
227+
end
228+
221229
"""
222-
repeat!(ds::Dataset, count::Integer)
230+
repeat(ds::AbstractDataset; inner::Integer = 1, outer::Integer = 1, freq = nothing, view = false)
231+
repeat(ds::AbstractDataset, count) = repeat!(ds, outer = count, view = false)
223232
224-
Update a data set `ds` in-place by repeating its rows the number of times
225-
specified by `count`. Columns of `ds` are freshly allocated.
233+
Variant of `repeat!` which returns a fresh copy of passed data set. If `view = true` a view of the result will be returned.
226234
227235
# Example
228236
```jldoctest
229237
julia> ds = Dataset(a = 1:2, b = 3:4)
230238
2×2 Dataset
231-
Row │ a b
232-
│ Int64 Int64
233-
─────┼──────────────
234-
1 │ 1 3
235-
2 │ 2 4
239+
Row │ a b
240+
│ identity identity
241+
│ Int64? Int64?
242+
─────┼────────────────────
243+
1 │ 1 3
244+
2 │ 2 4
236245
237-
julia> repeat(ds, 2)
238-
4×2 Dataset
239-
Row │ a b
240-
│ Int64 Int64
241-
─────┼──────────────
242-
1 │ 1 3
243-
2 │ 2 4
244-
3 │ 1 3
245-
4 │ 2 4
246+
julia> repeat(ds, inner = 2, outer = 3)
247+
12×2 Dataset
248+
Row │ a b
249+
│ identity identity
250+
│ Int64? Int64?
251+
─────┼────────────────────
252+
1 │ 1 3
253+
2 │ 1 3
254+
3 │ 2 4
255+
4 │ 2 4
256+
5 │ 1 3
257+
6 │ 1 3
258+
7 │ 2 4
259+
8 │ 2 4
260+
9 │ 1 3
261+
10 │ 1 3
262+
11 │ 2 4
263+
12 │ 2 4
264+
265+
julia> repeat(ds, freq = :a)
266+
3×2 Dataset
267+
Row │ a b
268+
│ identity identity
269+
│ Int64? Int64?
270+
─────┼────────────────────
271+
1 │ 1 3
272+
2 │ 2 4
273+
3 │ 2 4
246274
```
247275
"""
248-
function repeat!(ds::Dataset, count::Integer)
249-
250-
# Modify Dataset
251-
count <= 0 && throw(ArgumentError("count must be greater than zero"))
252-
repeat!(ds, inner = 1, outer = count)
253-
ds
254-
end
276+
Base.repeat(ds::AbstractDataset, count::Integer; view = false) = repeat(ds, outer = count, view = view)
277+
function Base.repeat(ds::AbstractDataset; inner::Integer = 1, outer::Integer = 1, freq = nothing, view = false)
278+
T = nrow(ds) < typemax(Int8) ? Int8 : nrow(ds) < typemax(Int32) ? Int32 : Int64
255279

256-
Base.repeat(ds::AbstractDataset, count::Integer) = repeat!(copy(ds), count)
257-
function Base.repeat(ds::AbstractDataset; inner::Integer = 1, outer::Integer = 1, freq = nothing)
258280
if freq === nothing
259-
repeat!(copy(ds), inner = inner, outer = outer)
281+
if view
282+
inner <= 0 && throw(ArgumentError("inner keyword argument must be greater than zero"))
283+
outer <= 0 && throw(ArgumentError("outer keyword argument must be greater than zero"))
284+
r_indx = repeat(T(1):T(nrow(ds)), inner = inner, outer = outer)
285+
Base.view(ds, r_indx, :)
286+
else
287+
repeat!(copy(ds), inner = inner, outer = outer)
288+
end
260289
else
261-
repeat!(copy(ds), freq = freq)
290+
if view
291+
if freq isa SubDatasetColumn || freq isa DatasetColumn
292+
lengths = __!(freq)
293+
elseif freq isa ColumnIndex
294+
lengths = _columns(ds)[index(ds)[freq]]
295+
elseif freq isa AbstractVector
296+
lengths = freq
297+
end
298+
if !(eltype(lengths) <: Union{Missing, Integer}) || any(ismissing, lengths) || any(x->isless(x, 0), lengths)
299+
throw(ArgumentError("The column selected for repeating must be an integer column with all values greater than or equal to zero and with no missing values"))
300+
end
301+
if length(lengths) != nrow(ds)
302+
throw(ArgumentError("The length of frequencies must match the number of rows in passed data set"))
303+
end
304+
r_index = _create_index_for_repeat(lengths, Val(T))
305+
Base.view(ds, r_index, :)
306+
else
307+
repeat!(copy(ds), freq = freq)
308+
end
262309
end
263310
end
264311

@@ -1169,3 +1216,30 @@ function mapcols(ds::AbstractDataset, f::Vector{T}, cols = :) where T <: Union{F
11691216
end
11701217
return Dataset(vs, names(ds, colsidx), copycols=false)
11711218
end
1219+
1220+
function _permute_ds_after_sort!(ds, perm; check = true, cols = :)
1221+
if check
1222+
@assert nrow(ds) == length(perm) "the length of perm and the nrow of the data set must match"
1223+
1224+
if issorted(perm)
1225+
return ds
1226+
end
1227+
end
1228+
colsidx = index(ds)[cols]
1229+
for j in 1:length(colsidx)
1230+
if DataAPI.refpool(_columns(ds)[colsidx[j]]) !== nothing
1231+
# if _columns(ds)[colsidx[j] isa PooledArray
1232+
# pa = _columns(ds)[colsidx[j]
1233+
# _columns(ds)[colsidx[j] = PooledArray(PooledArrays.RefArray(_threaded_permute(pa.refs, perm)), DataAPI.invrefpool(pa), DataAPI.refpool(pa), PooledArrays.refcount(pa))
1234+
# else
1235+
# # TODO must be optimised
1236+
# _columns(ds)[colsidx[j] = _columns(ds)[colsidx[j][perm]
1237+
# end
1238+
# since we don't support copycols for external usage it is safe to only permute refs
1239+
_columns(ds)[colsidx[j]].refs = _threaded_permute(_columns(ds)[colsidx[j]].refs, perm)
1240+
else
1241+
_columns(ds)[colsidx[j]] = _threaded_permute(_columns(ds)[colsidx[j]], perm)
1242+
end
1243+
end
1244+
_modified(_attributes(ds))
1245+
end

0 commit comments

Comments
 (0)