FreesearchR/R/plot_sankey.R

315 lines
9.9 KiB
R

#' Readying data for sankey plot
#'
#' @name data-plots
#'
#' @returns data.frame
#' @export
#'
#' @examples
#' ds <- data.frame(g = sample(LETTERS[1:2], 100, TRUE), first = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)), last = sample(c(letters[1:4], NA), 100, TRUE, prob = c(rep(.23, 4), .08)))
#' ds |> sankey_ready("first", "last")
#' ds |> sankey_ready("first", "last", numbers = "percentage")
#' data.frame(
#' g = sample(LETTERS[1:2], 100, TRUE),
#' first = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)),
#' last = sample(c(TRUE, FALSE, FALSE), 100, TRUE)
#' ) |>
#' sankey_ready("first", "last")
sankey_ready <- function(data, pri, sec, numbers = "count", ...) {
## TODO: Ensure ordering x and y
## Ensure all are factors
data <- data[c(pri, sec)] |>
dplyr::mutate(dplyr::across(!dplyr::where(is.factor), forcats::as_factor))
out <- dplyr::count(data, !!dplyr::sym(pri), !!dplyr::sym(sec), .drop = FALSE)
out <- out |>
dplyr::group_by(!!dplyr::sym(pri)) |>
dplyr::mutate(gx.sum = sum(n)) |>
dplyr::ungroup() |>
dplyr::group_by(!!dplyr::sym(sec)) |>
dplyr::mutate(gy.sum = sum(n)) |>
dplyr::ungroup()
if (numbers == "count") {
out <- out |> dplyr::mutate(lx = factor(paste0(
!!dplyr::sym(pri), "\n(n=", gx.sum, ")"
)), ly = factor(paste0(
!!dplyr::sym(sec), "\n(n=", gy.sum, ")"
)))
} else if (numbers == "percentage") {
out <- out |> dplyr::mutate(lx = factor(paste0(
!!dplyr::sym(pri), "\n(", round((gx.sum / sum(n)) * 100, 1), "%)"
)), ly = factor(paste0(
!!dplyr::sym(sec), "\n(", round((gy.sum / sum(n)) * 100, 1), "%)"
)))
}
if (is.factor(data[[pri]])) {
index <- match(levels(data[[pri]]), str_remove_last(levels(out$lx), "\n"))
out$lx <- factor(out$lx, levels = levels(out$lx)[index])
}
if (is.factor(data[[sec]])) {
index <- match(levels(data[[sec]]), str_remove_last(levels(out$ly), "\n"))
out$ly <- factor(out$ly, levels = levels(out$ly)[index])
}
out
}
str_remove_last <- function(data, pattern = "\n") {
strsplit(data, split = pattern) |>
lapply(\(.x)paste(unlist(.x[[-length(.x)]]), collapse = pattern)) |>
unlist()
}
#' Beautiful sankey plot with option to split by a tertiary group
#'
#' @returns ggplot2 object
#' @export
#'
#' @name data-plots
#'
#' @examples
#' ds <- data.frame(g = sample(LETTERS[1:2], 100, TRUE), first = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)), last = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)))
#' ds |> plot_sankey("first", "last")
#' ds |> plot_sankey("first", "last", color.group = "sec")
#' ds |> plot_sankey("first", "last", ter = "g", color.group = "sec")
#' mtcars |>
#' default_parsing() |>
#' plot_sankey("cyl", "gear", "am", color.group = "pri")
#' ## In this case, the last plot as the secondary variable in wrong order
#' ## Dont know why...
#' mtcars |>
#' default_parsing() |>
#' plot_sankey("cyl", "gear", "vs", color.group = "pri",color.palette="inferno")
plot_sankey <- function(data,
pri,
sec,
ter = NULL,
color.group = "pri",
colors = NULL,
color.palette = "viridis",
default.color = "#2986cc",
box.color = "#1E4B66",
na.color = "grey80",
missing.level = "Missing") {
if (!is.null(ter)) {
ds <- split(data, data[ter])
} else {
ds <- list(data)
}
# browser()
out <- lapply(ds, \(.ds) {
plot_sankey_single(
.ds,
pri = pri,
sec = sec,
color.palette = color.palette,
color.group = color.group,
colors = colors,
missing.level = missing.level
)
})
patchwork::wrap_plots(out)
}
#' Beautiful sankey plot
#'
#' @param color.group set group to colour by. "x" or "y".
#' @param colors optinally specify colors. Give NA color, color for each level
#' in primary group and color for each level in secondary group.
#' @param ... passed to sankey_ready()
#'
#' @returns ggplot2 object
#' @export
#'
#' @examples
#' ds <- data.frame(g = sample(LETTERS[1:2], 100, TRUE), first = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)), last = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)))
#' ds |> plot_sankey_single("first", "last")
#' ds |> plot_sankey_single("first", "last", color.group = "sec")
#' data.frame(
#' g = sample(LETTERS[1:2], 100, TRUE),
#' first = REDCapCAST::as_factor(sample(letters[1:4], 100, TRUE)),
#' last = sample(c(TRUE, FALSE, FALSE), 100, TRUE)
#' ) |>
#' plot_sankey_single("first", "last", color.group = "pri")
#' mtcars |>
#' default_parsing() |>
#' plot_sankey_single("cyl", "vs", color.group = "pri")
#' stRoke::trial |>
#' default_parsing() |>
#' plot_sankey_single("diabetes", "hypertension")
#'
#'
#' # stRoke::trial |> plot_sankey_single("mrs_1", "mrs_6", color.palette="magma")
#' # stRoke::trial |> plot_sankey_single("active", "male")
#' # stRoke::trial |> plot_sankey_single("diabetes", "active", color.group="sec")
#' # stRoke::trial |> plot_sankey_single("active", "diabetes", color.group="sec", color.palette="topo")
plot_sankey_single <- function(data,
pri,
sec,
color.group = c("pri", "sec"),
color.palette = "viridis",
colors=NULL,
missing.level = "Missing",
default.color = "#2986cc",
box.color = "#1E4B66",
na.color = "grey80",
...) {
color.group <- match.arg(color.group)
data_orig <- data
data[c(pri, sec)] <- with_labels(data,{
data[c(pri, sec)] |>
to_clean_levels() |>
missing_to_text_levels(missing.text=missing.level)
})
## Aggregate data
data_aggr <- data |> sankey_ready(pri = pri, sec = sec, ...)
default.color = default.color
box.color = box.color
na.color = na.color
if (is.null(colors)) {
if (color.group == "sec") {
main.colors <- color_levels_gen(data_orig[[sec]],palette=color.palette)
secondary.colors <- rep(default.color, length(levels(data[[pri]])))
label.colors <- Reduce(c, lapply(list(
secondary.colors, rev(main.colors)
), contrast_text))
} else {
main.colors <- color_levels_gen(data_orig[[pri]],palette=color.palette)
secondary.colors <- rep(default.color, length(levels(data[[sec]])))
label.colors <- Reduce(c, lapply(list(
rev(main.colors), secondary.colors
), contrast_text))
}
colors <- c(default.color, main.colors, secondary.colors)
colors[is.na(colors)] <- na.color
} else {
label.colors <- contrast_text(colors)
}
group_labels <- c(get_label(data, pri), get_label(data, sec)) |>
sapply(line_break) |>
unname()
p <- ggplot2::ggplot(data_aggr, ggplot2::aes(y = n, axis1 = lx, axis2 = ly))
if (color.group == "sec") {
p <- p +
ggalluvial::geom_alluvium(
ggplot2::aes(
fill = !!dplyr::sym(sec) # ,
## Including will print strings when levels are empty
# color = !!dplyr::sym(sec)
),
width = 1 / 16,
alpha = .8,
knot.pos = 0.4,
curve_type = "sigmoid"
) + ggalluvial::geom_stratum(ggplot2::aes(fill = !!dplyr::sym(sec)),
size = 2,
width = 1 / 3.4)
} else {
p <- p +
ggalluvial::geom_alluvium(
ggplot2::aes(
fill = !!dplyr::sym(pri) # ,
# color = !!dplyr::sym(pri)
),
width = 1 / 16,
alpha = .8,
knot.pos = 0.4,
curve_type = "sigmoid"
) + ggalluvial::geom_stratum(ggplot2::aes(fill = !!dplyr::sym(pri)),
size = 2,
width = 1 / 3.4)
}
## Will fail to use stat="stratum" if library is not loaded.
library(ggalluvial)
p +
ggplot2::geom_text(
stat = "stratum",
ggplot2::aes(label = after_stat(stratum)),
colour = label.colors,
size = 6,
lineheight = 1
) +
ggplot2::scale_x_continuous(breaks = 1:2, labels = group_labels) +
ggplot2::scale_fill_manual(values = colors[-1], na.value = colors[1]) +
# ggplot2::scale_color_manual(values = main.colors) +
ggplot2::theme_void() +
ggplot2::theme(
legend.position = "none",
# panel.grid.major = element_blank(),
# panel.grid.minor = element_blank(),
# axis.text.y = element_blank(),
# axis.title.y = element_blank(),
axis.text.x = ggplot2::element_text(size = 20),
# text = element_text(size = 5),
# plot.title = element_blank(),
# panel.background = ggplot2::element_rect(fill = "white"),
plot.background = ggplot2::element_rect(fill = "white"),
panel.border = ggplot2::element_blank()
)
}
# stRoke::trial["male"] |> to_clean_levels()
to_clean_levels <- function(data,missing.text="Missing"){
if (is.data.frame(data)){
data |>
lapply(all_levels_clean) |>
dplyr::bind_cols()
} else {
data |>
all_levels_clean()
}
}
# stRoke::trial["mrs_1"] |> missing_to_text_levels()
missing_to_text_levels <- function(data,missing.text="Missing"){
data |>
dplyr::mutate(
dplyr::across(dplyr::where(is.factor), \(.x) {
if (anyNA(.x)) forcats::fct_na_value_to_level(.x, missing.text) else .x
})
)
}
all_levels_clean <- function(data){
data |>
(\(.x){
if (is.logical(.x)) as.factor(.x) else .x
})() |>
(\(.x){
if (is.factor(.x)) forcats::fct_drop(.x) else .x
})()
}
# stRoke::trial$mrs_1 |> color_levels_gen()
color_levels_gen <- function(data,na.color="grey80",palette="viridis"){
out <- generate_colors(n = length(levels(to_clean_levels(data))),palette = palette)
if (anyNA(data)){
out <- c(out,na.color)
}
out
}