Basic Sankey

library(tidyr)
## Error : 'format_warning' is not an exported object from 'namespace:cli'
library(dplyr)

num_of_sample_data <- 1000
state_labels <- list("Treatment 1", "Treatment 1 & Treatment 2", "Treatment 2", "None")
stage_labels <- list("0 months", "3 months", "6 months", "12 months")
state_indices <- 0:3

# generate patient state data
patient_states <- tibble::tibble(
  t1   = sample(state_indices, num_of_sample_data, replace = TRUE),
  t2   = sample(state_indices, num_of_sample_data, replace = TRUE),
  t3   = sample(state_indices, num_of_sample_data, replace = TRUE),
  t4   = sample(state_indices, num_of_sample_data, replace = TRUE)
)

# format data
data <- patient_states %>%
  group_by(t1, t2, t3, t4) %>%
  summarise(size = n(), .groups = 'drop_last') %>%
  ungroup() %>%
  mutate(dem_category = -1, dem_value = 0)
data$id <- seq.int(nrow(data))
data <- data %>% select(dem_category, dem_value, id, size, t1, t2, t3, t4)

# clean-up
data <- lapply(1:nrow(data), function(i) {
  out <- as.list(data[i, ])
  names(out) <- NULL
  out
})

# Custom colors for the states can be passed as state_colors parameter, which should be a list with either a hex or an rgb string.
nswidgets::create_sankey(data = data, state_labels = state_labels, stage_labels = stage_labels, value_label = "Members", axis_label = "Months After Treatment", caption = "This is a sample caption.")

Sankey With Segmentation (Work in Progress)

library(tidyr)
library(dplyr)

# Set basic metadata
numEntities <- 200
stateLabels <- list("Treatment 1", "Treatment 1 & Treatment 2", "Treatment 2", "None")
stateColors <- list("#DB543C", "#09B0E6", "#5D5D5F", "#FBAD31")
stageLabels <- list("Start", "3 months", "6 months", "12 months")
demographics <- list(
  list("Gender", "Age Group"),
  list("Female Short TestLongLabel", "Male"),
  list("TestLongLabel 18-39", "40-49", "50-64", "65+")
)
demographicColors <- list(list("#37FD62", "#1867FB", "#EE8800"), list("#1867FB", "#66CB89", "#AE5699", "#EE8800", "#37FD62"))

# Create clusters
stateIndices <- 0:3
d1Indices <- 0:1
d2Indices <- 0:3

patientStates <- tibble::tibble(
  pid  = 1:numEntities,
  d1   = sample(d1Indices, numEntities, replace = TRUE),
  d2   = sample(d2Indices, numEntities, replace = TRUE),
  t1   = sample(stateIndices, numEntities, replace = TRUE),
  t2   = sample(stateIndices, numEntities, replace = TRUE),
  t3   = sample(stateIndices, numEntities, replace = TRUE),
  t4   = sample(stateIndices, numEntities, replace = TRUE)
)

flowTableNoD <- patientStates %>%
  select(-d1, -d2) %>%
  group_by(t1, t2, t3, t4) %>%
  summarise(size = n(), .groups = 'drop_last') %>%
  mutate(demCat = -1, demVal = 0) %>%
  select(demCat, demVal, size, t1, t2, t3, t4) %>%
  ungroup()

flowTableD1 <- patientStates %>%
  select(-d2) %>%
  group_by(d1, t1, t2, t3, t4) %>%
  summarise(size = n(), .groups = 'drop_last') %>%
  mutate(demCat = 0) %>%
  select(demCat, demVal = d1, size, t1, t2, t3, t4) %>%
  ungroup()

flowTableD2 <- patientStates %>%
  select(-d1) %>%
  group_by(d2, t1, t2, t3, t4) %>%
  summarise(size = n(), .groups = 'drop_last') %>%
  mutate(demCat = 1) %>%
  select(demCat, demVal = d2, size, t1, t2, t3, t4) %>%
  ungroup()

