Skip to content

Commit 2b591ca

Browse files
committed
Update JAX Primitive to accept is_linear
1 parent 2fd3c8a commit 2b591ca

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

s2fft/utils/jax_primitive.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def register_primitive(
1313
batcher: Optional[Callable] = None,
1414
jacobian_vector_product: Optional[Callable] = None,
1515
transpose: Optional[Callable] = None,
16+
is_linear: bool = False,
1617
):
1718
"""
1819
Register a new custom JAX primitive.
@@ -44,5 +45,8 @@ def register_primitive(
4445
if jacobian_vector_product is not None:
4546
ad.primitive_jvps[primitive] = jacobian_vector_product
4647
if transpose is not None:
47-
ad.primitive_transposes[primitive] = transpose
48+
if is_linear:
49+
ad.deflinear(primitive, transpose)
50+
else:
51+
ad.primitive_transposes[primitive] = transpose
4852
return primitive

0 commit comments

Comments
 (0)