mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-16 07:55:58 +00:00
feat: ♻️ Add Classic algorithm as #2 to compare performance
This commit is contained in:
@@ -30,8 +30,13 @@ namespace mdlp {
|
||||
case 1:
|
||||
computeCutPointsAlternative(0, X.size());
|
||||
break;
|
||||
case 2:
|
||||
indices = sortIndices1(X_);
|
||||
metrics.setData(y, indices);
|
||||
computeCutPointsClassic(0, X.size());
|
||||
break;
|
||||
default:
|
||||
throw invalid_argument("algorithm must be 0 or 1");
|
||||
throw invalid_argument("algorithm must be 0, 1 or 2");
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
@@ -93,6 +98,25 @@ namespace mdlp {
|
||||
computeCutPointsAlternative(cut, end);
|
||||
}
|
||||
}
|
||||
void CPPFImdlp::computeCutPointsClassic(size_t start, size_t end)
|
||||
{
|
||||
size_t cut;
|
||||
cut = getCandidate(start, end);
|
||||
if (cut == numeric_limits<size_t>::max() || !mdlp(start, cut, end)) {
|
||||
// cut.value == -1 means that there is no candidate in the interval
|
||||
// No boundary found, so we add both ends of the interval as cutpoints
|
||||
// because they were selected by the algorithm before
|
||||
if (start == end)
|
||||
return;
|
||||
if (start != 0)
|
||||
cutPoints.push_back((X[indices[start]] + X[indices[start - 1]]) / 2);
|
||||
if (end != X.size())
|
||||
cutPoints.push_back((X[indices[end]] + X[indices[end - 1]]) / 2);
|
||||
return;
|
||||
}
|
||||
computeCutPoints(start, cut);
|
||||
computeCutPoints(cut, end);
|
||||
}
|
||||
size_t CPPFImdlp::getCandidate(size_t start, size_t end)
|
||||
{
|
||||
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
|
||||
@@ -148,6 +172,18 @@ namespace mdlp {
|
||||
});
|
||||
return idx;
|
||||
}
|
||||
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
|
||||
indices_t CPPFImdlp::sortIndices1(samples_t& X_)
|
||||
{
|
||||
indices_t idx(X_.size());
|
||||
iota(idx.begin(), idx.end(), 0);
|
||||
for (size_t i = 0; i < X_.size(); i++)
|
||||
stable_sort(idx.begin(), idx.end(), [&X_](size_t i1, size_t i2)
|
||||
{
|
||||
return X_[i1] < X_[i2];
|
||||
});
|
||||
return idx;
|
||||
}
|
||||
cutPoints_t CPPFImdlp::getCutPoints()
|
||||
{
|
||||
// Remove duplicates and sort
|
||||
|
@@ -16,8 +16,10 @@ namespace mdlp {
|
||||
cutPoints_t cutPoints;
|
||||
|
||||
static indices_t sortIndices(samples_t&, labels_t&);
|
||||
static indices_t sortIndices1(samples_t&);
|
||||
void computeCutPoints(size_t, size_t);
|
||||
void computeCutPointsAlternative(size_t, size_t);
|
||||
void computeCutPointsClassic(size_t, size_t);
|
||||
bool mdlp(size_t, size_t, size_t);
|
||||
size_t getCandidate(size_t, size_t);
|
||||
precision_t halfWayValueCutPoint(size_t, size_t);
|
||||
|
Reference in New Issue
Block a user