Skip to content

Commit 7ae4d33

Browse files
committed
Move slice dispatcher functionality to subtensor.py
1 parent 37da5f6 commit 7ae4d33

File tree

2 files changed

+82
-78
lines changed

2 files changed

+82
-78
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import operator
2-
import sys
31
import warnings
42
from copy import copy
53
from functools import singledispatch
@@ -8,11 +6,10 @@
86
import numba
97
import numba.np.unsafe.ndarray as numba_ndarray
108
import numpy as np
11-
from llvmlite import ir
129
from numba import types
1310
from numba.core.errors import NumbaWarning, TypingError
1411
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
15-
from numba.extending import box, overload
12+
from numba.extending import overload
1613

1714
from pytensor import In, config
1815
from pytensor.compile import NUMBA
@@ -36,7 +33,7 @@
3633
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3734
from pytensor.tensor.sort import ArgSortOp, SortOp
3835
from pytensor.tensor.type import TensorType
39-
from pytensor.tensor.type_other import MakeSlice, NoneConst
36+
from pytensor.tensor.type_other import NoneConst
4037

4138

4239
def numba_njit(*args, fastmath=None, **kwargs):
@@ -149,69 +146,6 @@ def create_numba_signature(
149146
return numba.types.void(*input_types)
150147

151148

152-
def slice_new(self, start, stop, step):
153-
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
154-
fn = self._get_function(fnty, name="PySlice_New")
155-
return self.builder.call(fn, [start, stop, step])
156-
157-
158-
def enable_slice_boxing():
159-
"""Enable boxing for Numba's native ``slice``s.
160-
161-
TODO: this can be removed when https://github.com/numba/numba/pull/6939 is
162-
merged and a release is made.
163-
"""
164-
165-
@box(types.SliceType)
166-
def box_slice(typ, val, c):
167-
"""Implement boxing for ``slice`` objects in Numba.
168-
169-
This makes it possible to return an Numba's internal representation of a
170-
``slice`` object as a proper ``slice`` to Python.
171-
"""
172-
start = c.builder.extract_value(val, 0)
173-
stop = c.builder.extract_value(val, 1)
174-
175-
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
176-
177-
start_is_none = c.builder.icmp_signed("==", start, none_val)
178-
start = c.builder.select(
179-
start_is_none,
180-
c.pyapi.get_null_object(),
181-
c.box(types.int64, start),
182-
)
183-
184-
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
185-
stop = c.builder.select(
186-
stop_is_none,
187-
c.pyapi.get_null_object(),
188-
c.box(types.int64, stop),
189-
)
190-
191-
if typ.has_step:
192-
step = c.builder.extract_value(val, 2)
193-
step_is_none = c.builder.icmp_signed("==", step, none_val)
194-
step = c.builder.select(
195-
step_is_none,
196-
c.pyapi.get_null_object(),
197-
c.box(types.int64, step),
198-
)
199-
else:
200-
step = c.pyapi.get_null_object()
201-
202-
slice_val = slice_new(c.pyapi, start, stop, step)
203-
204-
return slice_val
205-
206-
@numba.extending.overload(operator.contains)
207-
def in_seq_empty_tuple(x, y):
208-
if isinstance(x, types.Tuple) and not x.types:
209-
return lambda x, y: False
210-
211-
212-
enable_slice_boxing()
213-
214-
215149
def to_scalar(x):
216150
return np.asarray(x).item()
217151

@@ -388,15 +322,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
388322
return deepcopyop
389323

390324

391-
@numba_funcify.register(MakeSlice)
392-
def numba_funcify_MakeSlice(op, **kwargs):
393-
@numba_njit
394-
def makeslice(*x):
395-
return slice(*x)
396-
397-
return makeslice
398-
399-
400325
@numba_funcify.register(Shape)
401326
def numba_funcify_Shape(op, **kwargs):
402327
@numba_njit

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1+
import operator
2+
import sys
3+
4+
import numba
15
import numpy as np
6+
from llvmlite import ir
7+
from numba import types
8+
from numba.core.pythonapi import box
29

310
from pytensor.graph import Type
411
from pytensor.link.numba.dispatch import numba_funcify
@@ -14,7 +21,79 @@
1421
IncSubtensor,
1522
Subtensor,
1623
)
17-
from pytensor.tensor.type_other import NoneTypeT, SliceType
24+
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
25+
26+
27+
def slice_new(self, start, stop, step):
28+
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
29+
fn = self._get_function(fnty, name="PySlice_New")
30+
return self.builder.call(fn, [start, stop, step])
31+
32+
33+
def enable_slice_boxing():
34+
"""Enable boxing for Numba's native ``slice``s.
35+
36+
TODO: this can be removed when https://github.com/numba/numba/pull/6939 is
37+
merged and a release is made.
38+
"""
39+
40+
@box(types.SliceType)
41+
def box_slice(typ, val, c):
42+
"""Implement boxing for ``slice`` objects in Numba.
43+
44+
This makes it possible to return an Numba's internal representation of a
45+
``slice`` object as a proper ``slice`` to Python.
46+
"""
47+
start = c.builder.extract_value(val, 0)
48+
stop = c.builder.extract_value(val, 1)
49+
50+
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
51+
52+
start_is_none = c.builder.icmp_signed("==", start, none_val)
53+
start = c.builder.select(
54+
start_is_none,
55+
c.pyapi.get_null_object(),
56+
c.box(types.int64, start),
57+
)
58+
59+
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
60+
stop = c.builder.select(
61+
stop_is_none,
62+
c.pyapi.get_null_object(),
63+
c.box(types.int64, stop),
64+
)
65+
66+
if typ.has_step:
67+
step = c.builder.extract_value(val, 2)
68+
step_is_none = c.builder.icmp_signed("==", step, none_val)
69+
step = c.builder.select(
70+
step_is_none,
71+
c.pyapi.get_null_object(),
72+
c.box(types.int64, step),
73+
)
74+
else:
75+
step = c.pyapi.get_null_object()
76+
77+
slice_val = slice_new(c.pyapi, start, stop, step)
78+
79+
return slice_val
80+
81+
@numba.extending.overload(operator.contains)
82+
def in_seq_empty_tuple(x, y):
83+
if isinstance(x, types.Tuple) and not x.types:
84+
return lambda x, y: False
85+
86+
87+
enable_slice_boxing
88+
89+
90+
@numba_funcify.register(MakeSlice)
91+
def numba_funcify_MakeSlice(op, **kwargs):
92+
@numba_njit
93+
def makeslice(*x):
94+
return slice(*x)
95+
96+
return makeslice
1897

1998

2099
@numba_funcify.register(Subtensor)

0 commit comments

Comments
 (0)