feat: 🎨 Refactor algorithm to accept 3 types of computation

This commit is contained in:
2022-12-15 12:08:33 +01:00
parent e21482900b
commit 50543e4921
3 changed files with 21 additions and 17 deletions

View File

@@ -6,7 +6,7 @@
#include "Metrics.h"
namespace mdlp {
CPPFImdlp::CPPFImdlp(bool proposal):proposal(proposal), indices(indices_t()), X(samples_t()), y(labels_t()), metrics(Metrics(y, indices))
CPPFImdlp::CPPFImdlp(int proposal):proposal(proposal), indices(indices_t()), X(samples_t()), y(labels_t()), metrics(Metrics(y, indices))
{
}
CPPFImdlp::~CPPFImdlp()
@@ -25,10 +25,17 @@ namespace mdlp {
}
indices = sortIndices(X_);
metrics.setData(y, indices);
if (proposal)
computeCutPointsProposal();
else
computeCutPoints(0, X.size());
switch (proposal) {
case 0:
computeCutPoints(0, X.size());
break;
case 1:
computeCutPointsProposal();
break;
case 2:
computeCutPointsAlternative(0, X.size());
break;
}
return *this;
}
void CPPFImdlp::computeCutPoints(size_t start, size_t end)
@@ -50,7 +57,7 @@ namespace mdlp {
computeCutPoints(start, cut);
computeCutPoints(cut, end);
}
void CPPFImdlp::computeCutPointsOriginal(size_t start, size_t end)
void CPPFImdlp::computeCutPointsAlternative(size_t start, size_t end)
{
precision_t cut;
if (end - start < 2)
@@ -61,8 +68,8 @@ namespace mdlp {
if (mdlp(start, cut, end)) {
cutPoints.push_back((X[indices[cut]] + X[indices[cut - 1]]) / 2);
}
computeCutPointsOriginal(start, cut);
computeCutPointsOriginal(cut, end);
computeCutPointsAlternative(start, cut);
computeCutPointsAlternative(cut, end);
}
void CPPFImdlp::computeCutPointsProposal()
{
@@ -102,7 +109,7 @@ namespace mdlp {
long int candidate = -1, elements = end - start;
precision_t entropy_left, entropy_right, minEntropy = numeric_limits<precision_t>::max();
for (auto idx = start + 1; idx < end; idx++) {
// Cutpoints are always on boudndaries
// Cutpoints are always on boundaries
if (y[indices[idx]] == y[indices[idx - 1]])
continue;
entropy_left = precision_t(idx - start) / elements * metrics.entropy(start, idx);

View File

@@ -6,7 +6,7 @@
namespace mdlp {
class CPPFImdlp {
protected:
bool proposal;
int proposal;
indices_t indices; // sorted indices to use with X and y
samples_t X;
labels_t y;
@@ -15,16 +15,13 @@ namespace mdlp {
static indices_t sortIndices(samples_t&);
void computeCutPoints(size_t, size_t);
long int getCandidate(size_t, size_t);
bool mdlp(size_t, size_t, size_t);
// Original algorithm
void computeCutPointsOriginal(size_t, size_t);
bool goodCut(size_t, size_t, size_t);
long int getCandidate(size_t, size_t);
void computeCutPointsAlternative(size_t, size_t);
void computeCutPointsProposal();
public:
CPPFImdlp(bool);
CPPFImdlp(int);
~CPPFImdlp();
CPPFImdlp& fit(samples_t&, labels_t&);
samples_t getCutPoints();

View File

@@ -41,7 +41,7 @@ int main(int argc, char** argv)
}
cout << y[i] << endl;
}
mdlp::CPPFImdlp test = mdlp::CPPFImdlp(false);
mdlp::CPPFImdlp test = mdlp::CPPFImdlp(0);
for (auto i = 0; i < attributes.size(); i++) {
cout << "Cut points for " << get<0>(attributes[i]) << endl;
cout << "--------------------------" << setprecision(3) << endl;