Skip to content

Commit 3cd3498

Browse files
author
chinmaychandak
committed
cudf tests for rolling aggregations
1 parent d909429 commit 3cd3498

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

streamz/dataframe/tests/dask-worker-space/global.lock

Whitespace-only changes.

streamz/dataframe/tests/dask-worker-space/purge.lock

Whitespace-only changes.

streamz/dataframe/tests/test_cudf.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,24 @@
77
"""
88
from __future__ import division, print_function
99

10-
import numpy as np
11-
import pytest
10+
import json
1211
import operator
12+
from time import sleep
13+
14+
import pytest
1315
from dask.dataframe.utils import assert_eq
14-
from distributed import Client
16+
import numpy as np
17+
import pandas as pd
18+
from tornado import gen
1519

1620
from streamz import Stream
17-
from streamz.dask import DaskStream
21+
from streamz.utils_test import gen_test
1822
from streamz.dataframe import DataFrame, Series, DataFrames, Aggregation
23+
import streamz.dataframe as sd
24+
from streamz.dask import DaskStream
25+
26+
from distributed import Client
27+
1928

2029
cudf = pytest.importorskip("cudf")
2130

@@ -358,7 +367,54 @@ def test_setitem_overwrites(stream):
358367

359368
assert_eq(L[-1], df.iloc[7:] * 2)
360369

370+
371+
@pytest.mark.parametrize('kwargs,op', [
372+
({}, 'sum'),
373+
({}, 'mean'),
374+
pytest.param({}, 'min'),
375+
pytest.param({}, 'median', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
376+
pytest.param({}, 'max'),
377+
pytest.param({}, 'var', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
378+
pytest.param({}, 'count'),
379+
pytest.param({'ddof': 0}, 'std', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
380+
pytest.param({'quantile': 0.5}, 'quantile', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
381+
pytest.param({'arg': {'A': 'sum', 'B': 'min'}}, 'aggregate', marks=pytest.mark.xfail(reason="Not implemented for rolling objects"))
382+
])
383+
@pytest.mark.parametrize('window', [
384+
pytest.param(2),
385+
7,
386+
pytest.param('3h'),
387+
pd.Timedelta('200 minutes')
388+
])
389+
@pytest.mark.parametrize('m', [
390+
2,
391+
pytest.param(5)
392+
])
393+
@pytest.mark.parametrize('pre_get,post_get', [
394+
(lambda df: df, lambda df: df),
395+
(lambda df: df.x, lambda x: x),
396+
pytest.param(lambda df: df, lambda df: df.x, marks=pytest.mark.xfail(reason="Cannot select columns for rolling over time objects"))
397+
])
398+
def test_rolling_count_aggregations(op, window, m, pre_get, post_get, kwargs,
399+
stream):
400+
index = pd.DatetimeIndex(start='2000-01-01', end='2000-01-03', freq='1h')
401+
df = cudf.DataFrame({'x': np.arange(len(index))}, index=index)
402+
403+
expected = getattr(post_get(pre_get(df).rolling(window)), op)(**kwargs)
361404

405+
sdf = DataFrame(example=df, stream=stream)
406+
roll = getattr(post_get(pre_get(sdf).rolling(window)), op)(**kwargs)
407+
L = roll.stream.gather().sink_to_list()
408+
assert len(L) == 0
409+
410+
for i in range(0, len(df), m):
411+
sdf.emit(df.iloc[i: i + m])
412+
413+
assert len(L) > 1
414+
415+
assert_eq(cudf.concat(L), expected)
416+
417+
362418
def test_stream_to_dataframe(stream):
363419
df = cudf.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
364420
source = stream

0 commit comments

Comments
 (0)