mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-17 16:35:57 +00:00
feat: ♻️ Add Classic algorithm as #2 to compare performance
This commit is contained in:
@@ -30,8 +30,13 @@ namespace mdlp {
|
|||||||
case 1:
|
case 1:
|
||||||
computeCutPointsAlternative(0, X.size());
|
computeCutPointsAlternative(0, X.size());
|
||||||
break;
|
break;
|
||||||
|
case 2:
|
||||||
|
indices = sortIndices1(X_);
|
||||||
|
metrics.setData(y, indices);
|
||||||
|
computeCutPointsClassic(0, X.size());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw invalid_argument("algorithm must be 0 or 1");
|
throw invalid_argument("algorithm must be 0, 1 or 2");
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -93,6 +98,25 @@ namespace mdlp {
|
|||||||
computeCutPointsAlternative(cut, end);
|
computeCutPointsAlternative(cut, end);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void CPPFImdlp::computeCutPointsClassic(size_t start, size_t end)
|
||||||
|
{
|
||||||
|
size_t cut;
|
||||||
|
cut = getCandidate(start, end);
|
||||||
|
if (cut == numeric_limits<size_t>::max() || !mdlp(start, cut, end)) {
|
||||||
|
// cut.value == -1 means that there is no candidate in the interval
|
||||||
|
// No boundary found, so we add both ends of the interval as cutpoints
|
||||||
|
// because they were selected by the algorithm before
|
||||||
|
if (start == end)
|
||||||
|
return;
|
||||||
|
if (start != 0)
|
||||||
|
cutPoints.push_back((X[indices[start]] + X[indices[start - 1]]) / 2);
|
||||||
|
if (end != X.size())
|
||||||
|
cutPoints.push_back((X[indices[end]] + X[indices[end - 1]]) / 2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
computeCutPoints(start, cut);
|
||||||
|
computeCutPoints(cut, end);
|
||||||
|
}
|
||||||
size_t CPPFImdlp::getCandidate(size_t start, size_t end)
|
size_t CPPFImdlp::getCandidate(size_t start, size_t end)
|
||||||
{
|
{
|
||||||
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
|
/* Definition 1: A binary discretization for A is determined by selecting the cut point TA for which
|
||||||
@@ -148,6 +172,18 @@ namespace mdlp {
|
|||||||
});
|
});
|
||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
// Argsort from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
|
||||||
|
indices_t CPPFImdlp::sortIndices1(samples_t& X_)
|
||||||
|
{
|
||||||
|
indices_t idx(X_.size());
|
||||||
|
iota(idx.begin(), idx.end(), 0);
|
||||||
|
for (size_t i = 0; i < X_.size(); i++)
|
||||||
|
stable_sort(idx.begin(), idx.end(), [&X_](size_t i1, size_t i2)
|
||||||
|
{
|
||||||
|
return X_[i1] < X_[i2];
|
||||||
|
});
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
cutPoints_t CPPFImdlp::getCutPoints()
|
cutPoints_t CPPFImdlp::getCutPoints()
|
||||||
{
|
{
|
||||||
// Remove duplicates and sort
|
// Remove duplicates and sort
|
||||||
|
@@ -16,8 +16,10 @@ namespace mdlp {
|
|||||||
cutPoints_t cutPoints;
|
cutPoints_t cutPoints;
|
||||||
|
|
||||||
static indices_t sortIndices(samples_t&, labels_t&);
|
static indices_t sortIndices(samples_t&, labels_t&);
|
||||||
|
static indices_t sortIndices1(samples_t&);
|
||||||
void computeCutPoints(size_t, size_t);
|
void computeCutPoints(size_t, size_t);
|
||||||
void computeCutPointsAlternative(size_t, size_t);
|
void computeCutPointsAlternative(size_t, size_t);
|
||||||
|
void computeCutPointsClassic(size_t, size_t);
|
||||||
bool mdlp(size_t, size_t, size_t);
|
bool mdlp(size_t, size_t, size_t);
|
||||||
size_t getCandidate(size_t, size_t);
|
size_t getCandidate(size_t, size_t);
|
||||||
precision_t halfWayValueCutPoint(size_t, size_t);
|
precision_t halfWayValueCutPoint(size_t, size_t);
|
||||||
|
Reference in New Issue
Block a user