Skip to content

Commit 0f235f5

Browse files
topepohfrick
andauthored
new extract functions (#518)
* new extract functions * missing hardhat remote * extract_parsnip_spec -> extract_spec_parsnip * notes about model extraction * small text adjustments * Update R/extract.R Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> * extract test cases * less specific remote Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com>
1 parent 89f8f93 commit 0f235f5

File tree

7 files changed

+166
-2
lines changed

7 files changed

+166
-2
lines changed

DESCRIPTION

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ Imports:
3131
tidyr (>= 1.0.0),
3232
globals,
3333
prettyunits,
34-
vctrs (>= 0.2.0)
34+
vctrs (>= 0.2.0),
35+
hardhat (>= 0.1.5.9000)
3536
Roxygen: list(markdown = TRUE)
3637
RoxygenNote: 7.1.1.9001
3738
Suggests:
@@ -59,4 +60,6 @@ Suggests:
5960
dials (>= 0.0.9.9000)
6061
Remotes:
6162
tidymodels/dials,
62-
topepo/C5.0
63+
topepo/C5.0,
64+
tidymodels/hardhat
65+

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(augment,model_fit)
4+
S3method(extract_fit_engine,model_fit)
5+
S3method(extract_spec_parsnip,model_fit)
46
S3method(fit,model_spec)
57
S3method(fit_xy,gen_additive_mod)
68
S3method(fit_xy,model_spec)
@@ -138,6 +140,8 @@ export(control_parsnip)
138140
export(convert_stan_interval)
139141
export(decision_tree)
140142
export(eval_args)
143+
export(extract_fit_engine)
144+
export(extract_spec_parsnip)
141145
export(find_engine_files)
142146
export(fit)
143147
export(fit.model_spec)
@@ -259,6 +263,8 @@ importFrom(generics,required_pkgs)
259263
importFrom(generics,tidy)
260264
importFrom(generics,varying_args)
261265
importFrom(glue,glue_collapse)
266+
importFrom(hardhat,extract_fit_engine)
267+
importFrom(hardhat,extract_spec_parsnip)
262268
importFrom(magrittr,"%>%")
263269
importFrom(purrr,as_vector)
264270
importFrom(purrr,imap)

R/extract.R

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#' Extract elements of a parsnip model object
2+
#'
3+
#' @description
4+
#' These functions extract various elements from a parsnip object. If they do
5+
#' not exist yet, an error is thrown.
6+
#'
7+
#' - `extract_spec_parsnip()` returns the parsnip model specification.
8+
#'
9+
#' - `extract_fit_engine()` returns the engine specific fit embedded within
10+
#' a parsnip model fit. For example, when using [parsnip::linear_reg()]
11+
#' with the `"lm"` engine, this returns the underlying `lm` object.
12+
#'
13+
#' @param x A parsnip `model_fit` object.
14+
#' @param ... Not currently used.
15+
#' @details
16+
#' Extracting the underlying engine fit can be helpful for describing the
17+
#' model (via `print()`, `summary()`, `plot()`, etc.) or for variable
18+
#' importance/explainers.
19+
#'
20+
#' However, users should not invoke the `predict()` method on an extracted
21+
#' model. There may be preprocessing operations that `parsnip` has executed on
22+
#' the data prior to giving it to the model. Bypassing these can lead to errors
23+
#' or silently generating incorrect predictions.
24+
#'
25+
#' **Good**:
26+
#' ```r
27+
#' parsnip_fit %>% predict(new_data)
28+
#' ```
29+
#'
30+
#' **Bad**:
31+
#' ```r
32+
#' parsnip_fit %>% extract_fit_engine() %>% predict(new_data)
33+
#' ```
34+
#' @return
35+
#' The extracted value from the parsnip object, `x`, as described in the description
36+
#' section.
37+
#'
38+
#' @name extract-parsnip
39+
#' @examples
40+
#' lm_spec <- linear_reg() %>% set_engine("lm")
41+
#' lm_fit <- fit(lm_spec, mpg ~ ., data = mtcars)
42+
#'
43+
#' lm_spec
44+
#' extract_spec_parsnip(lm_fit)
45+
#'
46+
#' extract_fit_engine(lm_fit)
47+
#' lm(mpg ~ ., data = mtcars)
48+
NULL
49+
50+
#' @export
51+
#' @rdname extract-parsnip
52+
extract_spec_parsnip.model_fit <- function(x, ...) {
53+
if (any(names(x) == "spec")) {
54+
return(x$spec)
55+
}
56+
rlang::abort("Internal error: The model fit does not have a model spec.")
57+
}
58+
59+
60+
#' @export
61+
#' @rdname extract-parsnip
62+
extract_fit_engine.model_fit <- function(x, ...) {
63+
if (any(names(x) == "fit")) {
64+
return(x$fit)
65+
}
66+
rlang::abort("Internal error: The model fit does not have an engine fit.")
67+
}

R/reexports.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ generics::augment
2626
#' @importFrom generics required_pkgs
2727
#' @export
2828
generics::required_pkgs
29+
30+
#' @importFrom hardhat extract_spec_parsnip
31+
#' @export
32+
hardhat::extract_spec_parsnip
33+
34+
#' @importFrom hardhat extract_fit_engine
35+
#' @export
36+
hardhat::extract_fit_engine

man/extract-parsnip.Rd

Lines changed: 57 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/reexports.Rd

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-extract.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
context("model extraction")
3+
4+
# ------------------------------------------------------------------------------
5+
6+
test_that('extract', {
7+
x <- linear_reg() %>% set_engine("lm") %>% fit(mpg ~ ., data = mtcars)
8+
x_no_spec <- x
9+
x_no_spec$spec <- NULL
10+
x_no_fit <- x
11+
x_no_fit$fit <- NULL
12+
13+
expect_true(inherits(extract_spec_parsnip(x), "model_spec"))
14+
expect_true(inherits(extract_fit_engine(x), "lm"))
15+
16+
expect_error(extract_spec_parsnip(x_no_spec), "Internal error")
17+
expect_error(extract_fit_engine(x_no_fit), "Internal error")
18+
})
19+

0 commit comments

Comments
 (0)