FreesearchR/R/regression_plot.R

161 lines
4.3 KiB
R
Raw Normal View History

2025-01-30 14:32:11 +01:00
#' Regression coef plot from gtsummary. Slightly modified to pass on arguments
#'
#' @param x (`tbl_regression`, `tbl_uvregression`)\cr
#' A 'tbl_regression' or 'tbl_uvregression' object
#' @param plot_ref (scalar `logical`)\cr
#' plot reference values
#' @param remove_header_rows (scalar `logical`)\cr
#' logical indicating whether to remove header rows
#' for categorical variables. Default is `TRUE`
#' @param remove_reference_rows (scalar `logical`)\cr
#' logical indicating whether to remove reference rows
#' for categorical variables. Default is `FALSE`.
2025-01-30 14:32:11 +01:00
#' @param ... arguments passed to `ggstats::ggcoef_plot(...)`
#'
#' @returns ggplot object
#' @export
#'
#' @examples
#' \dontrun{
2025-03-19 09:14:36 +01:00
#' mod <- lm(mpg ~ ., default_parsing(mtcars))
2025-01-30 14:32:11 +01:00
#' p <- mod |>
#' gtsummary::tbl_regression() |>
#' plot(colour = "variable")
#' }
#'
plot.tbl_regression <- function(x,
2025-03-19 09:14:36 +01:00
plot_ref = TRUE,
remove_header_rows = TRUE,
remove_reference_rows = FALSE,
2025-01-30 14:32:11 +01:00
...) {
# check_dots_empty()
gtsummary:::check_pkg_installed("ggstats")
gtsummary:::check_not_missing(x)
# gtsummary:::check_scalar_logical(remove_header_rows)
# gtsummary:::check_scalar_logical(remove_reference_rows)
df_coefs <- x$table_body
2025-03-19 13:10:56 +01:00
2025-03-19 09:14:36 +01:00
if (isTRUE(remove_header_rows)) {
df_coefs <- df_coefs |> dplyr::filter(!header_row %in% TRUE)
}
if (isTRUE(remove_reference_rows)) {
df_coefs <- df_coefs |> dplyr::filter(!reference_row %in% TRUE)
}
2025-01-30 14:32:11 +01:00
2025-03-19 09:14:36 +01:00
# Removes redundant label
2025-01-30 14:32:11 +01:00
df_coefs$label[df_coefs$row_type == "label"] <- ""
2025-03-24 14:40:30 +01:00
# browser()
2025-03-19 09:14:36 +01:00
# Add estimate value to reference level
2025-03-24 14:40:30 +01:00
if (plot_ref == TRUE) {
df_coefs[df_coefs$var_type %in% c("categorical", "dichotomous") & df_coefs$reference_row & !is.na(df_coefs$reference_row), "estimate"] <- if (x$inputs$exponentiate) 1 else 0
}
2025-03-19 09:14:36 +01:00
2025-03-19 13:10:56 +01:00
p <- df_coefs |>
2025-01-30 14:32:11 +01:00
ggstats::ggcoef_plot(exponentiate = x$inputs$exponentiate, ...)
2025-03-24 14:40:30 +01:00
if (x$inputs$exponentiate) {
2025-03-19 13:10:56 +01:00
p <- symmetrical_scale_x_log10(p)
}
p
}
2025-01-30 14:32:11 +01:00
#' Wrapper to pivot gtsummary table data to long for plotting
#'
#' @param list a custom regression models list
#' @param model.names names of models to include
#'
#' @returns list
#' @export
#'
merge_long <- function(list, model.names) {
l_subset <- list$tables[model.names]
l_merged <- l_subset |> tbl_merge()
df_body <- l_merged$table_body
sel_list <- lapply(seq_along(l_subset), \(.i){
endsWith(names(df_body), paste0("_", .i))
}) |>
setNames(names(l_subset))
common <- !Reduce(`|`, sel_list)
df_body_long <- sel_list |>
purrr::imap(\(.l, .i){
d <- dplyr::bind_cols(
df_body[common],
df_body[.l],
model = .i
)
setNames(d, gsub("_[0-9]{,}$", "", names(d)))
}) |>
2025-03-24 14:40:30 +01:00
dplyr::bind_rows() |>
2025-03-31 14:37:28 +02:00
dplyr::mutate(model = REDCapCAST::as_factor(model))
2025-01-30 14:32:11 +01:00
l_merged$table_body <- df_body_long
l_merged$inputs$exponentiate <- !identical(class(list$models$Multivariable$model), "lm")
l_merged
}
2025-03-19 13:10:56 +01:00
#' Easily round log scale limits for nice plots
#'
#' @param data data
#' @param fun rounding function (floor/ceiling)
#' @param ... ignored
#'
#' @returns numeric vector
#' @export
#'
#' @examples
2025-03-24 14:40:30 +01:00
#' limit_log(-.1, floor)
#' limit_log(.1, ceiling)
#' limit_log(-2.1, ceiling)
#' limit_log(2.1, ceiling)
limit_log <- function(data, fun, ...) {
fun(10^-floor(data) * 10^data) / 10^-floor(data)
}
#' Create summetric log ticks
#'
#' @param data numeric vector
#'
2025-03-24 14:43:50 +01:00
#' @returns numeric vector
2025-03-24 14:40:30 +01:00
#' @export
#'
#' @examples
#' c(sample(seq(.1, 1, .1), 3), sample(1:10, 3)) |> create_log_tics()
create_log_tics <- function(data) {
sort(round(unique(c(1 / data, data, 1)), 2))
2025-03-19 13:10:56 +01:00
}
#' Ensure symmetrical plot around 1 on a logarithmic x scale for ratio plots
#'
#' @param plot ggplot2 plot
#' @param breaks breaks used and mirrored
#' @param ... ignored
#'
#' @returns ggplot2 object
#' @export
#'
2025-03-24 14:40:30 +01:00
symmetrical_scale_x_log10 <- function(plot, breaks = c(1, 2, 3, 5, 10), ...) {
2025-03-19 13:10:56 +01:00
rx <- ggplot2::layer_scales(plot)$x$get_limits()
2025-03-24 14:40:30 +01:00
x_min <- floor(10 * rx[1]) / 10
x_max <- ceiling(10 * rx[2]) / 10
2025-03-19 13:10:56 +01:00
2025-03-24 14:40:30 +01:00
rx_min <- limit_log(rx[1], floor)
rx_max <- limit_log(rx[2], ceiling)
2025-03-19 13:10:56 +01:00
2025-03-24 14:40:30 +01:00
max_abs_x <- max(abs(c(x_min, x_max)))
2025-03-19 13:10:56 +01:00
2025-03-24 14:40:30 +01:00
ticks <- log10(breaks) + (ceiling(max_abs_x) - 1)
2025-03-19 13:28:34 +01:00
2025-03-24 14:40:30 +01:00
plot + ggplot2::scale_x_log10(limits = c(rx_min, rx_max), breaks = create_log_tics(10^ticks[ticks <= max_abs_x]))
2025-03-19 13:10:56 +01:00
}