@@ -323,6 +323,11 @@ check_interface_val <- function(x) {
323323# ' below, depending on context.
324324# ' @param pre,post Optional functions for pre- and post-processing of prediction
325325# ' results.
326+ # ' @param options A list of options for engine-specific encodings. Currently,
327+ # ' the option implemented is `predictor_indicators` which tells `parsnip`
328+ # ' whether the pre-processing should make indicator/dummy variables from factor
329+ # ' predictors. This only affects cases when [fit.model_spec()] is used and the
330+ # ' underlying model has an x/y interface.
326331# ' @param ... Optional arguments that should be passed into the `args` slot for
327332# ' prediction objects.
328333# ' @keywords internal
@@ -780,3 +785,77 @@ pred_value_template <- function(pre = NULL, post = NULL, func, ...) {
780785 list (pre = pre , post = post , func = func , args = list (... ))
781786}
782787
788+ # ------------------------------------------------------------------------------
789+
790+ check_encodings <- function (x ) {
791+ if (! is.list(x )) {
792+ rlang :: abort(" `values` should be a list." )
793+ }
794+ req_args <- list (predictor_indicators = TRUE )
795+
796+ missing_args <- setdiff(names(req_args ), names(x ))
797+ if (length(missing_args ) > 0 ) {
798+ rlang :: abort(
799+ glue :: glue(
800+ " The values passed to `set_encoding()` are missing arguments: " ,
801+ paste0(" '" , missing_args , " '" , collapse = " , " )
802+ )
803+ )
804+ }
805+ extra_args <- setdiff(names(x ), names(req_args ))
806+ if (length(extra_args ) > 0 ) {
807+ rlang :: abort(
808+ glue :: glue(
809+ " The values passed to `set_encoding()` had extra arguments: " ,
810+ paste0(" '" , extra_args , " '" , collapse = " , " )
811+ )
812+ )
813+ }
814+ invisible (x )
815+ }
816+
817+ # ' @export
818+ # ' @rdname set_new_model
819+ # ' @keywords internal
820+ set_encoding <- function (model , mode , eng , options ) {
821+ check_model_exists(model )
822+ check_eng_val(eng )
823+ check_mode_val(mode )
824+ check_encodings(options )
825+
826+ keys <- tibble :: tibble(model = model , engine = eng , mode = mode )
827+ options <- tibble :: as_tibble(options )
828+ new_values <- dplyr :: bind_cols(keys , options )
829+
830+
831+ current_db_list <- ls(envir = get_model_env())
832+ nm <- paste(model , " encoding" , sep = " _" )
833+ if (any(current_db_list == nm )) {
834+ current <- get_from_env(nm )
835+ dup_check <-
836+ current %> %
837+ dplyr :: inner_join(new_values , by = c(" model" , " engine" , " mode" , " predictor_indicators" ))
838+ if (nrow(dup_check )) {
839+ rlang :: abort(glue :: glue(" Engine '{eng}' and mode '{mode}' already have defined encodings." ))
840+ }
841+
842+ } else {
843+ current <- NULL
844+ }
845+
846+ db_values <- dplyr :: bind_rows(current , new_values )
847+ set_env_val(nm , db_values )
848+
849+ invisible (NULL )
850+ }
851+
852+
853+ # ' @rdname set_new_model
854+ # ' @keywords internal
855+ # ' @export
856+ get_encoding <- function (model ) {
857+ check_model_exists(model )
858+ nm <- paste0(model , " _encoding" )
859+ rlang :: env_get(get_model_env(), nm )
860+ }
861+
0 commit comments