Skip to content

Commit 14acd8a

Browse files
author
Chinmay Chandak
committed
Add groupby aggregate tests, along with a few additons for cudf integration for SDFs
1 parent f97057e commit 14acd8a

File tree

2 files changed

+344
-174
lines changed

2 files changed

+344
-174
lines changed

streamz/dataframe/aggregations.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numbers import Number
55

66
import numpy as np
7+
import pandas as pd
78
from .utils import is_series_like, is_index_like, get_dataframe_package
89

910

@@ -202,10 +203,10 @@ def diff_loc(dfs, new, window=None):
202203
"""
203204
dfs = deque(dfs)
204205
dfs.append(new)
205-
mx = max(df.index.max() for df in dfs)
206-
mn = mx - window
206+
mx = pd.Timestamp(max(df.index.max() for df in dfs))
207+
mn = pd.Timestamp(mx) - window
207208
old = []
208-
while dfs[0].index.min() < mn:
209+
while pd.Timestamp(dfs[0].index.min()) < mn:
209210
o = dfs[0].loc[:mn]
210211
old.append(o) # TODO: avoid copy if fully lost
211212
dfs[0] = dfs[0].iloc[len(o):]
@@ -347,8 +348,8 @@ def windowed_groupby_accumulator(acc, new, diff=None, window=None, agg=None, gro
347348
for o, og in zip(old, old_groupers):
348349
if 'groupers' in acc:
349350
assert len(o) == len(og)
350-
if hasattr(og, 'index'):
351-
assert (o.index == og.index).all()
351+
# if hasattr(og, 'index'):
352+
# assert (o.index == og.index).all()
352353
if len(o):
353354
state, result = agg.on_old(state, o, grouper=og)
354355
size_state, _ = size.on_old(size_state, o, grouper=og)
@@ -407,11 +408,13 @@ class GroupbySum(GroupbyAggregation):
407408
def on_new(self, acc, new, grouper=None):
408409
g = self.grouped(new, grouper=grouper)
409410
result = acc.add(g.sum(), fill_value=0)
411+
result.index.name = acc.index.name
410412
return result, result
411413

412414
def on_old(self, acc, old, grouper=None):
413415
g = self.grouped(old, grouper=grouper)
414416
result = acc.sub(g.sum(), fill_value=0)
417+
result.index.name = acc.index.name
415418
return result, result
416419

417420
def initial(self, new, grouper=None):
@@ -427,12 +430,14 @@ def on_new(self, acc, new, grouper=None):
427430
g = self.grouped(new, grouper=grouper)
428431
result = acc.add(g.count(), fill_value=0)
429432
result = result.astype(int)
433+
result.index.name = acc.index.name
430434
return result, result
431435

432436
def on_old(self, acc, old, grouper=None):
433437
g = self.grouped(old, grouper=grouper)
434438
result = acc.sub(g.count(), fill_value=0)
435439
result = result.astype(int)
440+
result.index.name = acc.index.name
436441
return result, result
437442

438443
def initial(self, new, grouper=None):
@@ -448,12 +453,14 @@ def on_new(self, acc, new, grouper=None):
448453
g = self.grouped(new, grouper=grouper)
449454
result = acc.add(g.size(), fill_value=0)
450455
result = result.astype(int)
456+
result.index.name = acc.index.name
451457
return result, result
452458

453459
def on_old(self, acc, old, grouper=None):
454460
g = self.grouped(old, grouper=grouper)
455461
result = acc.sub(g.size(), fill_value=0)
456462
result = result.astype(int)
463+
result.index.name = acc.index.name
457464
return result, result
458465

459466
def initial(self, new, grouper=None):
@@ -467,10 +474,12 @@ def initial(self, new, grouper=None):
467474
class ValueCounts(Aggregation):
468475
def on_new(self, acc, new, grouper=None):
469476
result = acc.add(new.value_counts(), fill_value=0).astype(int)
477+
result.index.name = acc.index.name
470478
return result, result
471479

472480
def on_old(self, acc, new, grouper=None):
473481
result = acc.sub(new.value_counts(), fill_value=0).astype(int)
482+
result.index.name = acc.index.name
474483
return result, result
475484

476485
def initial(self, new, grouper=None):
@@ -483,15 +492,17 @@ def on_new(self, acc, new, grouper=None):
483492
g = self.grouped(new, grouper=grouper)
484493
totals = totals.add(g.sum(), fill_value=0)
485494
counts = counts.add(g.count(), fill_value=0)
486-
495+
totals.index.name = acc[0].index.name
496+
counts.index.name = acc[1].index.name
487497
return (totals, counts), totals / counts
488498

489499
def on_old(self, acc, old, grouper=None):
490500
totals, counts = acc
491501
g = self.grouped(old, grouper=grouper)
492502
totals = totals.sub(g.sum(), fill_value=0)
493503
counts = counts.sub(g.count(), fill_value=0)
494-
504+
totals.index.name = acc[0].index.name
505+
counts.index.name = acc[1].index.name
495506
return (totals, counts), totals / counts
496507

497508
def initial(self, new, grouper=None):

0 commit comments

Comments
 (0)