We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2fd3c8a commit 2b591caCopy full SHA for 2b591ca
s2fft/utils/jax_primitive.py
@@ -13,6 +13,7 @@ def register_primitive(
13
batcher: Optional[Callable] = None,
14
jacobian_vector_product: Optional[Callable] = None,
15
transpose: Optional[Callable] = None,
16
+ is_linear: bool = False,
17
):
18
"""
19
Register a new custom JAX primitive.
@@ -44,5 +45,8 @@ def register_primitive(
44
45
if jacobian_vector_product is not None:
46
ad.primitive_jvps[primitive] = jacobian_vector_product
47
if transpose is not None:
- ad.primitive_transposes[primitive] = transpose
48
+ if is_linear:
49
+ ad.deflinear(primitive, transpose)
50
+ else:
51
+ ad.primitive_transposes[primitive] = transpose
52
return primitive
0 commit comments