fix repeated recompilation in no-progress-bar fori_collect path#2171
Merged
fehiepsi merged 3 commits intopyro-ppl:masterfrom Apr 17, 2026
Merged
Conversation
Collaborator
|
@simeoncampos CI is failing on lint and formatting. Please run |
Contributor
Author
|
done, thanks for flagging. |
juanitorduz
approved these changes
Apr 7, 2026
Collaborator
juanitorduz
left a comment
There was a problem hiding this comment.
thanks @simeoncampos 1
fehiepsi
reviewed
Apr 7, 2026
| def loop_fn(collection): | ||
| # Cache loop_fn so jit() reuses the compiled trace across calls. | ||
| # Without this, loop_fn is a fresh closure each call and jit recompiles. | ||
| @cached_by(fori_collect, body_fun, transform, upper, start_idx, thinning) |
Member
There was a problem hiding this comment.
I feel cache fori_collect is unnecessary. Since _body_fn is unique, maybe we can create a global partial loop fn instead?
8e8c3b1 to
a92ea2a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
fori_collect(..., progbar=False)recreatesloop_fnas a fresh closure on each call. JAX JIT caching depends on function identity, so recreating the closure defeats cache reuse and triggers full XLA recompilation on everyMCMC.run().Root cause
loop_fnwas introduced in #1802 to wrapfori_loopfordonate_argnumsviamaybe_jit. Unlike_body_fnwhich was already cached via@cached_by, the new wrapper was not. Each call produces a new function object, cache miss, recompile.Fix
Cache
loop_fnvia@cached_by(fori_collect, body_fun, transform, upper, start_idx, thinning)and moveinit_valfrom the closure into an explicit argument. Matches the existing_body_fncaching pattern atutil.py:387.Speedups on repeated
MCMC.run()withprogress_bar=False(CPU, median of runs 2-5):On GPU (2x RTX A5000): 1.2-2.0x.
Tests
test_fori_collect_no_recompilation: regression test for repeatedfori_collect(..., progbar=False)calls with correct non-stale results and a populated cached wrapper pathtest_fori_collect_repeated_mcmc_no_recompilation: end-to-end regression coverage through repeatedMCMC.run()Done with the help of Claude Code.