Skip to content

Commit f1a5dde

Browse files
committed
optimised Random forest
1 parent 33d17a6 commit f1a5dde

File tree

2 files changed

+101
-125
lines changed

2 files changed

+101
-125
lines changed

ml_library_include/ml/tree/RandomForestClassifier.hpp

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
#include <algorithm>
66
#include <numeric>
77
#include <limits>
8-
#include <map>
8+
#include <unordered_map>
99
#include <cmath>
1010
#include <random>
11-
#include <ctime>
12-
#include <cstdlib>
11+
#include <memory>
1312

1413
/**
1514
* @file RandomForestClassifier.hpp
@@ -34,7 +33,7 @@ class RandomForestClassifier {
3433
/**
3534
* @brief Destructor for RandomForestClassifier.
3635
*/
37-
~RandomForestClassifier();
36+
~RandomForestClassifier() = default;
3837

3938
/**
4039
* @brief Fits the model to the training data.
@@ -56,81 +55,79 @@ class RandomForestClassifier {
5655
int value; // Class label for leaf nodes
5756
int feature_index;
5857
double threshold;
59-
Node* left;
60-
Node* right;
58+
std::unique_ptr<Node> left;
59+
std::unique_ptr<Node> right;
6160

62-
Node() : is_leaf(false), value(0), feature_index(-1), threshold(0.0), left(nullptr), right(nullptr) {}
61+
Node() : is_leaf(false), value(0), feature_index(-1), threshold(0.0) {}
6362
};
6463

6564
struct DecisionTree {
66-
Node* root;
65+
std::unique_ptr<Node> root;
6766
int max_depth;
6867
int min_samples_split;
6968
int max_features;
69+
std::mt19937 random_engine;
7070

71-
DecisionTree(int max_depth, int min_samples_split, int max_features);
72-
~DecisionTree();
71+
DecisionTree(int max_depth, int min_samples_split, int max_features, std::mt19937::result_type seed);
72+
~DecisionTree() = default;
7373
void fit(const std::vector<std::vector<double>>& X, const std::vector<int>& y);
7474
int predict_sample(const std::vector<double>& x) const;
7575

7676
private:
77-
Node* build_tree(const std::vector<std::vector<double>>& X, const std::vector<int>& y, int depth);
77+
std::unique_ptr<Node> build_tree(const std::vector<std::vector<double>>& X, const std::vector<int>& y, int depth);
7878
double calculate_gini(const std::vector<int>& y) const;
7979
void split_dataset(const std::vector<std::vector<double>>& X, const std::vector<int>& y, int feature_index, double threshold,
8080
std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
8181
std::vector<std::vector<double>>& X_right, std::vector<int>& y_right) const;
82-
void delete_tree(Node* node);
82+
int majority_class(const std::vector<int>& y) const;
8383
};
8484

8585
int n_estimators;
8686
int max_depth;
8787
int min_samples_split;
8888
int max_features;
89-
std::vector<DecisionTree*> trees;
89+
std::vector<std::unique_ptr<DecisionTree>> trees;
90+
std::mt19937 random_engine;
9091

9192
void bootstrap_sample(const std::vector<std::vector<double>>& X, const std::vector<int>& y,
9293
std::vector<std::vector<double>>& X_sample, std::vector<int>& y_sample);
9394
};
9495

9596
RandomForestClassifier::RandomForestClassifier(int n_estimators, int max_depth, int min_samples_split, int max_features)
9697
: n_estimators(n_estimators), max_depth(max_depth), min_samples_split(min_samples_split), max_features(max_features) {
97-
std::srand(static_cast<unsigned int>(std::time(0)));
98-
}
99-
100-
RandomForestClassifier::~RandomForestClassifier() {
101-
for (auto tree : trees) {
102-
delete tree;
103-
}
98+
std::random_device rd;
99+
random_engine.seed(rd());
104100
}
105101

106102
void RandomForestClassifier::fit(const std::vector<std::vector<double>>& X, const std::vector<int>& y) {
107103
// Set max_features if not set
108-
if (max_features == -1) {
109-
max_features = static_cast<int>(std::sqrt(X[0].size()));
104+
int actual_max_features = max_features;
105+
if (actual_max_features == -1) {
106+
actual_max_features = static_cast<int>(std::sqrt(X[0].size()));
110107
}
111108

112109
for (int i = 0; i < n_estimators; ++i) {
113110
std::vector<std::vector<double>> X_sample;
114111
std::vector<int> y_sample;
115112
bootstrap_sample(X, y, X_sample, y_sample);
116113

117-
DecisionTree* tree = new DecisionTree(max_depth, min_samples_split, max_features);
114+
auto tree = std::make_unique<DecisionTree>(max_depth, min_samples_split, actual_max_features, random_engine());
118115
tree->fit(X_sample, y_sample);
119-
trees.push_back(tree);
116+
trees.push_back(std::move(tree));
120117
}
121118
}
122119

123120
std::vector<int> RandomForestClassifier::predict(const std::vector<std::vector<double>>& X) const {
124121
std::vector<int> predictions(X.size());
125122
for (size_t i = 0; i < X.size(); ++i) {
126-
std::map<int, int> votes;
123+
std::unordered_map<int, int> votes;
127124
for (const auto& tree : trees) {
128125
int vote = tree->predict_sample(X[i]);
129126
votes[vote]++;
130127
}
131128
// Majority vote
132129
predictions[i] = std::max_element(votes.begin(), votes.end(),
133-
[](const std::pair<int, int>& a, const std::pair<int, int>& b) {
130+
[](const auto& a, const auto& b) {
134131
return a.second < b.second;
135132
})->first;
136133
}
@@ -141,54 +138,41 @@ void RandomForestClassifier::bootstrap_sample(const std::vector<std::vector<doub
141138
std::vector<std::vector<double>>& X_sample, std::vector<int>& y_sample) {
142139
size_t n_samples = X.size();
143140
std::uniform_int_distribution<size_t> dist(0, n_samples - 1);
144-
std::default_random_engine engine(static_cast<unsigned long>(std::rand()));
145141

146142
for (size_t i = 0; i < n_samples; ++i) {
147-
size_t index = dist(engine);
143+
size_t index = dist(random_engine);
148144
X_sample.push_back(X[index]);
149145
y_sample.push_back(y[index]);
150146
}
151147
}
152148

153-
RandomForestClassifier::DecisionTree::DecisionTree(int max_depth, int min_samples_split, int max_features)
154-
: root(nullptr), max_depth(max_depth), min_samples_split(min_samples_split), max_features(max_features) {}
155-
156-
RandomForestClassifier::DecisionTree::~DecisionTree() {
157-
delete_tree(root);
158-
}
149+
RandomForestClassifier::DecisionTree::DecisionTree(int max_depth, int min_samples_split, int max_features, std::mt19937::result_type seed)
150+
: root(nullptr), max_depth(max_depth), min_samples_split(min_samples_split), max_features(max_features), random_engine(seed) {}
159151

160152
void RandomForestClassifier::DecisionTree::fit(const std::vector<std::vector<double>>& X, const std::vector<int>& y) {
161153
root = build_tree(X, y, 0);
162154
}
163155

164156
int RandomForestClassifier::DecisionTree::predict_sample(const std::vector<double>& x) const {
165-
Node* node = root;
157+
const Node* node = root.get();
166158
while (!node->is_leaf) {
167159
if (x[node->feature_index] <= node->threshold) {
168-
node = node->left;
160+
node = node->left.get();
169161
} else {
170-
node = node->right;
162+
node = node->right.get();
171163
}
172164
}
173165
return node->value;
174166
}
175167

176-
RandomForestClassifier::Node* RandomForestClassifier::DecisionTree::build_tree(const std::vector<std::vector<double>>& X,
177-
const std::vector<int>& y, int depth) {
178-
Node* node = new Node();
168+
std::unique_ptr<RandomForestClassifier::Node> RandomForestClassifier::DecisionTree::build_tree(
169+
const std::vector<std::vector<double>>& X, const std::vector<int>& y, int depth) {
170+
auto node = std::make_unique<Node>();
179171

180172
// Check stopping criteria
181173
if (depth >= max_depth || y.size() < static_cast<size_t>(min_samples_split) || calculate_gini(y) == 0.0) {
182174
node->is_leaf = true;
183-
// Majority class label
184-
std::map<int, int> class_counts;
185-
for (int label : y) {
186-
class_counts[label]++;
187-
}
188-
node->value = std::max_element(class_counts.begin(), class_counts.end(),
189-
[](const std::pair<int, int>& a, const std::pair<int, int>& b) {
190-
return a.second < b.second;
191-
})->first;
175+
node->value = majority_class(y);
192176
return node;
193177
}
194178

@@ -203,19 +187,25 @@ RandomForestClassifier::Node* RandomForestClassifier::DecisionTree::build_tree(c
203187
std::iota(features_indices.begin(), features_indices.end(), 0);
204188

205189
// Randomly select features without replacement
206-
std::shuffle(features_indices.begin(), features_indices.end(), std::default_random_engine(static_cast<unsigned long>(std::rand())));
190+
std::shuffle(features_indices.begin(), features_indices.end(), random_engine);
207191
if (max_features < num_features) {
208192
features_indices.resize(max_features);
209193
}
210194

211195
for (int feature_index : features_indices) {
212196
// Get all possible thresholds
213197
std::vector<double> feature_values;
198+
feature_values.reserve(X.size());
214199
for (const auto& x : X) {
215200
feature_values.push_back(x[feature_index]);
216201
}
217202
std::sort(feature_values.begin(), feature_values.end());
203+
feature_values.erase(std::unique(feature_values.begin(), feature_values.end()), feature_values.end());
204+
205+
if (feature_values.size() <= 1) continue;
206+
218207
std::vector<double> thresholds;
208+
thresholds.reserve(feature_values.size() - 1);
219209
for (size_t i = 1; i < feature_values.size(); ++i) {
220210
thresholds.push_back((feature_values[i - 1] + feature_values[i]) / 2.0);
221211
}
@@ -237,26 +227,18 @@ RandomForestClassifier::Node* RandomForestClassifier::DecisionTree::build_tree(c
237227
best_gini = gini;
238228
best_feature_index = feature_index;
239229
best_threshold = threshold;
240-
best_X_left = X_left;
241-
best_X_right = X_right;
242-
best_y_left = y_left;
243-
best_y_right = y_right;
230+
best_X_left = std::move(X_left);
231+
best_X_right = std::move(X_right);
232+
best_y_left = std::move(y_left);
233+
best_y_right = std::move(y_right);
244234
}
245235
}
246236
}
247237

248238
// If no split improves the Gini impurity, make this a leaf node
249239
if (best_feature_index == -1) {
250240
node->is_leaf = true;
251-
// Majority class label
252-
std::map<int, int> class_counts;
253-
for (int label : y) {
254-
class_counts[label]++;
255-
}
256-
node->value = std::max_element(class_counts.begin(), class_counts.end(),
257-
[](const std::pair<int, int>& a, const std::pair<int, int>& b) {
258-
return a.second < b.second;
259-
})->first;
241+
node->value = majority_class(y);
260242
return node;
261243
}
262244

@@ -269,19 +251,30 @@ RandomForestClassifier::Node* RandomForestClassifier::DecisionTree::build_tree(c
269251
}
270252

271253
double RandomForestClassifier::DecisionTree::calculate_gini(const std::vector<int>& y) const {
272-
std::map<int, int> class_counts;
254+
std::unordered_map<int, int> class_counts;
273255
for (int label : y) {
274256
class_counts[label]++;
275257
}
276258
double impurity = 1.0;
277259
size_t total = y.size();
278-
for (const auto& class_count : class_counts) {
279-
double prob = static_cast<double>(class_count.second) / total;
260+
for (const auto& [label, count] : class_counts) {
261+
double prob = static_cast<double>(count) / total;
280262
impurity -= prob * prob;
281263
}
282264
return impurity;
283265
}
284266

267+
int RandomForestClassifier::DecisionTree::majority_class(const std::vector<int>& y) const {
268+
std::unordered_map<int, int> class_counts;
269+
for (int label : y) {
270+
class_counts[label]++;
271+
}
272+
return std::max_element(class_counts.begin(), class_counts.end(),
273+
[](const auto& a, const auto& b) {
274+
return a.second < b.second;
275+
})->first;
276+
}
277+
285278
void RandomForestClassifier::DecisionTree::split_dataset(const std::vector<std::vector<double>>& X, const std::vector<int>& y,
286279
int feature_index, double threshold,
287280
std::vector<std::vector<double>>& X_left, std::vector<int>& y_left,
@@ -297,12 +290,4 @@ void RandomForestClassifier::DecisionTree::split_dataset(const std::vector<std::
297290
}
298291
}
299292

300-
void RandomForestClassifier::DecisionTree::delete_tree(Node* node) {
301-
if (node != nullptr) {
302-
delete_tree(node->left);
303-
delete_tree(node->right);
304-
delete node;
305-
}
306-
}
307-
308293
#endif // RANDOM_FOREST_CLASSIFIER_HPP

0 commit comments

Comments
 (0)