refactor computing max_cuts

This commit is contained in:
2023-04-14 11:53:16 +02:00
parent d77d27459b
commit e689d1f69c
3 changed files with 35 additions and 15 deletions

View File

@@ -50,6 +50,13 @@ namespace mdlp {
indices = sortIndices(X_, y_);
metrics.setData(y, indices);
computeCutPoints(0, X.size(), 1);
sort(cutPoints.begin(), cutPoints.end());
if (num_cut_points > 0) {
// Select the best (with lower entropy) cut points
while (cutPoints.size() > num_cut_points) {
resizeCutPoints();
}
}
}
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) {
@@ -87,8 +94,6 @@ namespace mdlp {
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) {
size_t cut;
pair<precision_t, size_t> result;
if (cutPoints.size() == num_cut_points)
return;
// Check if the interval length and the depth are Ok
if (end - start < min_length || depth_ > max_depth)
return;
@@ -174,12 +179,24 @@ namespace mdlp {
return idx;
}
cutPoints_t CPPFImdlp::getCutPoints() {
sort(cutPoints.begin(), cutPoints.end());
return cutPoints;
}
int CPPFImdlp::get_depth() const {
return depth;
void CPPFImdlp::resizeCutPoints() {
//Compute entropy of each of the whole cutpoint set and discards the biggest value
precision_t maxEntropy = 0;
precision_t entropy;
size_t maxEntropyIdx = 0;
size_t begin = 0;
size_t end;
for (size_t idx = 0; idx < cutPoints.size(); idx++) {
end = begin;
while (X[indices[end]] < cutPoints[idx] && end < X.size())
end++;
entropy = metrics.entropy(begin, end);
if (entropy > maxEntropy) {
maxEntropy = entropy;
maxEntropyIdx = idx;
}
begin = end;
}
cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx));
}
}

View File

@@ -21,10 +21,12 @@ namespace mdlp {
cutPoints_t cutPoints;
size_t num_cut_points = numeric_limits<size_t>::max();
static indices_t sortIndices(samples_t&, labels_t&);
static indices_t sortIndices(samples_t &, labels_t &);
void computeCutPoints(size_t, size_t, int);
void resizeCutPoints();
bool mdlp(size_t, size_t, size_t);
size_t getCandidate(size_t, size_t);
@@ -40,11 +42,11 @@ namespace mdlp {
~CPPFImdlp();
void fit(samples_t&, labels_t&);
void fit(samples_t &, labels_t &);
cutPoints_t getCutPoints();
inline cutPoints_t getCutPoints() const { return cutPoints; };
int get_depth() const;
inline int get_depth() const { return depth; };
static inline string version() { return "1.1.1"; };
};

View File

@@ -269,12 +269,13 @@ namespace mdlp {
auto test = CPPFImdlp(75, 2, 1);
vector<cutPoints_t> expected = {
{5.45f},
{3.35f},
{2.85f},
{2.45f},
{0.8f}
};
vector<int> depths = {1, 1, 1, 1};
vector<int> depths = {2, 2, 2, 2};
test_dataset(test, "iris", expected, depths);
}
TEST_F(TestFImdlp, MaxCutPointsFloat) {