You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We can further try to optimize and speed up the compile time of the above function by replacing `for` loop with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html).
336
401
337
402
```{code-cell} ipython3
338
403
def wealth_time_series_jax(w_0, n, wdy, size, rand_seed=1):
339
404
"""
340
405
Generate a single time series of length n for wealth given
341
406
initial value w_0.
342
407
408
+
* This implementation uses for jax.lax.scan
409
+
343
410
The initial persistent state z_0 for each household is drawn from
344
411
the stationary distribution of the AR(1) process.
345
412
@@ -371,7 +438,7 @@ def wealth_time_series_jax(w_0, n, wdy, size, rand_seed=1):
Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution.
441
+
Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution and also observe the difference in time between `wealth_time_series_jax` and `wealth_time_series_for_loop_jax`.
0 commit comments