Skip to content

Commit faf61ca

Browse files
authored
Merge pull request #243 from CJ-Wright/filter_args_kwargs
allow filter to take args and kwargs
2 parents 6445fa9 + 1d78f9a commit faf61ca

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

streamz/core.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,10 @@ class filter(Stream):
622622
predicate : function
623623
The predicate. Should return True or False, where
624624
True means that the predicate is satisfied.
625+
*args :
626+
The arguments to pass to the predicate.
627+
**kwargs:
628+
Keyword arguments to pass to predicate
625629
626630
Examples
627631
--------
@@ -633,15 +637,19 @@ class filter(Stream):
633637
2
634638
4
635639
"""
636-
def __init__(self, upstream, predicate, **kwargs):
640+
641+
def __init__(self, upstream, predicate, *args, **kwargs):
637642
if predicate is None:
638643
predicate = _truthy
639644
self.predicate = predicate
645+
stream_name = kwargs.pop("stream_name", None)
646+
self.kwargs = kwargs
647+
self.args = args
640648

641-
Stream.__init__(self, upstream, **kwargs)
649+
Stream.__init__(self, upstream, stream_name=stream_name)
642650

643651
def update(self, x, who=None):
644-
if self.predicate(x):
652+
if self.predicate(x, *self.args, **self.kwargs):
645653
return self._emit(x)
646654

647655

streamz/tests/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ def test_filter():
8383
assert L == [0, 2, 4, 6, 8]
8484

8585

86+
def test_filter_args():
87+
source = Stream()
88+
L = source.filter(lambda x, n: x % n == 0, 2).sink_to_list()
89+
90+
for i in range(10):
91+
source.emit(i)
92+
93+
assert L == [0, 2, 4, 6, 8]
94+
95+
96+
def test_filter_kwargs():
97+
source = Stream()
98+
L = source.filter(lambda x, n=1: x % n == 0, n=2).sink_to_list()
99+
100+
for i in range(10):
101+
source.emit(i)
102+
103+
assert L == [0, 2, 4, 6, 8]
104+
105+
86106
def test_filter_none():
87107
source = Stream()
88108
L = source.filter(None).sink_to_list()

0 commit comments

Comments
 (0)