Skip to content

Commit 90e1514

Browse files
committed
missing probability module
1 parent fe24a99 commit 90e1514

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

R/mlp_data.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ mlp_keras_data <-
7272
)
7373
)
7474

75+
76+
nnet_softmax <- function(results, object) {
77+
if (ncol(results) == 1)
78+
results <- cbind(1 - results, results)
79+
80+
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
81+
results <- as_tibble(t(results))
82+
names(results) <- paste0(".pred_", object$lvl)
83+
results
84+
}
85+
7586
mlp_nnet_data <-
7687
list(
7788
libs = "nnet",
@@ -103,6 +114,17 @@ mlp_nnet_data <-
103114
type = "class"
104115
)
105116
),
117+
classprob = list(
118+
pre = NULL,
119+
post = nnet_softmax,
120+
func = c(fun = "predict"),
121+
args =
122+
list(
123+
object = quote(object$fit),
124+
newdata = quote(new_data),
125+
type = "raw"
126+
)
127+
),
106128
raw = list(
107129
pre = NULL,
108130
func = c(fun = "predict"),
@@ -114,6 +136,7 @@ mlp_nnet_data <-
114136
)
115137
)
116138

139+
117140
# ------------------------------------------------------------------------------
118141

119142
# keras wrapper for feed-forward nnet

0 commit comments

Comments
 (0)