Skip to content

State-Centered Temporal Processes#828

Open
cdc-mitzimorris wants to merge 57 commits into
mainfrom
mem_810_centered_parameterization
Open

State-Centered Temporal Processes#828
cdc-mitzimorris wants to merge 57 commits into
mainfrom
mem_810_centered_parameterization

Conversation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator

Added state-centered parameterizations for all three temporal-process
classes in pyrenew.latent:

  • AR1 — stationary AR(1) on log-Rt levels
  • DifferencedAR1 — AR(1) on first differences of log-Rt (the production
    process)
  • RandomWalk — unconstrained drift on log-Rt

Each class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to
"innovation" to preserve current behavior. Setting "state" switches
the internal sampling from standardized increments to the latent state
path directly.

The state-centered variants are implemented via:

  • For RandomWalk: NumPyro's built-in dist.GaussianRandomWalk, shifted
    by the initial value.
  • For AR1 and DifferencedAR1: two new custom NumPyro Distribution
    subclasses (StateAR1, StateDifferencedAR1) in
    pyrenew/latent/state_centered_distributions.py. Both have vectorized
    log_prob using slice arithmetic (no scan during MCMC) and
    lax.scan-based sample (only called for prior/posterior predictive,
    not on the MCMC gradient path).

Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.

Code added

File Type Purpose
pyrenew/latent/state_centered_distributions.py new StateAR1, StateDifferencedAR1
pyrenew/latent/temporal_processes.py modified parameterization flag on all three classes; _prepare_initial_value helper
test/test_temporal_processes.py modified +31 unit tests (parameterization flag, state-centered shape/site/prior-equivalence)
test/test_helpers.py modified fixed_ar1_state, fixed_differenced_ar1_state factories
test/integration/conftest.py modified he_model_state_centered, he_weekly_rt_model_state_centered, he_weekly_model_state_centered fixtures
test/integration/test_population_infections_he_state_centered.py new 5 end-to-end tests, daily Rt
test/integration/test_population_infections_he_weekly_rt_state_centered.py new 5 end-to-end tests, weekly Rt via WeeklyTemporalProcess
_typos.toml modified Whitelist reparametrized_params (NumPyro upstream attribute name)

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

Benchmark tests on the models written for #819 show a 2x to 4x speedup.
Working on benchmark scripts for more systematic testing.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 19, 2026

Thank you for your contribution @cdc-mitzimorris 🚀! Your github-pages is ready for download 👉 here 👈!
(The artifact expires on 2026-05-26T21:51:47Z. You can re-generate it by re-running the workflow here.)

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an opt-in state-centered parameterization to the AR1, DifferencedAR1, and RandomWalk temporal-process classes in pyrenew.latent. The default "innovation" parameterization preserves the existing behavior; passing parameterization="state" switches the model to sample the latent state path directly, which can offer better HMC geometry when posteriors are tightly informed by data. To support this, three new NumPyro Distribution subclasses (StateRandomWalk, StateAR1, StateDifferencedAR1) are added with vectorized log_prob and lax.scan-based sample methods. Unit tests verify exact log-density equivalence with the manual transition density and prior-moment equivalence between the two parameterizations; new integration tests exercise the state-centered path end-to-end through MultiSignalModel.

Changes:

  • New state_centered_distributions.py with three custom NumPyro Distributions used by the state-mode samplers.
  • Added a parameterization: Literal["innovation", "state"] flag (default "innovation") to AR1, DifferencedAR1, and RandomWalk, with validation, repr updates, and a shared _prepare_initial_value helper.
  • Added unit/integration tests, factory helpers (fixed_ar1_state, fixed_differenced_ar1_state), and shared conftest helpers (_build_he_population_model, three new fixtures). Whitelist reparametrized_params in _typos.toml.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Show a summary per file
File Description
pyrenew/latent/state_centered_distributions.py New file defining StateRandomWalk, StateAR1, StateDifferencedAR1 with sample and log_prob.
pyrenew/latent/temporal_processes.py Adds parameterization arg to the three classes, validation helper, initial-value helper, and state-mode sampling branches.
test/test_temporal_processes.py Adds exact log-prob tests, parameterization-flag tests, and per-class state-mode shape/trace/prior-moment tests.
test/test_helpers.py Adds fixed_ar1_state and fixed_differenced_ar1_state factories.
test/integration/conftest.py Refactors duplicated builder code into _build_he_population_model; adds three state-centered fixtures.
test/integration/test_population_infections_he_state_centered.py New integration test, daily Rt with state-centered AR1.
test/integration/test_population_infections_he_weekly_rt_state_centered.py New integration test, weekly Rt with state-centered DifferencedAR1.
_typos.toml Whitelists reparametrized_params (matches the NumPyro upstream attribute name).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 98.71%. Comparing base (0f223f5) to head (1e99920).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #828      +/-   ##
==========================================
+ Coverage   98.61%   98.71%   +0.10%     
==========================================
  Files          55       56       +1     
  Lines        2023     2182     +159     
