diff --git a/pytensor/scan/views.py b/pytensor/scan/views.py index b86476b330..e3e656f7cc 100644 --- a/pytensor/scan/views.py +++ b/pytensor/scan/views.py @@ -170,3 +170,46 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) mode=mode, name=name, ) + + +def filter( + fn, + sequences, + non_sequences=None, + go_backwards=False, + mode=None, + name=None, +): + """Construct a `Scan` `Op` that functions like `filter`. + + Parameters + ---------- + fn : callable + Predicate function returning a boolean tensor. + sequences : list + Sequences to filter. + non_sequences : list + Non-iterated arguments passed to `fn`. + go_backwards : bool + Whether to iterate in reverse. + mode : str or None + See ``scan``. + name : str or None + See ``scan``. + """ + mask, _ = scan( + fn=fn, + sequences=sequences, + outputs_info=None, + non_sequences=non_sequences, + go_backwards=go_backwards, + mode=mode, + name=f"{name or ''}_mask", + ) + + if isinstance(sequences, (list, tuple)): + filtered_sequences = [seq[mask] for seq in sequences] + else: + filtered_sequences = sequences[mask] + + return filtered_sequences diff --git a/tests/scan/test_views.py b/tests/scan/test_views.py index 38c9b9cfcd..6717016704 100644 --- a/tests/scan/test_views.py +++ b/tests/scan/test_views.py @@ -133,3 +133,23 @@ def test_foldr_memory_consumption(): gx = grad(o, x) f2 = function([], gx) utt.assert_allclose(f2(), np.ones((10,))) + + +def test_filter(): + import pytensor.tensor as pt + + v = pt.vector("v") + + def fn(x): + return pt.eq(x % 2, 0) + + from pytensor.scan.views import filter as pt_filter + + filtered = pt_filter(fn, v) + f = function([v], filtered, allow_input_downcast=True) + + rng = np.random.default_rng(utt.fetch_seed()) + vals = rng.integers(0, 10, size=(10,)) + expected = vals[vals % 2 == 0] + result = f(vals) + utt.assert_allclose(expected, result)