Skip to content

Commit 4964061

Browse files
committed
svm models
1 parent 1bab926 commit 4964061

File tree

16 files changed

+1898
-6
lines changed

16 files changed

+1898
-6
lines changed

NAMESPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ S3method(print,multinom_reg)
3838
S3method(print,nearest_neighbor)
3939
S3method(print,rand_forest)
4040
S3method(print,surv_reg)
41+
S3method(print,svm_poly)
42+
S3method(print,svm_rbf)
4143
S3method(translate,boost_tree)
4244
S3method(translate,default)
4345
S3method(translate,mars)
4446
S3method(translate,mlp)
4547
S3method(translate,rand_forest)
4648
S3method(translate,surv_reg)
49+
S3method(translate,svm_poly)
50+
S3method(translate,svm_rbf)
4751
S3method(type_sum,model_fit)
4852
S3method(type_sum,model_spec)
4953
S3method(update,boost_tree)
@@ -55,6 +59,8 @@ S3method(update,multinom_reg)
5559
S3method(update,nearest_neighbor)
5660
S3method(update,rand_forest)
5761
S3method(update,surv_reg)
62+
S3method(update,svm_poly)
63+
S3method(update,svm_rbf)
5864
S3method(varying_args,model_spec)
5965
S3method(varying_args,recipe)
6066
S3method(varying_args,step)
@@ -106,6 +112,8 @@ export(set_engine)
106112
export(set_mode)
107113
export(show_call)
108114
export(surv_reg)
115+
export(svm_poly)
116+
export(svm_rbf)
109117
export(translate)
110118
export(varying)
111119
export(varying_args)

R/svm_poly.R

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#' General interface for polynomial support vector machines
2+
#'
3+
#' `svm_poly` is a way to generate a _specification_ of a model
4+
#' before fitting and allows the model to be created using
5+
#' different packages in R or via Spark. The main arguments for the
6+
#' model are:
7+
#' \itemize{
8+
#' \item \code{cost}: The cost of predicting a sample within or on the
9+
#' wrong side of the margin.
10+
#' \item \code{degree}: The polynomial degree.
11+
#' \item \code{scale_factor}: A scaling factor for the kernel.
12+
#' \item \code{margin}: The epsilon in the SVM insensitive loss function
13+
#' (regression only)
14+
#' }
15+
#' These arguments are converted to their specific names at the
16+
#' time that the model is fit. Other options and argument can be
17+
#' set using `set_engine`. If left to their defaults
18+
#' here (`NULL`), the values are taken from the underlying model
19+
#' functions. If parameters need to be modified, `update` can be used
20+
#' in lieu of recreating the object from scratch.
21+
#'
22+
#' @inheritParams boost_tree
23+
#' @param mode A single character string for the type of model.
24+
#' Possible values for this model are "unknown", "regression", or
25+
#' "classification".
26+
#' @param cost A positive number for the cost of predicting a sample within
27+
#' or on the wrong side of the margin
28+
#' @param degree A positive number for polynomial degree.
29+
#' @param scale_factor A positive number for the polynomial scaling factor.
30+
#' @param margin A positive number for the epsilon in the SVM insensitive
31+
#' loss function (regression only)
32+
#' @details
33+
#' The model can be created using the `fit()` function using the
34+
#' following _engines_:
35+
#' \itemize{
36+
#' \item \pkg{R}: `"kernlab"`
37+
#' }
38+
#'
39+
#' @section Engine Details:
40+
#'
41+
#' Engines may have pre-set default arguments when executing the
42+
#' model fit call. For this type of
43+
#' model, the template of the fit calls are::
44+
#'
45+
#' \pkg{kernlab} classification
46+
#'
47+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")}
48+
#'
49+
#' \pkg{kernlab} regression
50+
#'
51+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")}
52+
#'
53+
#' @importFrom purrr map_lgl
54+
#' @seealso [varying()], [fit()]
55+
#' @examples
56+
#' svm_poly(mode = "classification", degree = 1.2)
57+
#' # Parameters can be represented by a placeholder:
58+
#' svm_poly(mode = "regression", cost = varying())
59+
#' @export
60+
61+
svm_poly <-
62+
function(mode = "unknown",
63+
cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL) {
64+
65+
args <- list(
66+
cost = enquo(cost),
67+
degree = enquo(degree),
68+
scale_factor = enquo(scale_factor),
69+
margin = enquo(margin)
70+
)
71+
72+
new_model_spec(
73+
"svm_poly",
74+
args = args,
75+
eng_args = NULL,
76+
mode = mode,
77+
method = NULL,
78+
engine = NULL
79+
)
80+
}
81+
82+
#' @export
83+
print.svm_poly <- function(x, ...) {
84+
cat("Polynomial Support Vector Machine Specification (", x$mode, ")\n\n", sep = "")
85+
model_printer(x, ...)
86+
87+
if(!is.null(x$method$fit$args)) {
88+
cat("Model fit template:\n")
89+
print(show_call(x))
90+
}
91+
invisible(x)
92+
}
93+
94+
# ------------------------------------------------------------------------------
95+
96+
#' @export
97+
#' @inheritParams update.boost_tree
98+
#' @param object A polynomial SVM model specification.
99+
#' @examples
100+
#' model <- svm_poly(cost = 10, scale_factor = 0.1)
101+
#' model
102+
#' update(model, cost = 1)
103+
#' update(model, cost = 1, fresh = TRUE)
104+
#' @method update svm_poly
105+
#' @rdname svm_poly
106+
#' @export
107+
update.svm_poly <-
108+
function(object,
109+
cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL,
110+
fresh = FALSE,
111+
...) {
112+
update_dot_check(...)
113+
114+
args <- list(
115+
cost = enquo(cost),
116+
degree = enquo(degree),
117+
scale_factor = enquo(scale_factor),
118+
margin = enquo(margin)
119+
)
120+
121+
if (fresh) {
122+
object$args <- args
123+
} else {
124+
null_args <- map_lgl(args, null_value)
125+
if (any(null_args))
126+
args <- args[!null_args]
127+
if (length(args) > 0)
128+
object$args[names(args)] <- args
129+
}
130+
131+
new_model_spec(
132+
"svm_poly",
133+
args = object$args,
134+
eng_args = object$eng_args,
135+
mode = object$mode,
136+
method = NULL,
137+
engine = object$engine
138+
)
139+
}
140+
141+
# ------------------------------------------------------------------------------
142+
143+
#' @export
144+
translate.svm_poly <- function(x, engine = x$engine, ...) {
145+
x <- translate.default(x, engine = engine, ...)
146+
147+
# slightly cleaner code using
148+
arg_vals <- x$method$fit$args
149+
arg_names <- names(arg_vals)
150+
151+
# add checks to error trap or change things for this method
152+
if (x$engine == "kernlab") {
153+
154+
# unless otherwise specified, classification models predict probabilities
155+
if (x$mode == "classification" && !any(arg_names == "prob.model"))
156+
arg_vals$prob.model <- TRUE
157+
if (x$mode == "classification" && any(arg_names == "epsilon"))
158+
arg_vals$epsilon <- NULL
159+
160+
# convert degree and scale to a `kpar` argument.
161+
if (any(arg_names %in% c("degree", "scale", "offset"))) {
162+
kpar <- expr(list())
163+
if (any(arg_names == "degree")) {
164+
kpar$degree <- arg_vals$degree
165+
arg_vals$degree <- NULL
166+
}
167+
if (any(arg_names == "scale")) {
168+
kpar$scale <- arg_vals$scale
169+
arg_vals$scale <- NULL
170+
}
171+
if (any(arg_names == "offset")) {
172+
kpar$offset <- arg_vals$offset
173+
arg_vals$offset <- NULL
174+
}
175+
arg_vals$kpar <- kpar
176+
}
177+
178+
}
179+
x$method$fit$args <- arg_vals
180+
181+
# worried about people using this to modify the specification
182+
x
183+
}
184+
185+
# ------------------------------------------------------------------------------
186+
187+
check_args.svm_poly <- function(object) {
188+
invisible(object)
189+
}
190+
191+
# ------------------------------------------------------------------------------
192+
193+
svm_reg_post <- function(results, object) {
194+
results[,1]
195+
}
196+

