Skip to content

Add batched productmap#280

Open
mj023 wants to merge 10 commits intomainfrom
batched_vmap
Open

Add batched productmap#280
mj023 wants to merge 10 commits intomainfrom
batched_vmap

Conversation

@mj023
Copy link
Copy Markdown
Collaborator

@mj023 mj023 commented Mar 20, 2026

Problem

Sometimes running a model is not possible because of memory restrictions. The nested vmaps can lead to JAX creating large arrays for intermediate results, that can be temporarily saved in the GPU memory. Usually these arrays have the dimensions of the State-Action-Space of the model, so looping over batches of half the grid size along one of its dimensions can already halve the peak memory usage. The batching comes at a cost though, the execution time will get progressively worse the smaller the batch size. For big batches the drop in speed is bigger than I would have expected, given that not all the computations can happen at the same time anyways.

New feature

This PR implements a batched version of productmap. The user can for each grid specify the batch size for each states grid. Instead of using vmap to map the Q_and_F_Function along this grid, jax.lax.map will be used, which will then either loop over the batches of gridpoints or if batch_size=0, work like vmap. The batched version will only be used during the solution, as the State-Action-Space for the simulation is already much smaller, as it only depends on the number of simulated subjects.

Tasks

  • Add batched productmap
  • Refactor so not two versions of _base_productmap are needed
  • Fix typing, tests
  • Investigate speed drop

@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community bot commented Mar 20, 2026

Documentation build overview

📚 pylcm | 🛠️ Build #32047654 | 📁 Comparing c729e8e against latest (4c70a64)


🔍 Preview build

Show files changed (32 files in total): 📝 32 modified | ➕ 0 added | ➖ 0 deleted
File Status
index.html 📝 modified
approximating-continuous-shocks/index.html 📝 modified
benchmarking/index.html 📝 modified
benchmarking-1/index.html 📝 modified
beta-delta/index.html 📝 modified
conventions/index.html 📝 modified
debugging/index.html 📝 modified
defining-models/index.html 📝 modified
dispatchers/index.html 📝 modified
function-representation/index.html 📝 modified
grids/index.html 📝 modified
index-1/index.html 📝 modified
index-2/index.html 📝 modified
index-3/index.html 📝 modified
index-4/index.html 📝 modified
installation/index.html 📝 modified
interpolation/index.html 📝 modified
mahler-yum-2024/index.html 📝 modified
mortality/index.html 📝 modified
pandas-interop/index.html 📝 modified
parameters/index.html 📝 modified
precautionary-savings/index.html 📝 modified
precautionary-savings-health/index.html 📝 modified
regimes/index.html 📝 modified
setup/index.html 📝 modified
shocks/index.html 📝 modified
solving-and-simulating/index.html 📝 modified
stochastic-transitions/index.html 📝 modified
tiny/index.html 📝 modified
tiny-example/index.html 📝 modified
transitions/index.html 📝 modified
write-economics/index.html 📝 modified

@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 30, 2026

Benchmark comparison (main → HEAD)

Comparing 8e95dd99 (main) → c729e8e9 (HEAD)

Benchmark Statistic before after Ratio Alert
Mahler-Yum execution time 3.48±0s 3.57±0.03s 1.03
peak GPU mem 262M 262M 1.00
compilation time 1.95m 1.93m 0.99
peak CPU mem 2.19G 2.24G 1.02
Mortality execution time 257±0.1ms 242±1ms 0.94
peak GPU mem 542M 542M 1.00
compilation time 10.3s 10.8s 1.04
peak CPU mem 1.25G 1.24G 1.00
Precautionary Savings - Solve execution time 29.6±3ms 32.0±1ms 1.08
peak GPU mem 8.44M 8.44M 1.00
compilation time 5.10s 4.90s 0.96
peak CPU mem 1.07G 1.06G 1.00
Precautionary Savings - Simulate execution time 139±0.4ms 138±4ms 0.99
peak GPU mem 138M 138M 1.00
compilation time 7.18s 7.11s 0.99
peak CPU mem 1.2G 1.2G 1.00
Precautionary Savings - Solve & Simulate execution time 152±4ms 161±2ms 1.06
peak GPU mem 565M 565M 1.00
compilation time 11.2s 11.2s 1.00
peak CPU mem 1.21G 1.21G 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 292±1ms 289±2ms 0.99
peak GPU mem 2.18G 2.18G 1.00
compilation time 12.0s 12.0s 1.00
peak CPU mem 1.27G 1.26G 1.00

@mj023
Copy link
Copy Markdown
Collaborator Author

mj023 commented Mar 30, 2026

I wasn't yet able to create a good benchmark for a model where batching actually helps. It needs a model that is sufficiently complex, so the compiler can't optimize it well, but it needs to still be runnable somewhat quickly. If using batches actually helps seems to really depend on the model. With Marvins retirement model it nearly cut the memory usage in half, for the Mahler & Yum Model it does very little.

I also fixed an error in the MY model input creation and removed one of the tests, because productmap can now handle scalar inputs.

@mj023 mj023 requested a review from hmgaudecker March 30, 2026 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant