Refactor base proposal

This commit is contained in:
2022-12-08 21:55:15 +01:00
parent 5d930accca
commit c4e5cf1629
3 changed files with 11 additions and 63 deletions

View File

@@ -27,82 +27,38 @@ namespace mdlp {
}
indices = sortIndices(X_);
metrics.setData(y, indices);
computeCutPointsRecursive(0, X.size());
//simulateCutPointsRecursive();
computeCutPoints(0, X.size());
return *this;
}
void CPPFImdlp::simulateCutPointsRecursive()
void CPPFImdlp::computeCutPoints(size_t start, size_t end)
{
cutPoints_t jobs = cutPoints_t();
jobs.push_back(cutPoint_t({ 0, X.size() }));
while (jobs.size() > 0) {
auto interval = jobs.back();
jobs.pop_back();
//cout << "start: " << interval.start << " end: " << interval.end << endl;
auto cut = getCandidateSimulate(interval.start, interval.end);
if (cut == -1 || !mdlp(interval.start, cut, interval.end)) {
if (interval.start != 0)
xCutPoints.push_back(xcutPoint_t({ interval.start, (X[indices[interval.start]] + X[indices[interval.start - 1]]) / 2 }));
if (interval.end != X.size())
xCutPoints.push_back(xcutPoint_t({ interval.end, (X[indices[interval.end]] + X[indices[interval.end - 1]]) / 2 }));
continue;
}
jobs.push_back(cutPoint_t({ interval.start, size_t(cut) }));
jobs.push_back(cutPoint_t({ size_t(cut), interval.end }));
}
}
void CPPFImdlp::computeCutPointsRecursive(size_t start, size_t end)
{
xcutPoint_t cut;
//cout << "start: " << start << " end: " << end << endl;
int cut;
if (end - start < 2)
return;
cut = getCandidate(start, end);
if (cut.value == -1 || !mdlp(start, cut.index, end)) {
if (cut == -1 || !mdlp(start, cut, end)) {
// cut.value == -1 means that there is no candidate in the interval
// that enhances the information gain
//cout << "¡Ding! " << cut.value << " " << cut.index << endl;
if (start != 0)
xCutPoints.push_back(xcutPoint_t({ start, (X[indices[start]] + X[indices[start - 1]]) / 2 }));
if (end != X.size())
xCutPoints.push_back(xcutPoint_t({ end, (X[indices[end]] + X[indices[end - 1]]) / 2 }));
return;
}
computeCutPointsRecursive(start, cut.index);
computeCutPointsRecursive(cut.index, end);
computeCutPoints(start, cut);
computeCutPoints(cut, end);
}
xcutPoint_t CPPFImdlp::getCandidate(size_t start, size_t end)
long int CPPFImdlp::getCandidate(size_t start, size_t end)
{
xcutPoint_t candidate;
int elements = end - start;
candidate.value = -1;
candidate.index = -1;
long int candidate = -1, elements = end - start;
float entropy_left, entropy_right, minEntropy = numeric_limits<float>::max();
for (auto idx = start + 1; idx < end; idx++) {
// Cutpoints are always on boudndaries
if (y[indices[idx]] == y[indices[idx - 1]])
continue;
entropy_left = float(idx - start) / elements * metrics.entropy(start, idx);
entropy_right = float(end - idx) / elements * metrics.entropy(idx, end);
if (entropy_left + entropy_right < minEntropy) {
minEntropy = entropy_left + entropy_right;
candidate.value = (X[indices[idx]] + X[indices[idx - 1]]) / 2;
candidate.index = idx;
}
}
return candidate;
}
int CPPFImdlp::getCandidateSimulate(size_t start, size_t end)
{
int candidate = -1;
int elements = end - start;
float entropy_left, entropy_right, minEntropy = numeric_limits<float>::max();
for (auto idx = start + 1; idx < end; idx++) {
if (y[indices[idx]] == y[indices[idx - 1]])
continue;
entropy_left = float(idx - start) / elements * metrics.entropy(start, idx);
entropy_right = float(end - idx) / elements * metrics.entropy(idx, end);
if (minEntropy > entropy_left + entropy_right) {
minEntropy = entropy_left + entropy_right;
candidate = idx;
}
@@ -127,11 +83,6 @@ namespace mdlp {
ig = metrics.informationGain(start, cut, end);
delta = log2(pow(3, float(k)) - 2) - (float(k) * ent - float(k1) * ent1 - float(k2) * ent2);
float term = 1 / N * (log2(N - 1) + delta);
if (debug) {
cout << "start: " << start << " cut: " << cut << " end: " << end << endl;
cout << "k=" << k << " k1=" << k1 << " k2=" << k2 << " ent=" << ent << " ent1=" << ent1 << " ent2=" << ent2 << endl;
cout << "ig=" << ig << " delta=" << delta << " N " << N << " term " << term << endl;
}
return ig > term;
}
samples CPPFImdlp::getCutPointsx()

View File

@@ -17,17 +17,14 @@ namespace mdlp {
xcutPoints_t xCutPoints;
static indices_t sortIndices(samples&);
void computeCutPointsRecursive(size_t, size_t);
xcutPoint_t getCandidate(size_t, size_t);
void computeCutPoints(size_t, size_t);
long int getCandidate(size_t, size_t);
bool mdlp(size_t, size_t, size_t);
void simulateCutPointsRecursive();
int getCandidateSimulate(size_t, size_t);
public:
CPPFImdlp();
CPPFImdlp(bool, int, bool debug = false);
~CPPFImdlp();
indices_t getIndices();
CPPFImdlp& fitx(samples&, labels&);
samples getCutPointsx();
};