mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-17 16:35:57 +00:00
refactor computing max_cuts
This commit is contained in:
@@ -50,6 +50,13 @@ namespace mdlp {
|
|||||||
indices = sortIndices(X_, y_);
|
indices = sortIndices(X_, y_);
|
||||||
metrics.setData(y, indices);
|
metrics.setData(y, indices);
|
||||||
computeCutPoints(0, X.size(), 1);
|
computeCutPoints(0, X.size(), 1);
|
||||||
|
sort(cutPoints.begin(), cutPoints.end());
|
||||||
|
if (num_cut_points > 0) {
|
||||||
|
// Select the best (with lower entropy) cut points
|
||||||
|
while (cutPoints.size() > num_cut_points) {
|
||||||
|
resizeCutPoints();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) {
|
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) {
|
||||||
@@ -87,8 +94,6 @@ namespace mdlp {
|
|||||||
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) {
|
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) {
|
||||||
size_t cut;
|
size_t cut;
|
||||||
pair<precision_t, size_t> result;
|
pair<precision_t, size_t> result;
|
||||||
if (cutPoints.size() == num_cut_points)
|
|
||||||
return;
|
|
||||||
// Check if the interval length and the depth are Ok
|
// Check if the interval length and the depth are Ok
|
||||||
if (end - start < min_length || depth_ > max_depth)
|
if (end - start < min_length || depth_ > max_depth)
|
||||||
return;
|
return;
|
||||||
@@ -174,12 +179,24 @@ namespace mdlp {
|
|||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
cutPoints_t CPPFImdlp::getCutPoints() {
|
void CPPFImdlp::resizeCutPoints() {
|
||||||
sort(cutPoints.begin(), cutPoints.end());
|
//Compute entropy of each of the whole cutpoint set and discards the biggest value
|
||||||
return cutPoints;
|
precision_t maxEntropy = 0;
|
||||||
|
precision_t entropy;
|
||||||
|
size_t maxEntropyIdx = 0;
|
||||||
|
size_t begin = 0;
|
||||||
|
size_t end;
|
||||||
|
for (size_t idx = 0; idx < cutPoints.size(); idx++) {
|
||||||
|
end = begin;
|
||||||
|
while (X[indices[end]] < cutPoints[idx] && end < X.size())
|
||||||
|
end++;
|
||||||
|
entropy = metrics.entropy(begin, end);
|
||||||
|
if (entropy > maxEntropy) {
|
||||||
|
maxEntropy = entropy;
|
||||||
|
maxEntropyIdx = idx;
|
||||||
}
|
}
|
||||||
|
begin = end;
|
||||||
int CPPFImdlp::get_depth() const {
|
}
|
||||||
return depth;
|
cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
10
CPPFImdlp.h
10
CPPFImdlp.h
@@ -21,10 +21,12 @@ namespace mdlp {
|
|||||||
cutPoints_t cutPoints;
|
cutPoints_t cutPoints;
|
||||||
size_t num_cut_points = numeric_limits<size_t>::max();
|
size_t num_cut_points = numeric_limits<size_t>::max();
|
||||||
|
|
||||||
static indices_t sortIndices(samples_t&, labels_t&);
|
static indices_t sortIndices(samples_t &, labels_t &);
|
||||||
|
|
||||||
void computeCutPoints(size_t, size_t, int);
|
void computeCutPoints(size_t, size_t, int);
|
||||||
|
|
||||||
|
void resizeCutPoints();
|
||||||
|
|
||||||
bool mdlp(size_t, size_t, size_t);
|
bool mdlp(size_t, size_t, size_t);
|
||||||
|
|
||||||
size_t getCandidate(size_t, size_t);
|
size_t getCandidate(size_t, size_t);
|
||||||
@@ -40,11 +42,11 @@ namespace mdlp {
|
|||||||
|
|
||||||
~CPPFImdlp();
|
~CPPFImdlp();
|
||||||
|
|
||||||
void fit(samples_t&, labels_t&);
|
void fit(samples_t &, labels_t &);
|
||||||
|
|
||||||
cutPoints_t getCutPoints();
|
inline cutPoints_t getCutPoints() const { return cutPoints; };
|
||||||
|
|
||||||
int get_depth() const;
|
inline int get_depth() const { return depth; };
|
||||||
|
|
||||||
static inline string version() { return "1.1.1"; };
|
static inline string version() { return "1.1.1"; };
|
||||||
};
|
};
|
||||||
|
@@ -269,12 +269,13 @@ namespace mdlp {
|
|||||||
auto test = CPPFImdlp(75, 2, 1);
|
auto test = CPPFImdlp(75, 2, 1);
|
||||||
vector<cutPoints_t> expected = {
|
vector<cutPoints_t> expected = {
|
||||||
{5.45f},
|
{5.45f},
|
||||||
{3.35f},
|
{2.85f},
|
||||||
{2.45f},
|
{2.45f},
|
||||||
{0.8f}
|
{0.8f}
|
||||||
};
|
};
|
||||||
vector<int> depths = {1, 1, 1, 1};
|
vector<int> depths = {2, 2, 2, 2};
|
||||||
test_dataset(test, "iris", expected, depths);
|
test_dataset(test, "iris", expected, depths);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestFImdlp, MaxCutPointsFloat) {
|
TEST_F(TestFImdlp, MaxCutPointsFloat) {
|
||||||
|
Reference in New Issue
Block a user