Skip to content

Commit df981f9

Browse files
authored
Merge pull request #263 from chinmaychandak/master
Tests for rolling aggregations in cudf.
2 parents d909429 + f97057e commit df981f9

File tree

4 files changed

+65
-37
lines changed

4 files changed

+65
-37
lines changed

streamz/dataframe/tests/test_cudf.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
"""
2-
Tests for cudf DataFrame
3-
All these tests are taken from test_dataframes module in the same folder.
4-
Some of these tests pass with cudf as they are, and others are marked xfail
5-
where a pandas like method is not implemented yet in cudf.
6-
But these tests should pass as cudf implement more pandas like methods.
2+
Tests for cudf DataFrames:
3+
All tests have been cloned from the test_dataframes module in the same folder.
4+
Some of these tests pass with cudf, and others are marked with xfail
5+
where a pandas-like method is not yet implemented in cudf.
6+
But these tests should pass as and when cudf rolls out more pandas-like methods.
77
"""
88
from __future__ import division, print_function
99

10-
import numpy as np
11-
import pytest
1210
import operator
11+
12+
import pytest
1313
from dask.dataframe.utils import assert_eq
14-
from distributed import Client
14+
import numpy as np
15+
import pandas as pd
1516

1617
from streamz import Stream
17-
from streamz.dask import DaskStream
1818
from streamz.dataframe import DataFrame, Series, DataFrames, Aggregation
19+
from streamz.dask import DaskStream
20+
21+
from distributed import Client
22+
1923

2024
cudf = pytest.importorskip("cudf")
2125

@@ -70,8 +74,8 @@ def test_attributes():
7074
df = cudf.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
7175
sdf = DataFrame(example=df)
7276

73-
assert 'x' in dir(sdf)
74-
assert 'z' not in dir(sdf)
77+
assert getattr(sdf,'x',-1) != -1
78+
assert getattr(sdf,'z',-1) == -1
7579

7680
sdf.x
7781
with pytest.raises(AttributeError):
@@ -359,6 +363,53 @@ def test_setitem_overwrites(stream):
359363
assert_eq(L[-1], df.iloc[7:] * 2)
360364

361365

366+
@pytest.mark.parametrize('kwargs,op', [
367+
({}, 'sum'),
368+
({}, 'mean'),
369+
pytest.param({}, 'min'),
370+
pytest.param({}, 'median', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
371+
pytest.param({}, 'max'),
372+
pytest.param({}, 'var', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
373+
pytest.param({}, 'count'),
374+
pytest.param({'ddof': 0}, 'std', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
375+
pytest.param({'quantile': 0.5}, 'quantile', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")),
376+
pytest.param({'arg': {'A': 'sum', 'B': 'min'}}, 'aggregate', marks=pytest.mark.xfail(reason="Not implemented"))
377+
])
378+
@pytest.mark.parametrize('window', [
379+
pytest.param(2),
380+
7,
381+
pytest.param('3h'),
382+
pd.Timedelta('200 minutes')
383+
])
384+
@pytest.mark.parametrize('m', [
385+
2,
386+
pytest.param(5)
387+
])
388+
@pytest.mark.parametrize('pre_get,post_get', [
389+
(lambda df: df, lambda df: df),
390+
(lambda df: df.x, lambda x: x),
391+
(lambda df: df, lambda df: df.x)
392+
])
393+
def test_rolling_count_aggregations(op, window, m, pre_get, post_get, kwargs,
394+
stream):
395+
index = pd.DatetimeIndex(start='2000-01-01', end='2000-01-03', freq='1h')
396+
df = cudf.DataFrame({'x': np.arange(len(index))}, index=index)
397+
398+
expected = getattr(post_get(pre_get(df).rolling(window)), op)(**kwargs)
399+
400+
sdf = DataFrame(example=df, stream=stream)
401+
roll = getattr(post_get(pre_get(sdf).rolling(window)), op)(**kwargs)
402+
L = roll.stream.gather().sink_to_list()
403+
assert len(L) == 0
404+
405+
for i in range(0, len(df), m):
406+
sdf.emit(df.iloc[i: i + m])
407+
408+
assert len(L) > 1
409+
410+
assert_eq(cudf.concat(L), expected)
411+
412+
362413
def test_stream_to_dataframe(stream):
363414
df = cudf.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
364415
source = stream
@@ -382,24 +433,6 @@ def test_to_frame(stream):
382433
assert list(a.columns) == ['x']
383434

384435

385-
def test_instantiate_with_dict(stream):
386-
df = cudf.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
387-
sdf = DataFrame(example=df, stream=stream)
388-
389-
sdf2 = DataFrame({'a': sdf.x, 'b': sdf.x * 2,
390-
'c': sdf.y % 2})
391-
L = sdf2.stream.gather().sink_to_list()
392-
assert len(sdf2.columns) == 3
393-
394-
sdf.emit(df)
395-
sdf.emit(df)
396-
397-
assert len(L) == 2
398-
for x in L:
399-
assert_eq(x[['a', 'b', 'c']],
400-
cudf.DataFrame({'a': df.x, 'b': df.x * 2, 'c': df.y % 2}))
401-
402-
403436
@pytest.mark.parametrize('op', ['cumsum', 'cummax', 'cumprod', 'cummin'])
404437
@pytest.mark.parametrize('getter', [lambda df: df, lambda df: df.x])
405438
def test_cumulative_aggregations(op, getter, stream):

streamz/dataframe/tests/test_dataframe_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
import numpy as np
77

88

9-
def test_utils_is_dataframe_like():
10-
test_utils_dataframe = pytest.importorskip('dask.dataframe.tests.test_utils_dataframe')
11-
test_utils_dataframe.test_is_dataframe_like()
12-
13-
149
def test_utils_get_base_frame_type_pandas():
1510
with pytest.raises(TypeError):
1611
get_base_frame_type("DataFrame", is_dataframe_like, None)

streamz/dataframe/tests/test_dataframes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def test_attributes():
6969
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
7070
sdf = DataFrame(example=df)
7171

72-
assert 'x' in dir(sdf)
73-
assert 'z' not in dir(sdf)
72+
assert getattr(sdf,'x',-1) != -1
73+
assert getattr(sdf,'z',-1) == -1
7474

7575
sdf.x
7676
with pytest.raises(AttributeError):

streamz/tests/test_dask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from distributed.utils_test import gen_cluster, inc, cluster, loop, slowinc # noqa: F401
1515

1616

17-
@gen_cluster(client=True, check_new_threads=False)
17+
@gen_cluster(client=True)
1818
def test_map(c, s, a, b):
1919
source = Stream(asynchronous=True)
2020
futures = scatter(source).map(inc)

0 commit comments

Comments
 (0)