File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -72,6 +72,17 @@ mlp_keras_data <-
7272 )
7373 )
7474
75+
76+ nnet_softmax <- function (results , object ) {
77+ if (ncol(results ) == 1 )
78+ results <- cbind(1 - results , results )
79+
80+ results <- apply(results , 1 , function (x ) exp(x )/ sum(exp(x )))
81+ results <- as_tibble(t(results ))
82+ names(results ) <- paste0(" .pred_" , object $ lvl )
83+ results
84+ }
85+
7586mlp_nnet_data <-
7687 list (
7788 libs = " nnet" ,
@@ -103,6 +114,17 @@ mlp_nnet_data <-
103114 type = " class"
104115 )
105116 ),
117+ classprob = list (
118+ pre = NULL ,
119+ post = nnet_softmax ,
120+ func = c(fun = " predict" ),
121+ args =
122+ list (
123+ object = quote(object $ fit ),
124+ newdata = quote(new_data ),
125+ type = " raw"
126+ )
127+ ),
106128 raw = list (
107129 pre = NULL ,
108130 func = c(fun = " predict" ),
@@ -114,6 +136,7 @@ mlp_nnet_data <-
114136 )
115137 )
116138
139+
117140# ------------------------------------------------------------------------------
118141
119142# keras wrapper for feed-forward nnet
You can’t perform that action at this time.
0 commit comments