Skip to content

fix repeated recompilation in no-progress-bar fori_collect path#2171

Merged
fehiepsi merged 3 commits intopyro-ppl:masterfrom
simeoncampos:fix-fori-collect-recompilation
Apr 17, 2026
Merged

fix repeated recompilation in no-progress-bar fori_collect path#2171
fehiepsi merged 3 commits intopyro-ppl:masterfrom
simeoncampos:fix-fori-collect-recompilation

Conversation

@simeoncampos
Copy link
Copy Markdown
Contributor

fori_collect(..., progbar=False) recreates loop_fn as a fresh closure on each call. JAX JIT caching depends on function identity, so recreating the closure defeats cache reuse and triggers full XLA recompilation on every MCMC.run().

Root cause

loop_fn was introduced in #1802 to wrap fori_loop for donate_argnums via maybe_jit. Unlike _body_fn which was already cached via @cached_by, the new wrapper was not. Each call produces a new function object, cache miss, recompile.

Fix

Cache loop_fn via @cached_by(fori_collect, body_fun, transform, upper, start_idx, thinning) and move init_val from the closure into an explicit argument. Matches the existing _body_fn caching pattern at util.py:387.

Speedups on repeated MCMC.run() with progress_bar=False (CPU, median of runs 2-5):

  • normal_normal (2p): 0.44s to 0.08s
  • eight_schools (10p): 0.56s to 0.12s
  • logistic_regression (11p): 0.54s to 0.08s
  • stochastic_volatility (~1000p): 1.42s to 0.81s

On GPU (2x RTX A5000): 1.2-2.0x.

Tests

  • test_fori_collect_no_recompilation: regression test for repeated fori_collect(..., progbar=False) calls with correct non-stale results and a populated cached wrapper path
  • test_fori_collect_repeated_mcmc_no_recompilation: end-to-end regression coverage through repeated MCMC.run()

Done with the help of Claude Code.

@Qazalbash
Copy link
Copy Markdown
Collaborator

@simeoncampos CI is failing on lint and formatting. Please run prek run -a or pre-commit run -a and push the new changes.

@simeoncampos
Copy link
Copy Markdown
Contributor Author

done, thanks for flagging.

@juanitorduz juanitorduz requested a review from fehiepsi April 7, 2026 07:50
Copy link
Copy Markdown
Collaborator

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @simeoncampos 1

Comment thread numpyro/util.py Outdated
def loop_fn(collection):
# Cache loop_fn so jit() reuses the compiled trace across calls.
# Without this, loop_fn is a fresh closure each call and jit recompiles.
@cached_by(fori_collect, body_fun, transform, upper, start_idx, thinning)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel cache fori_collect is unnecessary. Since _body_fn is unique, maybe we can create a global partial loop fn instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

@simeoncampos simeoncampos force-pushed the fix-fori-collect-recompilation branch from 8e8c3b1 to a92ea2a Compare April 16, 2026 09:42
Copy link
Copy Markdown
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @simeoncampos!

@fehiepsi fehiepsi merged commit e10f664 into pyro-ppl:master Apr 17, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants