Skip to content

Commit 8f45291

Browse files
authored
Merge pull request #531 from utkarshpawade/fix/nuts-params-list-edge-cases
Fix edge-case validation in nuts_params.list()
2 parents 7d5ffd7 + 563720f commit 8f45291

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# bayesplot (development version)
22

3+
* Validate empty list and zero-row matrix inputs in `nuts_params.list()`.
34
* Validate user-provided `pit` values in `ppc_loo_pit_data()` and `ppc_loo_pit_qq()`, rejecting non-numeric inputs, missing values, and values outside `[0, 1]`.
45
* New `show_marginal` argument to `ppd_*()` functions to show the PPD - the marginal predictive distribution by @mattansb (#425)
56
* `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated.

R/bayesplot-extractors.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,18 @@ nuts_params.stanreg <-
145145
#' @export
146146
#' @method nuts_params list
147147
nuts_params.list <- function(object, pars = NULL, ...) {
148+
if (length(object) == 0) {
149+
abort("'object' must be a non-empty list.")
150+
}
151+
148152
if (!all(sapply(object, is.matrix))) {
149153
abort("All list elements should be matrices.")
150154
}
151155

156+
if (any(vapply(object, nrow, integer(1)) == 0)) {
157+
abort("All matrices in the list must have at least one row.")
158+
}
159+
152160
dd <- lapply(object, dim)
153161
if (length(unique(dd)) != 1) {
154162
abort("All matrices in the list must have the same dimensions.")

tests/testthat/test-extractors.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ x <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = 1:3, b = rnorm(3)))
99

1010
# nuts_params and log_posterior methods -----------------------------------
1111
test_that("nuts_params.list throws errors", {
12+
expect_error(nuts_params.list(list()), "non-empty list")
13+
1214
x[[3]] <- c(a = 1:3, b = rnorm(3))
1315
expect_error(nuts_params.list(x), "list elements should be matrices")
1416

@@ -17,6 +19,20 @@ test_that("nuts_params.list throws errors", {
1719

1820
x[[3]] <- cbind(a = 1:4, b = rnorm(4))
1921
expect_error(nuts_params.list(x), "same dimensions")
22+
23+
zero_row <- list(cbind(a = numeric(0), b = numeric(0)))
24+
expect_error(nuts_params.list(zero_row), "at least one row")
25+
26+
zero_row_nonfirst <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = numeric(0), b = numeric(0)))
27+
expect_error(nuts_params.list(zero_row_nonfirst), "at least one row")
28+
})
29+
30+
test_that("nuts_params.list works with single-chain list", {
31+
single <- list(cbind(a = 1:3, b = rnorm(3)))
32+
np <- nuts_params.list(single)
33+
expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value"))
34+
expect_true(all(np$Chain == 1L))
35+
expect_equal(nrow(np), 6L)
2036
})
2137

2238
test_that("nuts_params.list parameter selection ok", {

0 commit comments

Comments
 (0)