2424 ScalarFromTensor ,
2525 TensorFromScalar ,
2626 alloc ,
27+ arange ,
2728 cast ,
2829 concatenate ,
2930 expand_dims ,
3435 switch ,
3536)
3637from pytensor .tensor .basic import constant as tensor_constant
37- from pytensor .tensor .blockwise import Blockwise
38+ from pytensor .tensor .blockwise import Blockwise , _squeeze_left
3839from pytensor .tensor .elemwise import Elemwise
3940from pytensor .tensor .exceptions import NotScalarConstantError
41+ from pytensor .tensor .extra_ops import broadcast_to
4042from pytensor .tensor .math import (
4143 add ,
4244 and_ ,
5860)
5961from pytensor .tensor .shape import (
6062 shape_padleft ,
63+ shape_padright ,
6164 shape_tuple ,
6265)
6366from pytensor .tensor .sharedvar import TensorSharedVariable
@@ -1580,6 +1583,8 @@ def local_blockwise_of_subtensor(fgraph, node):
15801583 """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
15811584
15821585 Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1586+
1587+ TODO: Handle batched indices like we do with blockwise of inc_subtensor
15831588 """
15841589 if not isinstance (node .op .core_op , Subtensor ):
15851590 return
@@ -1600,64 +1605,150 @@ def local_blockwise_of_subtensor(fgraph, node):
16001605@register_stabilize ("shape_unsafe" )
16011606@register_specialize ("shape_unsafe" )
16021607@node_rewriter ([Blockwise ])
1603- def local_blockwise_advanced_inc_subtensor (fgraph , node ):
1604- """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1605- if not isinstance (node .op .core_op , AdvancedIncSubtensor ):
1606- return None
1608+ def local_blockwise_inc_subtensor (fgraph , node ):
1609+ """Rewrite blockwised inc_subtensors.
16071610
1608- x , y , * idxs = node .inputs
1611+ Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1612+ Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
16091613
1610- # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1611- if any (
1612- (
1613- isinstance (idx , SliceType | NoneTypeT )
1614- or (idx .type .dtype == "bool" and idx .type .ndim > 0 )
1615- )
1616- for idx in idxs
1617- ):
1614+ such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1615+ and can be safely rewritten without Blockwise.
1616+ """
1617+ core_op = node .op .core_op
1618+ if not isinstance (core_op , AdvancedIncSubtensor | IncSubtensor ):
16181619 return None
16191620
1620- op : Blockwise = node .op # type: ignore
1621- batch_ndim = op .batch_ndim (node )
1622-
1623- new_idxs = []
1624- for idx in idxs :
1625- if all (idx .type .broadcastable [:batch_ndim ]):
1626- new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1627- else :
1628- # Rewrite does not apply
1621+ x , y , * idxs = node .inputs
1622+ [out ] = node .outputs
1623+ if isinstance (node .op .core_op , AdvancedIncSubtensor ):
1624+ if any (
1625+ (
1626+ # Blockwise requires all inputs to be tensors so it is not possible
1627+ # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1628+ # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1629+ # are separated by basic indices
1630+ isinstance (idx , SliceType | NoneTypeT )
1631+ # Also get out if we have boolean indices as they cross dimension boundaries
1632+ # / can't be safely broadcasted depending on their runtime content
1633+ or (idx .type .dtype == "bool" )
1634+ )
1635+ for idx in idxs
1636+ ):
16291637 return None
16301638
1631- x_batch_bcast = x .type .broadcastable [:batch_ndim ]
1632- y_batch_bcast = y .type .broadcastable [:batch_ndim ]
1633- if any (xb and not yb for xb , yb in zip (x_batch_bcast , y_batch_bcast , strict = True )):
1634- # Need to broadcast batch x dims
1635- batch_shape = tuple (
1636- x_dim if (not xb or yb ) else y_dim
1637- for xb , x_dim , yb , y_dim in zip (
1638- x_batch_bcast ,
1639+ batch_ndim = node .op .batch_ndim (node )
1640+ idxs_core_ndim = [len (inp_sig ) for inp_sig in node .op .inputs_sig [2 :]]
1641+ max_idx_core_ndim = max (idxs_core_ndim , default = 0 )
1642+
1643+ # Step 1. Broadcast buffer to batch_shape
1644+ if x .type .broadcastable != out .type .broadcastable :
1645+ batch_shape = [1 ] * batch_ndim
1646+ for inp in node .inputs :
1647+ for i , (broadcastable , batch_dim ) in enumerate (
1648+ zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1649+ ):
1650+ if broadcastable :
1651+ # This dimension is broadcastable, it doesn't provide shape information
1652+ continue
1653+ if batch_shape [i ] != 1 :
1654+ # We already found a source of shape for this batch dimension
1655+ continue
1656+ batch_shape [i ] = batch_dim
1657+ x = broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1658+ assert x .type .broadcastable == out .type .broadcastable
1659+
1660+ # Step 2. Massage indices so they respect blockwise semantics
1661+ if isinstance (core_op , IncSubtensor ):
1662+ # For basic IncSubtensor there are two cases:
1663+ # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1664+ # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1665+ # in case we can end up with a basic IncSubtensor again
1666+ core_idxs = []
1667+ counter = 0
1668+ for idx in core_op .idx_list :
1669+ if isinstance (idx , slice ):
1670+ # Squeeze away dummy dimensions so we can convert to slice
1671+ new_entries = [None , None , None ]
1672+ for i , entry in enumerate ((idx .start , idx .stop , idx .step )):
1673+ if entry is None :
1674+ continue
1675+ else :
1676+ new_entries [i ] = new_entry = idxs [counter ].squeeze ()
1677+ counter += 1
1678+ if new_entry .ndim > 0 :
1679+ # If the slice entry has dimensions after the squeeze we can't convert it to a slice
1680+ # We could try to convert to equivalent integer indices, but nothing guarantees
1681+ # that the slice is "square".
1682+ return None
1683+ core_idxs .append (slice (* new_entries ))
1684+ else :
1685+ core_idxs .append (_squeeze_left (idxs [counter ]))
1686+ counter += 1
1687+ else :
1688+ # For AdvancedIncSubtensor we have tensor integer indices,
1689+ # We need to expand batch indexes on the right, so they don't interact with core index dimensions
1690+ # We still squeeze on the left in case that allows us to use simpler indices
1691+ core_idxs = [
1692+ _squeeze_left (
1693+ shape_padright (idx , max_idx_core_ndim - idx_core_ndim ),
1694+ stop_at_dim = batch_ndim ,
1695+ )
1696+ for idx , idx_core_ndim in zip (idxs , idxs_core_ndim )
1697+ ]
1698+
1699+ # Step 3. Create new indices for the new batch dimension of x
1700+ if not all (
1701+ all (idx .type .broadcastable [:batch_ndim ])
1702+ for idx in idxs
1703+ if not isinstance (idx , slice )
1704+ ):
1705+ # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1706+ # We build vectorized indexing with new arange indices that do not interact with core indices or each other
1707+ # (i.e., they broadcast)
1708+
1709+ # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1710+ # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1711+ # even if not all batch dimensions have corresponding batch indices.
1712+ batch_slices = [
1713+ shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1714+ for (x_batch_shape , n ) in zip (
16391715 tuple (x .shape )[:batch_ndim ],
1640- y_batch_bcast ,
1641- tuple (y .shape )[:batch_ndim ],
1642- strict = True ,
1716+ reversed (range (max_idx_core_ndim , max_idx_core_ndim + batch_ndim )),
16431717 )
1644- )
1645- core_shape = tuple (x .shape )[batch_ndim :]
1646- x = alloc (x , * batch_shape , * core_shape )
1647-
1648- new_idxs = [slice (None )] * batch_ndim + new_idxs
1649- x_view = x [tuple (new_idxs )]
1650-
1651- # We need to introduce any implicit expand_dims on core dimension of y
1652- y_core_ndim = y .type .ndim - batch_ndim
1653- if (missing_y_core_ndim := x_view .type .ndim - batch_ndim - y_core_ndim ) > 0 :
1654- missing_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1655- y = expand_dims (y , missing_axes )
1656-
1657- symbolic_idxs = x_view .owner .inputs [1 :]
1658- new_out = op .core_op .make_node (x , y , * symbolic_idxs ).outputs
1659- copy_stack_trace (node .outputs , new_out )
1660- return new_out
1718+ ]
1719+ else :
1720+ # In the case we don't have
1721+ batch_slices = [slice (None )] * batch_ndim
1722+
1723+ new_idxs = (* batch_slices , * core_idxs )
1724+ x_view = x [new_idxs ]
1725+
1726+ # Step 4. Introduce any implicit expand_dims on core dimension of y
1727+ missing_y_core_ndim = x_view .type .ndim - y .type .ndim
1728+ implicit_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1729+ y = _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1730+
1731+ if isinstance (core_op , IncSubtensor ):
1732+ # Check if we can still use a basic IncSubtensor
1733+ if isinstance (x_view .owner .op , Subtensor ):
1734+ new_props = core_op ._props_dict ()
1735+ new_props ["idx_list" ] = x_view .owner .op .idx_list
1736+ new_core_op = type (core_op )(** new_props )
1737+ symbolic_idxs = x_view .owner .inputs [1 :]
1738+ new_out = new_core_op (x , y , * symbolic_idxs )
1739+ else :
1740+ # We need to use AdvancedSet/IncSubtensor
1741+ if core_op .set_instead_of_inc :
1742+ new_out = x [new_idxs ].set (y )
1743+ else :
1744+ new_out = x [new_idxs ].inc (y )
1745+ else :
1746+ # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1747+ symbolic_idxs = x_view .owner .inputs [1 :]
1748+ new_out = core_op (x , y , * symbolic_idxs )
1749+
1750+ copy_stack_trace (out , new_out )
1751+ return [new_out ]
16611752
16621753
16631754@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
0 commit comments