mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-16 07:55:58 +00:00
refactor computing max_cuts
This commit is contained in:
@@ -50,6 +50,13 @@ namespace mdlp {
|
||||
indices = sortIndices(X_, y_);
|
||||
metrics.setData(y, indices);
|
||||
computeCutPoints(0, X.size(), 1);
|
||||
sort(cutPoints.begin(), cutPoints.end());
|
||||
if (num_cut_points > 0) {
|
||||
// Select the best (with lower entropy) cut points
|
||||
while (cutPoints.size() > num_cut_points) {
|
||||
resizeCutPoints();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pair<precision_t, size_t> CPPFImdlp::valueCutPoint(size_t start, size_t cut, size_t end) {
|
||||
@@ -87,8 +94,6 @@ namespace mdlp {
|
||||
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_) {
|
||||
size_t cut;
|
||||
pair<precision_t, size_t> result;
|
||||
if (cutPoints.size() == num_cut_points)
|
||||
return;
|
||||
// Check if the interval length and the depth are Ok
|
||||
if (end - start < min_length || depth_ > max_depth)
|
||||
return;
|
||||
@@ -174,12 +179,24 @@ namespace mdlp {
|
||||
return idx;
|
||||
}
|
||||
|
||||
cutPoints_t CPPFImdlp::getCutPoints() {
|
||||
sort(cutPoints.begin(), cutPoints.end());
|
||||
return cutPoints;
|
||||
}
|
||||
|
||||
int CPPFImdlp::get_depth() const {
|
||||
return depth;
|
||||
void CPPFImdlp::resizeCutPoints() {
|
||||
//Compute entropy of each of the whole cutpoint set and discards the biggest value
|
||||
precision_t maxEntropy = 0;
|
||||
precision_t entropy;
|
||||
size_t maxEntropyIdx = 0;
|
||||
size_t begin = 0;
|
||||
size_t end;
|
||||
for (size_t idx = 0; idx < cutPoints.size(); idx++) {
|
||||
end = begin;
|
||||
while (X[indices[end]] < cutPoints[idx] && end < X.size())
|
||||
end++;
|
||||
entropy = metrics.entropy(begin, end);
|
||||
if (entropy > maxEntropy) {
|
||||
maxEntropy = entropy;
|
||||
maxEntropyIdx = idx;
|
||||
}
|
||||
begin = end;
|
||||
}
|
||||
cutPoints.erase(cutPoints.begin() + static_cast<long>(maxEntropyIdx));
|
||||
}
|
||||
}
|
||||
|
10
CPPFImdlp.h
10
CPPFImdlp.h
@@ -21,10 +21,12 @@ namespace mdlp {
|
||||
cutPoints_t cutPoints;
|
||||
size_t num_cut_points = numeric_limits<size_t>::max();
|
||||
|
||||
static indices_t sortIndices(samples_t&, labels_t&);
|
||||
static indices_t sortIndices(samples_t &, labels_t &);
|
||||
|
||||
void computeCutPoints(size_t, size_t, int);
|
||||
|
||||
void resizeCutPoints();
|
||||
|
||||
bool mdlp(size_t, size_t, size_t);
|
||||
|
||||
size_t getCandidate(size_t, size_t);
|
||||
@@ -40,11 +42,11 @@ namespace mdlp {
|
||||
|
||||
~CPPFImdlp();
|
||||
|
||||
void fit(samples_t&, labels_t&);
|
||||
void fit(samples_t &, labels_t &);
|
||||
|
||||
cutPoints_t getCutPoints();
|
||||
inline cutPoints_t getCutPoints() const { return cutPoints; };
|
||||
|
||||
int get_depth() const;
|
||||
inline int get_depth() const { return depth; };
|
||||
|
||||
static inline string version() { return "1.1.1"; };
|
||||
};
|
||||
|
@@ -269,12 +269,13 @@ namespace mdlp {
|
||||
auto test = CPPFImdlp(75, 2, 1);
|
||||
vector<cutPoints_t> expected = {
|
||||
{5.45f},
|
||||
{3.35f},
|
||||
{2.85f},
|
||||
{2.45f},
|
||||
{0.8f}
|
||||
};
|
||||
vector<int> depths = {1, 1, 1, 1};
|
||||
vector<int> depths = {2, 2, 2, 2};
|
||||
test_dataset(test, "iris", expected, depths);
|
||||
|
||||
}
|
||||
|
||||
TEST_F(TestFImdlp, MaxCutPointsFloat) {
|
||||
|
Reference in New Issue
Block a user