allForks <- rbind(flowTableNoD, flowTableD1, flowTableD2) %>% as.data.frame()
allForks$clusterID <- seq.int(nrow(allForks))
allForks <- allForks %>% select(demCat, demVal, clusterID, size, t1, t2, t3, t4)

clusters <- lapply(1:nrow(allForks), function(i) { out <- as.list(allForks[i, ]); names(out) <- NULL; out})

nswidgets::create_sankey(data = clusters, state_labels = stateLabels, stage_labels = stageLabels, state_colors = stateColors, demographics = demographics, demographic_colors = demographicColors)

Sankey With Subset (Work in Progress)

library(tidyr)
library(dplyr)

# Set basic metadata
states <- list("Treatment 1", "Treatment 1 & Treatment 2", "Treatment 2", "None")
stateColors <- list("#DB543C", "#09B0E6", "#5D5D5F", "#FBAD31")
stageLabels <- list("Start", "3 months", "6 months", "12 months")
demographics <- list(list("Gender", "Age Group"), list("Female Short TestLongLabel", "Male"), list("TestLongLabel 18-39", "40-49", "50-64", "65+"))
demographicColors <- list(list("#37FD62", "#1867FB", "#EE8800"), list("#1867FB", "#66CB89", "#AE5699", "#EE8800", "#37FD62"))

createDataSet <- function(numEntities, label) {
  # Create clusters
  stateIndices <- 0:3
  d1Indices <- 0:1
  d2Indices <- 0:3

  patientStates <- tibble::tibble(
    pid  = 1:numEntities,
    d1   = sample(d1Indices, numEntities, replace = TRUE),
    d2   = sample(d2Indices, numEntities, replace = TRUE),
    t1   = sample(stateIndices, numEntities, replace = TRUE),
    t2   = sample(stateIndices, numEntities, replace = TRUE),
    t3   = sample(stateIndices, numEntities, replace = TRUE),
    t4   = sample(stateIndices, numEntities, replace = TRUE)
  )

  flowTableNoD <- patientStates %>%
    select(-d1, -d2) %>%
    group_by(t1, t2, t3, t4) %>%
    summarise(size = n(), .groups = 'drop_last') %>%
    mutate(demCat = -1, demVal = 0) %>%
    select(demCat, demVal, size, t1, t2, t3, t4) %>%
    ungroup()

  flowTableD1 <- patientStates %>%
    select(-d2) %>%
    group_by(d1, t1, t2, t3, t4) %>%
    summarise(size = n(), .groups = 'drop_last') %>%
    mutate(demCat = 0) %>%
    select(demCat, demVal = d1, size, t1, t2, t3, t4) %>%
    ungroup()

  flowTableD2 <- patientStates %>%
    select(-d1) %>%
    group_by(d2, t1, t2, t3, t4) %>%
    summarise(size = n(), .groups = 'drop_last') %>%
    mutate(demCat = 1) %>%
    select(demCat, demVal = d2, size, t1, t2, t3, t4) %>%
    ungroup()

  allForks <- rbind(flowTableNoD, flowTableD1, flowTableD2) %>% as.data.frame()
  allForks$clusterID <- seq.int(nrow(allForks))
  allForks <- allForks %>% select(demCat, demVal, clusterID, size, t1, t2, t3, t4)

  clusters <- lapply(1:nrow(allForks), function(i) { out <- as.list(allForks[i, ]); names(out) <- NULL; out})

  list(label = label, stateLabels = states, stateColors = stateColors, stageLabels = stageLabels,
    demographics = demographics, demographicColors = demographicColors, data = clusters)
}

dataSets <- list(option1 = createDataSet(200, 'Dataset 1'), option2 = createDataSet(250, 'Dataset 2'))

nswidgets::create_subset_sankey(dataSets = dataSets)