Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pytensor/scan/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,46 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
mode=mode,
name=name,
)


def filter(
fn,
sequences,
non_sequences=None,
go_backwards=False,
mode=None,
name=None,
):
"""Construct a `Scan` `Op` that functions like `filter`.

Parameters
----------
fn : callable
Predicate function returning a boolean tensor.
sequences : list
Sequences to filter.
non_sequences : list
Non-iterated arguments passed to `fn`.
go_backwards : bool
Whether to iterate in reverse.
mode : str or None
See ``scan``.
name : str or None
See ``scan``.
"""
mask, _ = scan(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may make sense to allow returning one a mask per sequence? If there's only one we use the same for every sequence like now

fn=fn,
sequences=sequences,
outputs_info=None,
non_sequences=non_sequences,
go_backwards=go_backwards,
mode=mode,
name=f"{name or ''}_mask",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use name directly

)

if isinstance(sequences, (list, tuple)):
filtered_sequences = [seq[mask] for seq in sequences]
else:
filtered_sequences = sequences[mask]

return filtered_sequences
20 changes: 20 additions & 0 deletions tests/scan/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,23 @@ def test_foldr_memory_consumption():
gx = grad(o, x)
f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,)))


def test_filter():
import pytensor.tensor as pt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do imports at tho global lever not inside the tests


v = pt.vector("v")

def fn(x):
return pt.eq(x % 2, 0)

from pytensor.scan.views import filter as pt_filter

filtered = pt_filter(fn, v)
f = function([v], filtered, allow_input_downcast=True)

rng = np.random.default_rng(utt.fetch_seed())
vals = rng.integers(0, 10, size=(10,))
expected = vals[vals % 2 == 0]
result = f(vals)
utt.assert_allclose(expected, result)