From e689d1f69c02982c12cb13bf58de2b4f91892635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Fri, 14 Apr 2023 11:53:16 +0200 Subject: [PATCH] refactor computing max_cuts --- CPPFImdlp.cpp | 35 ++++++++++++++++++++++++++--------- CPPFImdlp.h | 10 ++++++---- tests/FImdlp_unittest.cpp | 5 +++-- 3 files changed, 35 insertions(+), 15 deletions(-) diff --git a/CPPFImdlp.cpp b/CPPFImdlp.cpp index a57c9a9..79f9152 100644 --- a/CPPFImdlp.cpp +++ b/CPPFImdlp.cpp @@ -50,6 +50,13 @@ namespace mdlp { indices = sortIndices(X_, y_); metrics.setData(y, indices); computeCutPoints(0, X.size(), 1); + sort(cutPoints.begin(), cutPoints.end()); + if (num_cut_points > 0) { + // Select the best (with lower entropy) cut points + while (cutPoints.size() > num_cut_points) { + resizeCutPoints(); + } + } } pair CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) { @@ -87,8 +94,6 @@ namespace mdlp { void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) { size_t cut; pair result; - if (cutPoints.size() == num_cut_points) - return; // Check if the interval length and the depth are Ok if (end - start < min_length || depth_ > max_depth) return; @@ -174,12 +179,24 @@ namespace mdlp { return idx; } - cutPoints_t CPPFImdlp::getCutPoints() { - sort(cutPoints.begin(), cutPoints.end()); - return cutPoints; - } - - int CPPFImdlp::get_depth() const { - return depth; + void CPPFImdlp::resizeCutPoints() { + //Compute entropy of each of the whole cutpoint set and discards the biggest value + precision_t maxEntropy = 0; + precision_t entropy; + size_t maxEntropyIdx = 0; + size_t begin = 0; + size_t end; + for (size_t idx = 0; idx < cutPoints.size(); idx++) { + end = begin; + while (X[indices[end]] < cutPoints[idx] && end < X.size()) + end++; + entropy = metrics.entropy(begin, end); + if (entropy > maxEntropy) { + maxEntropy = entropy; + maxEntropyIdx = idx; + } + begin = end; + } + cutPoints.erase(cutPoints.begin() + static_cast(maxEntropyIdx)); } } diff --git a/CPPFImdlp.h b/CPPFImdlp.h index c205719..29b76c6 100644 --- a/CPPFImdlp.h +++ b/CPPFImdlp.h @@ -21,10 +21,12 @@ namespace mdlp { cutPoints_t cutPoints; size_t num_cut_points = numeric_limits::max(); - static indices_t sortIndices(samples_t&, labels_t&); + static indices_t sortIndices(samples_t &, labels_t &); void computeCutPoints(size_t, size_t, int); + void resizeCutPoints(); + bool mdlp(size_t, size_t, size_t); size_t getCandidate(size_t, size_t); @@ -40,11 +42,11 @@ namespace mdlp { ~CPPFImdlp(); - void fit(samples_t&, labels_t&); + void fit(samples_t &, labels_t &); - cutPoints_t getCutPoints(); + inline cutPoints_t getCutPoints() const { return cutPoints; }; - int get_depth() const; + inline int get_depth() const { return depth; }; static inline string version() { return "1.1.1"; }; }; diff --git a/tests/FImdlp_unittest.cpp b/tests/FImdlp_unittest.cpp index 1559fa3..d0da40e 100644 --- a/tests/FImdlp_unittest.cpp +++ b/tests/FImdlp_unittest.cpp @@ -269,12 +269,13 @@ namespace mdlp { auto test = CPPFImdlp(75, 2, 1); vector expected = { {5.45f}, - {3.35f}, + {2.85f}, {2.45f}, {0.8f} }; - vector depths = {1, 1, 1, 1}; + vector depths = {2, 2, 2, 2}; test_dataset(test, "iris", expected, depths); + } TEST_F(TestFImdlp, MaxCutPointsFloat) {