Skip to content

Commit b58a274

Browse files
jstacclaudeHumphreyYang
authored
[mccall_model] Update McCall model lecture --- take 2! (#663)
* Update mccall_model.md: Convert from Numba to JAX implementation Converted the McCall job search model lecture from using Numba to JAX for better performance and modern functional programming approach. Key changes: - Replaced Numba's @jit and @jitclass with JAX's @jax.jit - Converted NumPy arrays to JAX arrays (jnp) - Used NamedTuple instead of jitclass for model classes - Implemented JAX-style while_loop for iterations - Added vmap for efficient vectorized computations - Updated all code examples and exercises to use JAX The conversion maintains all functionality while providing improved performance and compatibility with modern Python scientific computing stack. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix JAX solution and add Numba vs JAX benchmark Fixed the JAX compute_mean_stopping_time function to avoid JIT compilation issues with dynamic num_reps parameter by moving jax.jit inside the function. Added benchmark_mccall.py to compare Numba vs JAX solutions for exercise mm_ex1. Results show Numba is significantly faster (~6.4x) for this CPU-bound Monte Carlo simulation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * update benchmark code * fix lecture * Optimize McCall model implementations: Parallel Numba + Optimized JAX Major performance improvements to ex_mm1 exercise implementations: **Numba Optimizations (5.39x speedup):** - Added parallel execution with @numba.jit(parallel=True) - Replaced range() with numba.prange() for parallel iteration - Achieves near-linear scaling with CPU cores (8 threads) **JAX Optimizations (~10-15% improvement):** - Improved state management in while_loop - Eliminated redundant jnp.where operation - Removed unnecessary jax.jit wrapper - Added vmap for computing across multiple c values (1.13x speedup) **Performance Results:** - Parallel Numba: 0.0242 ± 0.0014 seconds (🏆 Winner) - Optimized JAX: 0.1529 ± 0.1584 seconds - Numba is 6.31x faster than JAX for this problem **Changes:** - Updated mccall_model.md with optimized implementations - Added comprehensive OPTIMIZATION_REPORT.md with analysis - Created benchmark_numba_vs_jax.py for clean comparison - Removed old benchmark files (superseded) - Deleted benchmark_mccall.py (superseded) Both implementations produce identical results with no bias introduced. For Monte Carlo simulations with sequential logic, parallel Numba is the recommended approach. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * update warmup run * add geometric * updates --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Humphrey Yang <u6474961@anu.edu.au>
1 parent b7c6464 commit b58a274

File tree

1 file changed

+325
-269
lines changed

1 file changed

+325
-269
lines changed

0 commit comments

Comments
 (0)