@@ -342,9 +342,18 @@ check_interface <- function(formula, data, cl, model) {
342342}
343343
344344check_xy_interface <- function (x , y , cl , model ) {
345- # TODO Do we need a model spec attribute that is something like
346- # 'allow_sparse' to make this conditional on that?
347- inher(x , c(" data.frame" , " matrix" , " dgCMatrix" ), cl )
345+
346+ sparse_ok <- allow_sparse(model )
347+ sparse_x <- inherits(x , " dgCMatrix" )
348+ if (! sparse_ok & sparse_x ) {
349+ rlang :: abort(" Sparse matrices not supported by this model/engine combination." )
350+ }
351+
352+ if (sparse_ok ) {
353+ inher(x , c(" data.frame" , " matrix" , " dgCMatrix" ), cl )
354+ } else {
355+ inher(x , c(" data.frame" , " matrix" ), cl )
356+ }
348357
349358 # `y` can be a vector (which is not a class), or a factor (which is not a vector)
350359 if (! is.null(y ) && ! is.vector(y ))
@@ -359,22 +368,33 @@ check_xy_interface <- function(x, y, cl, model) {
359368 )
360369 )
361370
362- # Determine the `fit()` interface
363- # TODO conditional here too?
364- matrix_interface <- ! is.null(x ) & ! is.null(y ) && (is.matrix(x ) | inherits(x , " dgCMatrix" ))
371+
372+ if (sparse_ok ) {
373+ matrix_interface <- ! is.null(x ) & ! is.null(y ) && (is.matrix(x ) | sparse_x )
374+ } else {
375+ matrix_interface <- ! is.null(x ) & ! is.null(y ) && is.matrix(x )
376+ }
377+
365378 df_interface <- ! is.null(x ) & ! is.null(y ) && is.data.frame(x )
366379
367- if (inherits(model , " surv_reg" ) &&
368- (matrix_interface | df_interface ))
380+ if (inherits(model , " surv_reg" ) && (matrix_interface | df_interface )) {
369381 rlang :: abort(" Survival models must use the formula interface." )
382+ }
370383
371- if (matrix_interface )
384+ if (matrix_interface ) {
372385 return (" data.frame" )
373- if (df_interface )
386+ }
387+ if (df_interface ) {
374388 return (" data.frame" )
389+ }
375390 rlang :: abort(" Error when checking the interface" )
376391}
377392
393+ allow_sparse <- function (x ) {
394+ res <- get_from_env(paste0(class(x )[1 ], " _encoding" ))
395+ all(res $ allow_sparse_x [res $ engine == x $ engine ])
396+ }
397+
378398# ' @method print model_fit
379399# ' @export
380400print.model_fit <- function (x , ... ) {
0 commit comments