refactor computing max_cuts

This commit is contained in:
2023-04-14 11:53:16 +02:00
parent d77d27459b
commit e689d1f69c
3 changed files with 35 additions and 15 deletions

View File

@@ -50,6 +50,13 @@ namespace mdlp {
indices = sortIndices(X_, y_);
metrics.setData(y, indices);
computeCutPoints(0, X.size(), 1);
sort(cutPoints.begin(), cutPoints.end());
if (num_cut_points > 0) {
// Select the best (with lower entropy) cut points
while (cutPoints.size() > num_cut_points) {
resizeCutPoints();
}
}
}
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) {
@@ -87,8 +94,6 @@ namespace mdlp {
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) {
size_t cut;
pair<precision_t, size_t> result;
if (cutPoints.size() == num_cut_points)
return;
// Check if the interval length and the depth are Ok
if (end - start < min_length || depth_ > max_depth)
return;
@@ -174,12 +179,24 @@ namespace mdlp {
return idx;
}
cutPoints_t CPPFImdlp::getCutPoints() {
sort(cutPoints.begin(), cutPoints.end());
return cutPoints;
}
int CPPFImdlp::get_depth() const {
return depth;
void CPPFImdlp::resizeCutPoints() {
//Compute entropy of each of the whole cutpoint set and discards the biggest value
precision_t maxEntropy = 0;
precision_t entropy;
size_t maxEntropyIdx = 0;
size_t begin = 0;
size_t end;
for (size_t idx = 0; idx < cutPoints.size(); idx++) {
end = begin;
while (X[indices[end]] < cutPoints[idx] && end < X.size())
end++;
entropy = metrics.entropy(begin, end);
if (entropy > maxEntropy) {
maxEntropy = entropy;
maxEntropyIdx = idx;
}
begin = end;
}
cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx));
}
}