feat: ♻️ Add Classic algorithm as #2 to compare performance

This commit is contained in:
Ricardo Montañana Gómez
2023-01-13 11:44:17 +01:00
parent 7b20bde428
commit 1b89f5927c
2 changed files with 39 additions and 1 deletions

View File

@@ -30,8 +30,13 @@ namespace mdlp {
case 1:
computeCutPointsAlternative(0, X.size());
break;
case 2:
indices = sortIndices1(X_);
metrics.setData(y, indices);
computeCutPointsClassic(0, X.size());
break;
default:
throw invalid_argument("algorithm must be 0 or 1");
throw invalid_argument("algorithm must be 0, 1 or 2");
}
return *this;
}
@@ -93,6 +98,25 @@ namespace mdlp {
computeCutPointsAlternative(cut, end);
}
}
void CPPFImdlp::computeCutPointsClassic(size_t start, size_t end)
{
size_t cut;
cut = getCandidate(start, end);
if (cut == numeric_limits<size_t>::max() || !mdlp(start, cut, end)) {
// cut.value == -1 means that there is no candidate in the interval
// No boundary found, so we add both ends of the interval as cutpoints
// because they were selected by the algorithm before
if (start == end)
return;
if (start != 0)
cutPoints.push_back((X[indices[start]] + X[indices[start - 1]]) / 2);
if (end != X.size())
cutPoints.push_back((X[indices[end]] + X[indices[end - 1]]) / 2);
return;
}
computeCutPoints(start, cut);
computeCutPoints(cut, end);
}
size_t CPPFImdlp::getCandidate(size_t start, size_t end)
{
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
@@ -148,6 +172,18 @@ namespace mdlp {
});
return idx;
}
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
indices_t CPPFImdlp::sortIndices1(samples_t& X_)
{
indices_t idx(X_.size());
iota(idx.begin(), idx.end(), 0);
for (size_t i = 0; i < X_.size(); i++)
stable_sort(idx.begin(), idx.end(), [&X_](size_t i1, size_t i2)
{
return X_[i1] < X_[i2];
});
return idx;
}
cutPoints_t CPPFImdlp::getCutPoints()
{
// Remove duplicates and sort

View File

@@ -16,8 +16,10 @@ namespace mdlp {
cutPoints_t cutPoints;
static indices_t sortIndices(samples_t&, labels_t&);
static indices_t sortIndices1(samples_t&);
void computeCutPoints(size_t, size_t);
void computeCutPointsAlternative(size_t, size_t);
void computeCutPointsClassic(size_t, size_t);
bool mdlp(size_t, size_t, size_t);
size_t getCandidate(size_t, size_t);
precision_t halfWayValueCutPoint(size_t, size_t);