|
7 | 7 | """ |
8 | 8 | from __future__ import division, print_function |
9 | 9 |
|
10 | | -import numpy as np |
11 | | -import pytest |
| 10 | +import json |
12 | 11 | import operator |
| 12 | +from time import sleep |
| 13 | + |
| 14 | +import pytest |
13 | 15 | from dask.dataframe.utils import assert_eq |
14 | | -from distributed import Client |
| 16 | +import numpy as np |
| 17 | +import pandas as pd |
| 18 | +from tornado import gen |
15 | 19 |
|
16 | 20 | from streamz import Stream |
17 | | -from streamz.dask import DaskStream |
| 21 | +from streamz.utils_test import gen_test |
18 | 22 | from streamz.dataframe import DataFrame, Series, DataFrames, Aggregation |
| 23 | +import streamz.dataframe as sd |
| 24 | +from streamz.dask import DaskStream |
| 25 | + |
| 26 | +from distributed import Client |
| 27 | + |
19 | 28 |
|
20 | 29 | cudf = pytest.importorskip("cudf") |
21 | 30 |
|
@@ -358,7 +367,54 @@ def test_setitem_overwrites(stream): |
358 | 367 |
|
359 | 368 | assert_eq(L[-1], df.iloc[7:] * 2) |
360 | 369 |
|
| 370 | + |
| 371 | +@pytest.mark.parametrize('kwargs,op', [ |
| 372 | + ({}, 'sum'), |
| 373 | + ({}, 'mean'), |
| 374 | + pytest.param({}, 'min'), |
| 375 | + pytest.param({}, 'median', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")), |
| 376 | + pytest.param({}, 'max'), |
| 377 | + pytest.param({}, 'var', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")), |
| 378 | + pytest.param({}, 'count'), |
| 379 | + pytest.param({'ddof': 0}, 'std', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")), |
| 380 | + pytest.param({'quantile': 0.5}, 'quantile', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")), |
| 381 | + pytest.param({'arg': {'A': 'sum', 'B': 'min'}}, 'aggregate', marks=pytest.mark.xfail(reason="Not implemented for rolling objects")) |
| 382 | +]) |
| 383 | +@pytest.mark.parametrize('window', [ |
| 384 | + pytest.param(2), |
| 385 | + 7, |
| 386 | + pytest.param('3h'), |
| 387 | + pd.Timedelta('200 minutes') |
| 388 | +]) |
| 389 | +@pytest.mark.parametrize('m', [ |
| 390 | + 2, |
| 391 | + pytest.param(5) |
| 392 | +]) |
| 393 | +@pytest.mark.parametrize('pre_get,post_get', [ |
| 394 | + (lambda df: df, lambda df: df), |
| 395 | + (lambda df: df.x, lambda x: x), |
| 396 | + pytest.param(lambda df: df, lambda df: df.x, marks=pytest.mark.xfail(reason="Cannot select columns for rolling over time objects")) |
| 397 | +]) |
| 398 | +def test_rolling_count_aggregations(op, window, m, pre_get, post_get, kwargs, |
| 399 | + stream): |
| 400 | + index = pd.DatetimeIndex(start='2000-01-01', end='2000-01-03', freq='1h') |
| 401 | + df = cudf.DataFrame({'x': np.arange(len(index))}, index=index) |
| 402 | + |
| 403 | + expected = getattr(post_get(pre_get(df).rolling(window)), op)(**kwargs) |
361 | 404 |
|
| 405 | + sdf = DataFrame(example=df, stream=stream) |
| 406 | + roll = getattr(post_get(pre_get(sdf).rolling(window)), op)(**kwargs) |
| 407 | + L = roll.stream.gather().sink_to_list() |
| 408 | + assert len(L) == 0 |
| 409 | + |
| 410 | + for i in range(0, len(df), m): |
| 411 | + sdf.emit(df.iloc[i: i + m]) |
| 412 | + |
| 413 | + assert len(L) > 1 |
| 414 | + |
| 415 | + assert_eq(cudf.concat(L), expected) |
| 416 | + |
| 417 | + |
362 | 418 | def test_stream_to_dataframe(stream): |
363 | 419 | df = cudf.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) |
364 | 420 | source = stream |
|
0 commit comments