diff --git a/R/mcmc-parcoord.R b/R/mcmc-parcoord.R index 6891ed46..d5590d2b 100644 --- a/R/mcmc-parcoord.R +++ b/R/mcmc-parcoord.R @@ -132,8 +132,8 @@ mcmc_parcoord <- divg <- sym("Divergent") - draws <- dplyr::filter(data, UQ(divg) == 0) - div_draws <- dplyr::filter(data, UQ(divg) == 1) + draws <- dplyr::filter(data, !!divg == 0) + div_draws <- dplyr::filter(data, !!divg == 1) has_divs <- isTRUE(nrow(div_draws) > 0) graph <- ggplot(draws, aes( @@ -191,7 +191,7 @@ mcmc_parcoord_data <- # 'Parameter' and 'Value' so need to be a little careful) divs <- np %>% validate_nuts_data_frame() %>% - dplyr::filter(UQ(param) == "divergent__") %>% + dplyr::filter(!!param == "divergent__") %>% select(- !!param) %>% rename("Divergent" = !!value) diff --git a/R/mcmc-scatterplots.R b/R/mcmc-scatterplots.R index d4f939e1..dc07b979 100644 --- a/R/mcmc-scatterplots.R +++ b/R/mcmc-scatterplots.R @@ -352,10 +352,10 @@ mcmc_pairs <- function(x, param <- sym("Parameter") val <- sym("Value") np <- validate_nuts_data_frame(np, lp) - divs <- dplyr::filter(np, UQ(param) == "divergent__") %>% pull(UQ(val)) + divs <- dplyr::filter(np, !!param == "divergent__") %>% pull(!!val) divergent__ <- matrix(divs, nrow = n_iter * n_chain, ncol = n_param)[, 1] if (!no_max_td) { - gt_max_td <- (dplyr::filter(np, UQ(param) == "treedepth__") %>% pull(UQ(val))) >= max_treedepth + gt_max_td <- (dplyr::filter(np, !!param == "treedepth__") %>% pull(!!val)) >= max_treedepth max_td_hit__ <- matrix(gt_max_td, nrow = n_iter * n_chain, ncol = n_param)[, 1] } } @@ -674,11 +674,11 @@ pairs_condition <- function(chains = NULL, draws = NULL, nuts = NULL) { divg <- sym("Divergent") xydata$Divergent <- np %>% - dplyr::filter(UQ(param) == "divergent__") %>% - pull(UQ(val)) + dplyr::filter(!!param == "divergent__") %>% + pull(!!val) - divdata <- dplyr::filter(xydata, UQ(divg) == 1) - xydata <- dplyr::filter(xydata, UQ(divg) == 0) + divdata <- dplyr::filter(xydata, !!divg == 1) + xydata <- dplyr::filter(xydata, !!divg == 0) } graph <- ggplot(data = xydata, aes(x = .data$x, y = .data$y)) + @@ -880,7 +880,7 @@ handle_condition <- function(x, condition=NULL, np=NULL, lp=NULL) { } else { param <- sym("Parameter") - mark <- dplyr::filter(np, UQ(param) == condition) + mark <- dplyr::filter(np, !!param == condition) mark <- unstack_to_matrix(mark, Value ~ Chain) } if (condition == "divergent__") { diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R index 21252085..0050a8d2 100644 --- a/R/mcmc-traces.R +++ b/R/mcmc-traces.R @@ -789,7 +789,7 @@ divergence_rug <- function(np, np_style, n_iter, n_chain) { divg <- sym("Divergent") div_info <- np %>% - dplyr::filter(UQ(param) == "divergent__") %>% + dplyr::filter(!!param == "divergent__") %>% group_by(!! iter) %>% summarise( Divergent = ifelse(sum(!! val) > 0, !! iter, NA)