@@ -130,13 +130,18 @@ def non_materializable4(x: Array) -> Array:
130130 return non_materializable (x )
131131
132132
133+ def non_materializable5 (x : Array ) -> Array :
134+ return non_materializable (x )
135+
136+
133137lazy_xp_function (good_lazy )
134138# Works on JAX and Dask
135139lazy_xp_function (non_materializable2 , jax_jit = False , allow_dask_compute = 2 )
140+ lazy_xp_function (non_materializable3 , jax_jit = False , allow_dask_compute = True )
136141# Works on JAX, but not Dask
137- lazy_xp_function (non_materializable3 , jax_jit = False , allow_dask_compute = 1 )
142+ lazy_xp_function (non_materializable4 , jax_jit = False , allow_dask_compute = 1 )
138143# Works neither on Dask nor JAX
139- lazy_xp_function (non_materializable4 )
144+ lazy_xp_function (non_materializable5 )
140145
141146
142147def test_lazy_xp_function (xp : ModuleType ):
@@ -147,29 +152,30 @@ def test_lazy_xp_function(xp: ModuleType):
147152 xp_assert_equal (non_materializable (x ), xp .asarray ([1.0 , 2.0 ]))
148153 # Wrapping explicitly disabled
149154 xp_assert_equal (non_materializable2 (x ), xp .asarray ([1.0 , 2.0 ]))
155+ xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
150156
151157 if is_jax_namespace (xp ):
152- xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
158+ xp_assert_equal (non_materializable4 (x ), xp .asarray ([1.0 , 2.0 ]))
153159 with pytest .raises (
154160 TypeError , match = "Attempted boolean conversion of traced array"
155161 ):
156- _ = non_materializable4 (x ) # Wrapped
162+ _ = non_materializable5 (x ) # Wrapped
157163
158164 elif is_dask_namespace (xp ):
159165 with pytest .raises (
160166 AssertionError ,
161167 match = r"dask\.compute.* 2 times, but only up to 1 calls are allowed" ,
162168 ):
163- _ = non_materializable3 (x )
169+ _ = non_materializable4 (x )
164170 with pytest .raises (
165171 AssertionError ,
166172 match = r"dask\.compute.* 1 times, but no calls are allowed" ,
167173 ):
168- _ = non_materializable4 (x )
174+ _ = non_materializable5 (x )
169175
170176 else :
171- xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
172177 xp_assert_equal (non_materializable4 (x ), xp .asarray ([1.0 , 2.0 ]))
178+ xp_assert_equal (non_materializable5 (x ), xp .asarray ([1.0 , 2.0 ]))
173179
174180
175181def static_params (x : Array , n : int , flag : bool = False ) -> Array :
0 commit comments