==========================================
+ Hits         1995     2154     +159     
  Misses         28       28              
Flag Coverage Δ
unittests 98.71% <100.00%> (+0.10%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

ran the benchmarks on my machine - here are the results:

time python -m benchmarks.suites.rt_params --candidate he --prior both --repeats 3
rt_params suite: 4 candidate(s) x 2 prior(s) x 3 repeat(s) = 24 fits
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 1/3): 62.9s, divergences=0, min ESS/s=0.15
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 2/3): 66.4s, divergences=0, min ESS/s=0.25
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 3/3): 68.4s, divergences=0, min ESS/s=0.14
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 1/3): 63.5s, divergences=0, min ESS/s=5.79
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 2/3): 62.7s, divergences=0, min ESS/s=5.59
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 3/3): 63.4s, divergences=0, min ESS/s=6.61
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 1/3): 68.9s, divergences=0, min ESS/s=1.05
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 2/3): 69.3s, divergences=0, min ESS/s=0.12
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 3/3): 70.7s, divergences=0, min ESS/s=0.47
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 1/3): 17.9s, divergences=0, min ESS/s=28.92
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 2/3): 16.6s, divergences=0, min ESS/s=30.93
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 3/3): 16.8s, divergences=0, min ESS/s=32.74
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 1/3): 79.7s, divergences=0, min ESS/s=0.03
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 2/3): 79.4s, divergences=0, min ESS/s=0.03
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 3/3): 80.2s, divergences=0, min ESS/s=0.03
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 1/3): 30.0s, divergences=0, min ESS/s=10.49
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 2/3): 31.4s, divergences=0, min ESS/s=9.56
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 3/3): 29.4s, divergences=0, min ESS/s=10.88
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 1/3): 72.2s, divergences=0, min ESS/s=0.03
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 2/3): 72.8s, divergences=0, min ESS/s=0.04
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 3/3): 73.8s, divergences=0, min ESS/s=0.04
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 1/3): 22.6s, divergences=0, min ESS/s=42.72
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 2/3): 22.3s, divergences=0, min ESS/s=44.03
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 3/3): 22.3s, divergences=0, min ESS/s=52.17

--- synthetic_he_weekly_hospital | cadence=daily | innovation_sd=0.01 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  65.9         63.2        0.96x
ESS/s Rt (median)             0.748       27.329     36.53x *
ESS/s Rt (min)                0.183        5.997     32.78x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         9.91        0.99x
Tree depth (max)                 10           10        1.00x
E-BFMI (min)                  0.888        0.943      1.06x *
R-hat Rt (max)                1.275        1.006      0.79x *

--- synthetic_he_weekly_hospital | cadence=weekly | innovation_sd=0.01 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  69.7         17.1      0.25x *
ESS/s Rt (median)             1.856       98.404     53.02x *
ESS/s Rt (min)                0.546       30.864     56.49x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         7.17      0.72x *
Tree depth (max)                 10            9      0.90x *
E-BFMI (min)                  0.896        0.925        1.03x
R-hat Rt (max)                1.150        1.005      0.87x *

--- synthetic_he_weekly_hospital | cadence=daily | innovation_sd=0.1 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  79.8         30.2      0.38x *
ESS/s Rt (median)             0.078       72.302    928.36x *
ESS/s Rt (min)                0.032       10.311    322.37x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         8.06      0.81x *
Tree depth (max)                 10           10        1.00x
E-BFMI (min)                  0.901        0.920        1.02x
R-hat Rt (max)                2.350        1.014      0.43x *

--- synthetic_he_weekly_hospital | cadence=weekly | innovation_sd=0.1 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  73.0         22.4      0.31x *
ESS/s Rt (median)             0.098       75.878    772.58x *
ESS/s Rt (min)                0.038       46.309   1226.16x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         7.58      0.76x *
Tree depth (max)                 10            9      0.90x *
E-BFMI (min)                  0.980        0.941        0.96x
R-hat Rt (max)                2.165        1.004      0.46x *

(* marks an improvement over innovation; ratios are state / innovation)

Wrote results to benchmarks/results

real	21m15.997s
user	80m23.152s
sys	0m7.630s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants