Skip to content

Conversation

@emekaokoli19
Copy link

Description

This PR adds a filter helper function to pytensor/scan/views.py, complementing the existing map, reduce, foldl, and foldr utilities.

The new filter function applies a user-defined condition (predicate) over a sequence using PyTensor’s scan operation and returns only the elements that satisfy the condition (i.e., the ones for which the predicate evaluates to True).

This addition restores functionality discussed in Theano/Theano#5365, now updated for PyTensor where boolean indexing is supported natively.

A corresponding unit test (test_filter) was added in tests/scan/test_views.py to ensure correct behavior.

Related Issue

Checklist

  • Checked that pre-commit linting/style checks pass
  • Included a test (test_filter) that verifies the correctness of the new function
  • Added necessary documentation (docstring for filter)
  • Each commit corresponds to a relevant logical change

Type of change

  • New feature / enhancement

@ricardoV94, please take a look when available

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just some small suggestions/request

name : str or None
See ``scan``.
"""
mask, _ = scan(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may make sense to allow returning one a mask per sequence? If there's only one we use the same for every sequence like now

non_sequences=non_sequences,
go_backwards=go_backwards,
mode=mode,
name=f"{name or ''}_mask",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use name directly



def test_filter():
import pytensor.tensor as pt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do imports at tho global lever not inside the tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add filter to scan/views.py

2 participants