diff --git a/R/LearnerClassifSpatial.R b/R/LearnerClassifSpatial.R index 62f6192..4b2bfd0 100644 --- a/R/LearnerClassifSpatial.R +++ b/R/LearnerClassifSpatial.R @@ -6,7 +6,7 @@ LearnerClassifSpatial = R6::R6Class("LearnerClassifSpatial", initialize = function(learner) { self$learner = assert_learner(learner) super$initialize( - id = "classif.ranger", + id = learner$id, param_set = learner$param_set, predict_types = learner$predict_types, feature_types = learner$feature_types, @@ -14,6 +14,7 @@ LearnerClassifSpatial = R6::R6Class("LearnerClassifSpatial", packages = learner$packages, man = "mlr3learners::mlr_learners_classif.spatial" ) + self$predict_type = learner$predict_type }, predict = function(task, row_ids = NULL) { @@ -25,6 +26,13 @@ LearnerClassifSpatial = R6::R6Class("LearnerClassifSpatial", pred$data$row_ids = seq_len(nrow(data)) pred$data$response = response pred$data$truth = rep(NaN, nrow(data)) + if (self$learner$predict_type == "prob") { + prob = matrix(NaN, nrow = nrow(data), ncol = 2) + prob[ids, 1] = pred$data$prob[, 1] + prob[ids, 2] = pred$data$prob[, 2] + attributes(prob) = attributes(pred$prob) + pred$data$prob = prob + } pred } ) diff --git a/R/predict_spatial.R b/R/predict_spatial.R index 1f6cb65..a628fa2 100644 --- a/R/predict_spatial.R +++ b/R/predict_spatial.R @@ -15,6 +15,9 @@ #' For vector data only `"sf"` is supported. #' @param filename (`character(1)`)\cr #' Path where the spatial object should be written to. +#' @param predict_type (`character(1)`)\cr +#' Type of prediction to return. +#' Accepted values are `"response"` and `"prob"`. #' #' @return Spatial object of class given in argument `format`. #' @examples @@ -31,10 +34,11 @@ #' # predict land cover classes #' pred = predict_spatial(stack, learner, chunksize = 1L) #' @export -predict_spatial = function(newdata, learner, chunksize = 200L, format = "terra", filename = NULL) { +predict_spatial = function(newdata, learner, chunksize = 200L, format = "terra", filename = NULL, predict_type = "response") { task = as_task_unsupervised(newdata) assert_multi_class(task$backend, c("DataBackendRaster", "DataBackendVector")) assert_learner(learner) + assert_choice(predict_type, c("response", "prob")) if (test_class(task$backend, "DataBackendRaster")) { assert_number(chunksize) @@ -63,7 +67,8 @@ predict_spatial = function(newdata, learner, chunksize = 200L, format = "terra", stack = task$backend$stack pred = learner$predict(task, row_ids = cells_seq:((cells_seq + cells_to_read - 1))) - terra::writeValues(x = target_raster, v = pred$response, + vals = if (predict_type == "prob") pred$prob[, learner$learner$state$train_task$positive] else pred$response + terra::writeValues(x = target_raster, v = vals, start = terra::rowFromCell(stack, cells_seq), # start row number nrows = terra::rowFromCell(stack, cells_to_read)) # how many rows lg$info("Chunk %i of %i finished", n, length(bs$cells_seq)) @@ -72,7 +77,7 @@ predict_spatial = function(newdata, learner, chunksize = 200L, format = "terra", terra::writeStop(target_raster) lg$info("Finished raster prediction in %i seconds", as.integer(proc.time()[3] - start_time)) - if (learner$task_type == "classif") { + if (learner$task_type == "classif" && predict_type == "response") { levels = learner$learner$state$train_task$levels()[[learner$learner$state$train_task$target_names]] value = data.table(ID = seq_along(levels), categories = levels) target_raster = terra::categories(target_raster, value = value, index = 2) diff --git a/man/predict_spatial.Rd b/man/predict_spatial.Rd index 8f8afb8..d2a3c55 100644 --- a/man/predict_spatial.Rd +++ b/man/predict_spatial.Rd @@ -9,7 +9,8 @@ predict_spatial( learner, chunksize = 200L, format = "terra", - filename = NULL + filename = NULL, + predict_type = "response" ) } \arguments{ @@ -37,6 +38,10 @@ For vector data only \code{"sf"} is supported.} \item{filename}{(\code{character(1)})\cr Path where the spatial object should be written to.} + +\item{predict_type}{(\code{character(1)})\cr +Type of prediction to return. +Accepted values are \code{"response"} and \code{"prob"}.} } \value{ Spatial object of class given in argument \code{format}.