diff --git a/DESCRIPTION b/DESCRIPTION index 2a4a273..38d4384 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -39,6 +39,7 @@ Imports: utils, vctrs (>= 0.5.0) Suggests: + arules, cluster, ClusterR, clustMixType (>= 0.3-5), diff --git a/NAMESPACE b/NAMESPACE index 9960b43..4f1acdd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,12 +3,14 @@ S3method(as_tibble,cluster_metric_set) S3method(augment,cluster_fit) S3method(check_args,default) +S3method(check_args,freq_itemsets) S3method(check_args,hier_clust) S3method(check_args,k_means) S3method(extract_cluster_assignment,KMeansCluster) S3method(extract_cluster_assignment,cluster_fit) S3method(extract_cluster_assignment,cluster_spec) S3method(extract_cluster_assignment,hclust) +S3method(extract_cluster_assignment,itemsets) S3method(extract_cluster_assignment,kmeans) S3method(extract_cluster_assignment,kmodes) S3method(extract_cluster_assignment,kproto) @@ -18,6 +20,7 @@ S3method(extract_fit_summary,KMeansCluster) S3method(extract_fit_summary,cluster_fit) S3method(extract_fit_summary,cluster_spec) S3method(extract_fit_summary,hclust) +S3method(extract_fit_summary,itemsets) S3method(extract_fit_summary,kmeans) S3method(extract_fit_summary,kmodes) S3method(extract_fit_summary,kproto) @@ -37,6 +40,7 @@ S3method(print,cluster_fit) S3method(print,cluster_metric_set) S3method(print,cluster_spec) S3method(print,control_cluster) +S3method(print,freq_itemsets) S3method(print,hier_clust) S3method(print,k_means) S3method(required_pkgs,cluster_fit) @@ -58,23 +62,28 @@ S3method(sse_within_total,cluster_spec) S3method(sse_within_total,workflow) S3method(tidy,cluster_fit) S3method(translate_tidyclust,default) +S3method(translate_tidyclust,freq_itemsets) S3method(translate_tidyclust,hier_clust) S3method(translate_tidyclust,k_means) S3method(tunable,cluster_spec) +S3method(tunable,freq_itemsets) S3method(tunable,k_means) S3method(tune_args,cluster_spec) S3method(tune_cluster,cluster_spec) S3method(tune_cluster,default) S3method(tune_cluster,workflow) +S3method(update,freq_itemsets) S3method(update,hier_clust) S3method(update,k_means) export("%>%") +export(.freq_itemsets_fit_arules) export(.hier_clust_fit_stats) export(.k_means_fit_ClusterR) export(.k_means_fit_clustMixType) export(.k_means_fit_klaR) export(.k_means_fit_stats) export(augment) +export(augment_itemset_predict) export(cluster_metric_set) export(control_cluster) export(cut_height) @@ -83,6 +92,7 @@ export(extract_cluster_assignment) export(extract_fit_engine) export(extract_fit_parsnip) export(extract_fit_summary) +export(extract_itemset_predictions) export(extract_parameter_set_dials) export(extract_preprocessor) export(extract_spec_parsnip) @@ -92,6 +102,7 @@ export(fit) export(fit.cluster_spec) export(fit_xy) export(fit_xy.cluster_spec) +export(freq_itemsets) export(get_tidyclust_colors) export(glance) export(hier_clust) diff --git a/R/aaa.R b/R/aaa.R index db45641..b36ee89 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -10,6 +10,7 @@ utils::globalVariables( ".iter_model", ".iter_preprocessor", ".msg_model", + ".pred_item", ".submodels", "call_info", "cluster", @@ -23,6 +24,7 @@ utils::globalVariables( "exposed", "func", "id", + "item", "iteration", "lab", "name", @@ -32,10 +34,14 @@ utils::globalVariables( "orig_label", "original", "predictor_indicators", + "preds", "remove_intercept", + "row_id", "seed", + "setNames", "sil_width", "splits", + "truth_value", "tunable", "type", "value", diff --git a/R/augment_itemset_predict.R b/R/augment_itemset_predict.R new file mode 100644 index 0000000..4031f2d --- /dev/null +++ b/R/augment_itemset_predict.R @@ -0,0 +1,170 @@ +#' Augment Itemset Predictions with Truth Values +#' +#' This function processes the output of a `predict()` call for frequent itemset models +#' and joins it with the corresponding ground truth data. It's designed to prepare +#' the prediction and truth values in a format suitable for calculating evaluation metrics +#' using packages like `yardstick`. +#' +#' @param pred_output A data frame that is the output of `predict()` from a `freq_itemsets` model. +#' It is expected to have a column named `.pred_cluster`, where each cell contains +#' a data frame with prediction details (including `.pred_item`, `.obs_item`, and `item`). +#' @param truth_output A data frame representing the ground truth. It should have a similar +#' structure to the input data used for prediction, where columns represent items +#' and rows represent transactions. +#' +#' @details +#' The function first extracts and combines all individual item prediction data frames +#' nested within the `pred_output`. It then filters for items where a prediction was made +#' (i.e., `!is.na(.pred_item)`) and standardizes item names by removing backticks. +#' The `truth_output` is pivoted to a long format to match the structure of the predictions. +#' Finally, an inner join is performed to ensure that only predicted items are included in +#' the final result, aligning predictions with their corresponding true values. +#' +#' @return A data frame with the following columns: +#' \itemize{ +#' \item `item`: The name of the item. +#' \item `row_id`: An identifier for the transaction (row) from which the prediction came. +#' \item `preds`: The predicted value for the item (either raw probability or binary prediction). +#' \item `truth`: The true value for the item from `truth_output`. +#' } +#' This output is suitable for direct use with `yardstick` metric functions. +#' +#' @examples +#' toy_df <- data.frame( +#' "beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), +#' "milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), +#' "bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), +#' "diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), +#' "eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +#' ) +#' +#' new_data <- data.frame( +#' "beer" = NA, +#' "milk" = TRUE, +#' "bread" = TRUE, +#' "diapers" = TRUE, +#' "eggs" = FALSE +#' ) +#' +#' truth_df <- data.frame( +#' "beer" = FALSE, +#' "milk" = TRUE, +#' "bread" = TRUE, +#' "diapers" = TRUE, +#' "eggs" = FALSE +#' ) +#' +#' fi_spec <- freq_itemsets( +#' min_support = 0.05, +#' mining_method = "eclat" +#' ) |> +#' set_engine("arules") |> +#' set_mode("partition") +#' +#' fi_fit <- fi_spec |> +#' fit(~ ., +#' data = toy_df +#' ) +#' +#' aug_pred <- fi_fit |> +#' predict(new_data, type = "raw") |> +#' augment_itemset_predict(truth_output = truth_df) +#' +#' aug_pred +#' +#' # Example use of formatted output +#' aug_pred |> +#' yardstick::rmse(truth, preds) +#' +#' @export + +augment_itemset_predict <- function(pred_output, truth_output) { + # Extract all predictions (bind all .pred_cluster dataframes) + preds_df <- dplyr::bind_rows(pred_output$.pred_cluster, .id = "row_id") %>% + dplyr::filter(!is.na(.pred_item)) %>% # Keep only rows with predictions + dplyr::mutate( + item = gsub("`|TRUE|FALSE", "", item) # Remove backticks, TRUE, and FALSE from item names + ) + dplyr::select(row_id, item, preds = .pred_item) # Standardize column names + + # Pivot truth data to long format (to match predictions) + truth_long <- truth_output %>% + tibble::rownames_to_column("row_id") %>% + tidyr::pivot_longer( + cols = -row_id, + names_to = "item", + values_to = "truth_value" + ) %>% + dplyr::mutate(truth_value = as.numeric(truth_value)) + + # Join predictions with truth (inner join to keep only predicted items) + result <- preds_df %>% + dplyr::inner_join(truth_long, by = c("row_id", "item")) + + # Return simplified output (preds vs truth) + dplyr::select(result, item, row_id, preds, truth = truth_value) +} + +#' Generate Dataframe with Random NAs and Corresponding Truth +#' +#' @description +#' This helper function creates a new data frame by randomly introducing `NA` values +#' into an input data frame. It also returns the original data frame as a "truth" +#' reference, which can be useful for simulating scenarios with missing data +#' for prediction tasks. +#' +#' @param df The input data frame to which `NA` values will be introduced. +#' It is typically a transactional dataset where columns are items and rows are transactions. +#' @param na_prob The probability (between 0 and 1) that any given cell in the +#' input data frame will be replaced with `NA`. +#' +#' @return A list containing two data frames: +#' \itemize{ +#' \item `na_data`: The data frame with `NA` values randomly introduced. +#' \item `truth`: The original input data frame, serving as the ground truth. +#' } +#' @examples +#' # Create a sample data frame +#' sample_df <- data.frame( +#' itemA = c(1, 0, 1), +#' itemB = c(0, 1, 1), +#' itemC = c(1, 1, 0) +#' ) +#' +#' # Generate NA data and truth with 30% NA probability +#' set.seed(123) +#' na_data_list <- random_na_with_truth(sample_df, na_prob = 0.3) +#' +#' # View the NA data +#' print(na_data_list$na_data) +#' +#' # View the truth data +#' print(na_data_list$truth) +#' +#' This function is not exported as it was used to test and provide examples in +#' the vignettes, it may be formally introduced in the future. +random_na_with_truth <- function(df, na_prob = 0.3) { + # Create a copy of the original dataframe to store truth values + truth_df <- df + + # Create a mask of NAs (TRUE = becomes NA) + na_mask <- matrix( + sample( + c(TRUE, FALSE), + size = nrow(df) * ncol(df), + replace = TRUE, + prob = c(na_prob, 1 - na_prob) + ), + nrow = nrow(df) + ) + + # Apply the mask to create NA values + na_df <- df + na_df[na_mask] <- NA + + # Return both the NA-filled dataframe and the truth + list( + na_data = na_df, + truth = truth_df + ) +} diff --git a/R/extract_cluster_assignment.R b/R/extract_cluster_assignment.R index 1a483b9..377c4ef 100644 --- a/R/extract_cluster_assignment.R +++ b/R/extract_cluster_assignment.R @@ -159,6 +159,79 @@ extract_cluster_assignment.hclust <- function( cluster_assignment_tibble(clusters, length(unique(clusters)), ...) } +#' @export +extract_cluster_assignment.itemsets <- function(object, ...) { + max_iter = 1000 + items <- attr(object, "item_names") + itemsets <- arules::DATAFRAME(object) + + itemset_list <- lapply(strsplit(gsub("[{}]", "", itemsets$items), ","), trimws) + support <- itemsets$support + clusters <- numeric(length(items)) + changed <- TRUE # Flag to track convergence + iter <- 0 # Initialize iteration counter + + # Continue until no changes occur + while (changed && iter < max_iter) { + changed <- FALSE + iter <- iter + 1 + for (i in 1:length(items)) { + current_item <- items[i] + relevant_itemsets <- which(sapply(itemset_list, function(x) current_item %in% x)) + + if (length(relevant_itemsets) == 0) next # Skip if no itemsets + + # Find the best itemset (largest size, then highest support) + best_itemset <- relevant_itemsets[ + which.max( + sapply(itemset_list[relevant_itemsets], length) * 1000 + # Size dominates + support[relevant_itemsets] # Support breaks ties + ) + ] + best_itemset_size <- length(itemset_list[[best_itemset]]) + best_itemset_support <- support[best_itemset] + + # Current cluster info (if any) + current_cluster <- clusters[i] + current_cluster_size <- if (current_cluster != 0) + length(itemset_list[[current_cluster]]) else 0 + current_cluster_support <- if (current_cluster != 0) + support[current_cluster] else 0 + + # Reassign if: + # 1. No current cluster, OR + # 2. New itemset is larger, OR + # 3. Same size but higher support + if (current_cluster == 0 || + best_itemset_size > current_cluster_size || + (best_itemset_size == current_cluster_size && + best_itemset_support > current_cluster_support)) { + + # Assign all items in the best itemset to its cluster + new_cluster <- best_itemset + items_in_best <- match(itemset_list[[best_itemset]], items) + + if (!all(clusters[items_in_best] == new_cluster)) { + clusters[items_in_best] <- new_cluster + changed <- TRUE # Mark that a change occurred + } + } + } + } + + if (iter == max_iter && changed) { + rlang::warn( + paste0( + "Cluster assignment did not converge within the maximum of ", + max_iter, + " iterations. Returning the current cluster assignments." + ) + ) + } + + item_assignment_tibble_w_outliers(clusters, ...) +} + # ------------------------------------------------------------------------------ cluster_assignment_tibble <- function( @@ -173,3 +246,34 @@ cluster_assignment_tibble <- function( tibble::tibble(.cluster = factor(res)) } + +item_assignment_tibble_w_outliers <- function(clusters, + ..., + prefix = "Cluster_") { + # Vector to store the resulting cluster names + res <- character(length(clusters)) + + # For items with cluster value 0, assign to "Cluster_0" + res[clusters == 0] <- "Cluster_0" + zero_count <- 0 + res <- sapply(res, function(x) { + if (x == "Cluster_0") { + zero_count <<- zero_count + 1 + paste0("Cluster_0_", zero_count) + } else { + x + } + }) + + # For non-zero clusters, assign sequential cluster numbers starting from "Cluster_1" + non_zero_clusters <- clusters[clusters != 0] + unique_non_zero_clusters <- unique(non_zero_clusters) + + # Map each unique non-zero cluster to a new cluster starting from Cluster_1 + cluster_map <- stats::setNames(paste0(prefix, seq_along(unique_non_zero_clusters)), unique_non_zero_clusters) + + # Assign the corresponding cluster names to the non-zero clusters + res[clusters != 0] <- cluster_map[as.character(non_zero_clusters)] + + tibble::tibble(.cluster = factor(res)) +} diff --git a/R/extract_fit_summary.R b/R/extract_fit_summary.R index 7a61466..5e3e9c0 100644 --- a/R/extract_fit_summary.R +++ b/R/extract_fit_summary.R @@ -192,3 +192,14 @@ extract_fit_summary.hclust <- function(object, ...) { cluster_assignments = clusts ) } + +#' @export +extract_fit_summary.itemsets <- function(object, ..., + call = rlang::caller_env(n = 0)) { + cli::cli_abort( + "Centroids are not usfeul for frequent itemsets, we suggust looking at the + frequent itemsets directly.\n Please use arules::inspect() on the fit of + your cluster specification." + ) + +} diff --git a/R/extract_itemset_predictions.R b/R/extract_itemset_predictions.R new file mode 100644 index 0000000..916708d --- /dev/null +++ b/R/extract_itemset_predictions.R @@ -0,0 +1,73 @@ +#' Extract Predictions from Observation Data Frames +#' +#' This function processes a data frame containing observation data frames and extracts non-NA values. +#' +#' Returns recommender predictions with predicted values imputed into dataset +#' Notes: currently imputes thresholded probabilities +#' +#' @param pred_output A data frame with one column, where each cell contains a data frame. +#' @return A data frame with items as columns and non-NA values as rows. +#' +#' @examples +#' toy_df <- data.frame( +#' "beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), +#' "milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), +#' "bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), +#' "diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), +#' "eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +#' ) +#' +#' new_data <- data.frame( +#' "beer" = NA, +#' "milk" = TRUE, +#' "bread" = TRUE, +#' "diapers" = TRUE, +#' "eggs" = FALSE +#' ) +#' +#' fi_spec <- freq_itemsets( +#' min_support = 0.05, +#' mining_method = "eclat" +#' ) |> +#' set_engine("arules") |> +#' set_mode("partition") +#' +#' fi_fit <- fi_spec |> +#' fit(~ ., +#' data = toy_df +#' ) +#' +#' fi_fit |> +#' predict(new_data) |> +#' extract_itemset_predictions() +#' +#' @export + +extract_itemset_predictions <- function(pred_output) { + # Extract the list of data frames from the single column + data_frames <- pred_output$.pred_cluster + + # Define the function to be passed to reduce instead of using lambda + processing_function <- function(.x_acc, .y_current) { + # .x_acc is the accumulated result (the first argument to .f) + # .y_current is the current data frame from data_frames (the second argument to .f) + + # Process each data frame + processed <- .y_current %>% + dplyr::mutate(value = ifelse(!is.na(.obs_item), .obs_item, .pred_item)) %>% + dplyr::select(item, value) %>% + tidyr::pivot_wider(names_from = item, values_from = value) + + # Combine the processed data frame with the results + dplyr::bind_rows(.x_acc, processed) + } + + # Process each observation and combine results using reduce + result_df <- reduce( + .x = data_frames, + .f = processing_function, + .init = NULL + ) + + return(result_df) +} diff --git a/R/freq_itemsets.R b/R/freq_itemsets.R new file mode 100644 index 0000000..8738b9b --- /dev/null +++ b/R/freq_itemsets.R @@ -0,0 +1,196 @@ +#' Frequent Itemsets Mining +#' +#' @description +#' +#' `freq_itemsets()` defines a model for Frequent Itemset Mining (FIM), a data mining +#' technique used to discover relationships between items in transactional datasets. +#' This model finds sets of items (itemsets) that frequently co-occur based on a +#' user-specified minimum support threshold. +#' +#' The method of estimation is chosen by setting the model engine. The +#' engine-specific pages for this model are listed below. +#' +#' - \link[=details_freq_itemsets_arules]{arules} +#' +#' @param mode A single character string for the type of model. The only +#' possible value for this model is "partition". +#' @param engine A single character string specifying the computational engine +#' to use for fitting. The default for this model is `"arules"`. Currently, +#' `"arules"` is the only supported engine. +#' @param mining_method A single character string specifying the algorithm to use for +#' fitting. Possible algorithms are `"apriori"` and `"eclat"`. The default for +#' this model is `"eclat"`. +#' @param min_support Positive double, minimum support for an itemset (between 0 and 1). +#' +#' @details +#' +#' ## What does it mean to predict? +#' +#' For `freq_itemsets` models, the `predict()` function is implemented as a recommender system. +#' Given new data with partial transaction information (i.e., some items observed, others `NA`), +#' the model predicts other items likely to be in the transaction. +#' +#' Predictions are based on item-level probabilities derived from the confidence of frequent itemsets. +#' For each missing item, relevant frequent itemsets containing both the missing item and observed items are identified. +#' Confidence (support of itemset / support of observed items) is aggregated across relevant itemsets. +#' If no relevant itemsets are found, the item's global support from the training data is used as a fallback. +#' +#' The `predict()` output provides a nested data frame per transaction, including `item`, +#' `.obs_item` (observed status), and `.pred_item` (predicted values). +#' The `extract_itemset_predictions()` helper function can reformat this nested output into a single data frame. +#' +#' @return A `freq_itemsets` association specification. +#' +#' @examples +#' # Show all engines +#' modelenv::get_from_env("freq_itemsets") +#' +#' freq_itemsets() +#' @export +freq_itemsets <- + function(mode = "partition", # will add other modes + engine = "arules", + min_support = NULL, + mining_method = "eclat") { + args <- list( + min_support = enquo(min_support), + mining_method = enquo(mining_method) + ) + + new_cluster_spec( + "freq_itemsets", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = engine + ) + } + +#' @export +print.freq_itemsets <- function(x, ...) { + cat("Frequent Itemsets Mining Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if (!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + + invisible(x) +} + +# ------------------------------------------------------------------------------ + +#' @method update freq_itemsets +#' @rdname tidyclust_update +#' @export +update.freq_itemsets <- function(object, + parameters = NULL, + min_support = NULL, + mining_method = NULL, + fresh = FALSE, ...) { + eng_args <- parsnip::update_engine_parameters( + object$eng_args, + fresh = fresh, ... + ) + + if (!is.null(parameters)) { + parameters <- parsnip::check_final_param(parameters) + } + args <- list( + min_support = enquo(min_support), + mining_method = enquo(mining_method) + ) + + args <- parsnip::update_main_parameters(args, parameters) + + if (fresh) { + object$args <- args + object$eng_args <- eng_args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) { + args <- args[!null_args] + } + if (length(args) > 0) { + object$args[names(args)] <- args + } + if (length(eng_args) > 0) { + object$eng_args[names(eng_args)] <- eng_args + } + } + + new_cluster_spec( + "freq_itemsets", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + +# # ---------------------------------------------------------------------------- + +#' @export +check_args.freq_itemsets <- function(object) { + args <- lapply(object$args, rlang::eval_tidy) + + if (all(is.numeric(args$min_support)) && (any(args$min_support < 0) || any(args$min_support > 1))) { + cli::cli_abort("The minimum support should be between 0 and 1.") + } + + if (all(is.character(args$mining_method)) && + !all(args$mining_method %in% c("apriori", "eclat"))) { + cli::cli_abort("The mining method should be either 'apriori' or 'eclat'.") + } + + invisible(object) +} + +#' @export +translate_tidyclust.freq_itemsets <- function(x, engine = x$engine, ...) { + x <- translate_tidyclust.default(x, engine, ...) + x +} + +# ------------------------------------------------------------------------------ + +#' Simple Wrapper around arules functions +#' +#' This wrapper prepares the data and parameters to send to either `arules::apriori` +#' or `arules::eclat` for frequent itemsets mining, depending on the chosen method. +#' +#' @param x A transaction data set. +#' @param min_support Minimum support threshold. +#' @param mining_method Algorithm to use for mining frequent itemsets. Either "apriori" or "eclat". +#' +#' @return A set of frequent itemsets based on the specified parameters. +#' @keywords internal +#' @export +.freq_itemsets_fit_arules <- function(x, + min_support = NULL, + mining_method = NULL) { + + if (is.null(min_support)) { + cli::cli_abort( + "Please specify `min_support` to be able to fit specification." + ) + } + + if (mining_method == "apriori") { + res <- arules::apriori(data = x, + parameter = list(support = min_support, target = "frequent itemsets"), + control = list(verbose = FALSE)) + } else if (mining_method == "eclat") { + res <- arules::eclat(data = x, + parameter = list(support = min_support), + control = list(verbose = FALSE)) + } else { + stop("Invalid mining method specified. Choose 'apriori' or 'eclat'.") + } + + attr(res, "item_names") <- colnames(x) + return(res) +} diff --git a/R/freq_itemsets_arules.R b/R/freq_itemsets_arules.R new file mode 100644 index 0000000..ca41414 --- /dev/null +++ b/R/freq_itemsets_arules.R @@ -0,0 +1,11 @@ +#' Frequent Itemsets via arules +#' +#' [freq_itemsets()] creates frequent itemset using Apriori or Eclat model +#' +# @includeRmd man/rmd/freq_itemsets_arules.md details +#' +#' @name details_freq_itemsets_arules +#' @keywords internal +NULL + +# See inst/README-DOCS.md for a description of how these files are processed diff --git a/R/freq_itemsets_data.R b/R/freq_itemsets_data.R new file mode 100644 index 0000000..a85117a --- /dev/null +++ b/R/freq_itemsets_data.R @@ -0,0 +1,100 @@ +# nocov start + +make_freq_itemsets <- function() { + modelenv::set_new_model("freq_itemsets") + + modelenv::set_model_mode("freq_itemsets", "partition") + + # ---------------------------------------------------------------------------- + + modelenv::set_model_engine("freq_itemsets", "partition", "arules") + modelenv::set_dependency( + model = "freq_itemsets", + mode = "partition", + eng = "arules", + pkg = "arules" + ) + modelenv::set_dependency( + model = "freq_itemsets", + mode = "partition", + eng = "arules", + pkg = "tidyclust" + ) + + modelenv::set_fit( + model = "freq_itemsets", + eng = "arules", + mode = "partition", + value = list( + interface = "matrix", + protect = c("x"), + func = c(pkg = "tidyclust", fun = ".freq_itemsets_fit_arules"), + defaults = list() + ) + ) + + modelenv::set_encoding( + model = "freq_itemsets", + eng = "arules", + mode = "partition", + options = list( + predictor_indicators = "traditional", + compute_intercept = TRUE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) + ) + + modelenv::set_model_arg( + model = "freq_itemsets", + eng = "arules", + exposed = "min_support", + original = "min_support", + func = list(pkg = "dials", fun = "min_support"), + has_submodel = TRUE + ) + + modelenv::set_model_arg( + model = "freq_itemsets", + eng = "arules", + exposed = "mining_method", + original = "mining_method", + func = list(pkg = "tidyclust", fun = "mining_method"), + has_submodel = TRUE + ) + + modelenv::set_pred( + model = "freq_itemsets", + eng = "arules", + mode = "partition", + type = "cluster", + value = list( + pre = NULL, + post = NULL, + func = c(fun = ".freq_itemsets_predict_arules"), + args = + list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) + ) + ) + + # May want to change to pre and post instead of direct function + modelenv::set_pred( + model = "freq_itemsets", + eng = "arules", + mode = "partition", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = ".freq_itemsets_predict_raw_arules"), + args = + list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) + ) + ) +} diff --git a/R/predict_helpers.R b/R/predict_helpers.R index 2e1df4c..814439b 100644 --- a/R/predict_helpers.R +++ b/R/predict_helpers.R @@ -177,3 +177,88 @@ make_predictions <- function(x, prefix, n_clusters) { pred_clusts } + +itemsets_predict_helper <- function(object, new_data, ..., prefix = "Cluster_") { + new_data <- as.data.frame(new_data) + + # Extract frequent itemsets and their supports + items <- attr(object, "item_names") + itemsets <- arules::DATAFRAME(object) + frequent_itemsets <- lapply(strsplit(gsub("[{}]", "", itemsets$items), ","), trimws) + supports <- itemsets$support + + # Calculate global support for each item (fallback) + global_supports <- sapply(items, function(item) { + containing <- sapply(frequent_itemsets, function(x) item %in% x) + if (any(containing)) { + sum(supports[containing]) / sum(containing) + } else { + 0 + } + }) + + # Process each row of new_data + result_list <- lapply(1:nrow(new_data), function(i) { + row_data <- new_data[i, ] + observed <- names(row_data)[row_data == 1] + missing <- names(row_data)[is.na(row_data)] + + # Initialize prediction vector + pred_values <- rep(NA, length(items)) + names(pred_values) <- items + + # Calculate probabilities for missing items + for (item in missing) { + # Find itemsets containing both the current item and at least one observed item + relevant <- sapply(frequent_itemsets, function(x) { + item %in% x && any(observed %in% x) + }) + + if (!any(relevant)) { + pred_values[item] <- global_supports[item] + next + } + + # Calculate confidences + confidences <- sapply(which(relevant), function(idx) { + itemset <- frequent_itemsets[[idx]] + itemset_without <- setdiff(itemset, item) + + # Find support of itemset without the current item + if (length(itemset_without) == 0) return(NA) + + matches <- sapply(frequent_itemsets, function(x) identical(x, itemset_without)) + if (!any(matches)) return(NA) + + supports[idx] / supports[matches][1] + }) + + pred_values[item] <- mean(confidences, na.rm = TRUE) + if (is.nan(pred_values[item])) pred_values[item] <- global_supports[item] + } + + # Create result data frame + data.frame( + item = gsub("`", "", items), # Remove backticks from item names + .obs_item = unlist(row_data), + .pred_item = pred_values, + row.names = NULL + ) + }) +} + +.freq_itemsets_predict_raw_arules <- function(object, new_data, ..., prefix = "Cluster_") { + res <- itemsets_predict_helper(object, new_data, ..., prefix = "Cluster_") + return(tibble::tibble(.pred_cluster = unname(res))) +} + +.freq_itemsets_predict_arules <- function(object, new_data, ..., prefix = "Cluster_") { + res <- itemsets_predict_helper(object, new_data, ..., prefix = "Cluster_") + # Apply threshold to raw predictions + lapply(res, function(df) { + df$.pred_item <- ifelse(is.na(df$.obs_item), + ifelse(df$.pred_item >= 0.5, 1, 0), + NA) + df + }) +} diff --git a/R/tunable.R b/R/tunable.R index c63bf10..f2b240a 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -75,3 +75,25 @@ stats_k_means_engine_args <- component = "k_means", component_id = "engine" ) + +#' @export +tunable.freq_itemsets <- function(x, ...) { + res <- NextMethod() + if (x$engine == "arules") { + res <- add_engine_parameters(res, arules_freq_itemsets_engine_args) + } + res +} + +arules_freq_itemsets_engine_args <- + tibble::tibble( + name = c( + "support" + ), + call_info = list( + list(pkg = "tidyclust", fun = "min_support") + ), + source = "cluster_spec", + component = "freq_itemsets", + component_id = "engine" + ) diff --git a/R/zzz.R b/R/zzz.R index 44938d0..98b2159 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -3,6 +3,7 @@ .onLoad <- function(libname, pkgname) { make_hier_clust() make_k_means() + make_freq_itemsets() s3_register("generics::required_pkgs", "cluster_fit") s3_register("generics::required_pkgs", "cluster_spec") diff --git a/_pkgdown.yml b/_pkgdown.yml index 7749c25..135ffe3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -24,6 +24,7 @@ reference: - hier_clust - cluster_spec - cluster_fit + - freq_itemsets - title: Fit and Inspect desc: > These functions are the generics that are supported for specifications @@ -43,9 +44,11 @@ reference: at where the clusters are and which observations are associated with which cluster. contents: + - augment_itemset_predict - predict.cluster_fit - extract_cluster_assignment - extract_centroids + - extract_itemset_predictions - title: Model based performance metrics desc: > These metrics use the fitted clustering model to extract values denoting how diff --git a/man/augment_itemset_predict.Rd b/man/augment_itemset_predict.Rd new file mode 100644 index 0000000..3bf2aa7 --- /dev/null +++ b/man/augment_itemset_predict.Rd @@ -0,0 +1,89 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augment_itemset_predict.R +\name{augment_itemset_predict} +\alias{augment_itemset_predict} +\title{Augment Itemset Predictions with Truth Values} +\usage{ +augment_itemset_predict(pred_output, truth_output) +} +\arguments{ +\item{pred_output}{A data frame that is the output of \code{predict()} from a \code{freq_itemsets} model. +It is expected to have a column named \code{.pred_cluster}, where each cell contains +a data frame with prediction details (including \code{.pred_item}, \code{.obs_item}, and \code{item}).} + +\item{truth_output}{A data frame representing the ground truth. It should have a similar +structure to the input data used for prediction, where columns represent items +and rows represent transactions.} +} +\value{ +A data frame with the following columns: +\itemize{ +\item \code{item}: The name of the item. +\item \code{row_id}: An identifier for the transaction (row) from which the prediction came. +\item \code{preds}: The predicted value for the item (either raw probability or binary prediction). +\item \code{truth}: The true value for the item from \code{truth_output}. +} +This output is suitable for direct use with \code{yardstick} metric functions. +} +\description{ +This function processes the output of a \code{predict()} call for frequent itemset models +and joins it with the corresponding ground truth data. It's designed to prepare +the prediction and truth values in a format suitable for calculating evaluation metrics +using packages like \code{yardstick}. +} +\details{ +The function first extracts and combines all individual item prediction data frames +nested within the \code{pred_output}. It then filters for items where a prediction was made +(i.e., \code{!is.na(.pred_item)}) and standardizes item names by removing backticks. +The \code{truth_output} is pivoted to a long format to match the structure of the predictions. +Finally, an inner join is performed to ensure that only predicted items are included in +the final result, aligning predictions with their corresponding true values. +} +\examples{ +toy_df <- data.frame( +"beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), +"milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), +"bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), +"diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), +"eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +) + +new_data <- data.frame( +"beer" = NA, +"milk" = TRUE, +"bread" = TRUE, +"diapers" = TRUE, +"eggs" = FALSE +) + +truth_df <- data.frame( +"beer" = FALSE, +"milk" = TRUE, +"bread" = TRUE, +"diapers" = TRUE, +"eggs" = FALSE +) + +fi_spec <- freq_itemsets( + min_support = 0.05, + mining_method = "eclat" + ) |> + set_engine("arules") |> + set_mode("partition") + +fi_fit <- fi_spec |> + fit(~ ., + data = toy_df + ) + +aug_pred <- fi_fit |> + predict(new_data, type = "raw") |> + augment_itemset_predict(truth_output = truth_df) + +aug_pred + +# Example use of formatted output +aug_pred |> + yardstick::rmse(truth, preds) + +} diff --git a/man/details_freq_itemsets_arules.Rd b/man/details_freq_itemsets_arules.Rd new file mode 100644 index 0000000..5e6e264 --- /dev/null +++ b/man/details_freq_itemsets_arules.Rd @@ -0,0 +1,9 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/freq_itemsets_arules.R +\name{details_freq_itemsets_arules} +\alias{details_freq_itemsets_arules} +\title{Frequent Itemsets via arules} +\description{ +\code{\link[=freq_itemsets]{freq_itemsets()}} creates frequent itemset using Apriori or Eclat model +} +\keyword{internal} diff --git a/man/dot-freq_itemsets_fit_arules.Rd b/man/dot-freq_itemsets_fit_arules.Rd new file mode 100644 index 0000000..ed45f7e --- /dev/null +++ b/man/dot-freq_itemsets_fit_arules.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/freq_itemsets.R +\name{.freq_itemsets_fit_arules} +\alias{.freq_itemsets_fit_arules} +\title{Simple Wrapper around arules functions} +\usage{ +.freq_itemsets_fit_arules(x, min_support = NULL, mining_method = NULL) +} +\arguments{ +\item{x}{A transaction data set.} + +\item{min_support}{Minimum support threshold.} + +\item{mining_method}{Algorithm to use for mining frequent itemsets. Either "apriori" or "eclat".} +} +\value{ +A set of frequent itemsets based on the specified parameters. +} +\description{ +This wrapper prepares the data and parameters to send to either \code{arules::apriori} +or \code{arules::eclat} for frequent itemsets mining, depending on the chosen method. +} +\keyword{internal} diff --git a/man/extract_itemset_predictions.Rd b/man/extract_itemset_predictions.Rd new file mode 100644 index 0000000..08fc414 --- /dev/null +++ b/man/extract_itemset_predictions.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_itemset_predictions.R +\name{extract_itemset_predictions} +\alias{extract_itemset_predictions} +\title{Extract Predictions from Observation Data Frames} +\usage{ +extract_itemset_predictions(pred_output) +} +\arguments{ +\item{pred_output}{A data frame with one column, where each cell contains a data frame.} +} +\value{ +A data frame with items as columns and non-NA values as rows. +} +\description{ +This function processes a data frame containing observation data frames and extracts non-NA values. +} +\details{ +Returns recommender predictions with predicted values imputed into dataset +Notes: currently imputes thresholded probabilities +} +\examples{ +toy_df <- data.frame( +"beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), +"milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), +"bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), +"diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), +"eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +) + +new_data <- data.frame( +"beer" = NA, +"milk" = TRUE, +"bread" = TRUE, +"diapers" = TRUE, +"eggs" = FALSE +) + +fi_spec <- freq_itemsets( + min_support = 0.05, + mining_method = "eclat" + ) |> + set_engine("arules") |> + set_mode("partition") + +fi_fit <- fi_spec |> + fit(~ ., + data = toy_df + ) + +fi_fit |> + predict(new_data) |> + extract_itemset_predictions() + +} diff --git a/man/freq_itemsets.Rd b/man/freq_itemsets.Rd new file mode 100644 index 0000000..6bf8fbb --- /dev/null +++ b/man/freq_itemsets.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/freq_itemsets.R +\name{freq_itemsets} +\alias{freq_itemsets} +\title{Frequent Itemsets Mining} +\usage{ +freq_itemsets( + mode = "partition", + engine = "arules", + min_support = NULL, + mining_method = "eclat" +) +} +\arguments{ +\item{mode}{A single character string for the type of model. The only +possible value for this model is "partition".} + +\item{engine}{A single character string specifying the computational engine +to use for fitting. The default for this model is \code{"arules"}. Currently, +\code{"arules"} is the only supported engine.} + +\item{min_support}{Positive double, minimum support for an itemset (between 0 and 1).} + +\item{mining_method}{A single character string specifying the algorithm to use for +fitting. Possible algorithms are \code{"apriori"} and \code{"eclat"}. The default for +this model is \code{"eclat"}.} +} +\value{ +A \code{freq_itemsets} association specification. +} +\description{ +\code{freq_itemsets()} defines a model for Frequent Itemset Mining (FIM), a data mining +technique used to discover relationships between items in transactional datasets. +This model finds sets of items (itemsets) that frequently co-occur based on a +user-specified minimum support threshold. + +The method of estimation is chosen by setting the model engine. The +engine-specific pages for this model are listed below. +\itemize{ +\item \link[=details_freq_itemsets_arules]{arules} +} +} +\details{ +\subsection{What does it mean to predict?}{ + +For \code{freq_itemsets} models, the \code{predict()} function is implemented as a recommender system. +Given new data with partial transaction information (i.e., some items observed, others \code{NA}), +the model predicts other items likely to be in the transaction. + +Predictions are based on item-level probabilities derived from the confidence of frequent itemsets. +For each missing item, relevant frequent itemsets containing both the missing item and observed items are identified. +Confidence (support of itemset / support of observed items) is aggregated across relevant itemsets. +If no relevant itemsets are found, the item's global support from the training data is used as a fallback. + +The \code{predict()} output provides a nested data frame per transaction, including \code{item}, +\code{.obs_item} (observed status), and \code{.pred_item} (predicted values). +The \code{extract_itemset_predictions()} helper function can reformat this nested output into a single data frame. +} +} +\examples{ +# Show all engines +modelenv::get_from_env("freq_itemsets") + +freq_itemsets() +} diff --git a/man/random_na_with_truth.Rd b/man/random_na_with_truth.Rd new file mode 100644 index 0000000..8f385e0 --- /dev/null +++ b/man/random_na_with_truth.Rd @@ -0,0 +1,49 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augment_itemset_predict.R +\name{random_na_with_truth} +\alias{random_na_with_truth} +\title{Generate Dataframe with Random NAs and Corresponding Truth} +\usage{ +random_na_with_truth(df, na_prob = 0.3) +} +\arguments{ +\item{df}{The input data frame to which \code{NA} values will be introduced. +It is typically a transactional dataset where columns are items and rows are transactions.} + +\item{na_prob}{The probability (between 0 and 1) that any given cell in the +input data frame will be replaced with \code{NA}.} +} +\value{ +A list containing two data frames: +\itemize{ +\item \code{na_data}: The data frame with \code{NA} values randomly introduced. +\item \code{truth}: The original input data frame, serving as the ground truth. +} +} +\description{ +This helper function creates a new data frame by randomly introducing \code{NA} values +into an input data frame. It also returns the original data frame as a "truth" +reference, which can be useful for simulating scenarios with missing data +for prediction tasks. +} +\examples{ +# Create a sample data frame +sample_df <- data.frame( + itemA = c(1, 0, 1), + itemB = c(0, 1, 1), + itemC = c(1, 1, 0) +) + +# Generate NA data and truth with 30\% NA probability +set.seed(123) +na_data_list <- random_na_with_truth(sample_df, na_prob = 0.3) + +# View the NA data +print(na_data_list$na_data) + +# View the truth data +print(na_data_list$truth) + +This function is not exported as it was used to test and provide examples in +the vignettes, it may be formally introduced in the future. +} diff --git a/man/set_args.cluster_spec.Rd b/man/set_args.cluster_spec.Rd index d36e8a9..a9143d2 100644 --- a/man/set_args.cluster_spec.Rd +++ b/man/set_args.cluster_spec.Rd @@ -7,7 +7,7 @@ \method{set_args}{cluster_spec}(object, ...) } \arguments{ -\item{object}{A \link[parsnip:model_spec]{model specification}.} +\item{object}{A model specification.} \item{...}{One or more named model arguments.} } diff --git a/man/set_engine.cluster_spec.Rd b/man/set_engine.cluster_spec.Rd index f0600ff..cfd1412 100644 --- a/man/set_engine.cluster_spec.Rd +++ b/man/set_engine.cluster_spec.Rd @@ -7,7 +7,7 @@ \method{set_engine}{cluster_spec}(object, engine, ...) } \arguments{ -\item{object}{A \link[parsnip:model_spec]{model specification}.} +\item{object}{A model specification.} \item{engine}{A character string for the software that should be used to fit the model. This is highly dependent on the type diff --git a/man/set_mode.cluster_spec.Rd b/man/set_mode.cluster_spec.Rd index 03d6b7d..226c64e 100644 --- a/man/set_mode.cluster_spec.Rd +++ b/man/set_mode.cluster_spec.Rd @@ -7,7 +7,7 @@ \method{set_mode}{cluster_spec}(object, mode, ...) } \arguments{ -\item{object}{A \link[parsnip:model_spec]{model specification}.} +\item{object}{A model specification.} \item{mode}{A character string for the model type (e.g. "classification" or "regression")} diff --git a/man/tidyclust_update.Rd b/man/tidyclust_update.Rd index a74a99b..00357de 100644 --- a/man/tidyclust_update.Rd +++ b/man/tidyclust_update.Rd @@ -1,11 +1,22 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/hier_clust.R, R/k_means.R, R/update.R -\name{update.hier_clust} +% Please edit documentation in R/freq_itemsets.R, R/hier_clust.R, R/k_means.R, +% R/update.R +\name{update.freq_itemsets} +\alias{update.freq_itemsets} \alias{update.hier_clust} \alias{update.k_means} \alias{tidyclust_update} \title{Update a cluster specification} \usage{ +\method{update}{freq_itemsets}( + object, + parameters = NULL, + min_support = NULL, + mining_method = NULL, + fresh = FALSE, + ... +) + \method{update}{hier_clust}( object, parameters = NULL, @@ -27,6 +38,11 @@ updating. If the main arguments are used, these will supersede the values in \code{parameters}. Also, using engine arguments in this object will result in an error.} +\item{fresh}{A logical for whether the arguments should be modified in-place +or replaced wholesale.} + +\item{...}{Not used for \code{update()}.} + \item{num_clusters}{Positive integer, number of clusters in model.} \item{cut_height}{Positive double, height at which to cut dendrogram to @@ -36,11 +52,6 @@ obtain cluster assignments (only used if \code{num_clusters} is \code{NULL})} unambiguous abbreviation of) one of \code{"ward.D"}, \code{"ward.D2"}, \code{"single"}, \code{"complete"}, \code{"average"} (= UPGMA), \code{"mcquitty"} (= WPGMA), \code{"median"} (= WPGMC) or \code{"centroid"} (= UPGMC).} - -\item{fresh}{A logical for whether the arguments should be modified in-place -or replaced wholesale.} - -\item{...}{Not used for \code{update()}.} } \value{ An updated cluster specification. diff --git a/tests/testthat/_snaps/extract_centroids.md b/tests/testthat/_snaps/extract_centroids.md index 06521d4..bb937d4 100644 --- a/tests/testthat/_snaps/extract_centroids.md +++ b/tests/testthat/_snaps/extract_centroids.md @@ -33,3 +33,11 @@ ! Using `h` argument is not supported. i Please use `cut_height` instead. +# extract_centroids errors for freq_itemsets + + Code + extract_centroids(fi_fit) + Condition + Error in `extract_fit_summary()`: + ! Centroids are not usfeul for frequent itemsets, we suggust looking at the frequent itemsets directly. Please use arules::inspect() on the fit of your cluster specification. + diff --git a/tests/testthat/_snaps/extract_cluster_assignment.md b/tests/testthat/_snaps/extract_cluster_assignment.md index 1aa8c4b..aa04341 100644 --- a/tests/testthat/_snaps/extract_cluster_assignment.md +++ b/tests/testthat/_snaps/extract_cluster_assignment.md @@ -33,3 +33,12 @@ ! Using `h` argument is not supported. i Please use `cut_height` instead. +# extract_cluster_assignment() errors for freq_itemsets() cluster spec + + Code + fi_spec %>% extract_cluster_assignment() + Condition + Error in `extract_cluster_assignment()`: + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. + diff --git a/tests/testthat/_snaps/freq_itemsets-arules.md b/tests/testthat/_snaps/freq_itemsets-arules.md new file mode 100644 index 0000000..00ad3ea --- /dev/null +++ b/tests/testthat/_snaps/freq_itemsets-arules.md @@ -0,0 +1,8 @@ +# extract_centroids works + + Code + extract_centroids(fi_fit) + Condition + Error in `extract_fit_summary()`: + ! Centroids are not usfeul for frequent itemsets, we suggust looking at the frequent itemsets directly. Please use arules::inspect() on the fit of your cluster specification. + diff --git a/tests/testthat/_snaps/freq_itemsets.md b/tests/testthat/_snaps/freq_itemsets.md new file mode 100644 index 0000000..365c473 --- /dev/null +++ b/tests/testthat/_snaps/freq_itemsets.md @@ -0,0 +1,100 @@ +# bad input + + Code + freq_itemsets(mode = "bogus") + Condition + Error in `freq_itemsets()`: + ! "bogus" is not a known mode for model `freq_itemsets()`. + +--- + + Code + bt <- freq_itemsets(min_support = 0.05, mining_method = "bogus") + fit(bt, ~., toy_df) + Condition + Error in `check_args()`: + ! The mining method should be either 'apriori' or 'eclat'. + +--- + + Code + bt <- freq_itemsets(min_support = -1, mining_method = "eclat") %>% set_engine( + "arules") + fit(bt, ~., toy_df) + Condition + Error in `check_args()`: + ! The minimum support should be between 0 and 1. + +--- + + Code + translate_tidyclust(freq_itemsets(), engine = NULL) + Condition + Error in `translate_tidyclust.default()`: + ! Please set an engine. + +--- + + Code + translate_tidyclust(freq_itemsets(formula = ~x)) + Condition + Error in `freq_itemsets()`: + ! unused argument (formula = ~x) + +# extract_centroids work + + Code + extract_centroids(fi_fit) + Condition + Error in `extract_fit_summary()`: + ! Centroids are not usfeul for frequent itemsets, we suggust looking at the frequent itemsets directly. Please use arules::inspect() on the fit of your cluster specification. + +# printing + + Code + freq_itemsets() + Output + Frequent Itemsets Mining Specification (partition) + + Main Arguments: + mining_method = eclat + + Computational engine: arules + + +--- + + Code + freq_itemsets(min_support = 0.5) + Output + Frequent Itemsets Mining Specification (partition) + + Main Arguments: + min_support = 0.5 + mining_method = eclat + + Computational engine: arules + + +# updating + + Code + freq_itemsets(min_support = 0.5) %>% update(min_support = tune()) + Output + Frequent Itemsets Mining Specification (partition) + + Main Arguments: + min_support = tune() + mining_method = eclat + + Computational engine: arules + + +# errors if `min_support` isn't specified + + Code + freq_itemsets() %>% set_engine("arules") %>% fit(~., data = toy_df) + Condition + Error in `tidyclust::.freq_itemsets_fit_arules()`: + ! Please specify `min_support` to be able to fit specification. + diff --git a/tests/testthat/_snaps/predict.md b/tests/testthat/_snaps/predict.md index 4c333e0..fcaac50 100644 --- a/tests/testthat/_snaps/predict.md +++ b/tests/testthat/_snaps/predict.md @@ -33,3 +33,12 @@ ! Using `h` argument is not supported. i Please use `cut_height` instead. +# predict() errors for cluster spec for freq_itemsets + + Code + predict(spec) + Condition + Error in `predict()`: + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. + diff --git a/tests/testthat/_snaps/tune_cluster.md b/tests/testthat/_snaps/tune_cluster.md index 78f7054..64be851 100644 --- a/tests/testthat/_snaps/tune_cluster.md +++ b/tests/testthat/_snaps/tune_cluster.md @@ -79,10 +79,10 @@ 1 }, save_pred = TRUE)) Message - x Fold1: preprocessor 1/1: Error in `hardhat::mold()`: - ! The following predictor ... - x Fold2: preprocessor 1/1: Error in `hardhat::mold()`: - ! The following predictor ... + x Fold1: preprocessor 1/1: Error in `get_all_predictors()`: + ! The following predi... + x Fold2: preprocessor 1/1: Error in `get_all_predictors()`: + ! The following predi... Condition Warning: All models failed. diff --git a/tests/testthat/test-extract_centroids.R b/tests/testthat/test-extract_centroids.R index 59aeb1b..3608c19 100644 --- a/tests/testthat/test-extract_centroids.R +++ b/tests/testthat/test-extract_centroids.R @@ -1,3 +1,11 @@ +toy_df <- data.frame( + "beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), + "milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), + "bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), + "diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), + "eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +) + test_that("extract_centroids() errors for cluster spec", { spec <- tidyclust::k_means(num_clusters = 4) @@ -64,3 +72,13 @@ test_that("prefix is passed in extract_centroids()", { all(substr(res$.cluster, 1, 2) == "C_") ) }) + +test_that("extract_centroids errors for freq_itemsets", { + set.seed(1234) + skip_if_not_installed("arules") + fi_fit <- freq_itemsets(min_support = 0.5) %>% + set_engine("arules") %>% + fit(~., toy_df %>% dplyr::mutate(across(everything(), as.numeric))) + + expect_snapshot(error = TRUE, extract_centroids(fi_fit)) +}) diff --git a/tests/testthat/test-extract_cluster_assignment.R b/tests/testthat/test-extract_cluster_assignment.R index 6e27ab2..51a2249 100644 --- a/tests/testthat/test-extract_cluster_assignment.R +++ b/tests/testthat/test-extract_cluster_assignment.R @@ -1,3 +1,11 @@ +toy_df <- data.frame( + "beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), + "milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), + "bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), + "diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), + "eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +) + test_that("extract_cluster_assignment() errors for cluster spec", { spec <- tidyclust::k_means(num_clusters = 4) @@ -64,3 +72,14 @@ test_that("prefix is passed in extract_cluster_assignment()", { all(substr(res$.cluster, 1, 2) == "C_") ) }) + +test_that("extract_cluster_assignment() errors for freq_itemsets() cluster spec", { + skip_if_not_installed("arules") + fi_spec <- freq_itemsets(min_support = 0.5) + + expect_snapshot( + error = TRUE, + fi_spec %>% + extract_cluster_assignment() + ) +}) diff --git a/tests/testthat/test-freq_itemsets-arules.R b/tests/testthat/test-freq_itemsets-arules.R new file mode 100644 index 0000000..d2dda01 --- /dev/null +++ b/tests/testthat/test-freq_itemsets-arules.R @@ -0,0 +1,79 @@ +toy_df <- data.frame( + "beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), + "milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), + "bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), + "diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), + "eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +) + +toy_pred <- data.frame( + "beer" = FALSE, + "milk" = NA, + "bread" = TRUE, + "diapers" = TRUE, + "eggs" = FALSE +) + +test_that("fitting", { + set.seed(1234) + skip_if_not_installed("arules") + spec <- freq_itemsets(min_support = 0.5) %>% + set_engine("arules") + + expect_no_error( + res <- fit(spec, ~., toy_df) + ) +}) + +test_that("predicting", { + set.seed(1234) + skip_if_not_installed("arules") + spec <- freq_itemsets(min_support = 0.5) %>% + set_engine("arules") + + res <- fit(spec, ~., toy_df) + + preds <- predict(res, toy_pred)$.pred_cluster[[1]]$.pred_item + + expect_identical( + preds, + c(NA, 1, NA, NA, NA) + ) +}) + +test_that("extract_centroids works", { + set.seed(1234) + skip_if_not_installed("arules") + fi_fit <- freq_itemsets(min_support = 0.5) %>% + set_engine("arules") %>% + fit(~., toy_df %>% dplyr::mutate(across(everything(), as.numeric))) + + expect_snapshot(error = TRUE, extract_centroids(fi_fit)) +}) + +test_that("extract_cluster_assignment() works", { + set.seed(1234) + skip_if_not_installed("arules") + fi_fit <- freq_itemsets(min_support = 0.5, mining_method = "eclat") %>% + set_engine("arules") %>% + fit(~., toy_df %>% dplyr::mutate(across(everything(), as.numeric))) + + set.seed(1234) + ref_res <- arules::eclat(data = toy_df, + parameter = list(support = 0.5), + control = list(verbose = FALSE)) + + ref_itemsets <- arules::DATAFRAME(ref_res) + ref_clusts <- c(1, 2, 2, 2, 0) + ref_outliers <- "eggs" + + expect_equal( + arules::DATAFRAME(fi_fit$fit), + ref_itemsets + ) + + expect_equal( + ref_clusts, + extract_cluster_assignment(fi_fit)$.cluster %>% as.numeric() - 1 + ) +}) diff --git a/tests/testthat/test-freq_itemsets.R b/tests/testthat/test-freq_itemsets.R new file mode 100644 index 0000000..63bb180 --- /dev/null +++ b/tests/testthat/test-freq_itemsets.R @@ -0,0 +1,148 @@ +toy_df <- data.frame( + "beer" = c(FALSE, TRUE, TRUE, TRUE, FALSE), + "milk" = c(TRUE, FALSE, TRUE, TRUE, TRUE), + "bread" = c(TRUE, TRUE, FALSE, TRUE, TRUE), + "diapers" = c(TRUE, TRUE, TRUE, TRUE, TRUE), + "eggs" = c(FALSE, TRUE, FALSE, FALSE, FALSE) +) + +toy_pred <- data.frame( + "beer" = FALSE, + "milk" = NA, + "bread" = TRUE, + "diapers" = TRUE, + "eggs" = FALSE +) + +test_that("primary arguments", { + skip_if_not_installed("arules") + basic <- freq_itemsets(mode = "partition") + basic_arules <- translate_tidyclust(basic %>% set_engine("arules")) + expect_equal( + basic_arules$method$fit$args, + list( + x = rlang::expr(missing_arg()), + mining_method = new_empty_quosure("eclat") + ) + ) + + fi <- freq_itemsets(min_support = 0.5, mining_method = "apriori", mode = "partition") + fi_arules <- translate_tidyclust(fi %>% set_engine("arules")) + expect_equal( + fi_arules$method$fit$args, + list( + x = rlang::expr(missing_arg()), + min_support = new_empty_quosure(0.5), + mining_method = new_empty_quosure("apriori") + ) + ) +}) + +test_that("bad input", { + skip_if_not_installed("arules") + expect_snapshot(error = TRUE, freq_itemsets(mode = "bogus")) + expect_snapshot(error = TRUE, { + bt <- freq_itemsets(min_support = 0.05, mining_method = "bogus") + fit(bt, ~ ., toy_df) + }) + expect_snapshot(error = TRUE, { + bt <- freq_itemsets(min_support = -1, mining_method = "eclat") %>% set_engine("arules") + fit(bt, ~ ., toy_df) + }) + expect_snapshot(error = TRUE, translate_tidyclust(freq_itemsets(), engine = NULL)) + expect_snapshot(error = TRUE, translate_tidyclust(freq_itemsets(formula = ~x))) +}) + +test_that("clusters", { + set.seed(1234) + skip_if_not_installed("arules") + fi_fit <- freq_itemsets(min_support = 0.5, mining_method = "apriori") %>% + set_engine("arules") %>% + fit(~., toy_df %>% dplyr::mutate(across(everything(), as.numeric))) + + set.seed(1234) + ref_res <- arules::apriori(data = toy_df, + parameter = list(support = 0.5, target = "frequent itemsets"), + control = list(verbose = FALSE)) + + ref_itemsets <- arules::DATAFRAME(ref_res) + ref_clusts <- c(1, 2, 2, 2, 0) + ref_outliers <- "eggs" + + expect_equal( + arules::DATAFRAME(fi_fit$fit), + ref_itemsets + ) + + expect_equal( + ref_clusts, + extract_cluster_assignment(fi_fit)$.cluster %>% as.numeric() - 1 + ) +}) + +test_that("predict", { + set.seed(1234) + skip_if_not_installed("arules") + fi_fit <- freq_itemsets(min_support = 0.5, mining_method = "apriori") %>% + set_engine("arules") %>% + fit(~., toy_df) + + ref_pred_raw <- c(NA, 0.766666666667, NA, NA, NA) + ref_pred_thresh <- c(NA, 1, NA, NA, NA) + + expect_equal( + ref_pred_thresh, + predict(fi_fit, toy_pred)$.pred_cluster[[1]]$.pred_item + ) + + expect_equal( + ref_pred_raw, + predict(fi_fit, toy_pred, type = "raw")$.pred_cluster[[1]]$.pred_item + ) +}) + +test_that("extract_centroids work", { + set.seed(1234) + skip_if_not_installed("arules") + fi_fit <- freq_itemsets(min_support = 0.5) %>% + set_engine("arules") %>% + fit(~., toy_df %>% dplyr::mutate(across(everything(), as.numeric))) + + expect_snapshot(error = TRUE, extract_centroids(fi_fit)) +}) + +test_that("Right classes", { + skip_if_not_installed("arules") + expect_equal( + class(freq_itemsets()), + c("freq_itemsets", "cluster_spec", "unsupervised_spec") + ) +}) + +test_that("printing", { + skip_if_not_installed("arules") + expect_snapshot( + freq_itemsets() + ) + expect_snapshot( + freq_itemsets(min_support = 0.5) + ) +}) + +test_that("updating", { + skip_if_not_installed("arules") + expect_snapshot( + freq_itemsets(min_support = 0.5) %>% + update(min_support = tune()) + ) +}) + +test_that("errors if `min_support` isn't specified", { + skip_if_not_installed("arules") + expect_snapshot( + error = TRUE, + freq_itemsets() %>% + set_engine("arules") %>% + fit(~ ., data = toy_df) + ) +}) diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index c5688b1..406dad6 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -64,3 +64,13 @@ test_that("prefix is passed in predict()", { all(substr(res$.pred_cluster, 1, 2) == "C_") ) }) + +test_that("predict() errors for cluster spec for freq_itemsets", { + skip_if_not_installed("arules") + spec <- tidyclust::freq_itemsets(min_support = 0.5) + + expect_snapshot( + error = TRUE, + predict(spec) + ) +}) diff --git a/vignettes/articles/freq_itemsets.Rmd b/vignettes/articles/freq_itemsets.Rmd new file mode 100644 index 0000000..b4825c5 --- /dev/null +++ b/vignettes/articles/freq_itemsets.Rmd @@ -0,0 +1,334 @@ +--- +title: "Frequent Itemset Mining" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Frequent Itemset Mining} + %\VignetteEncoding{UTF-8} + %\VignetteEngine{knitr::rmarkdown} +editor_options: + markdown: + wrap: 72 +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +## Setup + +```{r} +library(workflows) +library(parsnip) +``` + +Load libraries: + +```{r setup} +library(tidyclust) +library(arules) +set.seed(838383) +``` + +Load and clean a dataset: + +```{r} +data(Groceries) + +# convert to data frame +groceries <- as.data.frame(as(Groceries, "matrix")) %>% + dplyr::mutate(across(everything(), ~.*1)) +``` + +## A Brief Introduction to Frequent Itemset Mining + +*Frequent Itemset Mining* (FIM) is a fundamental technique in data mining that +identifies sets of items that frequently appear together in +transactional datasets. These itemsets are often used to uncover +meaningful patterns, such as associations between items, which can then +be leveraged to generate *association rules*. + +For example, in a supermarket transaction database, frequent itemset +mining can identify groups of products that are commonly purchased +together, such as `{milk, bread, eggs}`. These insights are valuable for +applications like recommendation systems, inventory management, and +targeted marketing. + +The key to frequent itemset mining is determining the sets of items that +satisfy a user-defined threshold called the **minimum support**, where +support is defined as the proportion of transactions in which a +particular itemset appears. + +### Methods of Frequent Itemset Mining + +Efficiently discovering these frequent itemsets is a computational +challenge, and several algorithms have been developed to address this +challenge. The two implemented in `{tidyclust}` are the **Apriori** +algorithm and the **Eclat** algorithm. + +#### Finding Frequent Itemsets with the Apriori Algorithm + +The *Apriori* algorithm is one of the earliest and most widely known +methods for frequent itemset mining. It is based on the **Apriori +Principle** (also known as **Downward Closure Property**): any subset of +a frequent itemset must also be frequent. + +#### Process of the Apriori Algorithm + +1. **Initialization**: Begin by identifying all individual items + (1-itemsets) that satisfy the minimum support threshold. These are + called *frequent 1-itemsets*. + +2. **Candidate Generation**: Use the frequent itemsets from the + previous step to generate candidate itemsets of the next size (e.g. + combine frequent 1-itemsets to create candidate 2-itemsets). + +3. **Prune Candidates**: Eliminate candidate itemsets that have subsets + not found to be frequent. + +4. **Support Counting**: Scan the dataset to count the occurrences of + each candidate itemset. + +5. **Iteration**: Repeat steps 2–4 for larger itemsets until no more + frequent itemsets can be generated. + +[![Apriori Princple +Example](images/clipboard-4074102307.png){width="650"}](https://chih-ling-hsu.github.io/2017/03/25/apriori) + +The Apriori algorithm is computationally expensive due to repeated +database scans and the generation of numerous candidates. However, its +pruning strategy significantly reduces the search space compared to a +naïve approach. + +[Source](https://dl.acm.org/doi/pdf/10.1145/170036.170072) + +#### Finding Frequent Itemsets with the Eclat Algorithm + +The *Eclat* (Equivalence Class Transformation) algorithm is an +alternative to Apriori that uses a depth-first search strategy and +vertical data representation. Instead of scanning the dataset +repeatedly, Eclat represents transactions as *tid-lists* (transaction ID +lists), which map each item or itemset to the IDs of transactions in +which it appears. + +#### Process of the Eclat Algorithm + +1. **Vertical Data Representation**: Transform the dataset into a + vertical format, where each item is associated with a list of + transaction IDs. + +2. **Intersect Tid-lists**: Generate frequent itemsets by recursively + intersecting the tid-lists of individual items to form larger + itemsets. The intersection results in a new tid-list, representing + the transactions containing the larger itemset. + +3. **Check Support**: The length of the resulting tid-list determines + the support of the itemset. Remove itemsets not found to be + frequent. + +4. **Recursive Search**: Continue the process for all itemsets until no + further frequent itemsets can be found. + +[![Bookstore +database](images/clipboard-2860408154.png){width="325"}](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=846291) + +[![Computing support of itemsets via tid-list +intersections](images/clipboard-3198084587.png){width="650"}](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=846291) + +Eclat is generally more efficient than Apriori for datasets with many +transactions but fewer unique items, as it avoids the need for multiple +scans of the dataset. However, its performance can degrade for datasets +with very large tid-lists. + +[Source](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=846291) + +## **`freq_itemsets` specification in {tidyclust}** + +To specify a frequent itemsets mining model in `tidyclust`, simply +choose a value of `min_support` and (optionally) a mining method: + +```{r} +fi_spec <- freq_itemsets( + min_support = 0.05, + mining_method = "eclat" + ) %>% + set_engine("arules") %>% + set_mode("partition") + +fi_spec +``` + +Currently, the only supported engine is `arules`. The default mining +method is eclat because it is generally faster in most practical cases. A +default `min_support` value is not provided as it varies significantly depending +on the data characteristics. + +## **Fitting `freq_itemsets` models** + +We fit the model to the data in the usual way: + +```{r} +fi_fit <- fi_spec %>% + fit(~ ., + data = groceries + ) + +fi_fit %>% + summary() +``` + +We can not extract the standard `tidyclust` summary list since centroids are not +useful for frequent itemsets, however we can extract the frequent itemsets: + +```{r} +arules::inspect(fi_fit$fit) +``` + +Note that, although the frequent itemset algorithm is not focused on cluster's +like other unsupervised learning algorithms, we have created clusters based on +the itemsets. For each item, we find all itemsets that include that item: + +- Itemsets with the largest size are selected as the "dominate" itemset for the +item. If there is a tie in size, the itemset with the highest support is selected. +This prioritization aligns with findings that larger itemsets with higher +support exhibit greater predictive utility. + +- If an item has already been assigned a cluster, the algorithm compares the +current "best" itemset with the itemset under consideration and re-prioritizes +based on size and support. This process repeats until no items are reassigned to +a new cluster (convergence). + +- Items that appear in no frequent itemsets are labeled as outliers +(Cluster_0_X) while items within frequent itemsets are assigned sequential +cluster IDs (Cluster_1, Cluster_2, etc.). + +```{r} +fi_fit %>% + extract_cluster_assignment() +``` + +## Prediction + +Since frequent itemset mining identifies patterns in co-occurring items rather +than learning a predictive function, the notion of "prediction" is not as +straightforward as in supervised learning. However, given a set of frequent +itemsets from historical data, it is possible to estimate the likelihood that a +missing item in new data is present based on observed co-occurring items. + +The `predict()` function utilizes frequent itemsets and their support values to +estimate probabilities for missing items in new transactions. For each row in +`new_data`, the function identifies observed items and missing items. It then +searches for frequent itemsets that contain both the missing item and at least +one observed item. Using the support values of these itemsets, it estimates the +probability that the missing item is present based on the confidence of +association between observed and missing items. If no relevant itemsets are +found, the item's global support (frequency in training data) is used as a +fallback prediction. + +The function fills in missing values with these probability estimates, +effectively "predicting" the likelihood of item presence based on historical +co-occurrence patterns. The type argument allows returning raw prediction +probabilities ('raw') or binary predictions based on a 0.5 threshold ('cluster'). + +We display the predicted values in the column `.pred_item`, and the observed +values in the column `.obs_item`. + +```{r} +new_data <- groceries[1:5,] %>% + tidyclust:::random_na_with_truth(na_prob = 0.3) + +results <- fi_fit %>% + predict(new_data$na_data) + +results$.pred_cluster[[1]] +``` + +The function `random_na_with_truth()` is used for testing purposes to randomly +assign values to be `NA`, however it can be accessed using in `tidyclust` using +the `:::`. + +Additionally, we can extract the nested predicted output to be formatted in a +single data frame, filling in the `NA` values with their predicted value using +`extract_itemset_predictions()`. + +```{r} +results %>% + extract_itemset_predictions +``` + +## Evaluation Metrics + +While support values traditionally assess frequent itemset quality, they do not +guarantee predictive performance. Since the `predict()` methodology resembles a +recommender system, the results should be evaluated using similar metrics. +Common metrics such as root mean squared error (RMSE), accuracy, precision, and +recall are implemented in the `yardstick` package. + +To prepare the `predict()` output for metric calculation, we provide a new +function `augment_itemset_predict()`. + +```{r} +# Generate data to predict on +na_result <- tidyclust:::random_na_with_truth(groceries[1:5,], na_prob = 0.3) +new_data <- na_result$na_data # In a real scenario, this would be new, untrained on, data +truth_output <- na_result$truth # In a real scenario, this would be a separate holdout set + +# Example for RMSE (using type = 'raw') +fi_fit %>% + predict(new_data = new_data, type = 'raw') %>% + augment_itemset_predict(truth_output = truth_output) %>% + yardstick::rmse(truth, preds) + +# Example for Precision (using type = 'cluster') +fi_fit %>% + predict(new_data = new_data, type = 'cluster') %>% + augment_itemset_predict(truth_output = truth_output) %>% + dplyr::mutate( + truth = factor(truth, levels = c(0, 1)), + preds = factor(preds, levels = c(0, 1)) + ) %>% + yardstick::precision(truth, preds) +``` + +When using RMSE, the raw `predict()` output should be used, while accuracy, +precision, and recall will use the cluster output or user thresholded raw +output. Caution should be used when looking at accuracy, precision, and recall +for imbalanced datasets (where items are infrequently purchased). In such cases, +F1-Score offers a balance between precision and recall, and precision-recall +(PR) curves aid in determining the best threshold value. + +```{r} +fi_fit %>% + predict(new_data = new_data, type = 'raw') %>% + augment_itemset_predict(truth = truth_output) %>% + dplyr::mutate(truth = factor(truth, levels = c(0, 1))) %>% + yardstick::pr_curve(truth, preds) %>% + autoplot() +``` + +Each point on the PR curve represents the precision and recall of the model at a +specific threshold. By varying this threshold, different precision and recall +values are obtained, creating the curve. The ideal curve is close to the +top-right corner, indicating high precision and high recall. + +## Hyperparameter Tuning + +The sole parameter capable of being tuned in a FIM model is `min_support.` +Selecting the correct value is imperative for finding useful frequent itemsets. +The default grid for `min_support` is from 0.1 to 0.5. The lower bound of 0.1 was +chosen to avoid reporting too many frequent itemsets, even for smaller datasets, +while the upper bound of 0.5 was selected since it ensures that each frequent +itemset has a support of at least 50%. + +```{r} +dials::grid_regular( + dials::min_support(), + levels = 10 +) +``` + +Usually, the above object is paired with `tune_cluster`, however cross-validation +is not currently implemented for FIM. Future work will focus on improving tuning +and implementing cross-validation. diff --git a/vignettes/articles/images/clipboard-2860408154.png b/vignettes/articles/images/clipboard-2860408154.png new file mode 100644 index 0000000..cd01f40 Binary files /dev/null and b/vignettes/articles/images/clipboard-2860408154.png differ diff --git a/vignettes/articles/images/clipboard-3198084587.png b/vignettes/articles/images/clipboard-3198084587.png new file mode 100644 index 0000000..1e2bb24 Binary files /dev/null and b/vignettes/articles/images/clipboard-3198084587.png differ diff --git a/vignettes/articles/images/clipboard-4074102307.png b/vignettes/articles/images/clipboard-4074102307.png new file mode 100644 index 0000000..8a057fd Binary files /dev/null and b/vignettes/articles/images/clipboard-4074102307.png differ