@@ -16,6 +16,8 @@ ctrl <- fit_control(verbosity = 1, catch = FALSE)
1616caught_ctrl <- fit_control(verbosity = 1 , catch = TRUE )
1717quiet_ctrl <- fit_control(verbosity = 0 , catch = TRUE )
1818
19+ nn_dat <- read.csv(" nnet_test.txt" )
20+
1921# ------------------------------------------------------------------------------
2022
2123test_that(' keras execution, classification' , {
@@ -102,8 +104,8 @@ test_that('keras classification probabilities', {
102104 )
103105
104106 xy_pred <- keras :: predict_proba(xy_fit $ fit , x = as.matrix(iris [1 : 8 , num_pred ]))
105- xy_pred <- as_tibble(xy_pred )
106107 colnames(xy_pred ) <- paste0(" .pred_" , levels(iris $ Species ))
108+ xy_pred <- as_tibble(xy_pred )
107109 expect_equal(xy_pred , predict(xy_fit , new_data = iris [1 : 8 , num_pred ], type = " prob" ))
108110
109111 keras :: backend()$ clear_session()
@@ -116,8 +118,8 @@ test_that('keras classification probabilities', {
116118 )
117119
118120 form_pred <- keras :: predict_proba(form_fit $ fit , x = as.matrix(iris [1 : 8 , num_pred ]))
119- form_pred <- as_tibble(form_pred )
120121 colnames(form_pred ) <- paste0(" .pred_" , levels(iris $ Species ))
122+ form_pred <- as_tibble(form_pred )
121123 expect_equal(form_pred , predict(form_fit , new_data = iris [1 : 8 , num_pred ], type = " prob" ))
122124
123125 keras :: backend()$ clear_session()
@@ -204,8 +206,6 @@ test_that('keras regression prediction', {
204206
205207# ------------------------------------------------------------------------------
206208
207- nn_dat <- read.csv(" nnet_test.txt" )
208-
209209test_that(' multivariate nnet formula' , {
210210 skip_on_cran()
211211 skip_if_not_installed(" keras" )
@@ -218,6 +218,7 @@ test_that('multivariate nnet formula', {
218218 data = nn_dat [- (1 : 5 ),]
219219 )
220220 expect_equal(length(unlist(keras :: get_weights(nnet_form $ fit ))), 24 )
221+
221222 nnet_form_pred <- predict(nnet_form , new_data = nn_dat [1 : 5 , - (1 : 3 )])
222223 expect_equal(names(nnet_form_pred ), paste0(" .pred_" , c(" V1" , " V2" , " V3" )))
223224
0 commit comments