R/svm_poly_data.R

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
svm_poly_arg_key <- data.frame(
2+
kernlab = c( "C", "degree", "scale", "epsilon"),
3+
row.names = c("cost", "degree", "scale_factor", "margin"),
4+
stringsAsFactors = FALSE
5+
)
6+
7+
svm_poly_modes <- c("classification", "regression", "unknown")
8+
9+
svm_poly_engines <- data.frame(
10+
kernlab = c(TRUE, TRUE, FALSE),
11+
row.names = c("classification", "regression", "unknown")
12+
)
13+
14+
# ------------------------------------------------------------------------------
15+
16+
svm_poly_kernlab_data <-
17+
list(
18+
libs = "kernlab",
19+
fit = list(
20+
interface = "matrix",
21+
protect = c("x", "y"),
22+
func = c(pkg = "kernlab", fun = "ksvm"),
23+
defaults = list(
24+
kernel = "polydot"
25+
)
26+
),
27+
numeric = list(
28+
pre = NULL,
29+
post = svm_reg_post,
30+
func = c(pkg = "kernlab", fun = "predict"),
31+
args =
32+
list(
33+
object = quote(object$fit),
34+
newdata = quote(new_data),
35+
type = "response"
36+
)
37+
),
38+
class = list(
39+
pre = NULL,
40+
post = NULL,
41+
func = c(pkg = "kernlab", fun = "predict"),
42+
args =
43+
list(
44+
object = quote(object$fit),
45+
newdata = quote(new_data),
46+
type = "response"
47+
)
48+
),
49+
classprob = list(
50+
pre = NULL,
51+
post = function(result, object) as_tibble(result),
52+
func = c(pkg = "kernlab", fun = "predict"),
53+
args =
54+
list(
55+
object = quote(object$fit),
56+
newdata = quote(new_data),
57+
type = "probabilities"
58+
)
59+
),
60+
raw = list(
61+
pre = NULL,
62+
func = c(pkg = "kernlab", fun = "predict"),
63+
args =
64+
list(
65+
object = quote(object$fit),
66+
newdata = quote(new_data)
67+
)
68+
)
69+
)

0 commit comments

Comments
 (0)