Add same_values to getCandidate and fine tune ValueCutPoint

This commit is contained in:
2023-02-15 13:07:03 +01:00
parent e37702dcb0
commit 04c1772019

View File

@@ -4,43 +4,26 @@
#include <cmath> #include <cmath>
#include "CPPFImdlp.h" #include "CPPFImdlp.h"
#include "Metrics.h" #include "Metrics.h"
#include <iostream>
namespace mdlp { namespace mdlp {
CPPFImdlp::CPPFImdlp(int algorithm):algorithm(algorithm), indices(indices_t()), X(samples_t()), y(labels_t()), metrics(Metrics(y, indices)) CPPFImdlp::CPPFImdlp(int algorithm) : algorithm(algorithm), indices(indices_t()), X(samples_t()), y(labels_t()),
{ metrics(Metrics(y, indices)) {
} }
CPPFImdlp::~CPPFImdlp()
= default; CPPFImdlp::~CPPFImdlp() = default;
CPPFImdlp& CPPFImdlp::fit(samples_t& X_, labels_t& y_)
{ CPPFImdlp &CPPFImdlp::fit(samples_t &X_, labels_t &y_) {
X = X_; X = X_;
y = y_; y = y_;
cutPoints.clear(); cutPoints.clear();
if (X.size() != y.size()) { if (X.size() != y.size()) {
throw invalid_argument("X and y must have the same size"); throw invalid_argument("X and y must have the same size");
} }
if (X.size() == 0 || y.size() == 0) { if (X.empty() || y.empty()) {
throw invalid_argument("X and y must have at least one element"); throw invalid_argument("X and y must have at least one element");
} }
indices = sortIndices(X_, y_); indices = sortIndices(X_, y_);
metrics.setData(y, indices); metrics.setData(y, indices);
for (auto i=0; i< X.size(); i++) {
if (i% 10 ==0) {
cout << " # Idx --X-- y"<<endl;
cout << "--- --- ----- -"<<endl;
}
auto index = indices[i];
cout.width(3);
cout << i << " ";
cout.width(3);
cout << index << " ";
cout.width(5);
cout << X[index];
cout.width(1);
cout << " " << y[index] << endl;
}
switch (algorithm) { switch (algorithm) {
case 0: case 0:
computeCutPoints(0, X.size()); computeCutPoints(0, X.size());
@@ -58,8 +41,8 @@ namespace mdlp {
} }
return *this; return *this;
} }
precision_t CPPFImdlp::halfWayValueCutPoint(size_t start, size_t idx)
{ precision_t CPPFImdlp::halfWayValueCutPoint(size_t start, size_t idx) {
size_t idxPrev = idx - 1; size_t idxPrev = idx - 1;
precision_t previous = X[indices[idxPrev]], actual = X[indices[idx]]; precision_t previous = X[indices[idxPrev]], actual = X[indices[idx]];
// definition 2 of the paper => X[t-1] < X[t] // definition 2 of the paper => X[t-1] < X[t]
@@ -68,9 +51,10 @@ namespace mdlp {
} }
return (previous + actual) / 2; return (previous + actual) / 2;
} }
tuple<precision_t, size_t> CPPFImdlp::completeValueCutPoint(size_t start, size_t cut, size_t end)
{ tuple<precision_t, size_t> CPPFImdlp::completeValueCutPoint(size_t start, size_t cut, size_t end) {
size_t idxPrev = cut - 1; size_t idxPrev = cut - 1;
bool fforward = false;
precision_t previous, actual; precision_t previous, actual;
previous = X[indices[idxPrev]]; previous = X[indices[idxPrev]];
actual = X[indices[cut]]; actual = X[indices[cut]];
@@ -79,14 +63,19 @@ namespace mdlp {
previous = X[indices[idxPrev]]; previous = X[indices[idxPrev]];
} }
// get the last equal value of X in the interval // get the last equal value of X in the interval
while (actual == X[indices[++cut]] && cut < end); while (actual == X[indices[cut]] && cut + 1 < end) {
if (previous == actual && cut < end) cut++;
actual = X[indices[cut]]; fforward = true;
cut--; }
if (fforward)
cut--;
// try to get the next value if it can't be found backwards
if (previous == actual && cut + 1 < end)
actual = X[indices[cut + 1]];
return make_tuple((previous + actual) / 2, cut); return make_tuple((previous + actual) / 2, cut);
} }
void CPPFImdlp::computeCutPoints(size_t start, size_t end)
{ void CPPFImdlp::computeCutPoints(size_t start, size_t end) {
size_t cut; size_t cut;
tuple<precision_t, size_t> result; tuple<precision_t, size_t> result;
if (end - start < 2) if (end - start < 2)
@@ -102,8 +91,8 @@ namespace mdlp {
computeCutPoints(cut, end); computeCutPoints(cut, end);
} }
} }
void CPPFImdlp::computeCutPointsAlternative(size_t start, size_t end)
{ void CPPFImdlp::computeCutPointsAlternative(size_t start, size_t end) {
size_t cut; size_t cut;
if (end - start < 2) if (end - start < 2)
return; return;
@@ -116,8 +105,8 @@ namespace mdlp {
computeCutPointsAlternative(cut, end); computeCutPointsAlternative(cut, end);
} }
} }
void CPPFImdlp::computeCutPointsClassic(size_t start, size_t end)
{ void CPPFImdlp::computeCutPointsClassic(size_t start, size_t end) {
size_t cut; size_t cut;
cut = getCandidate(start, end); cut = getCandidate(start, end);
if (cut == numeric_limits<size_t>::max() || !mdlp(start, cut, end)) { if (cut == numeric_limits<size_t>::max() || !mdlp(start, cut, end)) {
@@ -135,14 +124,17 @@ namespace mdlp {
computeCutPoints(start, cut); computeCutPoints(start, cut);
computeCutPoints(cut, end); computeCutPoints(cut, end);
} }
size_t CPPFImdlp::getCandidate(size_t start, size_t end)
{ size_t CPPFImdlp::getCandidate(size_t start, size_t end) {
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which /* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
E(A, TA; S) is minimal amogst all the candidate cut points. */ E(A, TA; S) is minimal amongst all the candidate cut points. */
size_t candidate = numeric_limits<size_t>::max(), elements = end - start; size_t candidate = numeric_limits<size_t>::max(), elements = end - start;
bool same_values = true;
precision_t entropy_left, entropy_right, minEntropy; precision_t entropy_left, entropy_right, minEntropy;
minEntropy = metrics.entropy(start, end); minEntropy = metrics.entropy(start, end);
for (auto idx = start + 1; idx < end; idx++) { for (auto idx = start + 1; idx < end; idx++) {
if (X[indices[idx]] != X[indices[idx - 1]])
same_values = false;
// Cutpoints are always on boundaries (definition 2) // Cutpoints are always on boundaries (definition 2)
if (y[indices[idx]] == y[indices[idx - 1]]) if (y[indices[idx]] == y[indices[idx - 1]])
continue; continue;
@@ -153,10 +145,13 @@ namespace mdlp {
candidate = idx; candidate = idx;
} }
} }
// If all the values of the variable in the interval are the same, it doesn't consider the cut point
if (same_values)
candidate = numeric_limits<size_t>::max();
return candidate; return candidate;
} }
bool CPPFImdlp::mdlp(size_t start, size_t cut, size_t end)
{ bool CPPFImdlp::mdlp(size_t start, size_t cut, size_t end) {
int k, k1, k2; int k, k1, k2;
precision_t ig, delta; precision_t ig, delta;
precision_t ent, ent1, ent2; precision_t ent, ent1, ent2;
@@ -172,38 +167,37 @@ namespace mdlp {
ent2 = metrics.entropy(cut, end); ent2 = metrics.entropy(cut, end);
ig = metrics.informationGain(start, cut, end); ig = metrics.informationGain(start, cut, end);
delta = log2(pow(3, precision_t(k)) - 2) - delta = log2(pow(3, precision_t(k)) - 2) -
(precision_t(k) * ent - precision_t(k1) * ent1 - precision_t(k2) * ent2); (precision_t(k) * ent - precision_t(k1) * ent1 - precision_t(k2) * ent2);
precision_t term = 1 / N * (log2(N - 1) + delta); precision_t term = 1 / N * (log2(N - 1) + delta);
return ig > term; return ig > term;
} }
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes // Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
indices_t CPPFImdlp::sortIndices(samples_t& X_, labels_t& y_) indices_t CPPFImdlp::sortIndices(samples_t &X_, labels_t &y_) {
{
indices_t idx(X_.size()); indices_t idx(X_.size());
iota(idx.begin(), idx.end(), 0); iota(idx.begin(), idx.end(), 0);
for (size_t i = 0; i < X_.size(); i++) for (size_t i = 0; i < X_.size(); i++)
stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) stable_sort(idx.begin(), idx.end(), [&X_, &y_](size_t i1, size_t i2) {
{ if (X_[i1] == X_[i2])
if (X_[i1] == X_[i2]) return y_[i1] < y_[i2]; return y_[i1] < y_[i2];
else else
return X_[i1] < X_[i2];
});
return idx;
}
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
indices_t CPPFImdlp::sortIndices1(samples_t& X_)
{
indices_t idx(X_.size());
iota(idx.begin(), idx.end(), 0);
for (size_t i = 0; i < X_.size(); i++)
stable_sort(idx.begin(), idx.end(), [&X_](size_t i1, size_t i2)
{
return X_[i1] < X_[i2]; return X_[i1] < X_[i2];
}); });
return idx; return idx;
} }
cutPoints_t CPPFImdlp::getCutPoints()
{ // Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
indices_t CPPFImdlp::sortIndices1(samples_t &X_) {
indices_t idx(X_.size());
iota(idx.begin(), idx.end(), 0);
for (size_t i = 0; i < X_.size(); i++)
stable_sort(idx.begin(), idx.end(), [&X_](size_t i1, size_t i2) {
return X_[i1] < X_[i2];
});
return idx;
}
cutPoints_t CPPFImdlp::getCutPoints() {
// Remove duplicates and sort // Remove duplicates and sort
cutPoints_t output(cutPoints.size()); cutPoints_t output(cutPoints.size());
set<precision_t> s; set<precision_t> s;