Add max_depth and min_length as hyperparams

This commit is contained in:
2023-02-25 18:16:20 +01:00
parent e25ca378f0
commit d6cece1006
4 changed files with 105 additions and 36 deletions

View File

@@ -8,7 +8,11 @@
namespace mdlp {
CPPFImdlp::CPPFImdlp(): indices(indices_t()), X(samples_t()), y(labels_t()),
CPPFImdlp::CPPFImdlp():depth(0), max_depth(numeric_limits<int>::max()), min_length(3), indices(indices_t()), X(samples_t()), y(labels_t()),
metrics(Metrics(y, indices))
{
}
CPPFImdlp::CPPFImdlp(int min_length_, int max_depth_): depth(0), max_depth(max_depth_), min_length(min_length_), indices(indices_t()), X(samples_t()), y(labels_t()),
metrics(Metrics(y, indices))
{
}
@@ -25,9 +29,15 @@ namespace mdlp {
if (X.empty() || y.empty()) {
throw invalid_argument("X and y must have at least one element");
}
if (min_length < 3) {
throw invalid_argument("min_length must be greater than 2");
}
if (max_depth < 1) {
throw invalid_argument("max_depth must be greater than 0");
}
indices = sortIndices(X_, y_);
metrics.setData(y, indices);
computeCutPoints(0, X.size());
computeCutPoints(0, X.size(), 1);
return *this;
}
@@ -60,12 +70,14 @@ namespace mdlp {
return { (actual + previous) / 2, cut };
}
void CPPFImdlp::computeCutPoints(size_t start, size_t end)
void CPPFImdlp::computeCutPoints(size_t start, size_t end, int depth_)
{
size_t cut;
pair<precision_t, size_t> result;
if (end - start < 3)
// Check if the interval length and the depth are Ok
if (end - start < min_length || depth_ > max_depth)
return;
depth = depth_ > depth ? depth_ : depth;
cut = getCandidate(start, end);
if (cut == numeric_limits<size_t>::max())
return;
@@ -73,8 +85,8 @@ namespace mdlp {
result = valueCutPoint(start, cut, end);
cut = result.second;
cutPoints.push_back(result.first);
computeCutPoints(start, cut);
computeCutPoints(cut, end);
computeCutPoints(start, cut, depth_ + 1);
computeCutPoints(cut, end, depth_ + 1);
}
}
@@ -158,4 +170,8 @@ namespace mdlp {
sort(output.begin(), output.end());
return output;
}
int CPPFImdlp::get_depth()
{
return depth;
}
}