feat: 🎨 Refactor algorithm to accept 3 types of computation

This commit is contained in:
2022-12-15 12:08:33 +01:00
parent e21482900b
commit 50543e4921
3 changed files with 21 additions and 17 deletions

View File

@@ -6,7 +6,7 @@
#include "Metrics.h" #include "Metrics.h"
namespace mdlp { namespace mdlp {
CPPFImdlp::CPPFImdlp(bool proposal):proposal(proposal), indices(indices_t()), X(samples_t()), y(labels_t()), metrics(Metrics(y, indices)) CPPFImdlp::CPPFImdlp(int proposal):proposal(proposal), indices(indices_t()), X(samples_t()), y(labels_t()), metrics(Metrics(y, indices))
{ {
} }
CPPFImdlp::~CPPFImdlp() CPPFImdlp::~CPPFImdlp()
@@ -25,10 +25,17 @@ namespace mdlp {
} }
indices = sortIndices(X_); indices = sortIndices(X_);
metrics.setData(y, indices); metrics.setData(y, indices);
if (proposal) switch (proposal) {
computeCutPointsProposal(); case 0:
else computeCutPoints(0, X.size());
computeCutPoints(0, X.size()); break;
case 1:
computeCutPointsProposal();
break;
case 2:
computeCutPointsAlternative(0, X.size());
break;
}
return *this; return *this;
} }
void CPPFImdlp::computeCutPoints(size_t start, size_t end) void CPPFImdlp::computeCutPoints(size_t start, size_t end)
@@ -50,7 +57,7 @@ namespace mdlp {
computeCutPoints(start, cut); computeCutPoints(start, cut);
computeCutPoints(cut, end); computeCutPoints(cut, end);
} }
void CPPFImdlp::computeCutPointsOriginal(size_t start, size_t end) void CPPFImdlp::computeCutPointsAlternative(size_t start, size_t end)
{ {
precision_t cut; precision_t cut;
if (end - start < 2) if (end - start < 2)
@@ -61,8 +68,8 @@ namespace mdlp {
if (mdlp(start, cut, end)) { if (mdlp(start, cut, end)) {
cutPoints.push_back((X[indices[cut]] + X[indices[cut - 1]]) / 2); cutPoints.push_back((X[indices[cut]] + X[indices[cut - 1]]) / 2);
} }
computeCutPointsOriginal(start, cut); computeCutPointsAlternative(start, cut);
computeCutPointsOriginal(cut, end); computeCutPointsAlternative(cut, end);
} }
void CPPFImdlp::computeCutPointsProposal() void CPPFImdlp::computeCutPointsProposal()
{ {
@@ -102,7 +109,7 @@ namespace mdlp {
long int candidate = -1, elements = end - start; long int candidate = -1, elements = end - start;
precision_t entropy_left, entropy_right, minEntropy = numeric_limits<precision_t>::max(); precision_t entropy_left, entropy_right, minEntropy = numeric_limits<precision_t>::max();
for (auto idx = start + 1; idx < end; idx++) { for (auto idx = start + 1; idx < end; idx++) {
// Cutpoints are always on boudndaries // Cutpoints are always on boundaries
if (y[indices[idx]] == y[indices[idx - 1]]) if (y[indices[idx]] == y[indices[idx - 1]])
continue; continue;
entropy_left = precision_t(idx - start) / elements * metrics.entropy(start, idx); entropy_left = precision_t(idx - start) / elements * metrics.entropy(start, idx);

View File

@@ -6,7 +6,7 @@
namespace mdlp { namespace mdlp {
class CPPFImdlp { class CPPFImdlp {
protected: protected:
bool proposal; int proposal;
indices_t indices; // sorted indices to use with X and y indices_t indices; // sorted indices to use with X and y
samples_t X; samples_t X;
labels_t y; labels_t y;
@@ -15,16 +15,13 @@ namespace mdlp {
static indices_t sortIndices(samples_t&); static indices_t sortIndices(samples_t&);
void computeCutPoints(size_t, size_t); void computeCutPoints(size_t, size_t);
long int getCandidate(size_t, size_t);
bool mdlp(size_t, size_t, size_t); bool mdlp(size_t, size_t, size_t);
long int getCandidate(size_t, size_t);
// Original algorithm void computeCutPointsAlternative(size_t, size_t);
void computeCutPointsOriginal(size_t, size_t);
bool goodCut(size_t, size_t, size_t);
void computeCutPointsProposal(); void computeCutPointsProposal();
public: public:
CPPFImdlp(bool); CPPFImdlp(int);
~CPPFImdlp(); ~CPPFImdlp();
CPPFImdlp& fit(samples_t&, labels_t&); CPPFImdlp& fit(samples_t&, labels_t&);
samples_t getCutPoints(); samples_t getCutPoints();

View File

@@ -41,7 +41,7 @@ int main(int argc, char** argv)
} }
cout << y[i] << endl; cout << y[i] << endl;
} }
mdlp::CPPFImdlp test = mdlp::CPPFImdlp(false); mdlp::CPPFImdlp test = mdlp::CPPFImdlp(0);
for (auto i = 0; i < attributes.size(); i++) { for (auto i = 0; i < attributes.size(); i++) {
cout << "Cut points for " << get<0>(attributes[i]) << endl; cout << "Cut points for " << get<0>(attributes[i]) << endl;
cout << "--------------------------" << setprecision(3) << endl; cout << "--------------------------" << setprecision(3) << endl;