From 5d442d85987828f357aa17d39e22a739c6a1b283 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Thu, 6 Feb 2020 19:09:29 +0800 Subject: [PATCH] expose weights to user --- src/classification/main.jl | 7 +++++-- src/classification/tree.jl | 4 ++-- src/regression/main.jl | 5 ++++- src/scikitlearnAPI.jl | 12 ++++++++---- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/classification/main.jl b/src/classification/main.jl index 875c5213..2899c1f6 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -79,8 +79,9 @@ function build_tree( min_samples_leaf = 1, min_samples_split = 2, min_purity_increase = 0.0; + weights::Union{Nothing,AbstractVector{U}} = nothing, loss = util.entropy :: Function, - rng = Random.GLOBAL_RNG) where {S, T} + rng = Random.GLOBAL_RNG) where {S, T, U <: Integer} if max_depth == -1 max_depth = typemax(Int) @@ -93,7 +94,7 @@ function build_tree( t = treeclassifier.fit( X = features, Y = labels, - W = nothing, + W = weights, loss = loss, max_features = Int(n_subfeatures), max_depth = Int(max_depth), @@ -195,6 +196,7 @@ function build_forest( min_samples_leaf = 1, min_samples_split = 2, min_purity_increase = 0.0; + weights = nothing, rng = Random.GLOBAL_RNG) where {S, T} if n_trees < 1 @@ -229,6 +231,7 @@ function build_forest( min_samples_leaf, min_samples_split, min_purity_increase, + weights = (weights === nothing ? nothing : weights[inds]), loss = loss, rng = rngs) end diff --git a/src/classification/tree.jl b/src/classification/tree.jl index b581b72c..b8bbdb77 100644 --- a/src/classification/tree.jl +++ b/src/classification/tree.jl @@ -307,7 +307,7 @@ module treeclassifier function fit(; X :: Matrix{S}, Y :: Vector{T}, - W :: Union{Nothing, Vector{U}}, + W :: Union{Nothing, AbstractVector{U}}, loss=util.entropy :: Function, max_features :: Int, max_depth :: Int, @@ -318,7 +318,7 @@ module treeclassifier n_samples, n_features = size(X) list, Y_ = util.assign(Y) - if W == nothing + if W === nothing W = fill(1, n_samples) end diff --git a/src/regression/main.jl b/src/regression/main.jl index 45569ff5..838ba614 100644 --- a/src/regression/main.jl +++ b/src/regression/main.jl @@ -22,6 +22,7 @@ function build_tree( min_samples_leaf = 5, min_samples_split = 2, min_purity_increase = 0.0; + weights = nothing, rng = Random.GLOBAL_RNG) where {S, T <: Float64} if max_depth == -1 @@ -35,7 +36,7 @@ function build_tree( t = treeregressor.fit( X = features, Y = labels, - W = nothing, + W = weights, max_features = Int(n_subfeatures), max_depth = Int(max_depth), min_samples_leaf = Int(min_samples_leaf), @@ -56,6 +57,7 @@ function build_forest( min_samples_leaf = 5, min_samples_split = 2, min_purity_increase = 0.0; + weights = nothing, rng = Random.GLOBAL_RNG) where {S, T <: Float64} if n_trees < 1 @@ -86,6 +88,7 @@ function build_forest( min_samples_leaf, min_samples_split, min_purity_increase, + weights = (weights === nothing ? nothing : weights[inds]), rng = rngs) end diff --git a/src/scikitlearnAPI.jl b/src/scikitlearnAPI.jl index 64503301..095b9497 100644 --- a/src/scikitlearnAPI.jl +++ b/src/scikitlearnAPI.jl @@ -49,7 +49,7 @@ get_classes(dt::DecisionTreeClassifier) = dt.classes [:pruning_purity_threshold, :max_depth, :min_samples_leaf, :min_samples_split, :min_purity_increase, :rng]) -function fit!(dt::DecisionTreeClassifier, X, y) +function fit!(dt::DecisionTreeClassifier, X, y, weights=nothing) n_samples, n_features = size(X) dt.root = build_tree( y, X, @@ -58,6 +58,7 @@ function fit!(dt::DecisionTreeClassifier, X, y) dt.min_samples_leaf, dt.min_samples_split, dt.min_purity_increase; + weights = weights, rng = dt.rng) dt.root = prune_tree(dt.root, dt.pruning_purity_threshold) @@ -136,7 +137,7 @@ end [:pruning_purity_threshold, :min_samples_leaf, :n_subfeatures, :max_depth, :min_samples_split, :min_purity_increase, :rng]) -function fit!(dt::DecisionTreeRegressor, X::Matrix, y::Vector) +function fit!(dt::DecisionTreeRegressor, X::Matrix, y::Vector, weights=nothing) n_samples, n_features = size(X) dt.root = build_tree( float.(y), X, @@ -145,6 +146,7 @@ function fit!(dt::DecisionTreeRegressor, X::Matrix, y::Vector) dt.min_samples_leaf, dt.min_samples_split, dt.min_purity_increase; + weights = weights, rng = dt.rng) dt.pruning_purity_threshold dt.root = prune_tree(dt.root, dt.pruning_purity_threshold) @@ -213,7 +215,7 @@ get_classes(rf::RandomForestClassifier) = rf.classes :min_samples_leaf, :min_samples_split, :min_purity_increase, :rng]) -function fit!(rf::RandomForestClassifier, X::Matrix, y::Vector) +function fit!(rf::RandomForestClassifier, X::Matrix, y::Vector, weights=nothing) n_samples, n_features = size(X) rf.ensemble = build_forest( y, X, @@ -224,6 +226,7 @@ function fit!(rf::RandomForestClassifier, X::Matrix, y::Vector) rf.min_samples_leaf, rf.min_samples_split, rf.min_purity_increase; + weights = weights, rng = rf.rng) rf.classes = sort(unique(y)) rf @@ -297,7 +300,7 @@ end # since it'll change throughout fitting, but it works :max_depth, :rng]) -function fit!(rf::RandomForestRegressor, X::Matrix, y::Vector) +function fit!(rf::RandomForestRegressor, X::Matrix, y::Vector, weights=nothing) n_samples, n_features = size(X) rf.ensemble = build_forest( float.(y), X, @@ -308,6 +311,7 @@ function fit!(rf::RandomForestRegressor, X::Matrix, y::Vector) rf.min_samples_leaf, rf.min_samples_split, rf.min_purity_increase; + weights = weights, rng = rf.rng) rf end